mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
347 Commits
Author | SHA1 | Date | |
---|---|---|---|
80fd3d3f3c | |||
41b77cd5ff | |||
6f77477a1c | |||
7cfbe5a62a | |||
68344ecac9 | |||
84dc5c5c7b | |||
691ecb1f5b | |||
90b84c650f | |||
014be0ab67 | |||
e5d9f33f7b | |||
5a87e7b3f8 | |||
f8b673dc85 | |||
cb8e0cbf35 | |||
33bd9da26c | |||
9190abd487 | |||
ff47334f22 | |||
a8c3efd98a | |||
8c6860a2c5 | |||
fa8263e6f0 | |||
e4b8cb1d34 | |||
408a800593 | |||
9e5e3f1019 | |||
98a13aa7dc | |||
4418c118db | |||
110b0bc8fe | |||
175cfe41a4 | |||
a12d54afb9 | |||
18af5348a2 | |||
b18c8e1c96 | |||
ea1e647174 | |||
af059f2cff | |||
d8e21091e7 | |||
344041fd3a | |||
588a220dd4 | |||
770d4092b9 | |||
33fe02bdff | |||
8a353bc1e3 | |||
240f4801db | |||
da50507b2d | |||
67d150ab66 | |||
40d70add76 | |||
7bd9bf3ba5 | |||
c94d607089 | |||
ad801e54d4 | |||
fb4db83911 | |||
cc229c3ea0 | |||
ca00fabd79 | |||
b361fabf81 | |||
00669200c7 | |||
fa07e82d2a | |||
3632c5cd57 | |||
daef68d3c1 | |||
ba29376fba | |||
3efd9465eb | |||
a3b11c04cb | |||
8f9e3ac795 | |||
2367f53367 | |||
8b9f0a9551 | |||
ab57976e42 | |||
3c103c89f3 | |||
0f19176944 | |||
fc09a954b5 | |||
e7eee29825 | |||
2c1ba23f61 | |||
58ef6dc6ce | |||
8faefa89fe | |||
02f59a3831 | |||
2555be3058 | |||
e174ce038f | |||
0f10faf0d4 | |||
393e32f8a7 | |||
70412464c8 | |||
30fdb9dbfd | |||
66f6013436 | |||
49b04f7db8 | |||
253dc5d43d | |||
3ccb4e6ff9 | |||
200a9d1801 | |||
b09a76ea0d | |||
8a2030e78a | |||
dfa5505ed8 | |||
f8b731b900 | |||
fd9ab0fb7d | |||
f504a5c96e | |||
afe6639b9c | |||
1f1bf15099 | |||
8fa238f100 | |||
30b6a0ee23 | |||
784878c300 | |||
b51b163400 | |||
7e13224ec8 | |||
7bc454209c | |||
cc7f6c7048 | |||
8b8d950137 | |||
24fd7f41ff | |||
7c5e458372 | |||
a5dba4b0d9 | |||
72fb1cefff | |||
a64f1c0b20 | |||
974658107d | |||
07fb5d5c19 | |||
20c75e7a7e | |||
cfcb68696c | |||
7b1b6d3235 | |||
aefba52a0a | |||
6af46f9c5f | |||
190702d011 | |||
7785e8ff79 | |||
b3beaefa04 | |||
98be81354a | |||
2a2a5eb775 | |||
4a42b15b42 | |||
f24d5e5e31 | |||
4b106bc903 | |||
135ef9066f | |||
0567f98e4a | |||
5b66baa3ec | |||
a022aaf258 | |||
94065b090a | |||
091bf9220b | |||
8d243b1fca | |||
23c412e011 | |||
66692f02aa | |||
38af1c3a81 | |||
7b4b7e3781 | |||
02a3472505 | |||
909d354a38 | |||
7801b8c42f | |||
4fd259bb89 | |||
b8b3ef9725 | |||
3a8d5dc349 | |||
358cac9674 | |||
bdc2b8069b | |||
09295ae43b | |||
80ad14d89f | |||
c674eb3168 | |||
63138640a7 | |||
d103ff0d6e | |||
94931e8ac0 | |||
b409f3aaf9 | |||
f96b7f2e11 | |||
de3be4bd30 | |||
cc12f57a5a | |||
613f11a3ac | |||
a6e2d2c5e0 | |||
ae14df97d6 | |||
a6e1ac6096 | |||
8530635540 | |||
b2b7aed030 | |||
970d45f691 | |||
19b9a22d93 | |||
c0d9990344 | |||
4ac5e307c4 | |||
2815f737fe | |||
63e96fd1ea | |||
66ab56246a | |||
20a56bc757 | |||
82925e1539 | |||
0137a0db7b | |||
b410793684 | |||
894e9f127b | |||
dd9b1c8eec | |||
8d9c566656 | |||
9db7e073a3 | |||
5f64ed5bd5 | |||
7f75f6226b | |||
6dc819fd47 | |||
0cc81e5d63 | |||
daecc54153 | |||
4c31c7f9f1 | |||
d709c5519f | |||
5d84ecef49 | |||
641d246213 | |||
2e53aa48c9 | |||
ef12631450 | |||
d9eb626b62 | |||
8033589629 | |||
124075ae7a | |||
0bde933c89 | |||
fc5c5b6bdd | |||
ff53563152 | |||
12b0d735e7 | |||
d06ee94fd3 | |||
9dbdb6cf7c | |||
7c091570fe | |||
e99f3482cc | |||
d999c9ffd6 | |||
888db8ac46 | |||
7deef2cb27 | |||
ada807af0c | |||
aa132fb9e3 | |||
98a01368b8 | |||
fc9a62dbf5 | |||
4d8bec1605 | |||
cf9dad83bc | |||
0d0a2a5c91 | |||
0cab636ab0 | |||
de097ec58a | |||
bb6f426162 | |||
663f135b3c | |||
2f2097662a | |||
458c29cfa5 | |||
4bec01d6f2 | |||
9d79ee8dc4 | |||
78dd460348 | |||
9d27d354cf | |||
e8725a1099 | |||
479d65b6e1 | |||
5d4b388dfd | |||
4956fa282b | |||
51133522b7 | |||
6d5cc8b1ff | |||
08a5bb90e2 | |||
39bdf5c4e9 | |||
f31e4205aa | |||
4d05c4ff66 | |||
7e88d2a7f1 | |||
556f6aa174 | |||
6a74048af8 | |||
2cb51bff11 | |||
851e835e0e | |||
fe04f28841 | |||
258fc006ec | |||
dcb4ee47d5 | |||
1a56f5aaf9 | |||
5fc745653a | |||
47b5a90177 | |||
81518ee1af | |||
b06d63fb34 | |||
5278a64301 | |||
4de4473c0f | |||
2c28a850ca | |||
6dada3326d | |||
2dfdc02ec8 | |||
1f19db4c6a | |||
7c150c27f2 | |||
248916c190 | |||
be8b99eed5 | |||
2ad0752582 | |||
ba1f8878dd | |||
bc524026f9 | |||
ad7c571983 | |||
8559c6a392 | |||
c7904a32f4 | |||
17f5484f5b | |||
86a372b02f | |||
2e9aa9391d | |||
0c8112cf28 | |||
019898c7be | |||
2b1ff8d196 | |||
79fb691b4d | |||
560ae17e21 | |||
2bd1ab2f1c | |||
ed43472582 | |||
6e5e9176c0 | |||
4c6bcdbc18 | |||
20e6d4fa3c | |||
8e51392910 | |||
0b1c2acd61 | |||
86ac55ab5f | |||
3e82f63c7e | |||
631f6cae19 | |||
0845a0ed84 | |||
46c8ce9fed | |||
13a9ea35b5 | |||
94e8d1b6d5 | |||
2b1dc74080 | |||
f7e558d165 | |||
d959276217 | |||
dfcf38be91 | |||
fbded1c0f2 | |||
ad2926a24c | |||
34d5cad4c9 | |||
60aa3d4893 | |||
5c2884569e | |||
a1307b9f2e | |||
f505ec64ba | |||
f22eb368a3 | |||
96ae22c7e0 | |||
f5447cdc23 | |||
c76a6bd65f | |||
6c4eeaa569 | |||
1bbd13ead7 | |||
321b939d0e | |||
8fb77e431e | |||
083a4f3faa | |||
2005411f7e | |||
ba7b1b2665 | |||
b7ffd36cc6 | |||
199ddd6623 | |||
a7207ed8cf | |||
6bb2dda3f1 | |||
c1e5cd5893 | |||
ff249a2315 | |||
c58f8c3269 | |||
ed772a7107 | |||
cb0b389b4b | |||
8892df1d97 | |||
bc5f356390 | |||
bcb85e100d | |||
1f27ddc07d | |||
7a2b606001 | |||
83ddcc5f3a | |||
55fa785561 | |||
06429028c8 | |||
8b6e322697 | |||
54a67459bf | |||
7fe5283e74 | |||
fe0391c86b | |||
25386a76ef | |||
fd30cb4d90 | |||
0266946d3d | |||
a7f91b3e01 | |||
de0b72528c | |||
2932652787 | |||
db6bc7305a | |||
a5db204629 | |||
8e2b61e19f | |||
a3faa3792a | |||
c16eba78ab | |||
1a191c4655 | |||
e36d925bce | |||
b1ba18b3d1 | |||
aff46759f9 | |||
d7b7dcc7fe | |||
889a26c5b6 | |||
b4c774896a | |||
afbe889d35 | |||
9c1e52b1ef | |||
3f5ab02da9 | |||
bf48e8a03a | |||
e52434cb99 | |||
483bdbcb9f | |||
ae421fb4ab | |||
cc295a9f0a | |||
a7e23af9c6 | |||
3de4390711 | |||
3ceee2b2b2 | |||
5c7ed24aab | |||
183c9c4799 | |||
8baf3f78a2 | |||
ac2eb16a65 | |||
4aa7bee4b9 | |||
7e5ba2795e | |||
97a6c6eea7 | |||
f0e60a4ba2 | |||
aa089e8108 |
33
.github/actions/install-frontend-deps/action.yml
vendored
33
.github/actions/install-frontend-deps/action.yml
vendored
@ -1,33 +0,0 @@
|
||||
name: install frontend dependencies
|
||||
description: Installs frontend dependencies with pnpm, with caching
|
||||
runs:
|
||||
using: 'composite'
|
||||
steps:
|
||||
- name: setup node 18
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
|
||||
- name: setup pnpm
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: 8
|
||||
run_install: false
|
||||
|
||||
- name: get pnpm store directory
|
||||
shell: bash
|
||||
run: |
|
||||
echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV
|
||||
|
||||
- name: setup cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ env.STORE_PATH }}
|
||||
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-store-
|
||||
|
||||
- name: install frontend dependencies
|
||||
run: pnpm install --prefer-frozen-lockfile
|
||||
shell: bash
|
||||
working-directory: invokeai/frontend/web
|
28
.github/pr_labels.yml
vendored
28
.github/pr_labels.yml
vendored
@ -1,59 +1,59 @@
|
||||
root:
|
||||
Root:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: '*'
|
||||
|
||||
python-deps:
|
||||
PythonDeps:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'pyproject.toml'
|
||||
|
||||
python:
|
||||
Python:
|
||||
- changed-files:
|
||||
- all-globs-to-any-file:
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
|
||||
python-tests:
|
||||
PythonTests:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'tests/**'
|
||||
|
||||
ci-cd:
|
||||
CICD:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: .github/**
|
||||
|
||||
docker:
|
||||
Docker:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: docker/**
|
||||
|
||||
installer:
|
||||
Installer:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: installer/**
|
||||
|
||||
docs:
|
||||
Documentation:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: docs/**
|
||||
|
||||
invocations:
|
||||
Invocations:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/app/invocations/**'
|
||||
|
||||
backend:
|
||||
Backend:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/backend/**'
|
||||
|
||||
api:
|
||||
Api:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/app/api/**'
|
||||
|
||||
services:
|
||||
Services:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/app/services/**'
|
||||
|
||||
frontend-deps:
|
||||
FrontendDeps:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- '**/*/package.json'
|
||||
- '**/*/pnpm-lock.yaml'
|
||||
|
||||
frontend:
|
||||
Frontend:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/frontend/web/**'
|
||||
|
67
.github/pull_request_template.md
vendored
67
.github/pull_request_template.md
vendored
@ -1,25 +1,66 @@
|
||||
<!--Thanks for contributing!-->
|
||||
## What type of PR is this? (check all applicable)
|
||||
|
||||
## Summary
|
||||
- [ ] Refactor
|
||||
- [ ] Feature
|
||||
- [ ] Bug Fix
|
||||
- [ ] Optimization
|
||||
- [ ] Documentation Update
|
||||
- [ ] Community Node Submission
|
||||
|
||||
<!--A description of the changes in this PR. Include the kind of change (fix, feature, docs, etc), the "why" and the "how". Screenshots or videos are useful for frontend changes.-->
|
||||
|
||||
## Related Issues / Discussions
|
||||
## Have you discussed this change with the InvokeAI team?
|
||||
- [ ] Yes
|
||||
- [ ] No, because:
|
||||
|
||||
<!--List any related issues or discussions on github or discord. If this PR closes an issue, please use the "Closes #1234" format, so that the issue will be automatically closed when the PR merges.-->
|
||||
|
||||
## Have you updated all relevant documentation?
|
||||
- [ ] Yes
|
||||
- [ ] No
|
||||
|
||||
## QA Instructions
|
||||
|
||||
<!--WHEN APPLICABLE: Describe how we can test the changes in this PR.-->
|
||||
## Description
|
||||
|
||||
|
||||
## Related Tickets & Documents
|
||||
|
||||
<!--
|
||||
For pull requests that relate or close an issue, please include them
|
||||
below.
|
||||
|
||||
For example having the text: "closes #1234" would connect the current pull
|
||||
request to issue 1234. And when we merge the pull request, Github will
|
||||
automatically close the issue.
|
||||
-->
|
||||
|
||||
- Related Issue #
|
||||
- Closes #
|
||||
|
||||
## QA Instructions, Screenshots, Recordings
|
||||
|
||||
<!--
|
||||
Please provide steps on how to test changes, any hardware or
|
||||
software specifications as well as any other pertinent information.
|
||||
-->
|
||||
|
||||
## Merge Plan
|
||||
|
||||
<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.-->
|
||||
<!--
|
||||
A merge plan describes how this PR should be handled after it is approved.
|
||||
|
||||
## Checklist
|
||||
Example merge plans:
|
||||
- "This PR can be merged when approved"
|
||||
- "This must be squash-merged when approved"
|
||||
- "DO NOT MERGE - I will rebase and tidy commits before merging"
|
||||
- "#dev-chat on discord needs to be advised of this change when it is merged"
|
||||
|
||||
<!--If any of these are not completed or not applicable to the change, please add a note.-->
|
||||
A merge plan is particularly important for large PRs or PRs that touch the
|
||||
database in any way.
|
||||
-->
|
||||
|
||||
- [ ] The PR has a short but descriptive title
|
||||
- [ ] Tests added / updated
|
||||
- [ ] Documentation added / updated
|
||||
## Added/updated tests?
|
||||
|
||||
- [ ] Yes
|
||||
- [ ] No : _please replace this line with details on why tests
|
||||
have not been included_
|
||||
|
||||
## [optional] Are there any post deployment tasks we need to perform?
|
||||
|
2
.github/workflows/build-container.yml
vendored
2
.github/workflows/build-container.yml
vendored
@ -11,7 +11,7 @@ on:
|
||||
- 'docker/docker-entrypoint.sh'
|
||||
- 'workflows/build-container.yml'
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
|
45
.github/workflows/build-installer.yml
vendored
45
.github/workflows/build-installer.yml
vendored
@ -1,45 +0,0 @@
|
||||
# Builds and uploads the installer and python build artifacts.
|
||||
|
||||
name: build installer
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
jobs:
|
||||
build-installer:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <2 min
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: setup python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install pypa/build
|
||||
run: pip install --upgrade build
|
||||
|
||||
- name: setup frontend
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: create installer
|
||||
id: create_installer
|
||||
run: ./create_installer.sh
|
||||
working-directory: installer
|
||||
|
||||
- name: upload python distribution artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: ${{ steps.create_installer.outputs.DIST_PATH }}
|
||||
|
||||
- name: upload installer artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ steps.create_installer.outputs.INSTALLER_FILENAME }}
|
||||
path: ${{ steps.create_installer.outputs.INSTALLER_PATH }}
|
80
.github/workflows/frontend-checks.yml
vendored
80
.github/workflows/frontend-checks.yml
vendored
@ -1,80 +0,0 @@
|
||||
# Runs frontend code quality checks.
|
||||
#
|
||||
# Checks for changes to frontend files before running the checks.
|
||||
# If always_run is true, always runs the checks.
|
||||
|
||||
name: 'frontend checks'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the checks'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
workflow_call:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the checks'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
jobs:
|
||||
frontend-checks:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10 # expected run time: <2 min
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed frontend files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
frontend:
|
||||
- 'invokeai/frontend/web/**'
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: tsc
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: 'pnpm lint:tsc'
|
||||
shell: bash
|
||||
|
||||
- name: dpdm
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: 'pnpm lint:dpdm'
|
||||
shell: bash
|
||||
|
||||
- name: eslint
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: 'pnpm lint:eslint'
|
||||
shell: bash
|
||||
|
||||
- name: prettier
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: 'pnpm lint:prettier'
|
||||
shell: bash
|
||||
|
||||
- name: knip
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: 'pnpm lint:knip'
|
||||
shell: bash
|
60
.github/workflows/frontend-tests.yml
vendored
60
.github/workflows/frontend-tests.yml
vendored
@ -1,60 +0,0 @@
|
||||
# Runs frontend tests.
|
||||
#
|
||||
# Checks for changes to frontend files before running the tests.
|
||||
# If always_run is true, always runs the tests.
|
||||
|
||||
name: 'frontend tests'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the tests'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
workflow_call:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the tests'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
jobs:
|
||||
frontend-tests:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10 # expected run time: <2 min
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed frontend files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
frontend:
|
||||
- 'invokeai/frontend/web/**'
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: vitest
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: 'pnpm test:no-watch'
|
||||
shell: bash
|
12
.github/workflows/label-pr.yml
vendored
12
.github/workflows/label-pr.yml
vendored
@ -1,6 +1,6 @@
|
||||
name: 'label PRs'
|
||||
name: "Pull Request Labeler"
|
||||
on:
|
||||
- pull_request_target
|
||||
- pull_request_target
|
||||
|
||||
jobs:
|
||||
labeler:
|
||||
@ -9,10 +9,8 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: checkout
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: label PRs
|
||||
uses: actions/labeler@v5
|
||||
- uses: actions/labeler@v5
|
||||
with:
|
||||
configuration-path: .github/pr_labels.yml
|
||||
configuration-path: .github/pr_labels.yml
|
45
.github/workflows/lint-frontend.yml
vendored
Normal file
45
.github/workflows/lint-frontend.yml
vendored
Normal file
@ -0,0 +1,45 @@
|
||||
name: Lint frontend
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
jobs:
|
||||
lint-frontend:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Setup Node 18
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: '8.12.1'
|
||||
- name: Install dependencies
|
||||
run: 'pnpm install --prefer-frozen-lockfile'
|
||||
- name: Typescript
|
||||
run: 'pnpm run lint:tsc'
|
||||
- name: Madge
|
||||
run: 'pnpm run lint:dpdm'
|
||||
- name: ESLint
|
||||
run: 'pnpm run lint:eslint'
|
||||
- name: Prettier
|
||||
run: 'pnpm run lint:prettier'
|
||||
- name: Knip
|
||||
run: 'pnpm run lint:knip'
|
54
.github/workflows/mkdocs-material.yml
vendored
54
.github/workflows/mkdocs-material.yml
vendored
@ -1,49 +1,51 @@
|
||||
# This is a mostly a copy-paste from https://github.com/squidfunk/mkdocs-material/blob/master/docs/publishing-your-site.md
|
||||
|
||||
name: mkdocs
|
||||
|
||||
name: mkdocs-material
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch:
|
||||
- 'refs/heads/main'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
mkdocs-material:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
|
||||
REPO_NAME: '${{ github.repository }}'
|
||||
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
|
||||
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: checkout sources
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: setup python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: set cache id
|
||||
run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
|
||||
- name: install requirements
|
||||
env:
|
||||
PIP_USE_PEP517: 1
|
||||
run: |
|
||||
python -m \
|
||||
pip install ".[docs]"
|
||||
|
||||
- name: use cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
key: mkdocs-material-${{ env.cache_id }}
|
||||
path: .cache
|
||||
restore-keys: |
|
||||
mkdocs-material-
|
||||
- name: confirm buildability
|
||||
run: |
|
||||
python -m \
|
||||
mkdocs build \
|
||||
--clean \
|
||||
--verbose
|
||||
|
||||
- name: install dependencies
|
||||
run: python -m pip install ".[docs]"
|
||||
|
||||
- name: build & deploy
|
||||
run: mkdocs gh-deploy --force
|
||||
- name: deploy to gh-pages
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
run: |
|
||||
python -m \
|
||||
mkdocs gh-deploy \
|
||||
--clean \
|
||||
--force
|
||||
|
67
.github/workflows/pypi-release.yml
vendored
Normal file
67
.github/workflows/pypi-release.yml
vendored
Normal file
@ -0,0 +1,67 @@
|
||||
name: PyPI Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
publish_package:
|
||||
description: 'Publish build on PyPi? [true/false]'
|
||||
required: true
|
||||
default: 'false'
|
||||
|
||||
jobs:
|
||||
build-and-release:
|
||||
if: github.repository == 'invoke-ai/InvokeAI'
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||
TWINE_NON_INTERACTIVE: 1
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node 18
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: '8.12.1'
|
||||
|
||||
- name: Install frontend dependencies
|
||||
run: pnpm install --prefer-frozen-lockfile
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
- name: Build frontend
|
||||
run: pnpm run build
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
- name: Install python dependencies
|
||||
run: pip install --upgrade build twine
|
||||
|
||||
- name: Build python package
|
||||
run: python3 -m build
|
||||
|
||||
- name: Upload build as workflow artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: dist
|
||||
|
||||
- name: Check distribution
|
||||
run: twine check dist/*
|
||||
|
||||
- name: Check PyPI versions
|
||||
if: github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')
|
||||
run: |
|
||||
pip install --upgrade requests
|
||||
python -c "\
|
||||
import scripts.pypi_helper; \
|
||||
EXISTS=scripts.pypi_helper.local_on_pypi(); \
|
||||
print(f'PACKAGE_EXISTS={EXISTS}')" >> $GITHUB_ENV
|
||||
|
||||
- name: Publish build on PyPi
|
||||
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != '' && github.event.inputs.publish_package == 'true'
|
||||
run: twine upload dist/*
|
76
.github/workflows/python-checks.yml
vendored
76
.github/workflows/python-checks.yml
vendored
@ -1,76 +0,0 @@
|
||||
# Runs python code quality checks.
|
||||
#
|
||||
# Checks for changes to python files before running the checks.
|
||||
# If always_run is true, always runs the checks.
|
||||
#
|
||||
# TODO: Add mypy or pyright to the checks.
|
||||
|
||||
name: 'python checks'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the checks'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
workflow_call:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the checks'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
|
||||
jobs:
|
||||
python-checks:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed python files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install ruff
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: pip install ruff
|
||||
shell: bash
|
||||
|
||||
- name: ruff check
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: ruff check --output-format=github .
|
||||
shell: bash
|
||||
|
||||
- name: ruff format
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: ruff format --check .
|
||||
shell: bash
|
106
.github/workflows/python-tests.yml
vendored
106
.github/workflows/python-tests.yml
vendored
@ -1,106 +0,0 @@
|
||||
# Runs python tests on a matrix of python versions and platforms.
|
||||
#
|
||||
# Checks for changes to python files before running the tests.
|
||||
# If always_run is true, always runs the tests.
|
||||
|
||||
name: 'python tests'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the tests'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
workflow_call:
|
||||
inputs:
|
||||
always_run:
|
||||
description: 'Always run the tests'
|
||||
required: true
|
||||
type: boolean
|
||||
default: true
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
matrix:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- '3.10'
|
||||
- '3.11'
|
||||
platform:
|
||||
- linux-cuda-11_7
|
||||
- linux-rocm-5_2
|
||||
- linux-cpu
|
||||
- macos-default
|
||||
- windows-cpu
|
||||
include:
|
||||
- platform: linux-cuda-11_7
|
||||
os: ubuntu-22.04
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: linux-rocm-5_2
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: linux-cpu
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: macos-default
|
||||
os: macOS-12
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: windows-cpu
|
||||
os: windows-2022
|
||||
github-env: $env:GITHUB_ENV
|
||||
name: 'py${{ matrix.python-version }}: ${{ matrix.platform }}'
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 15 # expected run time: 2-6 min, depending on platform
|
||||
env:
|
||||
PIP_USE_PEP517: '1'
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed python files
|
||||
if: ${{ inputs.always_run != true }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
env:
|
||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||
run: >
|
||||
pip3 install --editable=".[test]"
|
||||
|
||||
- name: run pytest
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: pytest
|
108
.github/workflows/release.yml
vendored
108
.github/workflows/release.yml
vendored
@ -1,108 +0,0 @@
|
||||
# Main release workflow. Triggered on tag push or manual trigger.
|
||||
#
|
||||
# - Runs all code checks and tests
|
||||
# - Verifies the app version matches the tag version.
|
||||
# - Builds the installer and build, uploading them as artifacts.
|
||||
# - Publishes to TestPyPI and PyPI. Both are conditional on the previous steps passing and require a manual approval.
|
||||
#
|
||||
# See docs/RELEASE.md for more information on the release process.
|
||||
|
||||
name: release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
check-version:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check python version
|
||||
uses: samuelcolvin/check-python-version@v4
|
||||
id: check-python-version
|
||||
with:
|
||||
version_file_path: invokeai/version/invokeai_version.py
|
||||
|
||||
frontend-checks:
|
||||
uses: ./.github/workflows/frontend-checks.yml
|
||||
with:
|
||||
always_run: true
|
||||
|
||||
frontend-tests:
|
||||
uses: ./.github/workflows/frontend-tests.yml
|
||||
with:
|
||||
always_run: true
|
||||
|
||||
python-checks:
|
||||
uses: ./.github/workflows/python-checks.yml
|
||||
with:
|
||||
always_run: true
|
||||
|
||||
python-tests:
|
||||
uses: ./.github/workflows/python-tests.yml
|
||||
with:
|
||||
always_run: true
|
||||
|
||||
build:
|
||||
uses: ./.github/workflows/build-installer.yml
|
||||
|
||||
publish-testpypi:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
needs:
|
||||
[
|
||||
check-version,
|
||||
frontend-checks,
|
||||
frontend-tests,
|
||||
python-checks,
|
||||
python-tests,
|
||||
build,
|
||||
]
|
||||
environment:
|
||||
name: testpypi
|
||||
url: https://test.pypi.org/p/invokeai
|
||||
permissions:
|
||||
id-token: write
|
||||
steps:
|
||||
- name: download distribution from build job
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: dist/
|
||||
|
||||
- name: publish distribution to TestPyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
|
||||
publish-pypi:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
needs:
|
||||
[
|
||||
check-version,
|
||||
frontend-checks,
|
||||
frontend-tests,
|
||||
python-checks,
|
||||
python-tests,
|
||||
build,
|
||||
]
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/invokeai
|
||||
permissions:
|
||||
id-token: write
|
||||
steps:
|
||||
- name: download distribution from build job
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: dist/
|
||||
|
||||
- name: publish distribution to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
24
.github/workflows/style-checks.yml
vendored
Normal file
24
.github/workflows/style-checks.yml
vendored
Normal file
@ -0,0 +1,24 @@
|
||||
name: style checks
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches: main
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies with pip
|
||||
run: |
|
||||
pip install ruff
|
||||
|
||||
- run: ruff check --output-format=github .
|
||||
- run: ruff format --check .
|
129
.github/workflows/test-invoke-pip.yml
vendored
Normal file
129
.github/workflows/test-invoke-pip.yml
vendored
Normal file
@ -0,0 +1,129 @@
|
||||
name: Test invoke.py pip
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
matrix:
|
||||
if: github.event.pull_request.draft == false
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
# - '3.9'
|
||||
- '3.10'
|
||||
pytorch:
|
||||
- linux-cuda-11_7
|
||||
- linux-rocm-5_2
|
||||
- linux-cpu
|
||||
- macos-default
|
||||
- windows-cpu
|
||||
include:
|
||||
- pytorch: linux-cuda-11_7
|
||||
os: ubuntu-22.04
|
||||
github-env: $GITHUB_ENV
|
||||
- pytorch: linux-rocm-5_2
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
|
||||
github-env: $GITHUB_ENV
|
||||
- pytorch: linux-cpu
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||
github-env: $GITHUB_ENV
|
||||
- pytorch: macos-default
|
||||
os: macOS-12
|
||||
github-env: $GITHUB_ENV
|
||||
- pytorch: windows-cpu
|
||||
os: windows-2022
|
||||
github-env: $env:GITHUB_ENV
|
||||
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
env:
|
||||
PIP_USE_PEP517: '1'
|
||||
steps:
|
||||
- name: Checkout sources
|
||||
id: checkout-sources
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Check for changed python files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v41
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: set test prompt to main branch validation
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: setup python
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install invokeai
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
env:
|
||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||
run: >
|
||||
pip3 install
|
||||
--editable=".[test]"
|
||||
|
||||
- name: run pytest
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
id: run-pytest
|
||||
run: pytest
|
||||
|
||||
# - name: run invokeai-configure
|
||||
# env:
|
||||
# HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
|
||||
# run: >
|
||||
# invokeai-configure
|
||||
# --yes
|
||||
# --default_only
|
||||
# --full-precision
|
||||
# # can't use fp16 weights without a GPU
|
||||
|
||||
# - name: run invokeai
|
||||
# id: run-invokeai
|
||||
# env:
|
||||
# # Set offline mode to make sure configure preloaded successfully.
|
||||
# HF_HUB_OFFLINE: 1
|
||||
# HF_DATASETS_OFFLINE: 1
|
||||
# TRANSFORMERS_OFFLINE: 1
|
||||
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||
# run: >
|
||||
# invokeai
|
||||
# --no-patchmatch
|
||||
# --no-nsfw_checker
|
||||
# --precision=float32
|
||||
# --always_use_cpu
|
||||
# --use_memory_db
|
||||
# --outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
||||
# --from_file ${{ env.TEST_PROMPTS }}
|
||||
|
||||
# - name: Archive results
|
||||
# env:
|
||||
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||
# uses: actions/upload-artifact@v3
|
||||
# with:
|
||||
# name: results
|
||||
# path: ${{ env.INVOKEAI_OUTDIR }}
|
@ -7,7 +7,7 @@ embeddedLanguageFormatting: auto
|
||||
overrides:
|
||||
- files: '*.md'
|
||||
options:
|
||||
proseWrap: preserve
|
||||
proseWrap: always
|
||||
printWidth: 80
|
||||
parser: markdown
|
||||
cursorOffset: -1
|
||||
|
29
Makefile
29
Makefile
@ -6,18 +6,16 @@ default: help
|
||||
help:
|
||||
@echo Developer commands:
|
||||
@echo
|
||||
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
|
||||
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
|
||||
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
|
||||
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
|
||||
@echo "test Run the unit tests."
|
||||
@echo "update-config-docstring Update the app's config docstring so mkdocs can autogenerate it correctly."
|
||||
@echo "frontend-install Install the pnpm modules needed for the front end"
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||
@echo "installer-zip Build the installer .zip file for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
|
||||
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
|
||||
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
|
||||
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
|
||||
@echo "test" Run the unit tests.
|
||||
@echo "frontend-install" Install the pnpm modules needed for the front end
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "installer-zip Build the installer .zip file for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
|
||||
# Runs ruff, fixing any safely-fixable errors and formatting
|
||||
ruff:
|
||||
@ -42,10 +40,6 @@ mypy-all:
|
||||
test:
|
||||
pytest ./tests
|
||||
|
||||
# Update config docstring
|
||||
update-config-docstring:
|
||||
python scripts/update_config_docstring.py
|
||||
|
||||
# Install the pnpm modules needed for the front end
|
||||
frontend-install:
|
||||
rm -rf invokeai/frontend/web/node_modules
|
||||
@ -59,9 +53,6 @@ frontend-build:
|
||||
frontend-dev:
|
||||
cd invokeai/frontend/web && pnpm dev
|
||||
|
||||
frontend-typegen:
|
||||
cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen
|
||||
|
||||
# Installer zip file
|
||||
installer-zip:
|
||||
cd installer && ./create_installer.sh
|
||||
|
@ -9,6 +9,10 @@ set -e -o pipefail
|
||||
### Set INVOKEAI_ROOT pointing to a valid runtime directory
|
||||
# Otherwise configure the runtime dir first.
|
||||
|
||||
### Configure the InvokeAI runtime directory (done by default)):
|
||||
# docker run --rm -it <this image> --configure
|
||||
# or skip with --no-configure
|
||||
|
||||
### Set the CONTAINER_UID envvar to match your user.
|
||||
# Ensures files created in the container are owned by you:
|
||||
# docker run --rm -it -v /some/path:/invokeai -e CONTAINER_UID=$(id -u) <this image>
|
||||
@ -18,6 +22,27 @@ USER_ID=${CONTAINER_UID:-1000}
|
||||
USER=ubuntu
|
||||
usermod -u ${USER_ID} ${USER} 1>/dev/null
|
||||
|
||||
configure() {
|
||||
# Configure the runtime directory
|
||||
if [[ -f ${INVOKEAI_ROOT}/invokeai.yaml ]]; then
|
||||
echo "${INVOKEAI_ROOT}/invokeai.yaml exists. InvokeAI is already configured."
|
||||
echo "To reconfigure InvokeAI, delete the above file."
|
||||
echo "======================================================================"
|
||||
else
|
||||
mkdir -p "${INVOKEAI_ROOT}"
|
||||
chown --recursive ${USER} "${INVOKEAI_ROOT}"
|
||||
gosu ${USER} invokeai-configure --yes --default_only
|
||||
fi
|
||||
}
|
||||
|
||||
## Skip attempting to configure.
|
||||
## Must be passed first, before any other args.
|
||||
if [[ $1 != "--no-configure" ]]; then
|
||||
configure
|
||||
else
|
||||
shift
|
||||
fi
|
||||
|
||||
### Set the $PUBLIC_KEY env var to enable SSH access.
|
||||
# We do not install openssh-server in the image by default to avoid bloat.
|
||||
# but it is useful to have the full SSH server e.g. on Runpod.
|
||||
|
142
docs/RELEASE.md
142
docs/RELEASE.md
@ -1,142 +0,0 @@
|
||||
# Release Process
|
||||
|
||||
The app is published in twice, in different build formats.
|
||||
|
||||
- A [PyPI] distribution. This includes both a source distribution and built distribution (a wheel). Users install with `pip install invokeai`. The updater uses this build.
|
||||
- An installer on the [InvokeAI Releases Page]. This is a zip file with install scripts and a wheel. This is only used for new installs.
|
||||
|
||||
## General Prep
|
||||
|
||||
Make a developer call-out for PRs to merge. Merge and test things out.
|
||||
|
||||
While the release workflow does not include end-to-end tests, it does pause before publishing so you can download and test the final build.
|
||||
|
||||
## Release Workflow
|
||||
|
||||
The `release.yml` workflow runs a number of jobs to handle code checks, tests, build and publish on PyPI.
|
||||
|
||||
It is triggered on **tag push**, when the tag matches `v*`. It doesn't matter if you've prepped a release branch like `release/v3.5.0` or are releasing from `main` - it works the same.
|
||||
|
||||
> Because commits are reference-counted, it is safe to create a release branch, tag it, let the workflow run, then delete the branch. So long as the tag exists, that commit will exist.
|
||||
|
||||
### Triggering the Workflow
|
||||
|
||||
Run `make tag-release` to tag the current commit and kick off the workflow.
|
||||
|
||||
The release may also be dispatched [manually].
|
||||
|
||||
### Workflow Jobs and Process
|
||||
|
||||
The workflow consists of a number of concurrently-run jobs, and two final publish jobs.
|
||||
|
||||
The publish jobs require manual approval and are only run if the other jobs succeed.
|
||||
|
||||
#### `check-version` Job
|
||||
|
||||
This job checks that the git ref matches the app version. It matches the ref against the `__version__` variable in `invokeai/version/invokeai_version.py`.
|
||||
|
||||
When the workflow is triggered by tag push, the ref is the tag. If the workflow is run manually, the ref is the target selected from the **Use workflow from** dropdown.
|
||||
|
||||
This job uses [samuelcolvin/check-python-version].
|
||||
|
||||
> Any valid [version specifier] works, so long as the tag matches the version. The release workflow works exactly the same for `RC`, `post`, `dev`, etc.
|
||||
|
||||
#### Check and Test Jobs
|
||||
|
||||
- **`python-tests`**: runs `pytest` on matrix of platforms
|
||||
- **`python-checks`**: runs `ruff` (format and lint)
|
||||
- **`frontend-tests`**: runs `vitest`
|
||||
- **`frontend-checks`**: runs `prettier` (format), `eslint` (lint), `dpdm` (circular refs), `tsc` (static type check) and `knip` (unused imports)
|
||||
|
||||
> **TODO** We should add `mypy` or `pyright` to the **`check-python`** job.
|
||||
|
||||
> **TODO** We should add an end-to-end test job that generates an image.
|
||||
|
||||
#### `build-installer` Job
|
||||
|
||||
This sets up both python and frontend dependencies and builds the python package. Internally, this runs `installer/create_installer.sh` and uploads two artifacts:
|
||||
|
||||
- **`dist`**: the python distribution, to be published on PyPI
|
||||
- **`InvokeAI-installer-${VERSION}.zip`**: the installer to be included in the GitHub release
|
||||
|
||||
#### Sanity Check & Smoke Test
|
||||
|
||||
At this point, the release workflow pauses as the remaining publish jobs require approval.
|
||||
|
||||
A maintainer should go to the **Summary** tab of the workflow, download the installer and test it. Ensure the app loads and generates.
|
||||
|
||||
> The same wheel file is bundled in the installer and in the `dist` artifact, which is uploaded to PyPI. You should end up with the exactly the same installation of the `invokeai` package from any of these methods.
|
||||
|
||||
#### PyPI Publish Jobs
|
||||
|
||||
The publish jobs will run if any of the previous jobs fail.
|
||||
|
||||
They use [GitHub environments], which are configured as [trusted publishers] on PyPI.
|
||||
|
||||
Both jobs require a maintainer to approve them from the workflow's **Summary** tab.
|
||||
|
||||
- Click the **Review deployments** button
|
||||
- Select the environment (either `testpypi` or `pypi`)
|
||||
- Click **Approve and deploy**
|
||||
|
||||
> **If the version already exists on PyPI, the publish jobs will fail.** PyPI only allows a given version to be published once - you cannot change it. If version published on PyPI has a problem, you'll need to "fail forward" by bumping the app version and publishing a followup release.
|
||||
|
||||
#### `publish-testpypi` Job
|
||||
|
||||
Publishes the distribution on the [Test PyPI] index, using the `testpypi` GitHub environment.
|
||||
|
||||
This job is not required for the production PyPI publish, but included just in case you want to test the PyPI release.
|
||||
|
||||
If approved and successful, you could try out the test release like this:
|
||||
|
||||
```sh
|
||||
# Create a new virtual environment
|
||||
python -m venv ~/.test-invokeai-dist --prompt test-invokeai-dist
|
||||
# Install the distribution from Test PyPI
|
||||
pip install --index-url https://test.pypi.org/simple/ invokeai
|
||||
# Run and test the app
|
||||
invokeai-web
|
||||
# Cleanup
|
||||
deactivate
|
||||
rm -rf ~/.test-invokeai-dist
|
||||
```
|
||||
|
||||
#### `publish-pypi` Job
|
||||
|
||||
Publishes the distribution on the production PyPI index, using the `pypi` GitHub environment.
|
||||
|
||||
## Publish the GitHub Release with installer
|
||||
|
||||
Once the release is published to PyPI, it's time to publish the GitHub release.
|
||||
|
||||
1. [Draft a new release] on GitHub, choosing the tag that triggered the release.
|
||||
2. Write the release notes, describing important changes. The **Generate release notes** button automatically inserts the changelog and new contributors, and you can copy/paste the intro from previous releases.
|
||||
3. Upload the zip file created in **`build`** job into the Assets section of the release notes. You can also upload the zip into the body of the release notes, since it can be hard for users to find the Assets section.
|
||||
4. Check the **Set as a pre-release** and **Create a discussion for this release** checkboxes at the bottom of the release page.
|
||||
5. Publish the pre-release.
|
||||
6. Announce the pre-release in Discord.
|
||||
|
||||
> **TODO** Workflows can create a GitHub release from a template and upload release assets. One popular action to handle this is [ncipollo/release-action]. A future enhancement to the release process could set this up.
|
||||
|
||||
## Manual Build
|
||||
|
||||
The `build installer` workflow can be dispatched manually. This is useful to test the installer for a given branch or tag.
|
||||
|
||||
No checks are run, it just builds.
|
||||
|
||||
## Manual Release
|
||||
|
||||
The `release` workflow can be dispatched manually. You must dispatch the workflow from the right tag, else it will fail the version check.
|
||||
|
||||
This functionality is available as a fallback in case something goes wonky. Typically, releases should be triggered via tag push as described above.
|
||||
|
||||
[InvokeAI Releases Page]: https://github.com/invoke-ai/InvokeAI/releases
|
||||
[PyPI]: https://pypi.org/
|
||||
[Draft a new release]: https://github.com/invoke-ai/InvokeAI/releases/new
|
||||
[Test PyPI]: https://test.pypi.org/
|
||||
[version specifier]: https://packaging.python.org/en/latest/specifications/version-specifiers/
|
||||
[ncipollo/release-action]: https://github.com/ncipollo/release-action
|
||||
[GitHub environments]: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment
|
||||
[trusted publishers]: https://docs.pypi.org/trusted-publishers/
|
||||
[samuelcolvin/check-python-version]: https://github.com/samuelcolvin/check-python-version
|
||||
[manually]: #manual-release
|
@ -16,6 +16,11 @@ model. These are the:
|
||||
information. It is also responsible for managing the InvokeAI
|
||||
`models` directory and its contents.
|
||||
|
||||
* _ModelMetadataStore_ and _ModelMetaDataFetch_ Backend modules that
|
||||
are able to retrieve metadata from online model repositories,
|
||||
transform them into Pydantic models, and cache them to the InvokeAI
|
||||
SQL database.
|
||||
|
||||
* _DownloadQueueServiceBase_
|
||||
A multithreaded downloader responsible
|
||||
for downloading models from a remote source to disk. The download
|
||||
@ -27,6 +32,7 @@ model. These are the:
|
||||
Responsible for loading a model from disk
|
||||
into RAM and VRAM and getting it ready for inference.
|
||||
|
||||
|
||||
## Location of the Code
|
||||
|
||||
The four main services can be found in
|
||||
@ -57,21 +63,23 @@ provides the following fields:
|
||||
|----------------|-----------------|------------------|
|
||||
| `key` | str | Unique identifier for the model |
|
||||
| `name` | str | Name of the model (not unique) |
|
||||
| `model_type` | ModelType | The type of the model |
|
||||
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
||||
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
||||
| `model_type` | ModelType | The type of the model |
|
||||
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
||||
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
||||
| `path` | str | Location of model on disk |
|
||||
| `hash` | str | Hash of the model |
|
||||
| `original_hash` | str | Hash of the model when it was first installed |
|
||||
| `current_hash` | str | Most recent hash of the model's contents |
|
||||
| `description` | str | Human-readable description of the model (optional) |
|
||||
| `source` | str | Model's source URL or repo id (optional) |
|
||||
|
||||
The `key` is a unique 32-character random ID which was generated at
|
||||
install time. The `hash` field stores a hash of the model's
|
||||
install time. The `original_hash` field stores a hash of the model's
|
||||
contents at install time obtained by sampling several parts of the
|
||||
model's files using the `imohash` library. Over the course of the
|
||||
model's lifetime it may be transformed in various ways, such as
|
||||
changing its precision or converting it from a .safetensors to a
|
||||
diffusers model.
|
||||
diffusers model. When this happens, `original_hash` is unchanged, but
|
||||
`current_hash` is updated to indicate the current contents.
|
||||
|
||||
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
|
||||
are defined in `invokeai.backend.model_manager.config`. They are also
|
||||
@ -86,6 +94,7 @@ The `path` field can be absolute or relative. If relative, it is taken
|
||||
to be relative to the `models_dir` setting in the user's
|
||||
`invokeai.yaml` file.
|
||||
|
||||
|
||||
### CheckpointConfig
|
||||
|
||||
This adds support for checkpoint configurations, and adds the
|
||||
@ -165,7 +174,7 @@ store = context.services.model_manager.store
|
||||
or from elsewhere in the code by accessing
|
||||
`ApiDependencies.invoker.services.model_manager.store`.
|
||||
|
||||
### Creating a `ModelRecordService`
|
||||
### Creating a `ModelRecordService`
|
||||
|
||||
To create a new `ModelRecordService` database or open an existing one,
|
||||
you can directly create either a `ModelRecordServiceSQL` or a
|
||||
@ -208,27 +217,27 @@ for use in the InvokeAI web server. Its signature is:
|
||||
```
|
||||
def open(
|
||||
cls,
|
||||
config: InvokeAIAppConfig,
|
||||
conn: Optional[sqlite3.Connection] = None,
|
||||
lock: Optional[threading.Lock] = None
|
||||
config: InvokeAIAppConfig,
|
||||
conn: Optional[sqlite3.Connection] = None,
|
||||
lock: Optional[threading.Lock] = None
|
||||
) -> Union[ModelRecordServiceSQL, ModelRecordServiceFile]:
|
||||
```
|
||||
|
||||
The way it works is as follows:
|
||||
|
||||
1. Retrieve the value of the `model_config_db` option from the user's
|
||||
`invokeai.yaml` config file.
|
||||
`invokeai.yaml` config file.
|
||||
2. If `model_config_db` is `auto` (the default), then:
|
||||
* Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
|
||||
opened on the passed connection and lock.
|
||||
* Open up a new connection to `databases/invokeai.db` if `conn`
|
||||
- Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
|
||||
opened on the passed connection and lock.
|
||||
- Open up a new connection to `databases/invokeai.db` if `conn`
|
||||
and/or `lock` are missing (see note below).
|
||||
3. If `model_config_db` is a Path, then use `from_db_file`
|
||||
to return the appropriate type of ModelRecordService.
|
||||
4. If `model_config_db` is None, then retrieve the legacy
|
||||
`conf_path` option from `invokeai.yaml` and use the Path
|
||||
indicated there. This will default to `configs/models.yaml`.
|
||||
|
||||
|
||||
So a typical startup pattern would be:
|
||||
|
||||
```
|
||||
@ -246,7 +255,7 @@ store = ModelRecordServiceBase.open(config, db_conn, lock)
|
||||
|
||||
Configurations can be retrieved in several ways.
|
||||
|
||||
#### get_model(key) -> AnyModelConfig
|
||||
#### get_model(key) -> AnyModelConfig:
|
||||
|
||||
The basic functionality is to call the record store object's
|
||||
`get_model()` method with the desired model's unique key. It returns
|
||||
@ -263,28 +272,28 @@ print(model_conf.path)
|
||||
If the key is unrecognized, this call raises an
|
||||
`UnknownModelException`.
|
||||
|
||||
#### exists(key) -> AnyModelConfig
|
||||
#### exists(key) -> AnyModelConfig:
|
||||
|
||||
Returns True if a model with the given key exists in the databsae.
|
||||
|
||||
#### search_by_path(path) -> AnyModelConfig
|
||||
#### search_by_path(path) -> AnyModelConfig:
|
||||
|
||||
Returns the configuration of the model whose path is `path`. The path
|
||||
is matched using a simple string comparison and won't correctly match
|
||||
models referred to by different paths (e.g. using symbolic links).
|
||||
|
||||
#### search_by_name(name, base, type) -> List[AnyModelConfig]
|
||||
#### search_by_name(name, base, type) -> List[AnyModelConfig]:
|
||||
|
||||
This method searches for models that match some combination of `name`,
|
||||
`BaseType` and `ModelType`. Calling without any arguments will return
|
||||
all the models in the database.
|
||||
|
||||
#### all_models() -> List[AnyModelConfig]
|
||||
#### all_models() -> List[AnyModelConfig]:
|
||||
|
||||
Return all the model configs in the database. Exactly equivalent to
|
||||
calling `search_by_name()` with no arguments.
|
||||
|
||||
#### search_by_tag(tags) -> List[AnyModelConfig]
|
||||
#### search_by_tag(tags) -> List[AnyModelConfig]:
|
||||
|
||||
`tags` is a list of strings. This method returns a list of model
|
||||
configs that contain all of the given tags. Examples:
|
||||
@ -303,11 +312,11 @@ commercializable_models = [x for x in store.all_models() \
|
||||
if x.license.contains('allowCommercialUse=Sell')]
|
||||
```
|
||||
|
||||
#### version() -> str
|
||||
#### version() -> str:
|
||||
|
||||
Returns the version of the database, currently at `3.2`
|
||||
|
||||
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase
|
||||
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase:
|
||||
|
||||
This method exists to ease the transition from the previous version of
|
||||
the model manager, in which `get_model()` took the three arguments
|
||||
@ -328,7 +337,7 @@ model and pass its key to `get_model()`.
|
||||
Several methods allow you to create and update stored model config
|
||||
records.
|
||||
|
||||
#### add_model(key, config) -> AnyModelConfig
|
||||
#### add_model(key, config) -> AnyModelConfig:
|
||||
|
||||
Given a key and a configuration, this will add the model's
|
||||
configuration record to the database. `config` can either be a subclass of
|
||||
@ -343,7 +352,7 @@ model with the same key is already in the database, or an
|
||||
`InvalidModelConfigException` if a dict was passed and Pydantic
|
||||
experienced a parse or validation error.
|
||||
|
||||
### update_model(key, config) -> AnyModelConfig
|
||||
### update_model(key, config) -> AnyModelConfig:
|
||||
|
||||
Given a key and a configuration, this will update the model
|
||||
configuration record in the database. `config` can be either a
|
||||
@ -361,30 +370,33 @@ The `ModelInstallService` class implements the
|
||||
shop for all your model install needs. It provides the following
|
||||
functionality:
|
||||
|
||||
* Registering a model config record for a model already located on the
|
||||
- Registering a model config record for a model already located on the
|
||||
local filesystem, without moving it or changing its path.
|
||||
|
||||
* Installing a model alreadiy located on the local filesystem, by
|
||||
- Installing a model alreadiy located on the local filesystem, by
|
||||
moving it into the InvokeAI root directory under the
|
||||
`models` folder (or wherever config parameter `models_dir`
|
||||
specifies).
|
||||
|
||||
* Probing of models to determine their type, base type and other key
|
||||
|
||||
- Probing of models to determine their type, base type and other key
|
||||
information.
|
||||
|
||||
* Interface with the InvokeAI event bus to provide status updates on
|
||||
- Interface with the InvokeAI event bus to provide status updates on
|
||||
the download, installation and registration process.
|
||||
|
||||
* Downloading a model from an arbitrary URL and installing it in
|
||||
- Downloading a model from an arbitrary URL and installing it in
|
||||
`models_dir`.
|
||||
|
||||
* Special handling for HuggingFace repo_ids to recursively download
|
||||
|
||||
- Special handling for Civitai model URLs which allow the user to
|
||||
paste in a model page's URL or download link
|
||||
|
||||
- Special handling for HuggingFace repo_ids to recursively download
|
||||
the contents of the repository, paying attention to alternative
|
||||
variants such as fp16.
|
||||
|
||||
* Saving tags and other metadata about the model into the invokeai database
|
||||
- Saving tags and other metadata about the model into the invokeai database
|
||||
when fetching from a repo that provides that type of information,
|
||||
(currently only HuggingFace).
|
||||
(currently only Civitai and HuggingFace).
|
||||
|
||||
### Initializing the installer
|
||||
|
||||
@ -415,8 +427,8 @@ queue.start()
|
||||
|
||||
installer = ModelInstallService(app_config=config,
|
||||
record_store=record_store,
|
||||
download_queue=queue
|
||||
)
|
||||
download_queue=queue
|
||||
)
|
||||
installer.start()
|
||||
```
|
||||
|
||||
@ -428,8 +440,10 @@ required parameters:
|
||||
| `app_config` | InvokeAIAppConfig | InvokeAI app configuration object |
|
||||
| `record_store` | ModelRecordServiceBase | Config record storage database |
|
||||
| `download_queue` | DownloadQueueServiceBase | Download queue object |
|
||||
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|
||||
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
|
||||
|
||||
|
||||
Once initialized, the installer will provide the following methods:
|
||||
|
||||
#### install_job = installer.heuristic_import(source, [config], [access_token])
|
||||
@ -443,15 +457,15 @@ The `source` is a string that can be any of these forms
|
||||
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
|
||||
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
|
||||
3. A HuggingFace repo_id with any of the following formats:
|
||||
* `model/name` -- entire model
|
||||
* `model/name:fp32` -- entire model, using the fp32 variant
|
||||
* `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||
* `model/name::vae` -- vae submodel, using default precision
|
||||
* `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||
* `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||
- `model/name` -- entire model
|
||||
- `model/name:fp32` -- entire model, using the fp32 variant
|
||||
- `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||
- `model/name::vae` -- vae submodel, using default precision
|
||||
- `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||
- `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||
|
||||
Note that by specifying a relative path to the top of the HuggingFace
|
||||
repo, you can download and install arbitrary models files.
|
||||
repo, you can download and install arbitrary models files.
|
||||
|
||||
The variant, if not provided, will be automatically filled in with
|
||||
`fp32` if the user has requested full precision, and `fp16`
|
||||
@ -477,9 +491,9 @@ following illustrates basic usage:
|
||||
|
||||
```
|
||||
from invokeai.app.services.model_install import (
|
||||
LocalModelSource,
|
||||
HFModelSource,
|
||||
URLModelSource,
|
||||
LocalModelSource,
|
||||
HFModelSource,
|
||||
URLModelSource,
|
||||
)
|
||||
|
||||
source1 = LocalModelSource(path='/opt/models/sushi.safetensors') # a local safetensors file
|
||||
@ -499,13 +513,13 @@ for source in [source1, source2, source3, source4, source5, source6, source7]:
|
||||
source2job = installer.wait_for_installs(timeout=120)
|
||||
for source in sources:
|
||||
job = source2job[source]
|
||||
if job.complete:
|
||||
model_config = job.config_out
|
||||
model_key = model_config.key
|
||||
print(f"{source} installed as {model_key}")
|
||||
elif job.errored:
|
||||
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||
|
||||
if job.complete:
|
||||
model_config = job.config_out
|
||||
model_key = model_config.key
|
||||
print(f"{source} installed as {model_key}")
|
||||
elif job.errored:
|
||||
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||
|
||||
```
|
||||
|
||||
As shown here, the `import_model()` method accepts a variety of
|
||||
@ -514,7 +528,7 @@ HuggingFace repo_ids with and without a subfolder designation,
|
||||
Civitai model URLs and arbitrary URLs that point to checkpoint files
|
||||
(but not to folders).
|
||||
|
||||
Each call to `import_model()` return a `ModelInstallJob` job,
|
||||
Each call to `import_model()` return a `ModelInstallJob` job,
|
||||
an object which tracks the progress of the install.
|
||||
|
||||
If a remote model is requested, the model's files are downloaded in
|
||||
@ -541,7 +555,7 @@ The full list of arguments to `import_model()` is as follows:
|
||||
| `config` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
|
||||
|
||||
The next few sections describe the various types of ModelSource that
|
||||
can be passed to `import_model()`.
|
||||
can be passed to `import_model()`.
|
||||
|
||||
`config` can be used to override all or a portion of the configuration
|
||||
attributes returned by the model prober. See the section below for
|
||||
@ -552,6 +566,7 @@ details.
|
||||
This is used for a model that is located on a locally-accessible Posix
|
||||
filesystem, such as a local disk or networked fileshare.
|
||||
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `path` | str | Path | None | Path to the model file or directory |
|
||||
@ -571,7 +586,33 @@ The `AnyHttpUrl` class can be imported from `pydantic.networks`.
|
||||
|
||||
Ordinarily, no metadata is retrieved from these sources. However,
|
||||
there is special-case code in the installer that looks for HuggingFace
|
||||
and fetches the corresponding model metadata from the corresponding repo.
|
||||
and Civitai URLs and fetches the corresponding model metadata from
|
||||
the corresponding repo.
|
||||
|
||||
#### CivitaiModelSource
|
||||
|
||||
This is used for a model that is hosted by the Civitai web site.
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `version_id` | int | None | The ID of the particular version of the desired model. |
|
||||
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
|
||||
|
||||
Civitai has two model IDs, both of which are integers. The `model_id`
|
||||
corresponds to a collection of model versions that may different in
|
||||
arbitrary ways, such as derivation from different checkpoint training
|
||||
steps, SFW vs NSFW generation, pruned vs non-pruned, etc. The
|
||||
`version_id` points to a specific version. Please use the latter.
|
||||
|
||||
Some Civitai models require an access token to download. These can be
|
||||
generated from the Civitai profile page of a logged-in
|
||||
account. Somewhat annoyingly, if you fail to provide the access token
|
||||
when downloading a model that needs it, Civitai generates a redirect
|
||||
to a login page rather than a 403 Forbidden error. The installer
|
||||
attempts to catch this event and issue an informative error
|
||||
message. Otherwise you will get an "unrecognized model suffix" error
|
||||
when the model prober tries to identify the type of the HTML login
|
||||
page.
|
||||
|
||||
#### HFModelSource
|
||||
|
||||
@ -584,6 +625,7 @@ HuggingFace has the most complicated `ModelSource` structure:
|
||||
| `subfolder` | Path | None | Look for the model in a subfolder of the repo. |
|
||||
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
|
||||
|
||||
|
||||
The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`.
|
||||
|
||||
The `variant` is one of the various diffusers formats that HuggingFace
|
||||
@ -619,6 +661,7 @@ in. To download these files, you must provide an
|
||||
`HfFolder.get_token()` will be called to fill it in with the cached
|
||||
one.
|
||||
|
||||
|
||||
#### Monitoring the install job process
|
||||
|
||||
When you create an install job with `import_model()`, it launches the
|
||||
@ -632,13 +675,14 @@ The `ModelInstallJob` class has the following structure:
|
||||
| `id` | `int` | Integer ID for this job |
|
||||
| `status` | `InstallStatus` | An enum of [`waiting`, `downloading`, `running`, `completed`, `error` and `cancelled`]|
|
||||
| `config_in` | `dict` | Overriding configuration values provided by the caller |
|
||||
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
|
||||
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
|
||||
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
|
||||
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
|
||||
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
|
||||
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
|
||||
| `local_path` | `Path` | If a remote model, holds the path of the model after it is downloaded; if a local model, same as `source` |
|
||||
| `error_type` | `str` | Name of the exception that led to an error status |
|
||||
| `error` | `str` | Traceback of the error |
|
||||
|
||||
|
||||
If the `event_bus` argument was provided, events will also be
|
||||
broadcast to the InvokeAI event bus. The events will appear on the bus
|
||||
as an event of type `EventServiceBase.model_event`, a timestamp and
|
||||
@ -658,13 +702,14 @@ following keys:
|
||||
| `total_bytes` | int | Total size of all the files that make up the model |
|
||||
| `parts` | List[Dict]| Information on the progress of the individual files that make up the model |
|
||||
|
||||
|
||||
The parts is a list of dictionaries that give information on each of
|
||||
the components pieces of the download. The dictionary's keys are
|
||||
`source`, `local_path`, `bytes` and `total_bytes`, and correspond to
|
||||
the like-named keys in the main event.
|
||||
|
||||
Note that downloading events will not be issued for local models, and
|
||||
that downloading events occur _before_ the running event.
|
||||
that downloading events occur *before* the running event.
|
||||
|
||||
##### `model_install_running`
|
||||
|
||||
@ -707,13 +752,14 @@ properties: `waiting`, `downloading`, `running`, `complete`, `errored`
|
||||
and `cancelled`, as well as `in_terminal_state`. The last will return
|
||||
True if the job is in the complete, errored or cancelled states.
|
||||
|
||||
|
||||
#### Model configuration and probing
|
||||
|
||||
The install service uses the `invokeai.backend.model_manager.probe`
|
||||
module during import to determine the model's type, base type, and
|
||||
other configuration parameters. Among other things, it assigns a
|
||||
default name and description for the model based on probed
|
||||
fields.
|
||||
fields.
|
||||
|
||||
When downloading remote models is implemented, additional
|
||||
configuration information, such as list of trigger terms, will be
|
||||
@ -728,11 +774,11 @@ attributes. Here is an example of setting the
|
||||
```
|
||||
install_job = installer.import_model(
|
||||
source=HFModelSource(repo_id='stabilityai/stable-diffusion-2-1',variant='fp32'),
|
||||
config=dict(
|
||||
prediction_type=SchedulerPredictionType('v_prediction')
|
||||
name='stable diffusion 2 base model',
|
||||
)
|
||||
)
|
||||
config=dict(
|
||||
prediction_type=SchedulerPredictionType('v_prediction')
|
||||
name='stable diffusion 2 base model',
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
### Other installer methods
|
||||
@ -816,6 +862,7 @@ This method is similar to `unregister()`, but also unconditionally
|
||||
deletes the corresponding model weights file(s), regardless of whether
|
||||
they are inside or outside the InvokeAI models hierarchy.
|
||||
|
||||
|
||||
#### path = installer.download_and_cache(remote_source, [access_token], [timeout])
|
||||
|
||||
This utility routine will download the model file located at source,
|
||||
@ -906,7 +953,7 @@ following fields:
|
||||
|
||||
When you create a job, you can assign it a `priority`. If multiple
|
||||
jobs are queued, the job with the lowest priority runs first. (Don't
|
||||
blame me! The Unix developers came up with this convention.)
|
||||
blame me! The Unix developers came up with this convention.)
|
||||
|
||||
Every job has a `source` and a `destination`. `source` is a string in
|
||||
the base class, but subclassses redefine it more specifically.
|
||||
@ -927,7 +974,7 @@ is in its lifecycle. Values are defined in the string enum
|
||||
`DownloadJobStatus`, a symbol available from
|
||||
`invokeai.app.services.download_manager`. Possible values are:
|
||||
|
||||
| **Value** | **String Value** | **Description** |
|
||||
| **Value** | **String Value** | ** Description ** |
|
||||
|--------------|---------------------|-------------------|
|
||||
| `IDLE` | idle | Job created, but not submitted to the queue |
|
||||
| `ENQUEUED` | enqueued | Job is patiently waiting on the queue |
|
||||
@ -944,7 +991,7 @@ debugging and performance testing.
|
||||
|
||||
In case of an error, the Exception that caused the error will be
|
||||
placed in the `error` field, and the job's status will be set to
|
||||
`DownloadJobStatus.ERROR`.
|
||||
`DownloadJobStatus.ERROR`.
|
||||
|
||||
After an error occurs, any partially downloaded files will be deleted
|
||||
from disk, unless `preserve_partial_downloads` was set to True at job
|
||||
@ -993,11 +1040,11 @@ While a job is being downloaded, the queue will emit events at
|
||||
periodic intervals. A typical series of events during a successful
|
||||
download session will look like this:
|
||||
|
||||
* enqueued
|
||||
* running
|
||||
* running
|
||||
* running
|
||||
* completed
|
||||
- enqueued
|
||||
- running
|
||||
- running
|
||||
- running
|
||||
- completed
|
||||
|
||||
There will be a single enqueued event, followed by one or more running
|
||||
events, and finally one `completed`, `error` or `cancelled`
|
||||
@ -1006,12 +1053,12 @@ events.
|
||||
It is possible for a caller to pause download temporarily, in which
|
||||
case the events may look something like this:
|
||||
|
||||
* enqueued
|
||||
* running
|
||||
* running
|
||||
* paused
|
||||
* running
|
||||
* completed
|
||||
- enqueued
|
||||
- running
|
||||
- running
|
||||
- paused
|
||||
- running
|
||||
- completed
|
||||
|
||||
The download queue logs when downloads start and end (unless `quiet`
|
||||
is set to True at initialization time) but doesn't log any progress
|
||||
@ -1073,11 +1120,11 @@ A typical initialization sequence will look like:
|
||||
from invokeai.app.services.download_manager import DownloadQueueService
|
||||
|
||||
def log_download_event(job: DownloadJobBase):
|
||||
logger.info(f'job={job.id}: status={job.status}')
|
||||
logger.info(f'job={job.id}: status={job.status}')
|
||||
|
||||
queue = DownloadQueueService(
|
||||
event_handlers=[log_download_event]
|
||||
)
|
||||
event_handlers=[log_download_event]
|
||||
)
|
||||
```
|
||||
|
||||
Event handlers can be provided to the queue at initialization time as
|
||||
@ -1108,9 +1155,9 @@ To use the former method, follow this example:
|
||||
```
|
||||
job = DownloadJobRemoteSource(
|
||||
source='http://www.civitai.com/models/13456',
|
||||
destination='/tmp/models/',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
)
|
||||
destination='/tmp/models/',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
)
|
||||
queue.submit_download_job(job, start=True)
|
||||
```
|
||||
|
||||
@ -1125,13 +1172,13 @@ To have the queue create the job for you, follow this example instead:
|
||||
```
|
||||
job = queue.create_download_job(
|
||||
source='http://www.civitai.com/models/13456',
|
||||
destdir='/tmp/models/',
|
||||
filename='my_model.safetensors',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
start=True,
|
||||
)
|
||||
destdir='/tmp/models/',
|
||||
filename='my_model.safetensors',
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
start=True,
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
The `filename` argument forces the downloader to use the specified
|
||||
name for the file rather than the name provided by the remote source,
|
||||
and is equivalent to manually specifying a destination of
|
||||
@ -1140,6 +1187,7 @@ and is equivalent to manually specifying a destination of
|
||||
Here is the full list of arguments that can be provided to
|
||||
`create_download_job()`:
|
||||
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source |
|
||||
@ -1152,7 +1200,7 @@ Here is the full list of arguments that can be provided to
|
||||
|
||||
Internally, `create_download_job()` has a little bit of internal logic
|
||||
that looks at the type of the source and selects the right subclass of
|
||||
`DownloadJobBase` to create and enqueue.
|
||||
`DownloadJobBase` to create and enqueue.
|
||||
|
||||
**TODO**: move this logic into its own method for overriding in
|
||||
subclasses.
|
||||
@ -1218,30 +1266,51 @@ queue and have not yet reached a terminal state.
|
||||
|
||||
The modules found under `invokeai.backend.model_manager.metadata`
|
||||
provide a straightforward API for fetching model metadatda from online
|
||||
repositories. Currently only HuggingFace is supported. However, the
|
||||
modules are easily extended for additional repos, provided that they
|
||||
have defined APIs for metadata access.
|
||||
repositories. Currently two repositories are supported: HuggingFace
|
||||
and Civitai. However, the modules are easily extended for additional
|
||||
repos, provided that they have defined APIs for metadata access.
|
||||
|
||||
Metadata comprises any descriptive information that is not essential
|
||||
for getting the model to run. For example "author" is metadata, while
|
||||
"type", "base" and "format" are not. The latter fields are part of the
|
||||
model's config, as defined in `invokeai.backend.model_manager.config`.
|
||||
|
||||
### Example Usage
|
||||
### Example Usage:
|
||||
|
||||
```
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
CivitaiMetadata
|
||||
ModelMetadataStore,
|
||||
)
|
||||
# to access the initialized sql database
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
|
||||
hf = HuggingFaceMetadataFetch()
|
||||
civitai = CivitaiMetadataFetch()
|
||||
|
||||
# fetch the metadata
|
||||
model_metadata = hf.from_id("<repo_id>")
|
||||
model_metadata = civitai.from_url("https://civitai.com/models/215796")
|
||||
|
||||
assert isinstance(model_metadata, HuggingFaceMetadata)
|
||||
# get some common metadata fields
|
||||
author = model_metadata.author
|
||||
tags = model_metadata.tags
|
||||
|
||||
# get some Civitai-specific fields
|
||||
assert isinstance(model_metadata, CivitaiMetadata)
|
||||
|
||||
trained_words = model_metadata.trained_words
|
||||
base_model = model_metadata.base_model_trained_on
|
||||
thumbnail = model_metadata.thumbnail_url
|
||||
|
||||
# cache the metadata to the database using the key corresponding to
|
||||
# an existing model config record in the `model_config` table
|
||||
sql_cache = ModelMetadataStore(ApiDependencies.invoker.services.db)
|
||||
sql_cache.add_metadata('fb237ace520b6716adc98bcb16e8462c', model_metadata)
|
||||
|
||||
# now we can search the database by tag, author or model name
|
||||
# matches will contain a list of model keys that match the search
|
||||
matches = sql_cache.search_by_tag({"tool", "turbo"})
|
||||
```
|
||||
|
||||
### Structure of the Metadata objects
|
||||
@ -1259,6 +1328,7 @@ This is the common base class for metadata:
|
||||
| `author` | str | Model's author |
|
||||
| `tags` | Set[str] | Model tags |
|
||||
|
||||
|
||||
Note that the model config record also has a `name` field. It is
|
||||
intended that the config record version be locally customizable, while
|
||||
the metadata version is read-only. However, enforcing this is expected
|
||||
@ -1278,14 +1348,53 @@ This descends from `ModelMetadataBase` and adds the following fields:
|
||||
| `last_modified`| datetime | Date of last commit of this model to the repo |
|
||||
| `files` | List[Path] | List of the files in the model repo |
|
||||
|
||||
|
||||
#### `CivitaiMetadata`
|
||||
|
||||
This descends from `ModelMetadataBase` and adds the following fields:
|
||||
|
||||
| **Field Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `type` | Literal["civitai"] | Used for the discriminated union of metadata classes|
|
||||
| `id` | int | Civitai model id |
|
||||
| `version_name` | str | Name of this version of the model (distinct from model name) |
|
||||
| `version_id` | int | Civitai model version id (distinct from model id) |
|
||||
| `created` | datetime | Date this version of the model was created |
|
||||
| `updated` | datetime | Date this version of the model was last updated |
|
||||
| `published` | datetime | Date this version of the model was published to Civitai |
|
||||
| `description` | str | Model description. Quite verbose and contains HTML tags |
|
||||
| `version_description` | str | Model version description, usually describes changes to the model |
|
||||
| `nsfw` | bool | Whether the model tends to generate NSFW content |
|
||||
| `restrictions` | LicenseRestrictions | An object that describes what is and isn't allowed with this model |
|
||||
| `trained_words`| Set[str] | Trigger words for this model, if any |
|
||||
| `download_url` | AnyHttpUrl | URL for downloading this version of the model |
|
||||
| `base_model_trained_on` | str | Name of the model that this version was trained on |
|
||||
| `thumbnail_url` | AnyHttpUrl | URL to access a representative thumbnail image of the model's output |
|
||||
| `weight_min` | int | For LoRA sliders, the minimum suggested weight to apply |
|
||||
| `weight_max` | int | For LoRA sliders, the maximum suggested weight to apply |
|
||||
|
||||
Note that `weight_min` and `weight_max` are not currently populated
|
||||
and take the default values of (-1.0, +2.0). The issue is that these
|
||||
values aren't part of the structured data but appear in the text
|
||||
description. Some regular expression or LLM coding may be able to
|
||||
extract these values.
|
||||
|
||||
Also be aware that `base_model_trained_on` is free text and doesn't
|
||||
correspond to our `ModelType` enum.
|
||||
|
||||
`CivitaiMetadata` also defines some convenience properties relating to
|
||||
licensing restrictions: `credit_required`, `allow_commercial_use`,
|
||||
`allow_derivatives` and `allow_different_license`.
|
||||
|
||||
#### `AnyModelRepoMetadata`
|
||||
|
||||
This is a discriminated Union of `HuggingFaceMetadata`.
|
||||
This is a discriminated Union of `CivitaiMetadata` and
|
||||
`HuggingFaceMetadata`.
|
||||
|
||||
### Fetching Metadata from Online Repos
|
||||
|
||||
The `HuggingFaceMetadataFetch` class will
|
||||
retrieve metadata from its corresponding repository and return
|
||||
The `HuggingFaceMetadataFetch` and `CivitaiMetadataFetch` classes will
|
||||
retrieve metadata from their corresponding repositories and return
|
||||
`AnyModelRepoMetadata` objects. Their base class
|
||||
`ModelMetadataFetchBase` is an abstract class that defines two
|
||||
methods: `from_url()` and `from_id()`. The former accepts the type of
|
||||
@ -1303,17 +1412,98 @@ provide a `requests.Session` argument. This allows you to customize
|
||||
the low-level HTTP fetch requests and is used, for instance, in the
|
||||
testing suite to avoid hitting the internet.
|
||||
|
||||
The HuggingFace fetcher subclass add additional repo-specific fetching methods:
|
||||
The HuggingFace and Civitai fetcher subclasses add additional
|
||||
repo-specific fetching methods:
|
||||
|
||||
|
||||
#### HuggingFaceMetadataFetch
|
||||
|
||||
This overrides its base class `from_json()` method to return a
|
||||
`HuggingFaceMetadata` object directly.
|
||||
|
||||
#### CivitaiMetadataFetch
|
||||
|
||||
This adds the following methods:
|
||||
|
||||
`from_civitai_modelid()` This takes the ID of a model, finds the
|
||||
default version of the model, and then retrieves the metadata for
|
||||
that version, returning a `CivitaiMetadata` object directly.
|
||||
|
||||
`from_civitai_versionid()` This takes the ID of a model version and
|
||||
retrieves its metadata. Functionally equivalent to `from_id()`, the
|
||||
only difference is that it returna a `CivitaiMetadata` object rather
|
||||
than an `AnyModelRepoMetadata`.
|
||||
|
||||
|
||||
### Metadata Storage
|
||||
|
||||
The `ModelConfigBase` stores this response in the `source_api_response` field
|
||||
as a JSON blob.
|
||||
The `ModelMetadataStore` provides a simple facility to store model
|
||||
metadata in the `invokeai.db` database. The data is stored as a JSON
|
||||
blob, with a few common fields (`name`, `author`, `tags`) broken out
|
||||
to be searchable.
|
||||
|
||||
When a metadata object is saved to the database, it is identified
|
||||
using the model key, _and this key must correspond to an existing
|
||||
model key in the model_config table_. There is a foreign key integrity
|
||||
constraint between the `model_config.id` field and the
|
||||
`model_metadata.id` field such that if you attempt to save metadata
|
||||
under an unknown key, the attempt will result in an
|
||||
`UnknownModelException`. Likewise, when a model is deleted from
|
||||
`model_config`, the deletion of the corresponding metadata record will
|
||||
be triggered.
|
||||
|
||||
Tags are stored in a normalized fashion in the tables `model_tags` and
|
||||
`tags`. Triggers keep the tag table in sync with the `model_metadata`
|
||||
table.
|
||||
|
||||
To create the storage object, initialize it with the InvokeAI
|
||||
`SqliteDatabase` object. This is often done this way:
|
||||
|
||||
```
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
metadata_store = ModelMetadataStore(ApiDependencies.invoker.services.db)
|
||||
```
|
||||
|
||||
You can then access the storage with the following methods:
|
||||
|
||||
#### `add_metadata(key, metadata)`
|
||||
|
||||
Add the metadata using a previously-defined model key.
|
||||
|
||||
There is currently no `delete_metadata()` method. The metadata will
|
||||
persist until the matching config is deleted from the `model_config`
|
||||
table.
|
||||
|
||||
#### `get_metadata(key) -> AnyModelRepoMetadata`
|
||||
|
||||
Retrieve the metadata corresponding to the model key.
|
||||
|
||||
#### `update_metadata(key, new_metadata)`
|
||||
|
||||
Update an existing metadata record with new metadata.
|
||||
|
||||
#### `search_by_tag(tags: Set[str]) -> Set[str]`
|
||||
|
||||
Given a set of tags, find models that are tagged with them. If
|
||||
multiple tags are provided then a matching model must be tagged with
|
||||
*all* the tags in the set. This method returns a set of model keys and
|
||||
is intended to be used in conjunction with the `ModelRecordService`:
|
||||
|
||||
```
|
||||
model_config_store = ApiDependencies.invoker.services.model_records
|
||||
matches = metadata_store.search_by_tag({'license:other'})
|
||||
models = [model_config_store.get(x) for x in matches]
|
||||
```
|
||||
|
||||
#### `search_by_name(name: str) -> Set[str]
|
||||
|
||||
Find all model metadata records that have the given name and return a
|
||||
set of keys to the corresponding model config objects.
|
||||
|
||||
#### `search_by_author(author: str) -> Set[str]
|
||||
|
||||
Find all model metadata records that have the given author and return
|
||||
a set of keys to the corresponding model config objects.
|
||||
|
||||
***
|
||||
|
||||
@ -1345,16 +1535,16 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
)
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
)
|
||||
loader = ModelLoadService(
|
||||
app_config=config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry
|
||||
app_config=config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry
|
||||
)
|
||||
```
|
||||
|
||||
@ -1377,6 +1567,7 @@ The returned `LoadedModel` object contains a copy of the configuration
|
||||
record returned by the model record `get_model()` method, as well as
|
||||
the in-memory loaded model:
|
||||
|
||||
|
||||
| **Attribute Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
||||
@ -1390,6 +1581,7 @@ return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
|
||||
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
||||
models. The others are obvious.
|
||||
|
||||
|
||||
`LoadedModel` acts as a context manager. The context loads the model
|
||||
into the execution device (e.g. VRAM on CUDA systems), locks the model
|
||||
in the execution device for the duration of the context, and returns
|
||||
@ -1398,14 +1590,14 @@ the model. Use it like this:
|
||||
```
|
||||
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
with model_info as vae:
|
||||
image = vae.decode(latents)[0]
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
`get_model_by_key()` may raise any of the following exceptions:
|
||||
|
||||
* `UnknownModelException` -- key not in database
|
||||
* `ModelNotFoundException` -- key in database but model not found at path
|
||||
* `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
- `UnknownModelException` -- key not in database
|
||||
- `ModelNotFoundException` -- key in database but model not found at path
|
||||
- `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
|
||||
### Emitting model loading events
|
||||
|
||||
@ -1417,15 +1609,15 @@ following payload:
|
||||
|
||||
```
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel_type=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel_type=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
)
|
||||
```
|
||||
|
||||
@ -1532,7 +1724,6 @@ object, or in `context.services.model_manager` from within an
|
||||
invocation.
|
||||
|
||||
In the examples below, we have retrieved the manager using:
|
||||
|
||||
```
|
||||
mm = ApiDependencies.invoker.services.model_manager
|
||||
```
|
||||
|
@ -1,133 +0,0 @@
|
||||
# Invoke UI
|
||||
|
||||
Invoke's UI is made possible by many contributors and open-source libraries. Thank you!
|
||||
|
||||
## Dev environment
|
||||
|
||||
### Setup
|
||||
|
||||
1. Install [node] and [pnpm].
|
||||
1. Run `pnpm i` to install all packages.
|
||||
|
||||
#### Run in dev mode
|
||||
|
||||
1. From `invokeai/frontend/web/`, run `pnpm dev`.
|
||||
1. From repo root, run `python scripts/invokeai-web.py`.
|
||||
1. Point your browser to the dev server address, e.g. <http://localhost:5173/>
|
||||
|
||||
### Package scripts
|
||||
|
||||
- `dev`: run the frontend in dev mode, enabling hot reloading
|
||||
- `build`: run all checks (madge, eslint, prettier, tsc) and then build the frontend
|
||||
- `typegen`: generate types from the OpenAPI schema (see [Type generation])
|
||||
- `lint:dpdm`: check circular dependencies
|
||||
- `lint:eslint`: check code quality
|
||||
- `lint:prettier`: check code formatting
|
||||
- `lint:tsc`: check type issues
|
||||
- `lint:knip`: check for unused exports or objects (failures here are just suggestions, not hard fails)
|
||||
- `lint`: run all checks concurrently
|
||||
- `fix`: run `eslint` and `prettier`, fixing fixable issues
|
||||
|
||||
### Type generation
|
||||
|
||||
We use [openapi-typescript] to generate types from the app's OpenAPI schema.
|
||||
|
||||
The generated types are committed to the repo in [schema.ts].
|
||||
|
||||
```sh
|
||||
# from the repo root, start the server
|
||||
python scripts/invokeai-web.py
|
||||
# from invokeai/frontend/web/, run the script
|
||||
pnpm typegen
|
||||
```
|
||||
|
||||
### Localization
|
||||
|
||||
We use [i18next] for localization, but translation to languages other than English happens on our [Weblate] project.
|
||||
|
||||
Only the English source strings should be changed on this repo.
|
||||
|
||||
### VSCode
|
||||
|
||||
#### Example debugger config
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"type": "chrome",
|
||||
"request": "launch",
|
||||
"name": "Invoke UI",
|
||||
"url": "http://localhost:5173",
|
||||
"webRoot": "${workspaceFolder}/invokeai/frontend/web"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### Remote dev
|
||||
|
||||
We've noticed an intermittent timeout issue with the VSCode remote dev port forwarding.
|
||||
|
||||
We suggest disabling the editor's port forwarding feature and doing it manually via SSH:
|
||||
|
||||
```sh
|
||||
ssh -L 9090:localhost:9090 -L 5173:localhost:5173 user@host
|
||||
```
|
||||
|
||||
## Contributing Guidelines
|
||||
|
||||
Thanks for your interest in contributing to the Invoke Web UI!
|
||||
|
||||
Please follow these guidelines when contributing.
|
||||
|
||||
### Check in before investing your time
|
||||
|
||||
Please check in before you invest your time on anything besides a trivial fix, in case it conflicts with ongoing work or isn't aligned with the vision for the app.
|
||||
|
||||
If a feature request or issue doesn't already exist for the thing you want to work on, please create one.
|
||||
|
||||
Ping `@psychedelicious` on [discord] in the `#frontend-dev` channel or in the feature request / issue you want to work on - we're happy to chat.
|
||||
|
||||
### Code conventions
|
||||
|
||||
- This is a fairly complex app with a deep component tree. Please use memoization (`useCallback`, `useMemo`, `memo`) with enthusiasm.
|
||||
- If you need to add some global, ephemeral state, please use [nanostores] if possible.
|
||||
- Be careful with your redux selectors. If they need to be parameterized, consider creating them inside a `useMemo`.
|
||||
- Feel free to use `lodash` (via `lodash-es`) to make the intent of your code clear.
|
||||
- Please add comments describing the "why", not the "how" (unless it is really arcane).
|
||||
|
||||
### Commit format
|
||||
|
||||
Please use the [conventional commits] spec for the web UI, with a scope of "ui":
|
||||
|
||||
- `chore(ui): bump deps`
|
||||
- `chore(ui): lint`
|
||||
- `feat(ui): add some cool new feature`
|
||||
- `fix(ui): fix some bug`
|
||||
|
||||
### Submitting a PR
|
||||
|
||||
- Ensure your branch is tidy. Use an interactive rebase to clean up the commit history and reword the commit messages if they are not descriptive.
|
||||
- Run `pnpm lint`. Some issues are auto-fixable with `pnpm fix`.
|
||||
- Fill out the PR form when creating the PR.
|
||||
- It doesn't need to be super detailed, but a screenshot or video is nice if you changed something visually.
|
||||
- If a section isn't relevant, delete it. There are no UI tests at this time.
|
||||
|
||||
## Other docs
|
||||
|
||||
- [Workflows - Design and Implementation]
|
||||
- [State Management]
|
||||
|
||||
[node]: https://nodejs.org/en/download/
|
||||
[pnpm]: https://github.com/pnpm/pnpm
|
||||
[discord]: https://discord.gg/ZmtBAhwWhy
|
||||
[i18next]: https://github.com/i18next/react-i18next
|
||||
[Weblate]: https://hosted.weblate.org/engage/invokeai/
|
||||
[openapi-typescript]: https://github.com/drwpow/openapi-typescript
|
||||
[Type generation]: #type-generation
|
||||
[schema.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/services/api/schema.ts
|
||||
[conventional commits]: https://www.conventionalcommits.org/en/v1.0.0/
|
||||
[Workflows - Design and Implementation]: ./WORKFLOWS.md
|
||||
[State Management]: ./STATE_MGMT.md
|
@ -6,161 +6,259 @@ title: Configuration
|
||||
|
||||
## Intro
|
||||
|
||||
Runtime settings, including the location of files and
|
||||
directories, memory usage, and performance, are managed via the
|
||||
`invokeai.yaml` config file or environment variables. A subset
|
||||
of settings may be set via commandline arguments.
|
||||
InvokeAI has numerous runtime settings which can be used to adjust
|
||||
many aspects of its operations, including the location of files and
|
||||
directories, memory usage, and performance. These settings can be
|
||||
viewed and customized in several ways:
|
||||
|
||||
Settings sources are used in this order:
|
||||
1. By editing settings in the `invokeai.yaml` file.
|
||||
2. By setting environment variables.
|
||||
3. On the command-line, when InvokeAI is launched.
|
||||
|
||||
- CLI args
|
||||
- Environment variables
|
||||
- `invokeai.yaml` settings
|
||||
- Fallback: defaults
|
||||
In addition, the most commonly changed settings are accessible
|
||||
graphically via the `invokeai-configure` script.
|
||||
|
||||
### InvokeAI Root Directory
|
||||
### How the Configuration System Works
|
||||
|
||||
On startup, InvokeAI searches for its "root" directory. This is the directory
|
||||
that contains models, images, the database, and so on. It also contains
|
||||
a configuration file called `invokeai.yaml`.
|
||||
When InvokeAI is launched, the very first thing it needs to do is to
|
||||
find its "root" directory, which contains its configuration files,
|
||||
installed models, its database of images, and the folder(s) of
|
||||
generated images themselves. In this document, the root directory will
|
||||
be referred to as ROOT.
|
||||
|
||||
InvokeAI searches for the root directory in this order:
|
||||
#### Finding the Root Directory
|
||||
|
||||
1. The `--root <path>` CLI arg.
|
||||
2. The environment variable INVOKEAI_ROOT.
|
||||
3. The directory containing the currently active virtual environment.
|
||||
4. Fallback: a directory in the current user's home directory named `invokeai`.
|
||||
To find its root directory, InvokeAI uses the following recipe:
|
||||
|
||||
### InvokeAI Configuration File
|
||||
1. It first looks for the argument `--root <path>` on the command line
|
||||
it was launched from, and uses the indicated path if present.
|
||||
|
||||
Inside the root directory, we read settings from the `invokeai.yaml` file.
|
||||
2. Next it looks for the environment variable INVOKEAI_ROOT, and uses
|
||||
the directory path found there if present.
|
||||
|
||||
It has two sections - one for internal use and one for user settings:
|
||||
3. If neither of these are present, then InvokeAI looks for the
|
||||
folder containing the `.venv` Python virtual environment directory for
|
||||
the currently active environment. This directory is checked for files
|
||||
expected inside the InvokeAI root before it is used.
|
||||
|
||||
```yaml
|
||||
# Internal metadata - do not edit:
|
||||
schema_version: 4
|
||||
4. Finally, InvokeAI looks for a directory in the current user's home
|
||||
directory named `invokeai`.
|
||||
|
||||
# Put user settings here - see https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/:
|
||||
host: 0.0.0.0 # serve the app on your local network
|
||||
models_dir: D:\invokeai\models # store models on an external drive
|
||||
precision: float16 # always use fp16 precision
|
||||
#### Reading the InvokeAI Configuration File
|
||||
|
||||
Once the root directory has been located, InvokeAI looks for a file
|
||||
named `ROOT/invokeai.yaml`, and if present reads configuration values
|
||||
from it. The top of this file looks like this:
|
||||
|
||||
```
|
||||
InvokeAI:
|
||||
Web Server:
|
||||
host: localhost
|
||||
port: 9090
|
||||
allow_origins: []
|
||||
allow_credentials: true
|
||||
allow_methods:
|
||||
- '*'
|
||||
allow_headers:
|
||||
- '*'
|
||||
Features:
|
||||
esrgan: true
|
||||
internet_available: true
|
||||
log_tokenization: false
|
||||
patchmatch: true
|
||||
restore: true
|
||||
...
|
||||
```
|
||||
|
||||
The settings in this file will override the defaults. You only need
|
||||
to change this file if the default for a particular setting doesn't
|
||||
work for you.
|
||||
This lines in this file are used to establish default values for
|
||||
Invoke's settings. In the above fragment, the Web Server's listening
|
||||
port is set to 9090 by the `port` setting.
|
||||
|
||||
Some settings, like [Model Marketplace API Keys], require the YAML
|
||||
to be formatted correctly. Here is a [basic guide to YAML files].
|
||||
You can edit this file with a text editor such as "Notepad" (do not
|
||||
use Word or any other word processor). When editing, be careful to
|
||||
maintain the indentation, and do not add extraneous text, as syntax
|
||||
errors will prevent InvokeAI from launching. A basic guide to the
|
||||
format of YAML files can be found
|
||||
[here](https://circleci.com/blog/what-is-yaml-a-beginner-s-guide/).
|
||||
|
||||
You can fix a broken `invokeai.yaml` by deleting it and running the
|
||||
configuration script again -- option [6] in the launcher, "Re-run the
|
||||
configure script".
|
||||
|
||||
#### Custom Config File Location
|
||||
#### Reading Environment Variables
|
||||
|
||||
You can use any config file with the `--config` CLI arg. Pass in the path to the `invokeai.yaml` file you want to use.
|
||||
Next InvokeAI looks for defined environment variables in the format
|
||||
`INVOKEAI_<setting_name>`, for example `INVOKEAI_port`. Environment
|
||||
variable values take precedence over configuration file variables. On
|
||||
a Macintosh system, for example, you could change the port that the
|
||||
web server listens on by setting the environment variable this way:
|
||||
|
||||
Note that environment variables will trump any settings in the config file.
|
||||
|
||||
### Environment Variables
|
||||
|
||||
All settings may be set via environment variables by prefixing `INVOKEAI_`
|
||||
to the variable name. For example, `INVOKEAI_HOST` would set the `host`
|
||||
setting.
|
||||
|
||||
For non-primitive values, pass a JSON-encoded string:
|
||||
|
||||
```sh
|
||||
export INVOKEAI_REMOTE_API_TOKENS='[{"url_regex":"modelmarketplace", "token": "12345"}]'
|
||||
```
|
||||
export INVOKEAI_port=8000
|
||||
invokeai-web
|
||||
```
|
||||
|
||||
We suggest using `invokeai.yaml`, as it is more user-friendly.
|
||||
Please check out these
|
||||
[Macintosh](https://phoenixnap.com/kb/set-environment-variable-mac)
|
||||
and
|
||||
[Windows](https://phoenixnap.com/kb/windows-set-environment-variable)
|
||||
guides for setting temporary and permanent environment variables.
|
||||
|
||||
### CLI Args
|
||||
#### Reading the Command Line
|
||||
|
||||
A subset of settings may be specified using CLI args:
|
||||
Lastly, InvokeAI takes settings from the command line, which override
|
||||
everything else. The command-line settings have the same name as the
|
||||
corresponding configuration file settings, preceded by a `--`, for
|
||||
example `--port 8000`.
|
||||
|
||||
- `--root`: specify the root directory
|
||||
- `--config`: override the default `invokeai.yaml` file location
|
||||
If you are using the launcher (`invoke.sh` or `invoke.bat`) to launch
|
||||
InvokeAI, then just pass the command-line arguments to the launcher:
|
||||
|
||||
### All Settings
|
||||
|
||||
Following the table are additional explanations for certain settings.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
::: invokeai.app.services.config.config_default.InvokeAIAppConfig
|
||||
options:
|
||||
heading_level: 4
|
||||
members: false
|
||||
show_docstring_description: false
|
||||
group_by_category: true
|
||||
show_category_heading: false
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
#### Model Marketplace API Keys
|
||||
|
||||
Some model marketplaces require an API key to download models. You can provide a URL pattern and appropriate token in your `invokeai.yaml` file to provide that API key.
|
||||
|
||||
The pattern can be any valid regex (you may need to surround the pattern with quotes):
|
||||
|
||||
```yaml
|
||||
remote_api_tokens:
|
||||
# Any URL containing `models.com` will automatically use `your_models_com_token`
|
||||
- url_regex: models.com
|
||||
token: your_models_com_token
|
||||
# Any URL matching this contrived regex will use `some_other_token`
|
||||
- url_regex: '^[a-z]{3}whatever.*\.com$'
|
||||
token: some_other_token
|
||||
```
|
||||
invoke.bat --port 8000 --host 0.0.0.0
|
||||
```
|
||||
|
||||
The provided token will be added as a `Bearer` token to the network requests to download the model files. As far as we know, this works for all model marketplaces that require authorization.
|
||||
The arguments will be applied when you select the web server option
|
||||
(and the other options as well).
|
||||
|
||||
#### Model Hashing
|
||||
If, on the other hand, you prefer to launch InvokeAI directly from the
|
||||
command line, you would first activate the virtual environment (known
|
||||
as the "developer's console" in the launcher), and run `invokeai-web`:
|
||||
|
||||
Models are hashed during installation, providing a stable identifier for models across all platforms. The default algorithm is `blake3`, with a multi-threaded implementation.
|
||||
|
||||
If your models are stored on a spinning hard drive, we suggest using `blake3_single`, the single-threaded implementation. The hashes are the same, but it's much faster on spinning disks.
|
||||
|
||||
```yaml
|
||||
hashing_algorithm: blake3_single
|
||||
```
|
||||
> C:\Users\Fred\invokeai\.venv\scripts\activate
|
||||
(.venv) > invokeai-web --port 8000 --host 0.0.0.0
|
||||
```
|
||||
|
||||
Model hashing is a one-time operation, but it may take a couple minutes to hash a large model collection. You may opt out of model hashing entirely by setting the algorithm to `random`.
|
||||
You can get a listing and brief instructions for each of the
|
||||
command-line options by giving the `--help` argument:
|
||||
|
||||
```yaml
|
||||
hashing_algorithm: random
|
||||
```
|
||||
(.venv) > invokeai-web --help
|
||||
usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS ...]] [--allow_credentials | --no-allow_credentials] [--allow_methods [ALLOW_METHODS ...]]
|
||||
[--allow_headers [ALLOW_HEADERS ...]] [--esrgan | --no-esrgan] [--internet_available | --no-internet_available] [--log_tokenization | --no-log_tokenization]
|
||||
[--patchmatch | --no-patchmatch] [--restore | --no-restore]
|
||||
[--always_use_cpu | --no-always_use_cpu] [--free_gpu_mem | --no-free_gpu_mem] [--max_loaded_models MAX_LOADED_MODELS] [--max_cache_size MAX_CACHE_SIZE]
|
||||
[--max_vram_cache_size MAX_VRAM_CACHE_SIZE] [--gpu_mem_reserved GPU_MEM_RESERVED] [--precision {auto,float16,float32,autocast}]
|
||||
[--sequential_guidance | --no-sequential_guidance] [--xformers_enabled | --no-xformers_enabled] [--tiled_decode | --no-tiled_decode] [--root ROOT]
|
||||
[--autoimport_dir AUTOIMPORT_DIR] [--lora_dir LORA_DIR] [--embedding_dir EMBEDDING_DIR] [--controlnet_dir CONTROLNET_DIR] [--conf_path CONF_PATH]
|
||||
[--models_dir MODELS_DIR] [--legacy_conf_dir LEGACY_CONF_DIR] [--db_dir DB_DIR] [--outdir OUTDIR] [--from_file FROM_FILE]
|
||||
[--use_memory_db | --no-use_memory_db] [--model MODEL] [--log_handlers [LOG_HANDLERS ...]] [--log_format {plain,color,syslog,legacy}]
|
||||
[--log_level {debug,info,warning,error,critical}] [--version | --no-version]
|
||||
```
|
||||
|
||||
Most common algorithms are supported, like `md5`, `sha256`, and `sha512`. These are typically much, much slower than `blake3`.
|
||||
## The Configuration Settings
|
||||
|
||||
#### Path Settings
|
||||
The configuration settings are divided into several distinct
|
||||
groups in `invokeia.yaml`:
|
||||
|
||||
### Web Server
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|---------------------|---------------|----------------------------------------------------------------------------------------------------------------------------|
|
||||
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
|
||||
| `port` | `9090` | Network port number that the web server will listen on |
|
||||
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
|
||||
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
|
||||
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
|
||||
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
|
||||
| `ssl_certfile` | null | Path to an SSL certificate file, used to enable HTTPS. |
|
||||
| `ssl_keyfile` | null | Path to an SSL keyfile, if the key is not included in the certificate file. |
|
||||
|
||||
The documentation for InvokeAI's API can be accessed by browsing to the following URL: [http://localhost:9090/docs].
|
||||
|
||||
### Features
|
||||
|
||||
These configuration settings allow you to enable and disable various InvokeAI features:
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|----------|----------------|--------------|
|
||||
| `esrgan` | `true` | Activate the ESRGAN upscaling options|
|
||||
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
|
||||
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
|
||||
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
|
||||
|
||||
### Generation
|
||||
|
||||
These options tune InvokeAI's memory and performance characteristics.
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss |
|
||||
| `attention_type` | `auto` | Select the type of attention to use. One of `auto`,`normal`,`xformers`,`sliced`, or `torch-sdp` |
|
||||
| `attention_slice_size` | `auto` | When "sliced" attention is selected, set the slice size. One of `auto`, `balanced`, `max` or the integers 1-8|
|
||||
| `force_tiled_decode` | `false` | Force the VAE step to decode in tiles, reducing memory consumption at the cost of performance |
|
||||
|
||||
### Device
|
||||
|
||||
These options configure the generation execution device.
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `device` | `auto` | Preferred execution device. One of `auto`, `cpu`, `cuda`, `cuda:1`, `mps`. `auto` will choose the device depending on the hardware platform and the installed torch capabilities. |
|
||||
| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system |
|
||||
|
||||
|
||||
### Paths
|
||||
|
||||
These options set the paths of various directories and files used by
|
||||
InvokeAI. Relative paths are interpreted relative to the root directory, so
|
||||
if root is `/home/fred/invokeai` and the path is
|
||||
InvokeAI. Relative paths are interpreted relative to INVOKEAI_ROOT, so
|
||||
if INVOKEAI_ROOT is `/home/fred/invokeai` and the path is
|
||||
`autoimport/main`, then the corresponding directory will be located at
|
||||
`/home/fred/invokeai/autoimport/main`.
|
||||
|
||||
Note that the autoimport directory will be searched recursively,
|
||||
allowing you to organize the models into folders and subfolders in any
|
||||
way you wish.
|
||||
| Setting | Default Value | Description |
|
||||
|----------|----------------|--------------|
|
||||
| `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory |
|
||||
| `lora_dir` | `autoimport/lora` | At startup time, read and import any LoRA/LyCORIS models found in this directory |
|
||||
| `embedding_dir` | `autoimport/embedding` | At startup time, read and import any textual inversion (embedding) models found in this directory |
|
||||
| `controlnet_dir` | `autoimport/controlnet` | At startup time, read and import any ControlNet models found in this directory |
|
||||
| `conf_path` | `configs/models.yaml` | Location of the `models.yaml` model configuration file |
|
||||
| `models_dir` | `models` | Location of the directory containing models installed by InvokeAI's model manager |
|
||||
| `legacy_conf_dir` | `configs/stable-diffusion` | Location of the directory containing the .yaml configuration files for legacy checkpoint models |
|
||||
| `db_dir` | `databases` | Location of the directory containing InvokeAI's image, schema and session database |
|
||||
| `outdir` | `outputs` | Location of the directory in which the gallery of generated and uploaded images will be stored |
|
||||
| `use_memory_db` | `false` | Keep database information in memory rather than on disk; this will not preserve image gallery information across restarts |
|
||||
|
||||
#### Logging
|
||||
Note that the autoimport directories will be searched recursively,
|
||||
allowing you to organize the models into folders and subfolders in any
|
||||
way you wish. In addition, while we have split up autoimport
|
||||
directories by the type of model they contain, this isn't
|
||||
necessary. You can combine different model types in the same folder
|
||||
and InvokeAI will figure out what they are. So you can easily use just
|
||||
one autoimport directory by commenting out the unneeded paths:
|
||||
|
||||
```
|
||||
Paths:
|
||||
autoimport_dir: autoimport
|
||||
# lora_dir: null
|
||||
# embedding_dir: null
|
||||
# controlnet_dir: null
|
||||
```
|
||||
|
||||
### Logging
|
||||
|
||||
These settings control the information, warning, and debugging
|
||||
messages printed to the console log while InvokeAI is running:
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|----------|----------------|--------------|
|
||||
| `log_handlers` | `console` | This controls where log messages are sent, and can be a list of one or more destinations. Values include `console`, `file`, `syslog` and `http`. These are described in more detail below |
|
||||
| `log_format` | `color` | This controls the formatting of the log messages. Values are `plain`, `color`, `legacy` and `syslog` |
|
||||
| `log_level` | `debug` | This filters messages according to the level of severity and can be one of `debug`, `info`, `warning`, `error` and `critical`. For example, setting to `warning` will display all messages at the warning level or higher, but won't display "debug" or "info" messages |
|
||||
|
||||
Several different log handler destinations are available, and multiple destinations are supported by providing a list:
|
||||
|
||||
```yaml
|
||||
log_handlers:
|
||||
- console
|
||||
- syslog=localhost
|
||||
- file=/var/log/invokeai.log
|
||||
```
|
||||
log_handlers:
|
||||
- console
|
||||
- syslog=localhost
|
||||
- file=/var/log/invokeai.log
|
||||
```
|
||||
|
||||
- `console` is the default. It prints log messages to the command-line window from which InvokeAI was launched.
|
||||
* `console` is the default. It prints log messages to the command-line window from which InvokeAI was launched.
|
||||
|
||||
- `syslog` is only available on Linux and Macintosh systems. It uses
|
||||
* `syslog` is only available on Linux and Macintosh systems. It uses
|
||||
the operating system's "syslog" facility to write log file entries
|
||||
locally or to a remote logging machine. `syslog` offers a variety
|
||||
of configuration options:
|
||||
@ -173,7 +271,7 @@ log_handlers:
|
||||
- Log to LAN-connected server "fredserver" using the facility LOG_USER and datagram packets.
|
||||
```
|
||||
|
||||
- `http` can be used to log to a remote web server. The server must be
|
||||
* `http` can be used to log to a remote web server. The server must be
|
||||
properly configured to receive and act on log messages. The option
|
||||
accepts the URL to the web server, and a `method` argument
|
||||
indicating whether the message should be submitted using the GET or
|
||||
@ -185,10 +283,7 @@ log_handlers:
|
||||
|
||||
The `log_format` option provides several alternative formats:
|
||||
|
||||
- `color` - default format providing time, date and a message, using text colors to distinguish different log severities
|
||||
- `plain` - same as above, but monochrome text only
|
||||
- `syslog` - the log level and error message only, allowing the syslog system to attach the time and date
|
||||
- `legacy` - a format similar to the one used by the legacy 2.3 InvokeAI releases.
|
||||
|
||||
[basic guide to yaml files]: https://circleci.com/blog/what-is-yaml-a-beginner-s-guide/
|
||||
[Model Marketplace API Keys]: #model-marketplace-api-keys
|
||||
* `color` - default format providing time, date and a message, using text colors to distinguish different log severities
|
||||
* `plain` - same as above, but monochrome text only
|
||||
* `syslog` - the log level and error message only, allowing the syslog system to attach the time and date
|
||||
* `legacy` - a format similar to the one used by the legacy 2.3 InvokeAI releases.
|
||||
|
@ -1,35 +0,0 @@
|
||||
---
|
||||
title: Database
|
||||
---
|
||||
|
||||
# Invoke's SQLite Database
|
||||
|
||||
Invoke uses a SQLite database to store image, workflow, model, and execution data.
|
||||
|
||||
We take great care to ensure your data is safe, by utilizing transactions and a database migration system.
|
||||
|
||||
Even so, when testing an prerelease version of the app, we strongly suggest either backing up your database or using an in-memory database. This ensures any prelease hiccups or databases schema changes will not cause problems for your data.
|
||||
|
||||
## Database Backup
|
||||
|
||||
Backing up your database is very simple. Invoke's data is stored in an `$INVOKEAI_ROOT` directory - where your `invoke.sh`/`invoke.bat` and `invokeai.yaml` files live.
|
||||
|
||||
To back up your database, copy the `invokeai.db` file from `$INVOKEAI_ROOT/databases/invokeai.db` to somewhere safe.
|
||||
|
||||
If anything comes up during prelease testing, you can simply copy your backup back into `$INVOKEAI_ROOT/databases/`.
|
||||
|
||||
## In-Memory Database
|
||||
|
||||
SQLite can run on an in-memory database. Your existing database is untouched when this mode is enabled, but your existing data won't be accessible.
|
||||
|
||||
This is very useful for testing, as there is no chance of a database change modifying your "physical" database.
|
||||
|
||||
To run Invoke with a memory database, edit your `invokeai.yaml` file, and add `use_memory_db: true` to the `Paths:` stanza:
|
||||
|
||||
```yaml
|
||||
InvokeAI:
|
||||
Development:
|
||||
use_memory_db: true
|
||||
```
|
||||
|
||||
Delete this line (or set it to `false`) to use your main database.
|
@ -122,9 +122,9 @@ experimental versions later.
|
||||
[latest release](https://github.com/invoke-ai/InvokeAI/releases/latest),
|
||||
and look for a file named:
|
||||
|
||||
- InvokeAI-installer-v4.X.X.zip
|
||||
- InvokeAI-installer-v3.X.X.zip
|
||||
|
||||
where "4.X.X" is the latest released version. The file is located
|
||||
where "3.X.X" is the latest released version. The file is located
|
||||
at the very bottom of the release page, under **Assets**.
|
||||
|
||||
4. **Unpack the installer**: Unpack the zip file into a convenient directory. This will create a new
|
||||
@ -199,7 +199,136 @@ experimental versions later.
|
||||

|
||||
</figure>
|
||||
|
||||
10. **Running InvokeAI for the first time**: The script will now exit and you'll be ready to generate some images. Look
|
||||
10. **Post-install Configuration**: After installation completes, the
|
||||
installer will launch the configuration form, which will guide you
|
||||
through the first-time process of adjusting some of InvokeAI's
|
||||
startup settings. To move around this form use ctrl-N for
|
||||
<N>ext and ctrl-P for <P>revious, or use <tab>
|
||||
and shift-<tab> to move forward and back. Once you are in a
|
||||
multi-checkbox field use the up and down cursor keys to select the
|
||||
item you want, and <space> to toggle it on and off. Within
|
||||
a directory field, pressing <tab> will provide autocomplete
|
||||
options.
|
||||
|
||||
Generally the defaults are fine, and you can come back to this screen at
|
||||
any time to tweak your system. Here are the options you can adjust:
|
||||
|
||||
- ***HuggingFace Access Token***
|
||||
InvokeAI has the ability to download embedded styles and subjects
|
||||
from the HuggingFace Concept Library on-demand. However, some of
|
||||
the concept library files are password protected. To make download
|
||||
smoother, you can set up an account at huggingface.co, obtain an
|
||||
access token, and paste it into this field. Note that you paste
|
||||
to this screen using ctrl-shift-V
|
||||
|
||||
- ***Free GPU memory after each generation***
|
||||
This is useful for low-memory machines and helps minimize the
|
||||
amount of GPU VRAM used by InvokeAI.
|
||||
|
||||
- ***Enable xformers support if available***
|
||||
If the xformers library was successfully installed, this will activate
|
||||
it to reduce memory consumption and increase rendering speed noticeably.
|
||||
Note that xformers has the side effect of generating slightly different
|
||||
images even when presented with the same seed and other settings.
|
||||
|
||||
- ***Force CPU to be used on GPU systems***
|
||||
This will use the (slow) CPU rather than the accelerated GPU. This
|
||||
can be used to generate images on systems that don't have a compatible
|
||||
GPU.
|
||||
|
||||
- ***Precision***
|
||||
This controls whether to use float32 or float16 arithmetic.
|
||||
float16 uses less memory but is also slightly less accurate.
|
||||
Ordinarily the right arithmetic is picked automatically ("auto"),
|
||||
but you may have to use float32 to get images on certain systems
|
||||
and graphics cards. The "autocast" option is deprecated and
|
||||
shouldn't be used unless you are asked to by a member of the team.
|
||||
|
||||
- **Size of the RAM cache used for fast model switching***
|
||||
This allows you to keep models in memory and switch rapidly among
|
||||
them rather than having them load from disk each time. This slider
|
||||
controls how many models to keep loaded at once. A typical SD-1 or SD-2 model
|
||||
uses 2-3 GB of memory. A typical SDXL model uses 6-7 GB. Providing more
|
||||
RAM will allow more models to be co-resident.
|
||||
|
||||
- ***Output directory for images***
|
||||
This is the path to a directory in which InvokeAI will store all its
|
||||
generated images.
|
||||
|
||||
- ***Autoimport Folder***
|
||||
This is the directory in which you can place models you have
|
||||
downloaded and wish to load into InvokeAI. You can place a variety
|
||||
of models in this directory, including diffusers folders, .ckpt files,
|
||||
.safetensors files, as well as LoRAs, ControlNet and Textual Inversion
|
||||
files (both folder and file versions). To help organize this folder,
|
||||
you can create several levels of subfolders and drop your models into
|
||||
whichever ones you want.
|
||||
|
||||
- ***LICENSE***
|
||||
|
||||
At the bottom of the screen you will see a checkbox for accepting
|
||||
the CreativeML Responsible AI Licenses. You need to accept the license
|
||||
in order to download Stable Diffusion models from the next screen.
|
||||
|
||||
_You can come back to the startup options form_ as many times as you like.
|
||||
From the `invoke.sh` or `invoke.bat` launcher, select option (6) to relaunch
|
||||
this script. On the command line, it is named `invokeai-configure`.
|
||||
|
||||
11. **Downloading Models**: After you press `[NEXT]` on the screen, you will be taken
|
||||
to another screen that prompts you to download a series of starter models. The ones
|
||||
we recommend are preselected for you, but you are encouraged to use the checkboxes to
|
||||
pick and choose.
|
||||
You will probably wish to download `autoencoder-840000` for use with models that
|
||||
were trained with an older version of the Stability VAE.
|
||||
|
||||
<figure markdown>
|
||||

|
||||
</figure>
|
||||
|
||||
Below the preselected list of starter models is a large text field which you can use
|
||||
to specify a series of models to import. You can specify models in a variety of formats,
|
||||
each separated by a space or newline. The formats accepted are:
|
||||
|
||||
- The path to a .ckpt or .safetensors file. On most systems, you can drag a file from
|
||||
the file browser to the textfield to automatically paste the path. Be sure to remove
|
||||
extraneous quotation marks and other things that come along for the ride.
|
||||
|
||||
- The path to a directory containing a combination of `.ckpt` and `.safetensors` files.
|
||||
The directory will be scanned from top to bottom (including subfolders) and any
|
||||
file that can be imported will be.
|
||||
|
||||
- A URL pointing to a `.ckpt` or `.safetensors` file. You can cut
|
||||
and paste directly from a web page, or simply drag the link from the web page
|
||||
or navigation bar. (You can also use ctrl-shift-V to paste into this field)
|
||||
The file will be downloaded and installed.
|
||||
|
||||
- The HuggingFace repository ID (repo_id) for a `diffusers` model. These IDs have
|
||||
the format _author_name/model_name_, as in `andite/anything-v4.0`
|
||||
|
||||
- The path to a local directory containing a `diffusers`
|
||||
model. These directories always have the file `model_index.json`
|
||||
at their top level.
|
||||
|
||||
_Select a directory for models to import_ You may select a local
|
||||
directory for autoimporting at startup time. If you select this
|
||||
option, the directory you choose will be scanned for new
|
||||
.ckpt/.safetensors files each time InvokeAI starts up, and any new
|
||||
files will be automatically imported and made available for your
|
||||
use.
|
||||
|
||||
_Convert imported models into diffusers_ When legacy checkpoint
|
||||
files are imported, you may select to use them unmodified (the
|
||||
default) or to convert them into `diffusers` models. The latter
|
||||
load much faster and have slightly better rendering performance,
|
||||
but not all checkpoint files can be converted. Note that Stable Diffusion
|
||||
Version 2.X files are **only** supported in `diffusers` format and will
|
||||
be converted regardless.
|
||||
|
||||
_You can come back to the model install form_ as many times as you like.
|
||||
From the `invoke.sh` or `invoke.bat` launcher, select option (5) to relaunch
|
||||
this script. On the command line, it is named `invokeai-model-install`.
|
||||
|
||||
12. **Running InvokeAI for the first time**: The script will now exit and you'll be ready to generate some images. Look
|
||||
for the directory `invokeai` installed in the location you chose at the
|
||||
beginning of the install session. Look for a shell script named `invoke.sh`
|
||||
(Linux/Mac) or `invoke.bat` (Windows). Launch the script by double-clicking
|
||||
@ -220,14 +349,14 @@ experimental versions later.
|
||||
http://localhost:9090. Click on this link to open up a browser
|
||||
and start exploring InvokeAI's features.
|
||||
|
||||
12. **InvokeAI Options**: You can configure using the `invokeai.yaml` config file.
|
||||
For example, you can change the location of the
|
||||
12. **InvokeAI Options**: You can launch InvokeAI with several different command-line arguments that
|
||||
customize its behavior. For example, you can change the location of the
|
||||
image output directory or balance memory usage vs performance. See
|
||||
[Configuration](../features/CONFIGURATION.md) for a full list of the options.
|
||||
|
||||
- To set defaults that will take effect every time you launch InvokeAI,
|
||||
use a text editor (e.g. Notepad) to exit the file
|
||||
`invokeai\invokeai.yaml`. It contains a variety of examples that you can
|
||||
`invokeai\invokeai.init`. It contains a variety of examples that you can
|
||||
follow to add and modify launch options.
|
||||
|
||||
- The launcher script also offers you an option labeled "open the developer
|
||||
@ -265,6 +394,7 @@ rm .\.venv -r -force
|
||||
python -mvenv .venv
|
||||
.\.venv\Scripts\activate
|
||||
pip install invokeai
|
||||
invokeai-configure --yes --root .
|
||||
```
|
||||
|
||||
If you see anything marked as an error during this process please stop
|
||||
@ -296,10 +426,16 @@ error messages:
|
||||
This failure mode occurs when there is a network glitch during
|
||||
downloading the very large SDXL model.
|
||||
|
||||
To address this, first go to the Model Manager and delete the
|
||||
Stable-Diffusion-XL-base-1.X model. Then, click the HuggingFace tab,
|
||||
paste the Repo ID stabilityai/stable-diffusion-xl-base-1.0 and install
|
||||
the model.
|
||||
To address this, first go to the Web Model Manager and delete the
|
||||
Stable-Diffusion-XL-base-1.X model. Then navigate to HuggingFace and
|
||||
manually download the .safetensors version of the model. The 1.0
|
||||
version is located at
|
||||
https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main
|
||||
and the file is named `sd_xl_base_1.0.safetensors`.
|
||||
|
||||
Save this file to disk and then reenter the Model Manager. Navigate to
|
||||
Import Models->Add Model, then type (or drag-and-drop) the path to the
|
||||
.safetensors file. Press "Add Model".
|
||||
|
||||
### _Package dependency conflicts_
|
||||
|
||||
@ -352,7 +488,15 @@ download models, etc), but this doesn't fix the problem.
|
||||
|
||||
This issue is often caused by a misconfigured configuration directive in the
|
||||
`invokeai\invokeai.init` initialization file that contains startup settings. The
|
||||
easiest way to fix the problem is to move the file out of the way and restart the app.
|
||||
easiest way to fix the problem is to move the file out of the way and re-run
|
||||
`invokeai-configure`. Enter the developer's console (option 3 of the launcher
|
||||
script) and run this command:
|
||||
|
||||
```cmd
|
||||
invokeai-configure --root=.
|
||||
```
|
||||
|
||||
Note the dot (.) after `--root`. It is part of the command.
|
||||
|
||||
_If none of these maneuvers fixes the problem_ then please report the problem to
|
||||
the [InvokeAI Issues](https://github.com/invoke-ai/InvokeAI/issues) section, or
|
||||
@ -421,4 +565,16 @@ This distribution is changing rapidly, and we add new features
|
||||
regularly. Releases are announced at
|
||||
http://github.com/invoke-ai/InvokeAI/releases, and at
|
||||
https://pypi.org/project/InvokeAI/ To update to the latest released
|
||||
version (recommended), download the latest release and run the installer.
|
||||
version (recommended), follow these steps:
|
||||
|
||||
1. Start the `invoke.sh`/`invoke.bat` launch script from within the
|
||||
`invokeai` root directory.
|
||||
|
||||
2. Choose menu item (10) "Update InvokeAI".
|
||||
|
||||
3. This will launch a menu that gives you the option of:
|
||||
|
||||
1. Updating to the latest official release;
|
||||
2. Updating to the bleeding-edge development version; or
|
||||
3. Manually entering the tag or branch name of a version of
|
||||
InvokeAI you wish to try out.
|
||||
|
@ -26,7 +26,7 @@ driver).
|
||||
|
||||
🖥️ **Download the latest installer .zip file here** : https://github.com/invoke-ai/InvokeAI/releases/latest
|
||||
|
||||
- *Look for the file labelled "InvokeAI-installer-v4.X.X.zip" at the bottom of the page*
|
||||
- *Look for the file labelled "InvokeAI-installer-v3.X.X.zip" at the bottom of the page*
|
||||
- If you experience issues, read through the full [installation instructions](010_INSTALL_AUTOMATED.md) to make sure you have met all of the installation requirements. If you need more help, join the [Discord](discord.gg/invoke-ai) or create an issue on [Github](https://github.com/invoke-ai/InvokeAI).
|
||||
|
||||
|
||||
|
@ -22,24 +22,6 @@ class MyInvocation(BaseInvocation):
|
||||
...
|
||||
```
|
||||
|
||||
The full API is documented below.
|
||||
|
||||
## Invocation Mixins
|
||||
|
||||
Two important mixins are provided to facilitate working with metadata and gallery boards.
|
||||
|
||||
### `WithMetadata`
|
||||
|
||||
Inherit from this class (in addition to `BaseInvocation`) to add a `metadata` input to your node. When you do this, you can access the metadata dict from `self.metadata` in the `invoke()` function.
|
||||
|
||||
The dict will be populated via the node's input, and you can add any metadata you'd like to it. When you call `context.images.save()`, if the metadata dict has any data, it be automatically embedded in the image.
|
||||
|
||||
### `WithBoard`
|
||||
|
||||
Inherit from this class (in addition to `BaseInvocation`) to add a `board` input to your node. This renders as a drop-down to select a board. The user's selection will be accessible from `self.board` in the `invoke()` function.
|
||||
|
||||
When you call `context.images.save()`, if a board was selected, the image will added to that board as it is saved.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
::: invokeai.app.services.shared.invocation_context.InvocationContext
|
||||
options:
|
||||
|
@ -32,7 +32,6 @@ To use a community workflow, download the the `.json` node graph file and load i
|
||||
+ [Image to Character Art Image Nodes](#image-to-character-art-image-nodes)
|
||||
+ [Image Picker](#image-picker)
|
||||
+ [Image Resize Plus](#image-resize-plus)
|
||||
+ [Latent Upscale](#latent-upscale)
|
||||
+ [Load Video Frame](#load-video-frame)
|
||||
+ [Make 3D](#make-3d)
|
||||
+ [Mask Operations](#mask-operations)
|
||||
@ -291,13 +290,6 @@ View:
|
||||
</br><img src="https://raw.githubusercontent.com/VeyDlin/image-resize-plus-node/master/.readme/node.png" width="500" />
|
||||
|
||||
|
||||
--------------------------------
|
||||
### Latent Upscale
|
||||
|
||||
**Description:** This node uses a small (~2.4mb) model to upscale the latents used in a Stable Diffusion 1.5 or Stable Diffusion XL image generation, rather than the typical interpolation method, avoiding the traditional downsides of the latent upscale technique.
|
||||
|
||||
**Node Link:** [https://github.com/gogurtenjoyer/latent-upscale](https://github.com/gogurtenjoyer/latent-upscale)
|
||||
|
||||
--------------------------------
|
||||
### Load Video Frame
|
||||
|
||||
@ -354,21 +346,12 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
|
||||
|
||||
**Description:** A set of nodes for Metadata. Collect Metadata from within an `iterate` node & extract metadata from an image.
|
||||
|
||||
- `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node
|
||||
- `Metadata From Image` - Provides Metadata from an image
|
||||
- `Metadata To String` - Extracts a String value of a label from metadata
|
||||
- `Metadata To Integer` - Extracts an Integer value of a label from metadata
|
||||
- `Metadata To Float` - Extracts a Float value of a label from metadata
|
||||
- `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata
|
||||
- `Metadata To Bool` - Extracts Bool types from metadata
|
||||
- `Metadata To Model` - Extracts model types from metadata
|
||||
- `Metadata To SDXL Model` - Extracts SDXL model types from metadata
|
||||
- `Metadata To LoRAs` - Extracts Loras from metadata.
|
||||
- `Metadata To SDXL LoRAs` - Extracts SDXL Loras from metadata
|
||||
- `Metadata To ControlNets` - Extracts ControNets from metadata
|
||||
- `Metadata To IP-Adapters` - Extracts IP-Adapters from metadata
|
||||
- `Metadata To T2I-Adapters` - Extracts T2I-Adapters from metadata
|
||||
- `Denoise Latents + Metadata` - This is an inherited version of the existing `Denoise Latents` node but with a metadata input and output.
|
||||
- `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node.
|
||||
- `Metadata From Image` - Provides Metadata from an image.
|
||||
- `Metadata To String` - Extracts a String value of a label from metadata.
|
||||
- `Metadata To Integer` - Extracts an Integer value of a label from metadata.
|
||||
- `Metadata To Float` - Extracts a Float value of a label from metadata.
|
||||
- `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata.
|
||||
|
||||
**Node Link:** https://github.com/skunkworxdark/metadata-linked-nodes
|
||||
|
||||
|
@ -19,8 +19,6 @@ their descriptions.
|
||||
| Conditioning Primitive | A conditioning tensor primitive value |
|
||||
| Content Shuffle Processor | Applies content shuffle processing to image |
|
||||
| ControlNet | Collects ControlNet info to pass to other nodes |
|
||||
| Create Denoise Mask | Converts a greyscale or transparency image into a mask for denoising. |
|
||||
| Create Gradient Mask | Creates a mask for Gradient ("soft", "differential") inpainting that gradually expands during denoising. Improves edge coherence. |
|
||||
| Denoise Latents | Denoises noisy latents to decodable images |
|
||||
| Divide Integers | Divides two numbers |
|
||||
| Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator |
|
||||
|
@ -2,18 +2,22 @@
|
||||
|
||||
set -e
|
||||
|
||||
BCYAN="\033[1;36m"
|
||||
BYELLOW="\033[1;33m"
|
||||
BGREEN="\033[1;32m"
|
||||
BRED="\033[1;31m"
|
||||
RED="\033[31m"
|
||||
RESET="\033[0m"
|
||||
BCYAN="\e[1;36m"
|
||||
BYELLOW="\e[1;33m"
|
||||
BGREEN="\e[1;32m"
|
||||
BRED="\e[1;31m"
|
||||
RED="\e[31m"
|
||||
RESET="\e[0m"
|
||||
|
||||
function is_bin_in_path {
|
||||
builtin type -P "$1" &>/dev/null
|
||||
}
|
||||
|
||||
function git_show {
|
||||
git show -s --format=oneline --abbrev-commit "$1" | cat
|
||||
}
|
||||
|
||||
if [[ ! -z "${VIRTUAL_ENV}" ]]; then
|
||||
if [[ -v "VIRTUAL_ENV" ]]; then
|
||||
# we can't just call 'deactivate' because this function is not exported
|
||||
# to the environment of this script from the bash process that runs the script
|
||||
echo -e "${BRED}A virtual environment is activated. Please deactivate it before proceeding.${RESET}"
|
||||
@ -22,63 +26,31 @@ fi
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
echo
|
||||
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
|
||||
echo "The current working directory is $(pwd)"
|
||||
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
|
||||
echo
|
||||
|
||||
# Some machines only have `python3` in PATH, others have `python` - make an alias.
|
||||
# We can use a function to approximate an alias within a non-interactive shell.
|
||||
if ! is_bin_in_path python && is_bin_in_path python3; then
|
||||
function python {
|
||||
python3 "$@"
|
||||
}
|
||||
fi
|
||||
|
||||
VERSION=$(
|
||||
cd ..
|
||||
python3 -c "from invokeai.version import __version__ as version; print(version)"
|
||||
python -c "from invokeai.version import __version__ as version; print(version)"
|
||||
)
|
||||
VERSION="v${VERSION}"
|
||||
|
||||
if [[ ! -z ${CI} ]]; then
|
||||
echo
|
||||
echo -e "${BCYAN}CI environment detected${RESET}"
|
||||
echo
|
||||
else
|
||||
echo
|
||||
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
|
||||
echo "The current working directory is $(pwd)"
|
||||
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
|
||||
echo
|
||||
fi
|
||||
PATCH=""
|
||||
VERSION="v${VERSION}${PATCH}"
|
||||
|
||||
echo -e "${BGREEN}HEAD${RESET}:"
|
||||
git_show HEAD
|
||||
echo
|
||||
|
||||
# ---------------------- FRONTEND ----------------------
|
||||
|
||||
pushd ../invokeai/frontend/web >/dev/null
|
||||
echo "Installing frontend dependencies..."
|
||||
echo
|
||||
pnpm i --frozen-lockfile
|
||||
echo
|
||||
if [[ ! -z ${CI} ]]; then
|
||||
echo "Building frontend without checks..."
|
||||
# In CI, we have already done the frontend checks and can just build
|
||||
pnpm vite build
|
||||
else
|
||||
echo "Running checks and building frontend..."
|
||||
# This runs all the frontend checks and builds
|
||||
pnpm build
|
||||
fi
|
||||
echo
|
||||
popd
|
||||
|
||||
# ---------------------- BACKEND ----------------------
|
||||
|
||||
echo
|
||||
echo "Building wheel..."
|
||||
echo
|
||||
|
||||
# install the 'build' package in the user site packages, if needed
|
||||
# could be improved by using a temporary venv, but it's tiny and harmless
|
||||
if [[ $(python3 -c 'from importlib.util import find_spec; print(find_spec("build") is None)') == "True" ]]; then
|
||||
pip install --user build
|
||||
fi
|
||||
|
||||
rm -rf ../build
|
||||
|
||||
python3 -m build --outdir dist/ ../.
|
||||
|
||||
# ----------------------
|
||||
|
||||
echo
|
||||
@ -106,28 +78,10 @@ chmod a+x InvokeAI-Installer/install.sh
|
||||
cp install.bat.in InvokeAI-Installer/install.bat
|
||||
cp WinLongPathsEnabled.reg InvokeAI-Installer/
|
||||
|
||||
FILENAME=InvokeAI-installer-$VERSION.zip
|
||||
|
||||
# Zip everything up
|
||||
zip -r ${FILENAME} InvokeAI-Installer
|
||||
zip -r InvokeAI-installer-$VERSION.zip InvokeAI-Installer
|
||||
|
||||
echo
|
||||
echo -e "${BGREEN}Built installer: ./${FILENAME}${RESET}"
|
||||
echo -e "${BGREEN}Built PyPi distribution: ./dist${RESET}"
|
||||
|
||||
# clean up, but only if we are not in a github action
|
||||
if [[ -z ${CI} ]]; then
|
||||
echo
|
||||
echo "Cleaning up intermediate build files..."
|
||||
rm -rf InvokeAI-Installer tmp ../invokeai/frontend/web/dist/
|
||||
fi
|
||||
|
||||
if [[ ! -z ${CI} ]]; then
|
||||
echo
|
||||
echo "Setting GitHub action outputs..."
|
||||
echo "INSTALLER_FILENAME=${FILENAME}" >>$GITHUB_OUTPUT
|
||||
echo "INSTALLER_PATH=installer/${FILENAME}" >>$GITHUB_OUTPUT
|
||||
echo "DIST_PATH=installer/dist/" >>$GITHUB_OUTPUT
|
||||
fi
|
||||
# clean up
|
||||
rm -rf InvokeAI-Installer tmp dist ../invokeai/frontend/web/dist/
|
||||
|
||||
exit 0
|
||||
|
@ -149,6 +149,9 @@ class Installer:
|
||||
# install the launch/update scripts into the runtime directory
|
||||
self.instance.install_user_scripts()
|
||||
|
||||
# run through the configuration flow
|
||||
self.instance.configure()
|
||||
|
||||
|
||||
class InvokeAiInstance:
|
||||
"""
|
||||
@ -239,6 +242,53 @@ class InvokeAiInstance:
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
def configure(self):
|
||||
"""
|
||||
Configure the InvokeAI runtime directory
|
||||
"""
|
||||
|
||||
auto_install = False
|
||||
# set sys.argv to a consistent state
|
||||
new_argv = [sys.argv[0]]
|
||||
for i in range(1, len(sys.argv)):
|
||||
el = sys.argv[i]
|
||||
if el in ["-r", "--root"]:
|
||||
new_argv.append(el)
|
||||
new_argv.append(sys.argv[i + 1])
|
||||
elif el in ["-y", "--yes", "--yes-to-all"]:
|
||||
auto_install = True
|
||||
sys.argv = new_argv
|
||||
|
||||
import messages
|
||||
import requests # to catch download exceptions
|
||||
|
||||
auto_install = auto_install or messages.user_wants_auto_configuration()
|
||||
if auto_install:
|
||||
sys.argv.append("--yes")
|
||||
else:
|
||||
messages.introduction()
|
||||
|
||||
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
||||
|
||||
# NOTE: currently the config script does its own arg parsing! this means the command-line switches
|
||||
# from the installer will also automatically propagate down to the config script.
|
||||
# this may change in the future with config refactoring!
|
||||
succeeded = False
|
||||
try:
|
||||
invokeai_configure()
|
||||
succeeded = True
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
print(f"\nA network error was encountered during configuration and download: {str(e)}")
|
||||
except OSError as e:
|
||||
print(f"\nAn OS error was encountered during configuration and download: {str(e)}")
|
||||
except Exception as e:
|
||||
print(f"\nA problem was encountered during the configuration and download steps: {str(e)}")
|
||||
finally:
|
||||
if not succeeded:
|
||||
print('To try again, find the "invokeai" directory, run the script "invoke.sh" or "invoke.bat"')
|
||||
print("and choose option 7 to fix a broken install, optionally followed by option 5 to install models.")
|
||||
print("Alternatively you can relaunch the installer.")
|
||||
|
||||
def install_user_scripts(self):
|
||||
"""
|
||||
Copy the launch and update scripts to the runtime dir
|
||||
|
@ -8,7 +8,7 @@ import platform
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit import HTML, prompt
|
||||
from prompt_toolkit.completion import FuzzyWordCompleter, PathCompleter
|
||||
from prompt_toolkit.validation import Validator
|
||||
from rich import box, print
|
||||
@ -98,6 +98,39 @@ def choose_version(available_releases: tuple | None = None) -> str:
|
||||
return "stable" if response == "" else response
|
||||
|
||||
|
||||
def user_wants_auto_configuration() -> bool:
|
||||
"""Prompt the user to choose between manual and auto configuration."""
|
||||
console.rule("InvokeAI Configuration Section")
|
||||
console.print(
|
||||
Panel(
|
||||
Group(
|
||||
"\n".join(
|
||||
[
|
||||
"Libraries are installed and InvokeAI will now set up its root directory and configuration. Choose between:",
|
||||
"",
|
||||
" * AUTOMATIC configuration: install reasonable defaults and a minimal set of starter models.",
|
||||
" * MANUAL configuration: manually inspect and adjust configuration options and pick from a larger set of starter models.",
|
||||
"",
|
||||
"Later you can fine tune your configuration by selecting option [6] 'Change InvokeAI startup options' from the invoke.bat/invoke.sh launcher script.",
|
||||
]
|
||||
),
|
||||
),
|
||||
box=box.MINIMAL,
|
||||
padding=(1, 1),
|
||||
)
|
||||
)
|
||||
choice = (
|
||||
prompt(
|
||||
HTML("Choose <b><a></b>utomatic or <b><m></b>anual configuration [a/m] (a): "),
|
||||
validator=Validator.from_callable(
|
||||
lambda n: n == "" or n.startswith(("a", "A", "m", "M")), error_message="Please select 'a' or 'm'"
|
||||
),
|
||||
)
|
||||
or "a"
|
||||
)
|
||||
return choice.lower().startswith("a")
|
||||
|
||||
|
||||
def confirm_install(dest: Path) -> bool:
|
||||
if dest.exists():
|
||||
print(f":stop_sign: Directory {dest} already exists!")
|
||||
@ -318,6 +351,34 @@ def windows_long_paths_registry() -> None:
|
||||
)
|
||||
|
||||
|
||||
def introduction() -> None:
|
||||
"""
|
||||
Display a banner when starting configuration of the InvokeAI application
|
||||
"""
|
||||
|
||||
console.rule()
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
title=":art: Configuring InvokeAI :art:",
|
||||
renderable=Group(
|
||||
"",
|
||||
"[b]This script will:",
|
||||
"",
|
||||
"1. Configure the InvokeAI application directory",
|
||||
"2. Help download the Stable Diffusion weight files",
|
||||
" and other large models that are needed for text to image generation",
|
||||
"3. Create initial configuration files.",
|
||||
"",
|
||||
"[i]At any point you may interrupt this program and resume later.",
|
||||
"",
|
||||
"[b]For the best user experience, please enlarge or maximize this window",
|
||||
),
|
||||
)
|
||||
)
|
||||
console.line(2)
|
||||
|
||||
|
||||
def _platform_specific_help() -> Text | None:
|
||||
if OS == "Darwin":
|
||||
text = Text.from_markup(
|
||||
|
@ -2,12 +2,12 @@
|
||||
|
||||
set -e
|
||||
|
||||
BCYAN="\033[1;36m"
|
||||
BYELLOW="\033[1;33m"
|
||||
BGREEN="\033[1;32m"
|
||||
BRED="\033[1;31m"
|
||||
RED="\033[31m"
|
||||
RESET="\033[0m"
|
||||
BCYAN="\e[1;36m"
|
||||
BYELLOW="\e[1;33m"
|
||||
BGREEN="\e[1;32m"
|
||||
BRED="\e[1;31m"
|
||||
RED="\e[31m"
|
||||
RESET="\e[0m"
|
||||
|
||||
function does_tag_exist {
|
||||
git rev-parse --quiet --verify "refs/tags/$1" >/dev/null
|
||||
@ -23,40 +23,49 @@ function git_show {
|
||||
|
||||
VERSION=$(
|
||||
cd ..
|
||||
python3 -c "from invokeai.version import __version__ as version; print(version)"
|
||||
python -c "from invokeai.version import __version__ as version; print(version)"
|
||||
)
|
||||
PATCH=""
|
||||
MAJOR_VERSION=$(echo $VERSION | sed 's/\..*$//')
|
||||
VERSION="v${VERSION}${PATCH}"
|
||||
LATEST_TAG="v${MAJOR_VERSION}-latest"
|
||||
|
||||
if does_tag_exist $VERSION; then
|
||||
echo -e "${BCYAN}${VERSION}${RESET} already exists:"
|
||||
git_show_ref tags/$VERSION
|
||||
echo
|
||||
fi
|
||||
if does_tag_exist $LATEST_TAG; then
|
||||
echo -e "${BCYAN}${LATEST_TAG}${RESET} already exists:"
|
||||
git_show_ref tags/$LATEST_TAG
|
||||
echo
|
||||
fi
|
||||
|
||||
echo -e "${BGREEN}HEAD${RESET}:"
|
||||
git_show
|
||||
echo
|
||||
|
||||
echo -e "${BGREEN}git remote -v${RESET}:"
|
||||
git remote -v
|
||||
echo
|
||||
|
||||
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on origin remote${RESET}? "
|
||||
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} and ${BCYAN}${LATEST_TAG}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on remote${RESET}? "
|
||||
read -e -p 'y/n [n]: ' input
|
||||
RESPONSE=${input:='n'}
|
||||
if [ "$RESPONSE" == 'y' ]; then
|
||||
echo
|
||||
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on origin remote..."
|
||||
git push origin :refs/tags/$VERSION
|
||||
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on remote..."
|
||||
git push --delete origin $VERSION
|
||||
|
||||
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} on locally..."
|
||||
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} locally..."
|
||||
if ! git tag -fa $VERSION; then
|
||||
echo "Existing/invalid tag"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
echo -e "Pushing updated tags to origin remote..."
|
||||
echo -e "Deleting ${BCYAN}${LATEST_TAG}${RESET} tag on remote..."
|
||||
git push --delete origin $LATEST_TAG
|
||||
|
||||
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${LATEST_TAG}${RESET} locally..."
|
||||
git tag -fa $LATEST_TAG
|
||||
|
||||
echo -e "Pushing updated tags to remote..."
|
||||
git push origin --tags
|
||||
fi
|
||||
exit 0
|
||||
|
@ -9,10 +9,15 @@ set INVOKEAI_ROOT=.
|
||||
:start
|
||||
echo Desired action:
|
||||
echo 1. Generate images with the browser-based interface
|
||||
echo 2. Open the developer console
|
||||
echo 3. Update InvokeAI (DEPRECATED - please use the installer)
|
||||
echo 4. Run the InvokeAI image database maintenance script
|
||||
echo 5. Command-line help
|
||||
echo 2. Run textual inversion training
|
||||
echo 3. Merge models (diffusers type only)
|
||||
echo 4. Download and install models
|
||||
echo 5. Change InvokeAI startup options
|
||||
echo 6. Re-run the configure script to fix a broken install or to complete a major upgrade
|
||||
echo 7. Open the developer console
|
||||
echo 8. Update InvokeAI (DEPRECATED - please use the installer)
|
||||
echo 9. Run the InvokeAI image database maintenance script
|
||||
echo 10. Command-line help
|
||||
echo Q - Quit
|
||||
set /P choice="Please enter 1-10, Q: [1] "
|
||||
if not defined choice set choice=1
|
||||
@ -20,6 +25,21 @@ IF /I "%choice%" == "1" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
python .venv\Scripts\invokeai-web.exe %*
|
||||
) ELSE IF /I "%choice%" == "2" (
|
||||
echo Starting textual inversion training..
|
||||
python .venv\Scripts\invokeai-ti.exe --gui
|
||||
) ELSE IF /I "%choice%" == "3" (
|
||||
echo Starting model merging script..
|
||||
python .venv\Scripts\invokeai-merge.exe --gui
|
||||
) ELSE IF /I "%choice%" == "4" (
|
||||
echo Running invokeai-model-install...
|
||||
python .venv\Scripts\invokeai-model-install.exe
|
||||
) ELSE IF /I "%choice%" == "5" (
|
||||
echo Running invokeai-configure...
|
||||
python .venv\Scripts\invokeai-configure.exe --skip-sd-weight --skip-support-models
|
||||
) ELSE IF /I "%choice%" == "6" (
|
||||
echo Running invokeai-configure...
|
||||
python .venv\Scripts\invokeai-configure.exe --yes --skip-sd-weight
|
||||
) ELSE IF /I "%choice%" == "7" (
|
||||
echo Developer Console
|
||||
echo Python command is:
|
||||
where python
|
||||
@ -31,15 +51,15 @@ IF /I "%choice%" == "1" (
|
||||
echo *************************
|
||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||
call cmd /k
|
||||
) ELSE IF /I "%choice%" == "3" (
|
||||
) ELSE IF /I "%choice%" == "8" (
|
||||
echo UPDATING FROM WITHIN THE APP IS BEING DEPRECATED.
|
||||
echo Please download the installer from https://github.com/invoke-ai/InvokeAI/releases/latest and run it to update your installation.
|
||||
timeout 4
|
||||
python -m invokeai.frontend.install.invokeai_update
|
||||
) ELSE IF /I "%choice%" == "4" (
|
||||
) ELSE IF /I "%choice%" == "9" (
|
||||
echo Running the db maintenance script...
|
||||
python .venv\Scripts\invokeai-db-maintenance.exe
|
||||
) ELSE IF /I "%choice%" == "5" (
|
||||
) ELSE IF /I "%choice%" == "10" (
|
||||
echo Displaying command line help...
|
||||
python .venv\Scripts\invokeai-web.exe --help %*
|
||||
pause
|
||||
|
@ -58,24 +58,49 @@ do_choice() {
|
||||
invokeai-web $PARAMS
|
||||
;;
|
||||
2)
|
||||
clear
|
||||
printf "Textual inversion training\n"
|
||||
invokeai-ti --gui $PARAMS
|
||||
;;
|
||||
3)
|
||||
clear
|
||||
printf "Merge models (diffusers type only)\n"
|
||||
invokeai-merge --gui $PARAMS
|
||||
;;
|
||||
4)
|
||||
clear
|
||||
printf "Download and install models\n"
|
||||
invokeai-model-install --root ${INVOKEAI_ROOT}
|
||||
;;
|
||||
5)
|
||||
clear
|
||||
printf "Change InvokeAI startup options\n"
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --skip-sd-weights --skip-support-models
|
||||
;;
|
||||
6)
|
||||
clear
|
||||
printf "Re-run the configure script to fix a broken install or to complete a major upgrade\n"
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only --skip-sd-weights
|
||||
;;
|
||||
7)
|
||||
clear
|
||||
printf "Open the developer console\n"
|
||||
file_name=$(basename "${BASH_SOURCE[0]}")
|
||||
bash --init-file "$file_name"
|
||||
;;
|
||||
3)
|
||||
8)
|
||||
clear
|
||||
printf "UPDATING FROM WITHIN THE APP IS BEING DEPRECATED\n"
|
||||
printf "Please download the installer from https://github.com/invoke-ai/InvokeAI/releases/latest and run it to update your installation.\n"
|
||||
sleep 4
|
||||
python -m invokeai.frontend.install.invokeai_update
|
||||
;;
|
||||
4)
|
||||
9)
|
||||
clear
|
||||
printf "Running the db maintenance script\n"
|
||||
invokeai-db-maintenance --root ${INVOKEAI_ROOT}
|
||||
;;
|
||||
5)
|
||||
10)
|
||||
clear
|
||||
printf "Command-line help\n"
|
||||
invokeai-web --help
|
||||
@ -93,10 +118,15 @@ do_choice() {
|
||||
do_dialog() {
|
||||
options=(
|
||||
1 "Generate images with a browser-based interface"
|
||||
2 "Open the developer console"
|
||||
3 "Update InvokeAI (DEPRECATED - please use the installer)"
|
||||
4 "Run the InvokeAI image database maintenance script"
|
||||
5 "Command-line help"
|
||||
2 "Textual inversion training"
|
||||
3 "Merge models (diffusers type only)"
|
||||
4 "Download and install models"
|
||||
5 "Change InvokeAI startup options"
|
||||
6 "Re-run the configure script to fix a broken install or to complete a major upgrade"
|
||||
7 "Open the developer console"
|
||||
8 "Update InvokeAI (DEPRECATED - please use the installer)"
|
||||
9 "Run the InvokeAI image database maintenance script"
|
||||
10 "Command-line help"
|
||||
)
|
||||
|
||||
choice=$(dialog --clear \
|
||||
@ -121,10 +151,15 @@ do_line_input() {
|
||||
printf " ** For a more attractive experience, please install the 'dialog' utility using your package manager. **\n\n"
|
||||
printf "What would you like to do?\n"
|
||||
printf "1: Generate images using the browser-based interface\n"
|
||||
printf "2: Open the developer console\n"
|
||||
printf "3: Update InvokeAI\n"
|
||||
printf "4: Run the InvokeAI image database maintenance script\n"
|
||||
printf "5: Command-line help\n"
|
||||
printf "2: Run textual inversion training\n"
|
||||
printf "3: Merge models (diffusers type only)\n"
|
||||
printf "4: Download and install models\n"
|
||||
printf "5: Change InvokeAI startup options\n"
|
||||
printf "6: Re-run the configure script to fix a broken install\n"
|
||||
printf "7: Open the developer console\n"
|
||||
printf "8: Update InvokeAI\n"
|
||||
printf "9: Run the InvokeAI image database maintenance script\n"
|
||||
printf "10: Command-line help\n"
|
||||
printf "Q: Quit\n\n"
|
||||
read -p "Please enter 1-10, Q: [1] " yn
|
||||
choice=${yn:='1'}
|
||||
|
11
invokeai/README
Normal file
11
invokeai/README
Normal file
@ -0,0 +1,11 @@
|
||||
Organization of the source tree:
|
||||
|
||||
app -- Home of nodes invocations and services
|
||||
assets -- Images and other data files used by InvokeAI
|
||||
backend -- Non-user facing libraries, including the rendering
|
||||
core.
|
||||
configs -- Configuration files used at install and run times
|
||||
frontend -- User-facing scripts, including the CLI and the WebUI
|
||||
version -- Current InvokeAI version string, stored
|
||||
in version/invokeai_version.py
|
||||
|
@ -25,8 +25,8 @@ from ..services.invocation_cache.invocation_cache_memory import MemoryInvocation
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.model_images.model_images_default import ModelImageFileStorageDisk
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_metadata import ModelMetadataStoreSQL
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
@ -64,15 +64,14 @@ class ApiDependencies:
|
||||
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
|
||||
logger.info(f"InvokeAI version {__version__}")
|
||||
logger.info(f"Root directory = {str(config.root_path)}")
|
||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||
|
||||
output_folder = config.outputs_path
|
||||
output_folder = config.output_path
|
||||
if output_folder is None:
|
||||
raise ValueError("Output folder is not set")
|
||||
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
model_images_folder = config.models_path
|
||||
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
|
||||
configuration = config
|
||||
@ -94,10 +93,10 @@ class ApiDependencies:
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
model_metadata_service = ModelMetadataStoreSQL(db=db)
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
model_record_service=ModelRecordServiceSQL(db=db),
|
||||
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
|
||||
download_queue=download_queue_service,
|
||||
events=events,
|
||||
)
|
||||
@ -121,7 +120,6 @@ class ApiDependencies:
|
||||
images=images,
|
||||
invocation_cache=invocation_cache,
|
||||
logger=logger,
|
||||
model_images=model_images_service,
|
||||
model_manager=model_manager,
|
||||
download_queue=download_queue_service,
|
||||
names=names,
|
||||
|
@ -12,6 +12,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
from invokeai.backend.util.logging import logging
|
||||
@ -113,7 +114,9 @@ async def get_config() -> AppConfig:
|
||||
if SafetyChecker.safety_checker_available():
|
||||
nsfw_methods.append("nsfw_checker")
|
||||
|
||||
watermarking_methods = ["invisible_watermark"]
|
||||
watermarking_methods = []
|
||||
if InvisibleWatermark.invisible_watermark_available():
|
||||
watermarking_methods.append("invisible_watermark")
|
||||
|
||||
return AppConfig(
|
||||
infill_methods=infill_methods,
|
||||
|
@ -1,31 +1,27 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""FastAPI route for model configuration records."""
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import huggingface_hub
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelRecordOrderBy,
|
||||
ModelSummary,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
|
||||
from invokeai.app.util.suppress_output import SuppressOutput
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@ -34,18 +30,14 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.model_manager.starter_models import STARTER_MODELS, StarterModel
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
|
||||
|
||||
# images are immutable; set a high max-age
|
||||
IMAGE_MAX_AGE = 31536000
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
"""Return list of configs."""
|
||||
@ -55,6 +47,15 @@ class ModelsList(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class ModelTagSet(BaseModel):
|
||||
"""Return tags for a set of models."""
|
||||
|
||||
key: str
|
||||
name: str
|
||||
author: str
|
||||
tags: Set[str]
|
||||
|
||||
|
||||
##############################################################################
|
||||
# These are example inputs and outputs that are used in places where Swagger
|
||||
# is unable to generate a correct example.
|
||||
@ -65,16 +66,19 @@ example_model_config = {
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config_path": "string",
|
||||
"config": "string",
|
||||
"key": "string",
|
||||
"hash": "string",
|
||||
"original_hash": "string",
|
||||
"current_hash": "string",
|
||||
"description": "string",
|
||||
"source": "string",
|
||||
"converted_at": 0,
|
||||
"last_modified": 0,
|
||||
"vae": "string",
|
||||
"variant": "normal",
|
||||
"prediction_type": "epsilon",
|
||||
"repo_variant": "fp16",
|
||||
"upcast_attention": False,
|
||||
"ztsnr_training": False,
|
||||
}
|
||||
|
||||
example_model_input = {
|
||||
@ -83,12 +87,50 @@ example_model_input = {
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config_path": "configs/stable-diffusion/v1-inference.yaml",
|
||||
"config": "configs/stable-diffusion/v1-inference.yaml",
|
||||
"description": "Model description",
|
||||
"vae": None,
|
||||
"variant": "normal",
|
||||
}
|
||||
|
||||
example_model_metadata = {
|
||||
"name": "ip_adapter_sd_image_encoder",
|
||||
"author": "InvokeAI",
|
||||
"tags": [
|
||||
"transformers",
|
||||
"safetensors",
|
||||
"clip_vision_model",
|
||||
"endpoints_compatible",
|
||||
"region:us",
|
||||
"has_space",
|
||||
"license:apache-2.0",
|
||||
],
|
||||
"files": [
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
|
||||
"path": "ip_adapter_sd_image_encoder/README.md",
|
||||
"size": 628,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
|
||||
"path": "ip_adapter_sd_image_encoder/config.json",
|
||||
"size": 560,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
|
||||
"path": "ip_adapter_sd_image_encoder/model.safetensors",
|
||||
"size": 2528373448,
|
||||
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
|
||||
},
|
||||
],
|
||||
"type": "huggingface",
|
||||
"id": "InvokeAI/ip_adapter_sd_image_encoder",
|
||||
"tag_dict": {"license": "apache-2.0"},
|
||||
"last_modified": "2023-09-23T17:33:25Z",
|
||||
}
|
||||
|
||||
##############################################################################
|
||||
# ROUTES
|
||||
##############################################################################
|
||||
@ -120,9 +162,6 @@ async def list_model_records(
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||
)
|
||||
for model in found_models:
|
||||
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key)
|
||||
model.cover_image = cover_image
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@ -166,23 +205,53 @@ async def get_model_record(
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
config: AnyModelConfig = record_store.get_model(key)
|
||||
cover_image = ApiDependencies.invoker.services.model_images.get_url(key)
|
||||
config.cover_image = cover_image
|
||||
return config
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
# @model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
# async def list_model_summary(
|
||||
# page: int = Query(default=0, description="The page to get"),
|
||||
# per_page: int = Query(default=10, description="The number of models per page"),
|
||||
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
# ) -> PaginatedResults[ModelSummary]:
|
||||
# """Gets a page of model summary data."""
|
||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
# return results
|
||||
@model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
async def list_model_summary(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of models per page"),
|
||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Gets a page of model summary data."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
return results
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/i/{key}/metadata",
|
||||
operation_id="get_model_metadata",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model metadata was retrieved successfully",
|
||||
"content": {"application/json": {"example": example_model_metadata}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def get_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Get a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
)
|
||||
async def list_tags() -> Set[str]:
|
||||
"""Get a unique set of all the model tags."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Set[str] = record_store.list_tags()
|
||||
return result
|
||||
|
||||
|
||||
class FoundModel(BaseModel):
|
||||
@ -254,38 +323,17 @@ async def scan_for_models(
|
||||
return scan_results
|
||||
|
||||
|
||||
class HuggingFaceModels(BaseModel):
|
||||
urls: List[AnyHttpUrl] | None = Field(description="URLs for all checkpoint format models in the metadata")
|
||||
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model")
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/hugging_face",
|
||||
operation_id="get_hugging_face_models",
|
||||
responses={
|
||||
200: {"description": "Hugging Face repo scanned successfully"},
|
||||
400: {"description": "Invalid hugging face repo"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=HuggingFaceModels,
|
||||
"/tags/search",
|
||||
operation_id="search_by_metadata_tags",
|
||||
)
|
||||
async def get_hugging_face_models(
|
||||
hugging_face_repo: str = Query(description="Hugging face repo to search for models", default=None),
|
||||
) -> HuggingFaceModels:
|
||||
try:
|
||||
metadata = HuggingFaceMetadataFetch().from_id(hugging_face_repo)
|
||||
except UnknownMetadataException:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No HuggingFace repository found",
|
||||
)
|
||||
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
|
||||
return HuggingFaceModels(
|
||||
urls=metadata.ckpt_urls,
|
||||
is_diffusers=metadata.is_diffusers,
|
||||
)
|
||||
async def search_by_metadata_tags(
|
||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results = record_store.search_by_metadata_tag(tags)
|
||||
return ModelsList(models=results)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
@ -304,13 +352,15 @@ async def get_hugging_face_models(
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
|
||||
info: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
) -> AnyModelConfig:
|
||||
"""Update a model's config."""
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
model_response: AnyModelConfig = record_store.update_model(key, changes=changes)
|
||||
model_response: AnyModelConfig = record_store.update_model(key, config=info)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -320,85 +370,16 @@ async def update_model_record(
|
||||
return model_response
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/i/{key}/image",
|
||||
operation_id="get_model_image",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model image was fetched successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model image could not be found"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def get_model_image(
|
||||
key: str = Path(description="The name of model image file to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets an image file that previews the model"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.model_images.get_path(key)
|
||||
|
||||
response = FileResponse(
|
||||
path,
|
||||
media_type="image/png",
|
||||
filename=key + ".png",
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}/image",
|
||||
operation_id="update_model_image",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model image was updated successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def update_model_image(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
image: UploadFile,
|
||||
) -> None:
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await image.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_images = ApiDependencies.invoker.services.model_images
|
||||
try:
|
||||
model_images.save(pil_image, key)
|
||||
logger.info(f"Updated image for model: {key}")
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="delete_model",
|
||||
operation_id="del_model_record",
|
||||
responses={
|
||||
204: {"description": "Model deleted successfully"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_model(
|
||||
async def del_model_record(
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
) -> Response:
|
||||
"""
|
||||
@ -419,62 +400,42 @@ async def delete_model(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}/image",
|
||||
operation_id="delete_model_image",
|
||||
@model_manager_router.post(
|
||||
"/i/",
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
204: {"description": "Model image deleted successfully"},
|
||||
404: {"description": "Model image not found"},
|
||||
201: {
|
||||
"description": "The model added successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=204,
|
||||
status_code=201,
|
||||
)
|
||||
async def delete_model_image(
|
||||
key: str = Path(description="Unique key of model image to remove from model_images directory."),
|
||||
) -> None:
|
||||
async def add_model_record(
|
||||
config: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
) -> AnyModelConfig:
|
||||
"""Add a model using the configuration information appropriate for its type."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_images = ApiDependencies.invoker.services.model_images
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
if config.key == "<NOKEY>":
|
||||
config.key = sha1(randbytes(100)).hexdigest()
|
||||
logger.info(f"Created model {config.key} for {config.name}")
|
||||
try:
|
||||
model_images.delete(key)
|
||||
logger.info(f"Deleted model image: {key}")
|
||||
return
|
||||
except UnknownModelException as e:
|
||||
record_store.add_model(config.key, config)
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
|
||||
# @model_manager_router.post(
|
||||
# "/i/",
|
||||
# operation_id="add_model_record",
|
||||
# responses={
|
||||
# 201: {
|
||||
# "description": "The model added successfully",
|
||||
# "content": {"application/json": {"example": example_model_config}},
|
||||
# },
|
||||
# 409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
# 415: {"description": "Unrecognized file/folder format"},
|
||||
# },
|
||||
# status_code=201,
|
||||
# )
|
||||
# async def add_model_record(
|
||||
# config: Annotated[
|
||||
# AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
# ],
|
||||
# ) -> AnyModelConfig:
|
||||
# """Add a model using the configuration information appropriate for its type."""
|
||||
# logger = ApiDependencies.invoker.services.logger
|
||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
# try:
|
||||
# record_store.add_model(config)
|
||||
# except DuplicateModelException as e:
|
||||
# logger.error(str(e))
|
||||
# raise HTTPException(status_code=409, detail=str(e))
|
||||
# except InvalidModelException as e:
|
||||
# logger.error(str(e))
|
||||
# raise HTTPException(status_code=415)
|
||||
|
||||
# # now fetch it out
|
||||
# result: AnyModelConfig = record_store.get_model(config.key)
|
||||
# return result
|
||||
# now fetch it out
|
||||
result: AnyModelConfig = record_store.get_model(config.key)
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
@ -490,7 +451,6 @@ async def delete_model_image(
|
||||
)
|
||||
async def install_model(
|
||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
||||
# TODO(MM2): Can we type this?
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
@ -533,7 +493,6 @@ async def install_model(
|
||||
source=source,
|
||||
config=config,
|
||||
access_token=access_token,
|
||||
inplace=bool(inplace),
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
@ -549,10 +508,10 @@ async def install_model(
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/install",
|
||||
operation_id="list_model_installs",
|
||||
"/import",
|
||||
operation_id="list_model_install_jobs",
|
||||
)
|
||||
async def list_model_installs() -> List[ModelInstallJob]:
|
||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
"""Return the list of model install jobs.
|
||||
|
||||
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
||||
@ -566,8 +525,9 @@ async def list_model_installs() -> List[ModelInstallJob]:
|
||||
* "cancelled" -- Job was cancelled before completion.
|
||||
|
||||
Once completed, information about the model such as its size, base
|
||||
model and type can be retrieved from the `config_out` field. For multi-file models such as diffusers,
|
||||
information on individual files can be retrieved from `download_parts`.
|
||||
model, type, and metadata can be retrieved from the `config_out`
|
||||
field. For multi-file models such as diffusers, information on individual files
|
||||
can be retrieved from `download_parts`.
|
||||
|
||||
See the example and schema below for more information.
|
||||
"""
|
||||
@ -576,7 +536,7 @@ async def list_model_installs() -> List[ModelInstallJob]:
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/install/{id}",
|
||||
"/import/{id}",
|
||||
operation_id="get_model_install_job",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
@ -596,7 +556,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/install/{id}",
|
||||
"/import/{id}",
|
||||
operation_id="cancel_model_install_job",
|
||||
responses={
|
||||
201: {"description": "The job was cancelled successfully"},
|
||||
@ -614,8 +574,8 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
||||
installer.cancel_job(job)
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/install",
|
||||
@model_manager_router.patch(
|
||||
"/import",
|
||||
operation_id="prune_model_install_jobs",
|
||||
responses={
|
||||
204: {"description": "All completed and errored jobs have been pruned"},
|
||||
@ -685,7 +645,7 @@ async def convert_model(
|
||||
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||
|
||||
# loading the model will convert it into a cached diffusers file
|
||||
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
|
||||
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
|
||||
|
||||
# Get the path of the converted model from the loader
|
||||
cache_path = loader.convert_cache.cache_path(key)
|
||||
@ -694,8 +654,7 @@ async def convert_model(
|
||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||
original_name = model_config.name
|
||||
model_config.name = f"{original_name}.DELETE"
|
||||
changes = ModelRecordChanges(name=model_config.name)
|
||||
store.update_model(key, changes=changes)
|
||||
store.update_model(key, config=model_config)
|
||||
|
||||
# install the diffusers
|
||||
try:
|
||||
@ -704,7 +663,7 @@ async def convert_model(
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"hash": model_config.hash,
|
||||
"original_hash": model_config.original_hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
)
|
||||
@ -712,6 +671,10 @@ async def convert_model(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
# get the original metadata
|
||||
if orig_metadata := store.get_metadata(key):
|
||||
store.metadata_store.add_metadata(new_key, orig_metadata)
|
||||
|
||||
# delete the original safetensors file
|
||||
installer.delete(key)
|
||||
|
||||
@ -723,132 +686,66 @@ async def convert_model(
|
||||
return new_config
|
||||
|
||||
|
||||
# @model_manager_router.put(
|
||||
# "/merge",
|
||||
# operation_id="merge",
|
||||
# responses={
|
||||
# 200: {
|
||||
# "description": "Model converted successfully",
|
||||
# "content": {"application/json": {"example": example_model_config}},
|
||||
# },
|
||||
# 400: {"description": "Bad request"},
|
||||
# 404: {"description": "Model not found"},
|
||||
# 409: {"description": "There is already a model registered at this location"},
|
||||
# },
|
||||
# )
|
||||
# async def merge(
|
||||
# keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
# merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
# alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
# force: bool = Body(
|
||||
# description="Force merging of models created with different versions of diffusers",
|
||||
# default=False,
|
||||
# ),
|
||||
# interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||
# merge_dest_directory: Optional[str] = Body(
|
||||
# description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
# default=None,
|
||||
# ),
|
||||
# ) -> AnyModelConfig:
|
||||
# """
|
||||
# Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||
# ```
|
||||
# Argument Description [default]
|
||||
# -------- ----------------------
|
||||
# keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||
# merged_model_name Name for the merged model [Concat model names]
|
||||
# alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||
# force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||
# interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
# merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||
# ```
|
||||
# """
|
||||
# logger = ApiDependencies.invoker.services.logger
|
||||
# try:
|
||||
# logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
# dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
# installer = ApiDependencies.invoker.services.model_manager.install
|
||||
# merger = ModelMerger(installer)
|
||||
# model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
# response = merger.merge_diffusion_models_and_save(
|
||||
# model_keys=keys,
|
||||
# merged_model_name=merged_model_name or "+".join(model_names),
|
||||
# alpha=alpha,
|
||||
# interp=interp,
|
||||
# force=force,
|
||||
# merge_dest_directory=dest,
|
||||
# )
|
||||
# except UnknownModelException:
|
||||
# raise HTTPException(
|
||||
# status_code=404,
|
||||
# detail=f"One or more of the models '{keys}' not found",
|
||||
# )
|
||||
# except ValueError as e:
|
||||
# raise HTTPException(status_code=400, detail=str(e))
|
||||
# return response
|
||||
|
||||
|
||||
@model_manager_router.get("/starter_models", operation_id="get_starter_models", response_model=list[StarterModel])
|
||||
async def get_starter_models() -> list[StarterModel]:
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
|
||||
installed_model_sources = {m.source for m in installed_models}
|
||||
starter_models = deepcopy(STARTER_MODELS)
|
||||
for model in starter_models:
|
||||
if model.source in installed_model_sources:
|
||||
model.is_installed = True
|
||||
# Remove already-installed dependencies
|
||||
missing_deps: list[str] = []
|
||||
for dep in model.dependencies or []:
|
||||
if dep not in installed_model_sources:
|
||||
missing_deps.append(dep)
|
||||
model.dependencies = missing_deps
|
||||
|
||||
return starter_models
|
||||
|
||||
|
||||
class HFTokenStatus(str, Enum):
|
||||
VALID = "valid"
|
||||
INVALID = "invalid"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class HFTokenHelper:
|
||||
@classmethod
|
||||
def get_status(cls) -> HFTokenStatus:
|
||||
try:
|
||||
if huggingface_hub.get_token_permission(huggingface_hub.get_token()):
|
||||
# Valid token!
|
||||
return HFTokenStatus.VALID
|
||||
# No token set
|
||||
return HFTokenStatus.INVALID
|
||||
except Exception:
|
||||
return HFTokenStatus.UNKNOWN
|
||||
|
||||
@classmethod
|
||||
def set_token(cls, token: str) -> HFTokenStatus:
|
||||
with SuppressOutput(), contextlib.suppress(Exception):
|
||||
huggingface_hub.login(token=token, add_to_git_credential=False)
|
||||
return cls.get_status()
|
||||
|
||||
|
||||
@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus)
|
||||
async def get_hf_login_status() -> HFTokenStatus:
|
||||
token_status = HFTokenHelper.get_status()
|
||||
|
||||
if token_status is HFTokenStatus.UNKNOWN:
|
||||
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
|
||||
|
||||
return token_status
|
||||
|
||||
|
||||
@model_manager_router.post("/hf_login", operation_id="do_hf_login", response_model=HFTokenStatus)
|
||||
async def do_hf_login(
|
||||
token: str = Body(description="Hugging Face token to use for login", embed=True),
|
||||
) -> HFTokenStatus:
|
||||
HFTokenHelper.set_token(token)
|
||||
token_status = HFTokenHelper.get_status()
|
||||
|
||||
if token_status is HFTokenStatus.UNKNOWN:
|
||||
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
|
||||
|
||||
return token_status
|
||||
@model_manager_router.put(
|
||||
"/merge",
|
||||
operation_id="merge",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Model converted successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "Model not found"},
|
||||
409: {"description": "There is already a model registered at this location"},
|
||||
},
|
||||
)
|
||||
async def merge(
|
||||
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
force: bool = Body(
|
||||
description="Force merging of models created with different versions of diffusers",
|
||||
default=False,
|
||||
),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||
merge_dest_directory: Optional[str] = Body(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
),
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||
```
|
||||
Argument Description [default]
|
||||
-------- ----------------------
|
||||
keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||
merged_model_name Name for the merged model [Concat model names]
|
||||
alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||
force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||
```
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
merger = ModelMerger(installer)
|
||||
model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
response = merger.merge_diffusion_models_and_save(
|
||||
model_keys=keys,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
except UnknownModelException:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"One or more of the models '{keys}' not found",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
@ -1,59 +1,71 @@
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import socket
|
||||
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||
# values from the command line or config file.
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.json_schema import models_json_schema
|
||||
from torch.backends.mps import is_available as is_mps_available
|
||||
|
||||
# for PyCharm:
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import (
|
||||
app_info,
|
||||
board_images,
|
||||
boards,
|
||||
download_queue,
|
||||
images,
|
||||
model_manager,
|
||||
session_queue,
|
||||
utilities,
|
||||
workflows,
|
||||
)
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
UIConfigBase,
|
||||
)
|
||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
app_config = get_config()
|
||||
|
||||
|
||||
if is_mps_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
if app_config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
sys.exit(0)
|
||||
|
||||
if True: # hack to make flake8 happy with imports coming after setting up the config
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import socket
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.json_schema import models_json_schema
|
||||
from torch.backends.mps import is_available as is_mps_available
|
||||
|
||||
# for PyCharm:
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import (
|
||||
app_info,
|
||||
board_images,
|
||||
boards,
|
||||
download_queue,
|
||||
images,
|
||||
model_manager,
|
||||
session_queue,
|
||||
utilities,
|
||||
workflows,
|
||||
)
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
UIConfigBase,
|
||||
)
|
||||
|
||||
if is_mps_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
# fix for windows mimetypes registry entries being borked
|
||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||
@ -144,19 +156,17 @@ def custom_openapi() -> dict[str, Any]:
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
|
||||
|
||||
# Some models don't end up in the schemas as standalone definitions
|
||||
additional_schemas = models_json_schema(
|
||||
# Add Node Editor UI helper schemas
|
||||
ui_config_schemas = models_json_schema(
|
||||
[
|
||||
(UIConfigBase, "serialization"),
|
||||
(InputFieldJSONSchemaExtra, "serialization"),
|
||||
(OutputFieldJSONSchemaExtra, "serialization"),
|
||||
(ModelIdentifierField, "serialization"),
|
||||
(ProgressImage, "serialization"),
|
||||
],
|
||||
ref_template="#/components/schemas/{model}",
|
||||
)
|
||||
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = schema_json
|
||||
for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
@ -233,6 +243,10 @@ def invoke_api() -> None:
|
||||
else:
|
||||
return port
|
||||
|
||||
from invokeai.backend.install.check_root import check_invokeai_root
|
||||
|
||||
check_invokeai_root(app_config) # note, may exit with an exception if root not set up
|
||||
|
||||
if app_config.dev_reload:
|
||||
try:
|
||||
import jurigged
|
||||
|
@ -3,9 +3,9 @@ import sys
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
|
||||
custom_nodes_path = Path(get_config().custom_nodes_path)
|
||||
custom_nodes_path = Path(InvokeAIAppConfig.get_config().custom_nodes_path.resolve())
|
||||
custom_nodes_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
custom_nodes_init_path = str(custom_nodes_path / "__init__.py")
|
||||
|
@ -33,7 +33,7 @@ from invokeai.app.invocations.fields import (
|
||||
FieldKind,
|
||||
Input,
|
||||
)
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
@ -191,7 +191,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
@classmethod
|
||||
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
||||
"""Gets all invocations, respecting the allowlist and denylist."""
|
||||
app_config = get_config()
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
allowed_invocations: set[BaseInvocation] = set()
|
||||
for sc in cls._invocation_classes:
|
||||
invocation_type = sc.get_type()
|
||||
|
@ -20,7 +20,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
from invokeai.backend.util.devices import torch_dtype
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from .model import CLIPField
|
||||
from .model import ClipField
|
||||
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
|
||||
@ -36,7 +36,7 @@ from .model import CLIPField
|
||||
title="Prompt",
|
||||
tags=["prompt", "compel"],
|
||||
category="conditioning",
|
||||
version="1.1.1",
|
||||
version="1.0.1",
|
||||
)
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@ -46,7 +46,7 @@ class CompelInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.compel_prompt,
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
clip: CLIPField = InputField(
|
||||
clip: ClipField = InputField(
|
||||
title="CLIP",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
@ -54,16 +54,16 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
@ -127,16 +127,16 @@ class SDXLPromptInvocationBase:
|
||||
def run_clip_compel(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
clip_field: CLIPField,
|
||||
clip_field: ClipField,
|
||||
prompt: str,
|
||||
get_pooled: bool,
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(clip_field.text_encoder)
|
||||
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
|
||||
@ -163,7 +163,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
lora_model = lora_info.model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
yield (lora_model, lora.weight)
|
||||
@ -232,7 +232,7 @@ class SDXLPromptInvocationBase:
|
||||
title="SDXL Prompt",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.1.1",
|
||||
version="1.0.1",
|
||||
)
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@ -253,8 +253,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
target_width: int = InputField(default=1024, description="")
|
||||
target_height: int = InputField(default=1024, description="")
|
||||
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
@ -325,7 +325,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
title="SDXL Refiner Prompt",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.1.1",
|
||||
version="1.0.1",
|
||||
)
|
||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@ -340,7 +340,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
crop_top: int = InputField(default=0, description="")
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
|
||||
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
@ -370,10 +370,10 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
|
||||
|
||||
@invocation_output("clip_skip_output")
|
||||
class CLIPSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""CLIP skip node output"""
|
||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""Clip skip node output"""
|
||||
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -381,17 +381,17 @@ class CLIPSkipInvocationOutput(BaseInvocationOutput):
|
||||
title="CLIP Skip",
|
||||
tags=["clipskip", "clip", "skip"],
|
||||
category="conditioning",
|
||||
version="1.1.0",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CLIPSkipInvocation(BaseInvocation):
|
||||
class ClipSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
|
||||
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CLIPSkipInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||
self.clip.skipped_layers += self.skipped_layers
|
||||
return CLIPSkipInvocationOutput(
|
||||
return ClipSkipInvocationOutput(
|
||||
clip=self.clip,
|
||||
)
|
||||
|
||||
|
@ -31,11 +31,9 @@ from invokeai.app.invocations.fields import (
|
||||
Input,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@ -53,9 +51,15 @@ CONTROLNET_RESIZE_VALUES = Literal[
|
||||
]
|
||||
|
||||
|
||||
class ControlNetModelField(BaseModel):
|
||||
"""ControlNet model field"""
|
||||
|
||||
key: str = Field(description="Model config record key for the ControlNet model")
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
@ -91,9 +95,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
@ -171,12 +173,11 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Canny Processor",
|
||||
tags=["controlnet", "canny"],
|
||||
category="controlnet",
|
||||
version="1.3.1",
|
||||
version="1.2.1",
|
||||
)
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
low_threshold: int = InputField(
|
||||
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
@ -190,12 +191,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
def run_processor(self, image):
|
||||
canny_processor = CannyDetector()
|
||||
processed_image = canny_processor(
|
||||
image,
|
||||
self.low_threshold,
|
||||
self.high_threshold,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
|
||||
return processed_image
|
||||
|
||||
|
||||
@ -204,7 +200,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="HED (softedge) Processor",
|
||||
tags=["controlnet", "hed", "softedge"],
|
||||
category="controlnet",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies HED edge detection to image"""
|
||||
@ -233,7 +229,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Lineart Processor",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art processing to image"""
|
||||
@ -255,7 +251,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Lineart Anime Processor",
|
||||
tags=["controlnet", "lineart", "anime"],
|
||||
category="controlnet",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art anime processing to image"""
|
||||
@ -278,14 +274,13 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Midas Depth Processor",
|
||||
tags=["controlnet", "midas"],
|
||||
category="controlnet",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Midas depth processing to image"""
|
||||
|
||||
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||
|
||||
@ -295,7 +290,6 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
image_resolution=self.image_resolution,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
@ -307,7 +301,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Normal BAE Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies NormalBae processing to image"""
|
||||
@ -324,7 +318,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.2"
|
||||
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.1"
|
||||
)
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies MLSD processing to image"""
|
||||
@ -347,7 +341,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.2"
|
||||
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.1"
|
||||
)
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies PIDI processing to image"""
|
||||
@ -374,7 +368,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Content Shuffle Processor",
|
||||
tags=["controlnet", "contentshuffle"],
|
||||
category="controlnet",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies content shuffle processing to image"""
|
||||
@ -404,7 +398,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Zoe (Depth) Processor",
|
||||
tags=["controlnet", "zoe", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
@ -420,20 +414,17 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Mediapipe Face Processor",
|
||||
tags=["controlnet", "mediapipe", "face"],
|
||||
category="controlnet",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
|
||||
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
|
||||
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
mediapipe_face_processor = MediapipeFaceDetector()
|
||||
processed_image = mediapipe_face_processor(
|
||||
image, max_faces=self.max_faces, min_confidence=self.min_confidence, image_resolution=self.image_resolution
|
||||
)
|
||||
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||
return processed_image
|
||||
|
||||
|
||||
@ -442,7 +433,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Leres (Depth) Processor",
|
||||
tags=["controlnet", "leres", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies leres processing to image"""
|
||||
@ -471,7 +462,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Tile Resample Processor",
|
||||
tags=["controlnet", "tile"],
|
||||
category="controlnet",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Tile resampler processor"""
|
||||
@ -511,20 +502,18 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Segment Anything Processor",
|
||||
tags=["controlnet", "segmentanything"],
|
||||
category="controlnet",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies segment anything processing to image"""
|
||||
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||
)
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_image = segment_anything_processor(np_img, image_resolution=self.image_resolution)
|
||||
processed_image = segment_anything_processor(np_img)
|
||||
return processed_image
|
||||
|
||||
|
||||
@ -555,7 +544,7 @@ class SamDetectorReproducibleColors(SamDetector):
|
||||
title="Color Map Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a color map from the provided image"""
|
||||
@ -587,7 +576,7 @@ DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
|
||||
title="Depth Anything Processor",
|
||||
tags=["controlnet", "depth", "depth anything"],
|
||||
category="controlnet",
|
||||
version="1.1.1",
|
||||
version="1.0.1",
|
||||
)
|
||||
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a depth map based on the Depth Anything algorithm"""
|
||||
@ -610,7 +599,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="DW Openpose Image Processor",
|
||||
tags=["controlnet", "dwpose", "openpose"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.0.0",
|
||||
)
|
||||
class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates an openpose pose from an image using DWPose"""
|
||||
|
@ -13,7 +13,7 @@ from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, WithBoard, WithMetadata
|
||||
|
||||
|
||||
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.3.1")
|
||||
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.1")
|
||||
class CvInpaintInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Simple inpaint using opencv."""
|
||||
|
||||
|
@ -435,7 +435,7 @@ def get_faces_list(
|
||||
return all_faces
|
||||
|
||||
|
||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.2")
|
||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.1")
|
||||
class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
||||
|
||||
@ -514,7 +514,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
return output
|
||||
|
||||
|
||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.2")
|
||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.1")
|
||||
class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
"""Face mask creation using mediapipe face detection"""
|
||||
|
||||
@ -617,7 +617,7 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
|
||||
|
||||
@invocation(
|
||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.2"
|
||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.1"
|
||||
)
|
||||
class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
||||
|
@ -39,15 +39,13 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
|
||||
# region Model Field Types
|
||||
MainModel = "MainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VAEModel = "VAEModelField"
|
||||
VaeModel = "VAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
@ -88,6 +86,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
|
||||
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
|
||||
StringPolymorphic = "DEPRECATED_StringPolymorphic"
|
||||
MainModel = "DEPRECATED_MainModel"
|
||||
UNet = "DEPRECATED_UNet"
|
||||
Vae = "DEPRECATED_Vae"
|
||||
CLIP = "DEPRECATED_CLIP"
|
||||
@ -229,7 +228,7 @@ class ConditioningField(BaseModel):
|
||||
# endregion
|
||||
|
||||
|
||||
class MetadataField(RootModel[dict[str, Any]]):
|
||||
class MetadataField(RootModel):
|
||||
"""
|
||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||
Metadata is stored without a strict schema.
|
||||
|
@ -49,7 +49,7 @@ class ShowImageInvocation(BaseInvocation):
|
||||
title="Blank Image",
|
||||
tags=["image"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Creates a blank image and forwards it to the pipeline"""
|
||||
@ -72,7 +72,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Crop Image",
|
||||
tags=["image", "crop"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ImageCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
@ -143,7 +143,7 @@ class CenterPadCropInvocation(BaseInvocation):
|
||||
title="Paste Image",
|
||||
tags=["image", "paste"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Pastes an image into another image."""
|
||||
@ -190,7 +190,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Mask from Alpha",
|
||||
tags=["image", "mask"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
@ -215,7 +215,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Multiply Images",
|
||||
tags=["image", "multiply"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ImageMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||
@ -242,7 +242,7 @@ IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||
title="Extract Image Channel",
|
||||
tags=["image", "channel"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ImageChannelInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Gets a channel from an image."""
|
||||
@ -265,7 +265,7 @@ class ImageChannelInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Convert Image Mode",
|
||||
tags=["image", "convert"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ImageConvertInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Converts an image to a different mode."""
|
||||
@ -288,7 +288,7 @@ class ImageConvertInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Blur Image",
|
||||
tags=["image", "blur"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Blurs an image"""
|
||||
@ -316,7 +316,7 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Unsharp Mask",
|
||||
tags=["image", "unsharp_mask"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class UnsharpMaskInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@ -385,7 +385,7 @@ PIL_RESAMPLING_MAP = {
|
||||
title="Resize Image",
|
||||
tags=["image", "resize"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ImageResizeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Resizes an image to specific dimensions"""
|
||||
@ -415,7 +415,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Scale Image",
|
||||
tags=["image", "scale"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ImageScaleInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Scales an image by a factor"""
|
||||
@ -450,7 +450,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Lerp Image",
|
||||
tags=["image", "lerp"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ImageLerpInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
@ -477,7 +477,7 @@ class ImageLerpInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Inverse Lerp Image",
|
||||
tags=["image", "ilerp"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ImageInverseLerpInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
@ -504,7 +504,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Blur NSFW Image",
|
||||
tags=["image", "nsfw"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Add blur to NSFW-flagged images"""
|
||||
@ -539,7 +539,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Add Invisible Watermark",
|
||||
tags=["image", "watermark"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Add an invisible watermark to an image"""
|
||||
@ -560,7 +560,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Mask Edge",
|
||||
tags=["image", "mask", "inpaint"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Applies an edge mask to an image"""
|
||||
@ -599,7 +599,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Combine Masks",
|
||||
tags=["image", "mask", "multiply"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||
@ -623,7 +623,7 @@ class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Color Correct",
|
||||
tags=["image", "color"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ColorCorrectInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""
|
||||
@ -727,7 +727,7 @@ class ColorCorrectInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Adjust Image Hue",
|
||||
tags=["image", "hue"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Adjusts the Hue of an image."""
|
||||
@ -816,7 +816,7 @@ CHANNEL_FORMATS = {
|
||||
"value",
|
||||
],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Add or subtract a value from a specific color channel of an image."""
|
||||
@ -872,7 +872,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"value",
|
||||
],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Scale a specific color channel of an image."""
|
||||
@ -916,7 +916,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Save Image",
|
||||
tags=["primitives", "image"],
|
||||
category="primitives",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
use_cache=False,
|
||||
)
|
||||
class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
@ -9,7 +9,6 @@ from PIL import Image, ImageOps
|
||||
from invokeai.app.invocations.fields import ColorField, ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
@ -121,7 +120,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
||||
return si
|
||||
|
||||
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||
class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
@ -144,7 +143,7 @@ class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.3")
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
@ -169,7 +168,7 @@ class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
|
||||
@invocation(
|
||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2"
|
||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1"
|
||||
)
|
||||
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
@ -209,7 +208,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||
class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
@ -218,13 +217,6 @@ class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# Downloads the LaMa model if it doesn't already exist
|
||||
download_with_progress_bar(
|
||||
name="LaMa Inpainting Model",
|
||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
dest_path=context.config.get().models_path / "core/misc/lama/lama.pt",
|
||||
)
|
||||
|
||||
infilled = infill_lama(image.copy())
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
@ -232,7 +224,7 @@ class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||
class CV2InfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||
|
||||
|
@ -10,18 +10,26 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, IPAdapterConfig, ModelType
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
|
||||
|
||||
# LS: Consider moving these two classes into model.py
|
||||
class IPAdapterModelField(BaseModel):
|
||||
key: str = Field(description="Key to the IP-Adapter model")
|
||||
|
||||
|
||||
class CLIPVisionModelField(BaseModel):
|
||||
key: str = Field(description="Key to the CLIP Vision image encoder model")
|
||||
|
||||
|
||||
class IPAdapterField(BaseModel):
|
||||
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
|
||||
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
|
||||
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
|
||||
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
|
||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
@ -48,18 +56,14 @@ class IPAdapterOutput(BaseInvocationOutput):
|
||||
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
|
||||
|
||||
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2")
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.2")
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
"""Collects IP-Adapter info to pass to other nodes."""
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: ModelIdentifierField = InputField(
|
||||
description="The IP-Adapter model.",
|
||||
title="IP-Adapter Model",
|
||||
input=Input.Direct,
|
||||
ui_order=-1,
|
||||
ui_type=UIType.IPAdapterModel,
|
||||
ip_adapter_model: IPAdapterModelField = InputField(
|
||||
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
|
||||
)
|
||||
|
||||
weight: Union[float, List[float]] = InputField(
|
||||
@ -86,35 +90,20 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||
assert isinstance(ip_adapter_info, IPAdapterConfig)
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
||||
image_encoder_models = context.models.search_by_attrs(
|
||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||
)
|
||||
assert len(image_encoder_models) == 1
|
||||
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
|
||||
return IPAdapterOutput(
|
||||
ip_adapter=IPAdapterField(
|
||||
image=self.image,
|
||||
ip_adapter_model=self.ip_adapter_model,
|
||||
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
|
||||
image_encoder_model=image_encoder_model,
|
||||
weight=self.weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
),
|
||||
)
|
||||
|
||||
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
|
||||
found = False
|
||||
while not found:
|
||||
image_encoder_models = context.models.search_by_attrs(
|
||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||
)
|
||||
found = len(image_encoder_models) > 0
|
||||
if not found:
|
||||
context.logger.warning(
|
||||
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed."
|
||||
)
|
||||
context.logger.warning("Downloading and installing now. This may take a while.")
|
||||
installer = context._services.model_manager.install
|
||||
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
|
||||
installer.wait_for_job(job, timeout=600) # wait up to 10 minutes - then raise a TimeoutException
|
||||
assert len(image_encoder_models) == 1
|
||||
return image_encoder_models[0]
|
||||
|
@ -26,7 +26,6 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from PIL import Image, ImageFilter
|
||||
from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.invocations.fields import (
|
||||
@ -66,6 +65,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
T2IAdapterData,
|
||||
image_resized_to_grid_as_tensor,
|
||||
)
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from .baseinvocation import (
|
||||
@ -75,7 +75,7 @@ from .baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelIdentifierField, UNetField, VAEField
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
@ -113,12 +113,12 @@ class SchedulerInvocation(BaseInvocation):
|
||||
title="Create Denoise Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
||||
vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
||||
@ -153,7 +153,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
if image_tensor is not None:
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
@ -173,16 +173,6 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("gradient_mask_output")
|
||||
class GradientMaskOutput(BaseInvocationOutput):
|
||||
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
|
||||
|
||||
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||
expanded_mask_area: ImageField = OutputField(
|
||||
description="Image representing the total gradient area of the mask. For paste-back purposes."
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"create_gradient_mask",
|
||||
title="Create Gradient Mask",
|
||||
@ -203,53 +193,49 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
if self.edge_radius > 0:
|
||||
if self.coherence_mode == "Box Blur":
|
||||
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||
else: # Gaussian Blur OR Staged
|
||||
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||
if self.coherence_mode == "Box Blur":
|
||||
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||
else: # Gaussian Blur OR Staged
|
||||
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||
|
||||
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
# redistribute blur so that the edges are 0 and blur out to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
|
||||
threshold = 1 - self.minimum_denoise
|
||||
|
||||
if self.coherence_mode == "Staged":
|
||||
# wherever the blur_tensor is less than fully masked, convert it to threshold
|
||||
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
|
||||
else:
|
||||
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||
threshold = 1 - self.minimum_denoise
|
||||
|
||||
if self.coherence_mode == "Staged":
|
||||
# wherever the blur_tensor is masked to any degree, convert it to threshold
|
||||
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
|
||||
else:
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||
|
||||
# multiply original mask to force actually masked regions to 0
|
||||
blur_tensor = mask_tensor * blur_tensor
|
||||
|
||||
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
||||
|
||||
# compute a [0, 1] mask from the blur_tensor
|
||||
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
|
||||
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
||||
expanded_image_dto = context.images.save(expanded_mask_image)
|
||||
|
||||
return GradientMaskOutput(
|
||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
|
||||
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
||||
return DenoiseMaskOutput.build(
|
||||
mask_name=mask_name,
|
||||
masked_latents_name=None,
|
||||
gradient=True,
|
||||
)
|
||||
|
||||
|
||||
def get_scheduler(
|
||||
context: InvocationContext,
|
||||
scheduler_info: ModelIdentifierField,
|
||||
scheduler_info: ModelInfo,
|
||||
scheduler_name: str,
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.models.load(scheduler_info)
|
||||
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
@ -279,7 +265,7 @@ def get_scheduler(
|
||||
title="Denoise Latents",
|
||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
version="1.5.3",
|
||||
version="1.5.2",
|
||||
)
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
@ -374,6 +360,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = c.extra_conditioning
|
||||
|
||||
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
@ -383,6 +370,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
text_embeddings=c,
|
||||
guidance_scale=self.cfg_scale,
|
||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
extra=extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0, # threshold,
|
||||
warmup=0.2, # warmup,
|
||||
h_symmetry_time_pct=None, # h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None, # v_symmetry_time_pct,
|
||||
),
|
||||
)
|
||||
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
|
||||
@ -455,7 +449,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# and if weight is None, populate with default 1.0?
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
|
||||
|
||||
# control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
@ -517,10 +511,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.models.load(single_ip_adapter.ip_adapter_model)
|
||||
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
|
||||
)
|
||||
|
||||
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
||||
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
|
||||
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
@ -531,7 +526,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||
single_ipa_images, image_encoder_model
|
||||
@ -571,8 +565,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
|
||||
t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
@ -677,7 +671,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if self.denoise_mask.masked_latents_name is not None:
|
||||
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
|
||||
else:
|
||||
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
||||
masked_latents = None
|
||||
|
||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||
|
||||
@ -725,13 +719,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
unet_info = context.models.load(**self.unet.unet.model_dump())
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
@ -784,7 +777,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
result_latents = pipeline.latents_from_embeddings(
|
||||
(
|
||||
result_latents,
|
||||
result_attention_map_saver,
|
||||
) = pipeline.latents_from_embeddings(
|
||||
latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
@ -816,7 +812,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
title="Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i"],
|
||||
category="latents",
|
||||
version="1.2.2",
|
||||
version="1.2.1",
|
||||
)
|
||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents."""
|
||||
@ -825,7 +821,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
@ -836,15 +832,15 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
@ -866,7 +862,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
vae.to(dtype=torch.float16)
|
||||
latents = latents.half()
|
||||
|
||||
if self.tiled or context.config.get().force_tiled_decode:
|
||||
if self.tiled or context.config.get().tiled_decode:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
@ -903,7 +899,7 @@ LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic",
|
||||
title="Resize Latents",
|
||||
tags=["latents", "resize"],
|
||||
category="latents",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
class ResizeLatentsInvocation(BaseInvocation):
|
||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||
@ -953,7 +949,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
title="Scale Latents",
|
||||
tags=["latents", "resize"],
|
||||
category="latents",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
"""Scales latents by a given factor."""
|
||||
@ -995,7 +991,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
title="Image to Latents",
|
||||
tags=["latents", "image", "vae", "i2l"],
|
||||
category="latents",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
@ -1003,7 +999,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode",
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
@ -1018,7 +1014,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
if upcast:
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
@ -1059,7 +1055,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
@ -1094,7 +1090,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
title="Blend Latents",
|
||||
tags=["latents", "blend"],
|
||||
category="latents",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
class BlendLatentsInvocation(BaseInvocation):
|
||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||
@ -1185,7 +1181,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
title="Crop Latents",
|
||||
tags=["latents", "crop"],
|
||||
category="latents",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
|
||||
# Currently, if the class names conflict then 'GET /openapi.json' fails.
|
||||
@ -1246,7 +1242,7 @@ class IdealSizeOutput(BaseInvocationOutput):
|
||||
"ideal_size",
|
||||
title="Ideal Size",
|
||||
tags=["latents", "math", "ideal_size"],
|
||||
version="1.0.3",
|
||||
version="1.0.2",
|
||||
)
|
||||
class IdealSizeInvocation(BaseInvocation):
|
||||
"""Calculates the ideal size for generation to avoid duplication"""
|
||||
|
@ -12,7 +12,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
|
||||
|
||||
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.1")
|
||||
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0")
|
||||
class AddInvocation(BaseInvocation):
|
||||
"""Adds two numbers"""
|
||||
|
||||
@ -23,7 +23,7 @@ class AddInvocation(BaseInvocation):
|
||||
return IntegerOutput(value=self.a + self.b)
|
||||
|
||||
|
||||
@invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math", version="1.0.1")
|
||||
@invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math", version="1.0.0")
|
||||
class SubtractInvocation(BaseInvocation):
|
||||
"""Subtracts two numbers"""
|
||||
|
||||
@ -34,7 +34,7 @@ class SubtractInvocation(BaseInvocation):
|
||||
return IntegerOutput(value=self.a - self.b)
|
||||
|
||||
|
||||
@invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math", version="1.0.1")
|
||||
@invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math", version="1.0.0")
|
||||
class MultiplyInvocation(BaseInvocation):
|
||||
"""Multiplies two numbers"""
|
||||
|
||||
@ -45,7 +45,7 @@ class MultiplyInvocation(BaseInvocation):
|
||||
return IntegerOutput(value=self.a * self.b)
|
||||
|
||||
|
||||
@invocation("div", title="Divide Integers", tags=["math", "divide"], category="math", version="1.0.1")
|
||||
@invocation("div", title="Divide Integers", tags=["math", "divide"], category="math", version="1.0.0")
|
||||
class DivideInvocation(BaseInvocation):
|
||||
"""Divides two numbers"""
|
||||
|
||||
@ -61,7 +61,7 @@ class DivideInvocation(BaseInvocation):
|
||||
title="Random Integer",
|
||||
tags=["math", "random"],
|
||||
category="math",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
use_cache=False,
|
||||
)
|
||||
class RandomIntInvocation(BaseInvocation):
|
||||
@ -100,7 +100,7 @@ class RandomFloatInvocation(BaseInvocation):
|
||||
title="Float To Integer",
|
||||
tags=["math", "round", "integer", "float", "convert"],
|
||||
category="math",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FloatToIntegerInvocation(BaseInvocation):
|
||||
"""Rounds a float number to (a multiple of) an integer."""
|
||||
@ -122,7 +122,7 @@ class FloatToIntegerInvocation(BaseInvocation):
|
||||
return IntegerOutput(value=int(self.value / self.multiple) * self.multiple)
|
||||
|
||||
|
||||
@invocation("round_float", title="Round Float", tags=["math", "round"], category="math", version="1.0.1")
|
||||
@invocation("round_float", title="Round Float", tags=["math", "round"], category="math", version="1.0.0")
|
||||
class RoundInvocation(BaseInvocation):
|
||||
"""Rounds a float to a specified number of decimal places."""
|
||||
|
||||
@ -176,7 +176,7 @@ INTEGER_OPERATIONS_LABELS = {
|
||||
"max",
|
||||
],
|
||||
category="math",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class IntegerMathInvocation(BaseInvocation):
|
||||
"""Performs integer math."""
|
||||
@ -250,7 +250,7 @@ FLOAT_OPERATIONS_LABELS = {
|
||||
title="Float Math",
|
||||
tags=["math", "float", "add", "subtract", "multiply", "divide", "power", "root", "absolute value", "min", "max"],
|
||||
category="math",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FloatMathInvocation(BaseInvocation):
|
||||
"""Performs floating point math."""
|
||||
|
@ -8,10 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import (
|
||||
CONTROLNET_MODE_VALUES,
|
||||
CONTROLNET_RESIZE_VALUES,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
@ -20,7 +17,9 @@ from invokeai.app.invocations.fields import (
|
||||
OutputField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from ...version import __version__
|
||||
@ -34,7 +33,7 @@ class MetadataItemField(BaseModel):
|
||||
class LoRAMetadataField(BaseModel):
|
||||
"""LoRA Metadata Field"""
|
||||
|
||||
model: ModelIdentifierField = Field(description=FieldDescriptions.lora_model)
|
||||
model: LoRAModelField = Field(description=FieldDescriptions.lora_model)
|
||||
weight: float = Field(description=FieldDescriptions.lora_weight)
|
||||
|
||||
|
||||
@ -42,41 +41,16 @@ class IPAdapterMetadataField(BaseModel):
|
||||
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
|
||||
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
|
||||
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
|
||||
ip_adapter_model: IPAdapterModelField = Field(
|
||||
description="The IP-Adapter model.",
|
||||
)
|
||||
weight: Union[float, list[float]] = Field(
|
||||
description="The weight given to the IP-Adapter",
|
||||
)
|
||||
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||
|
||||
|
||||
class T2IAdapterMetadataField(BaseModel):
|
||||
image: ImageField = Field(description="The control image.")
|
||||
processed_image: Optional[ImageField] = Field(default=None, description="The control image, after processing.")
|
||||
t2i_adapter_model: ModelIdentifierField = Field(description="The T2I-Adapter model to use.")
|
||||
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
|
||||
)
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
|
||||
class ControlNetMetadataField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
processed_image: Optional[ImageField] = Field(default=None, description="The control image, after processing.")
|
||||
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
|
||||
@invocation_output("metadata_item_output")
|
||||
class MetadataItemOutput(BaseInvocationOutput):
|
||||
"""Metadata Item Output"""
|
||||
@ -84,7 +58,7 @@ class MetadataItemOutput(BaseInvocationOutput):
|
||||
item: MetadataItemField = OutputField(description="Metadata Item")
|
||||
|
||||
|
||||
@invocation("metadata_item", title="Metadata Item", tags=["metadata"], category="metadata", version="1.0.1")
|
||||
@invocation("metadata_item", title="Metadata Item", tags=["metadata"], category="metadata", version="1.0.0")
|
||||
class MetadataItemInvocation(BaseInvocation):
|
||||
"""Used to create an arbitrary metadata item. Provide "label" and make a connection to "value" to store that data as the value."""
|
||||
|
||||
@ -100,7 +74,7 @@ class MetadataOutput(BaseInvocationOutput):
|
||||
metadata: MetadataField = OutputField(description="Metadata Dict")
|
||||
|
||||
|
||||
@invocation("metadata", title="Metadata", tags=["metadata"], category="metadata", version="1.0.1")
|
||||
@invocation("metadata", title="Metadata", tags=["metadata"], category="metadata", version="1.0.0")
|
||||
class MetadataInvocation(BaseInvocation):
|
||||
"""Takes a MetadataItem or collection of MetadataItems and outputs a MetadataDict."""
|
||||
|
||||
@ -121,7 +95,7 @@ class MetadataInvocation(BaseInvocation):
|
||||
return MetadataOutput(metadata=MetadataField.model_validate(data))
|
||||
|
||||
|
||||
@invocation("merge_metadata", title="Metadata Merge", tags=["metadata"], category="metadata", version="1.0.1")
|
||||
@invocation("merge_metadata", title="Metadata Merge", tags=["metadata"], category="metadata", version="1.0.0")
|
||||
class MergeMetadataInvocation(BaseInvocation):
|
||||
"""Merged a collection of MetadataDict into a single MetadataDict."""
|
||||
|
||||
@ -140,7 +114,7 @@ GENERATION_MODES = Literal[
|
||||
]
|
||||
|
||||
|
||||
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="2.0.0")
|
||||
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.1.1")
|
||||
class CoreMetadataInvocation(BaseInvocation):
|
||||
"""Collects core generation metadata into a MetadataField"""
|
||||
|
||||
@ -166,14 +140,14 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: Optional[ModelIdentifierField] = InputField(default=None, description="The main model used for inference")
|
||||
controlnets: Optional[list[ControlNetMetadataField]] = InputField(
|
||||
model: Optional[MainModelField] = InputField(default=None, description="The main model used for inference")
|
||||
controlnets: Optional[list[ControlField]] = InputField(
|
||||
default=None, description="The ControlNets used for inference"
|
||||
)
|
||||
ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField(
|
||||
default=None, description="The IP Adapters used for inference"
|
||||
)
|
||||
t2iAdapters: Optional[list[T2IAdapterMetadataField]] = InputField(
|
||||
t2iAdapters: Optional[list[T2IAdapterField]] = InputField(
|
||||
default=None, description="The IP Adapters used for inference"
|
||||
)
|
||||
loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference")
|
||||
@ -185,7 +159,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
default=None,
|
||||
description="The name of the initial image",
|
||||
)
|
||||
vae: Optional[ModelIdentifierField] = InputField(
|
||||
vae: Optional[VAEModelField] = InputField(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
@ -216,7 +190,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Optional[ModelIdentifierField] = InputField(
|
||||
refiner_model: Optional[MainModelField] = InputField(
|
||||
default=None,
|
||||
description="The SDXL Refiner model used",
|
||||
)
|
||||
@ -248,9 +222,10 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
as_dict = self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
||||
as_dict["app_version"] = __version__
|
||||
|
||||
return MetadataOutput(metadata=MetadataField.model_validate(as_dict))
|
||||
return MetadataOutput(
|
||||
metadata=MetadataField.model_validate(
|
||||
self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
||||
)
|
||||
)
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
@ -3,11 +3,11 @@ from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
|
||||
from ...backend.model_manager import SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -16,52 +16,33 @@ from .baseinvocation import (
|
||||
)
|
||||
|
||||
|
||||
class ModelIdentifierField(BaseModel):
|
||||
key: str = Field(description="The model's unique key")
|
||||
hash: str = Field(description="The model's BLAKE3 hash")
|
||||
name: str = Field(description="The model's name")
|
||||
base: BaseModelType = Field(description="The model's base model type")
|
||||
type: ModelType = Field(description="The model's type")
|
||||
submodel_type: Optional[SubModelType] = Field(
|
||||
description="The submodel to load, if this is a main model", default=None
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: "AnyModelConfig", submodel_type: Optional[SubModelType] = None
|
||||
) -> "ModelIdentifierField":
|
||||
return cls(
|
||||
key=config.key,
|
||||
hash=config.hash,
|
||||
name=config.name,
|
||||
base=config.base,
|
||||
type=config.type,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
class ModelInfo(BaseModel):
|
||||
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
|
||||
|
||||
class LoRAField(BaseModel):
|
||||
lora: ModelIdentifierField = Field(description="Info to load lora model")
|
||||
weight: float = Field(description="Weight to apply to lora model")
|
||||
class LoraInfo(ModelInfo):
|
||||
weight: float = Field(description="Lora's weight which to use when apply to model")
|
||||
|
||||
|
||||
class UNetField(BaseModel):
|
||||
unet: ModelIdentifierField = Field(description="Info to load unet submodel")
|
||||
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
unet: ModelInfo = Field(description="Info to load unet submodel")
|
||||
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
||||
|
||||
|
||||
class CLIPField(BaseModel):
|
||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
class ClipField(BaseModel):
|
||||
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
|
||||
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
|
||||
|
||||
class VAEField(BaseModel):
|
||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||
class VaeField(BaseModel):
|
||||
# TODO: better naming?
|
||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
|
||||
|
||||
@ -76,14 +57,14 @@ class UNetOutput(BaseInvocationOutput):
|
||||
class VAEOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a VAE field"""
|
||||
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation_output("clip_output")
|
||||
class CLIPOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a CLIP field"""
|
||||
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation_output("model_loader_output")
|
||||
@ -93,54 +74,84 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
||||
pass
|
||||
|
||||
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
key: str = Field(description="Model key")
|
||||
|
||||
|
||||
class LoRAModelField(BaseModel):
|
||||
"""LoRA model field"""
|
||||
|
||||
key: str = Field(description="LoRA model key")
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
|
||||
)
|
||||
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
# TODO: not found exceptions
|
||||
if not context.models.exists(self.model.key):
|
||||
raise Exception(f"Unknown model {self.model.key}")
|
||||
key = self.model.key
|
||||
|
||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
# TODO: not found exceptions
|
||||
if not context.models.exists(key):
|
||||
raise Exception(f"Unknown model {key}")
|
||||
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
||||
vae=VAEField(vae=vae),
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("lora_loader_output")
|
||||
class LoRALoaderOutput(BaseInvocationOutput):
|
||||
class LoraLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.2")
|
||||
class LoRALoaderInvocation(BaseInvocation):
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None,
|
||||
@ -148,41 +159,46 @@ class LoRALoaderInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[CLIPField] = InputField(
|
||||
clip: Optional[ClipField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise Exception(f"Unkown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to unet')
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to clip')
|
||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||
|
||||
output = LoRALoaderOutput()
|
||||
output = LoraLoaderOutput()
|
||||
|
||||
if self.unet is not None:
|
||||
output.unet = self.unet.model_copy(deep=True)
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@ -191,12 +207,12 @@ class LoRALoaderInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation_output("sdxl_lora_loader_output")
|
||||
class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
||||
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL LoRA Loader Output"""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -204,14 +220,12 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
||||
title="SDXL LoRA",
|
||||
tags=["lora", "model"],
|
||||
category="model",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
|
||||
)
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None,
|
||||
@ -219,59 +233,65 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[CLIPField] = InputField(
|
||||
clip: Optional[ClipField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP 1",
|
||||
)
|
||||
clip2: Optional[CLIPField] = InputField(
|
||||
clip2: Optional[ClipField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP 2",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise Exception(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to unet')
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to clip')
|
||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||
|
||||
if self.clip2 is not None and any(lora.lora.key == lora_key for lora in self.clip2.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to clip2')
|
||||
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip2')
|
||||
|
||||
output = SDXLLoRALoaderOutput()
|
||||
output = SDXLLoraLoaderOutput()
|
||||
|
||||
if self.unet is not None:
|
||||
output.unet = self.unet.model_copy(deep=True)
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip2 is not None:
|
||||
output.clip2 = self.clip2.model_copy(deep=True)
|
||||
output.clip2 = copy.deepcopy(self.clip2)
|
||||
output.clip2.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@ -279,12 +299,20 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
return output
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.2")
|
||||
class VAELoaderInvocation(BaseInvocation):
|
||||
class VAEModelField(BaseModel):
|
||||
"""Vae model field"""
|
||||
|
||||
key: str = Field(description="Model's key")
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel
|
||||
vae_model: VAEModelField = InputField(
|
||||
description=FieldDescriptions.vae_model,
|
||||
input=Input.Direct,
|
||||
title="VAE",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||
@ -293,7 +321,7 @@ class VAELoaderInvocation(BaseInvocation):
|
||||
if not context.models.exists(key):
|
||||
raise Exception(f"Unkown vae: {key}!")
|
||||
|
||||
return VAEOutput(vae=VAEField(vae=self.vae_model))
|
||||
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
|
||||
|
||||
|
||||
@invocation_output("seamless_output")
|
||||
@ -301,7 +329,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
|
||||
"""Modified Seamless Model output"""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
vae: Optional[VAEField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
|
||||
vae: Optional[VaeField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -309,7 +337,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
|
||||
title="Seamless",
|
||||
tags=["seamless", "model"],
|
||||
category="model",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SeamlessModeInvocation(BaseInvocation):
|
||||
"""Applies the seamless transformation to the Model UNet and VAE."""
|
||||
@ -320,7 +348,7 @@ class SeamlessModeInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
vae: Optional[VAEField] = InputField(
|
||||
vae: Optional[VaeField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.vae_model,
|
||||
input=Input.Connection,
|
||||
@ -349,7 +377,7 @@ class SeamlessModeInvocation(BaseInvocation):
|
||||
return SeamlessModeOutput(unet=unet, vae=vae)
|
||||
|
||||
|
||||
@invocation("freeu", title="FreeU", tags=["freeu"], category="unet", version="1.0.1")
|
||||
@invocation("freeu", title="FreeU", tags=["freeu"], category="unet", version="1.0.0")
|
||||
class FreeUInvocation(BaseInvocation):
|
||||
"""
|
||||
Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2):
|
||||
|
@ -81,7 +81,7 @@ class NoiseOutput(BaseInvocationOutput):
|
||||
title="Noise",
|
||||
tags=["latents", "noise"],
|
||||
category="latents",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
@ -51,7 +51,7 @@ from .fields import InputField
|
||||
title="Float Range",
|
||||
tags=["math", "range"],
|
||||
category="math",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FloatLinearRangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
@ -111,7 +111,7 @@ EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
|
||||
title="Step Param Easing",
|
||||
tags=["step", "easing"],
|
||||
category="step",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
|
@ -54,7 +54,7 @@ class BooleanCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
@invocation(
|
||||
"boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives", version="1.0.1"
|
||||
"boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives", version="1.0.0"
|
||||
)
|
||||
class BooleanInvocation(BaseInvocation):
|
||||
"""A boolean primitive value"""
|
||||
@ -70,7 +70,7 @@ class BooleanInvocation(BaseInvocation):
|
||||
title="Boolean Collection Primitive",
|
||||
tags=["primitives", "boolean", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
class BooleanCollectionInvocation(BaseInvocation):
|
||||
"""A collection of boolean primitive values"""
|
||||
@ -103,7 +103,7 @@ class IntegerCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
@invocation(
|
||||
"integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives", version="1.0.1"
|
||||
"integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives", version="1.0.0"
|
||||
)
|
||||
class IntegerInvocation(BaseInvocation):
|
||||
"""An integer primitive value"""
|
||||
@ -119,7 +119,7 @@ class IntegerInvocation(BaseInvocation):
|
||||
title="Integer Collection Primitive",
|
||||
tags=["primitives", "integer", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
class IntegerCollectionInvocation(BaseInvocation):
|
||||
"""A collection of integer primitive values"""
|
||||
@ -151,7 +151,7 @@ class FloatCollectionOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives", version="1.0.1")
|
||||
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives", version="1.0.0")
|
||||
class FloatInvocation(BaseInvocation):
|
||||
"""A float primitive value"""
|
||||
|
||||
@ -166,7 +166,7 @@ class FloatInvocation(BaseInvocation):
|
||||
title="Float Collection Primitive",
|
||||
tags=["primitives", "float", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
class FloatCollectionInvocation(BaseInvocation):
|
||||
"""A collection of float primitive values"""
|
||||
@ -198,7 +198,7 @@ class StringCollectionOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives", version="1.0.1")
|
||||
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives", version="1.0.0")
|
||||
class StringInvocation(BaseInvocation):
|
||||
"""A string primitive value"""
|
||||
|
||||
@ -213,7 +213,7 @@ class StringInvocation(BaseInvocation):
|
||||
title="String Collection Primitive",
|
||||
tags=["primitives", "string", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
class StringCollectionInvocation(BaseInvocation):
|
||||
"""A collection of string primitive values"""
|
||||
@ -255,7 +255,7 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.2")
|
||||
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.1")
|
||||
class ImageInvocation(BaseInvocation):
|
||||
"""An image primitive value"""
|
||||
|
||||
@ -276,7 +276,7 @@ class ImageInvocation(BaseInvocation):
|
||||
title="Image Collection Primitive",
|
||||
tags=["primitives", "image", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageCollectionInvocation(BaseInvocation):
|
||||
"""A collection of image primitive values"""
|
||||
@ -341,7 +341,7 @@ class LatentsCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
@invocation(
|
||||
"latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.2"
|
||||
"latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.1"
|
||||
)
|
||||
class LatentsInvocation(BaseInvocation):
|
||||
"""A latents tensor primitive value"""
|
||||
@ -359,7 +359,7 @@ class LatentsInvocation(BaseInvocation):
|
||||
title="Latents Collection Primitive",
|
||||
tags=["primitives", "latents", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LatentsCollectionInvocation(BaseInvocation):
|
||||
"""A collection of latents tensor primitive values"""
|
||||
@ -393,7 +393,7 @@ class ColorCollectionOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives", version="1.0.1")
|
||||
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives", version="1.0.0")
|
||||
class ColorInvocation(BaseInvocation):
|
||||
"""A color primitive value"""
|
||||
|
||||
@ -433,7 +433,7 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||
title="Conditioning Primitive",
|
||||
tags=["primitives", "conditioning"],
|
||||
category="primitives",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ConditioningInvocation(BaseInvocation):
|
||||
"""A conditioning tensor primitive value"""
|
||||
@ -449,7 +449,7 @@ class ConditioningInvocation(BaseInvocation):
|
||||
title="Conditioning Collection Primitive",
|
||||
tags=["primitives", "conditioning", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
class ConditioningCollectionInvocation(BaseInvocation):
|
||||
"""A collection of conditioning tensor primitive values"""
|
||||
|
@ -17,7 +17,7 @@ from .fields import InputField, UIComponent
|
||||
title="Dynamic Prompt",
|
||||
tags=["prompt", "collection"],
|
||||
category="prompt",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
use_cache=False,
|
||||
)
|
||||
class DynamicPromptInvocation(BaseInvocation):
|
||||
@ -46,7 +46,7 @@ class DynamicPromptInvocation(BaseInvocation):
|
||||
title="Prompts from File",
|
||||
tags=["prompt", "file"],
|
||||
category="prompt",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
class PromptsFromFileInvocation(BaseInvocation):
|
||||
"""Loads prompts from a text file"""
|
||||
|
@ -8,7 +8,7 @@ from .baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .model import CLIPField, ModelIdentifierField, UNetField, VAEField
|
||||
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
|
||||
|
||||
|
||||
@invocation_output("sdxl_model_loader_output")
|
||||
@ -16,9 +16,9 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL base model loader output"""
|
||||
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation_output("sdxl_refiner_model_loader_output")
|
||||
@ -26,15 +26,15 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL refiner model loader output"""
|
||||
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.2")
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1")
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
model: MainModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
||||
)
|
||||
# TODO: precision?
|
||||
@ -46,19 +46,48 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
if not context.models.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
return SDXLModelLoaderOutput(
|
||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
||||
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
||||
vae=VAEField(vae=vae),
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -67,13 +96,15 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
title="SDXL Refiner Model",
|
||||
tags=["model", "sdxl", "refiner"],
|
||||
category="model",
|
||||
version="1.0.2",
|
||||
version="1.0.1",
|
||||
)
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
|
||||
model: MainModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model,
|
||||
input=Input.Direct,
|
||||
ui_type=UIType.SDXLRefinerModel,
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
@ -84,14 +115,34 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
if not context.models.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
return SDXLRefinerModelLoaderOutput(
|
||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
||||
vae=VAEField(vae=vae),
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -27,7 +27,7 @@ class StringPosNegOutput(BaseInvocationOutput):
|
||||
title="String Split Negative",
|
||||
tags=["string", "split", "negative"],
|
||||
category="string",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class StringSplitNegInvocation(BaseInvocation):
|
||||
"""Splits string into two strings, inside [] goes into negative string everthing else goes into positive string. Each [ and ] character is replaced with a space"""
|
||||
@ -69,7 +69,7 @@ class String2Output(BaseInvocationOutput):
|
||||
string_2: str = OutputField(description="string 2")
|
||||
|
||||
|
||||
@invocation("string_split", title="String Split", tags=["string", "split"], category="string", version="1.0.1")
|
||||
@invocation("string_split", title="String Split", tags=["string", "split"], category="string", version="1.0.0")
|
||||
class StringSplitInvocation(BaseInvocation):
|
||||
"""Splits string into two strings, based on the first occurance of the delimiter. The delimiter will be removed from the string"""
|
||||
|
||||
@ -89,7 +89,7 @@ class StringSplitInvocation(BaseInvocation):
|
||||
return String2Output(string_1=part1, string_2=part2)
|
||||
|
||||
|
||||
@invocation("string_join", title="String Join", tags=["string", "join"], category="string", version="1.0.1")
|
||||
@invocation("string_join", title="String Join", tags=["string", "join"], category="string", version="1.0.0")
|
||||
class StringJoinInvocation(BaseInvocation):
|
||||
"""Joins string left to string right"""
|
||||
|
||||
@ -100,7 +100,7 @@ class StringJoinInvocation(BaseInvocation):
|
||||
return StringOutput(value=((self.string_left or "") + (self.string_right or "")))
|
||||
|
||||
|
||||
@invocation("string_join_three", title="String Join Three", tags=["string", "join"], category="string", version="1.0.1")
|
||||
@invocation("string_join_three", title="String Join Three", tags=["string", "join"], category="string", version="1.0.0")
|
||||
class StringJoinThreeInvocation(BaseInvocation):
|
||||
"""Joins string left to string middle to string right"""
|
||||
|
||||
@ -113,7 +113,7 @@ class StringJoinThreeInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"string_replace", title="String Replace", tags=["string", "replace", "regex"], category="string", version="1.0.1"
|
||||
"string_replace", title="String Replace", tags=["string", "replace", "regex"], category="string", version="1.0.0"
|
||||
)
|
||||
class StringReplaceInvocation(BaseInvocation):
|
||||
"""Replaces the search string with the replace string"""
|
||||
|
@ -9,15 +9,18 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
class T2IAdapterModelField(BaseModel):
|
||||
key: str = Field(description="Model record key for the T2I-Adapter model")
|
||||
|
||||
|
||||
class T2IAdapterField(BaseModel):
|
||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||
t2i_adapter_model: ModelIdentifierField = Field(description="The T2I-Adapter model to use.")
|
||||
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
|
||||
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||
@ -45,19 +48,18 @@ class T2IAdapterOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
@invocation(
|
||||
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.2"
|
||||
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.1"
|
||||
)
|
||||
class T2IAdapterInvocation(BaseInvocation):
|
||||
"""Collects T2I-Adapter info to pass to other nodes."""
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||
t2i_adapter_model: ModelIdentifierField = InputField(
|
||||
t2i_adapter_model: T2IAdapterModelField = InputField(
|
||||
description="The T2I-Adapter model.",
|
||||
title="T2I-Adapter Model",
|
||||
input=Input.Direct,
|
||||
ui_order=-1,
|
||||
ui_type=UIType.T2IAdapterModel,
|
||||
)
|
||||
weight: Union[float, list[float]] = InputField(
|
||||
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"
|
||||
|
@ -39,7 +39,7 @@ class CalculateImageTilesOutput(BaseInvocationOutput):
|
||||
title="Calculate Image Tiles",
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class CalculateImageTilesInvocation(BaseInvocation):
|
||||
@ -73,7 +73,7 @@ class CalculateImageTilesInvocation(BaseInvocation):
|
||||
title="Calculate Image Tiles Even Split",
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.1.1",
|
||||
version="1.1.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
|
||||
@ -116,7 +116,7 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
|
||||
title="Calculate Image Tiles Minimum Overlap",
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation):
|
||||
@ -167,7 +167,7 @@ class TileToPropertiesOutput(BaseInvocationOutput):
|
||||
title="Tile to Properties",
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class TileToPropertiesInvocation(BaseInvocation):
|
||||
@ -200,7 +200,7 @@ class PairTileImageOutput(BaseInvocationOutput):
|
||||
title="Pair Tile with Image",
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class PairTileImageInvocation(BaseInvocation):
|
||||
@ -229,7 +229,7 @@ BLEND_MODES = Literal["Linear", "Seam"]
|
||||
title="Merge Tiles to Image",
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.1.1",
|
||||
version="1.1.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
@ -11,7 +11,6 @@ from pydantic import ConfigDict
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
@ -28,18 +27,11 @@ ESRGAN_MODELS = Literal[
|
||||
"RealESRGAN_x2plus.pth",
|
||||
]
|
||||
|
||||
ESRGAN_MODEL_URLS: dict[str, str] = {
|
||||
"RealESRGAN_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
"RealESRGAN_x4plus_anime_6B.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
"ESRGAN_SRx4_DF2KOST_official.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
}
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.1")
|
||||
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Upscales an image using RealESRGAN."""
|
||||
|
||||
@ -53,6 +45,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
models_path = context.config.get().models_path
|
||||
|
||||
rrdbnet_model = None
|
||||
netscale = None
|
||||
@ -99,16 +92,11 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
context.logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}")
|
||||
|
||||
# Downloads the ESRGAN model if it doesn't already exist
|
||||
download_with_progress_bar(
|
||||
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
|
||||
)
|
||||
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
|
||||
|
||||
upscaler = RealESRGAN(
|
||||
scale=netscale,
|
||||
model_path=esrgan_model_path,
|
||||
model_path=models_path / esrgan_model_path,
|
||||
model=rrdbnet_model,
|
||||
half=False,
|
||||
tile=self.tile_size,
|
||||
|
@ -1,12 +0,0 @@
|
||||
"""This is a wrapper around the main app entrypoint, to allow for CLI args to be parsed before running the app."""
|
||||
|
||||
|
||||
def run_app() -> None:
|
||||
# Before doing _anything_, parse CLI args!
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
InvokeAIArgs.parse_args()
|
||||
|
||||
from invokeai.app.api_app import invoke_api
|
||||
|
||||
invoke_api()
|
@ -2,6 +2,6 @@
|
||||
|
||||
from invokeai.app.services.config.config_common import PagingArgumentParser
|
||||
|
||||
from .config_default import InvokeAIAppConfig, get_config
|
||||
from .config_default import InvokeAIAppConfig, get_invokeai_config
|
||||
|
||||
__all__ = ["InvokeAIAppConfig", "get_config", "PagingArgumentParser"]
|
||||
__all__ = ["InvokeAIAppConfig", "get_invokeai_config", "PagingArgumentParser"]
|
||||
|
224
invokeai/app/services/config/config_base.py
Normal file
224
invokeai/app/services/config/config_base.py
Normal file
@ -0,0 +1,224 @@
|
||||
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
Base class for the InvokeAI configuration system.
|
||||
It defines a type of pydantic BaseSettings object that
|
||||
is able to read and write from an omegaconf-based config file,
|
||||
with overriding of settings from environment variables and/or
|
||||
the command line.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
|
||||
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
|
||||
|
||||
initconf: ClassVar[Optional[DictConfig]] = None
|
||||
argparse_groups: ClassVar[Dict[str, Any]] = {}
|
||||
|
||||
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
|
||||
|
||||
def parse_args(self, argv: Optional[List[str]] = sys.argv[1:]) -> None:
|
||||
"""Call to parse command-line arguments."""
|
||||
parser = self.get_parser()
|
||||
opt, unknown_opts = parser.parse_known_args(argv)
|
||||
if len(unknown_opts) > 0:
|
||||
print("Unknown args:", unknown_opts)
|
||||
for name in self.model_fields:
|
||||
if name not in self._excluded():
|
||||
value = getattr(opt, name)
|
||||
if isinstance(value, ListConfig):
|
||||
value = list(value)
|
||||
elif isinstance(value, DictConfig):
|
||||
value = dict(value)
|
||||
setattr(self, name, value)
|
||||
|
||||
def to_yaml(self) -> str:
|
||||
"""Return a YAML string representing our settings. This can be used as the contents of `invokeai.yaml` to restore settings later."""
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)["type"])[0]
|
||||
field_dict: Dict[str, Dict[str, Any]] = {type: {}}
|
||||
for name, field in self.model_fields.items():
|
||||
if name in cls._excluded_from_yaml():
|
||||
continue
|
||||
assert isinstance(field.json_schema_extra, dict)
|
||||
category = (
|
||||
field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized"
|
||||
)
|
||||
value = getattr(self, name)
|
||||
assert isinstance(category, str)
|
||||
if category not in field_dict[type]:
|
||||
field_dict[type][category] = {}
|
||||
# keep paths as strings to make it easier to read
|
||||
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
||||
conf = OmegaConf.create(field_dict)
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser: ArgumentParser) -> None:
|
||||
"""Dynamically create arguments for a settings parser."""
|
||||
if "type" in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
||||
else:
|
||||
settings_stanza = "Uncategorized"
|
||||
|
||||
env_prefix = getattr(cls.model_config, "env_prefix", None)
|
||||
env_prefix = env_prefix if env_prefix is not None else settings_stanza.upper()
|
||||
|
||||
initconf = (
|
||||
cls.initconf.get(settings_stanza)
|
||||
if cls.initconf and settings_stanza in cls.initconf
|
||||
else OmegaConf.create()
|
||||
)
|
||||
|
||||
# create an upcase version of the environment in
|
||||
# order to achieve case-insensitive environment
|
||||
# variables (the way Windows does)
|
||||
upcase_environ = {}
|
||||
for key, value in os.environ.items():
|
||||
upcase_environ[key.upper()] = value
|
||||
|
||||
fields = cls.model_fields
|
||||
cls.argparse_groups = {}
|
||||
|
||||
for name, field in fields.items():
|
||||
if name not in cls._excluded():
|
||||
current_default = field.default
|
||||
|
||||
category = (
|
||||
field.json_schema_extra.get("category", "Uncategorized")
|
||||
if field.json_schema_extra
|
||||
else "Uncategorized"
|
||||
)
|
||||
env_name = env_prefix + "_" + name
|
||||
if category in initconf and name in initconf.get(category):
|
||||
field.default = initconf.get(category).get(name)
|
||||
if env_name.upper() in upcase_environ:
|
||||
field.default = upcase_environ[env_name.upper()]
|
||||
cls.add_field_argument(parser, name, field)
|
||||
|
||||
field.default = current_default
|
||||
|
||||
@classmethod
|
||||
def cmd_name(cls, command_field: str = "type") -> str:
|
||||
"""Return the category of a setting."""
|
||||
hints = get_type_hints(cls)
|
||||
if command_field in hints:
|
||||
result: str = get_args(hints[command_field])[0]
|
||||
return result
|
||||
else:
|
||||
return "Uncategorized"
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls) -> ArgumentParser:
|
||||
"""Get the command-line parser for a setting."""
|
||||
parser = PagingArgumentParser(
|
||||
prog=cls.cmd_name(),
|
||||
description=cls.__doc__,
|
||||
)
|
||||
cls.add_parser_arguments(parser)
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def _excluded(cls) -> List[str]:
|
||||
# internal fields that shouldn't be exposed as command line options
|
||||
return ["type", "initconf"]
|
||||
|
||||
@classmethod
|
||||
def _excluded_from_yaml(cls) -> List[str]:
|
||||
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
||||
return [
|
||||
"type",
|
||||
"initconf",
|
||||
"version",
|
||||
"from_file",
|
||||
"model",
|
||||
"root",
|
||||
"max_cache_size",
|
||||
"max_vram_cache_size",
|
||||
"always_use_cpu",
|
||||
"free_gpu_mem",
|
||||
"xformers_enabled",
|
||||
"tiled_decode",
|
||||
"lora_dir",
|
||||
"embedding_dir",
|
||||
"controlnet_dir",
|
||||
"conf_path",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None) -> None:
|
||||
"""Add the argparse arguments for a setting parser."""
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
default = (
|
||||
default_override
|
||||
if default_override is not None
|
||||
else field.default
|
||||
if field.default_factory is None
|
||||
else field.default_factory()
|
||||
)
|
||||
if category := (field.json_schema_extra.get("category", None) if field.json_schema_extra else None):
|
||||
if category not in cls.argparse_groups:
|
||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||
argparse_group = cls.argparse_groups[category]
|
||||
else:
|
||||
argparse_group = command_parser
|
||||
|
||||
if get_origin(field_type) == Literal:
|
||||
allowed_values = get_args(field.annotation)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else int_or_float_or_str
|
||||
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.description,
|
||||
)
|
||||
|
||||
elif get_origin(field_type) == Union:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=int_or_float_or_str,
|
||||
default=default,
|
||||
help=field.description,
|
||||
)
|
||||
|
||||
elif get_origin(field_type) == list:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
nargs="*",
|
||||
type=field.annotation,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.annotation == bool else "store",
|
||||
help=field.description,
|
||||
)
|
||||
else:
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.annotation,
|
||||
default=default,
|
||||
action=argparse.BooleanOptionalAction if field.annotation == bool else "store",
|
||||
help=field.description,
|
||||
)
|
@ -12,6 +12,7 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import pydoc
|
||||
from typing import Union
|
||||
|
||||
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
@ -23,3 +24,18 @@ class PagingArgumentParser(argparse.ArgumentParser):
|
||||
def print_help(self, file=None) -> None:
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
|
||||
def int_or_float_or_str(value: str) -> Union[int, float, str]:
|
||||
"""
|
||||
Workaround for argparse type checking.
|
||||
"""
|
||||
try:
|
||||
return int(value)
|
||||
except Exception as e: # noqa F841
|
||||
pass
|
||||
try:
|
||||
return float(value)
|
||||
except Exception as e: # noqa F841
|
||||
pass
|
||||
return str(value)
|
||||
|
@ -1,23 +1,183 @@
|
||||
# TODO(psyche): pydantic-settings supports YAML settings sources. If we can figure out a way to integrate the YAML
|
||||
# migration logic, we could use that for simpler config loading.
|
||||
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||
|
||||
"""Invokeai configuration system.
|
||||
|
||||
Arguments and fields are taken from the pydantic definition of the
|
||||
model. Defaults can be set by creating a yaml configuration file that
|
||||
has a top-level key of "InvokeAI" and subheadings for each of the
|
||||
categories returned by `invokeai --help`. The file looks like this:
|
||||
|
||||
[file: invokeai.yaml]
|
||||
|
||||
InvokeAI:
|
||||
Web Server:
|
||||
host: 127.0.0.1
|
||||
port: 9090
|
||||
allow_origins: []
|
||||
allow_credentials: true
|
||||
allow_methods:
|
||||
- '*'
|
||||
allow_headers:
|
||||
- '*'
|
||||
Features:
|
||||
esrgan: true
|
||||
internet_available: true
|
||||
log_tokenization: false
|
||||
patchmatch: true
|
||||
ignore_missing_core_models: false
|
||||
Paths:
|
||||
autoimport_dir: autoimport
|
||||
lora_dir: null
|
||||
embedding_dir: null
|
||||
controlnet_dir: null
|
||||
models_dir: models
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
db_dir: databases
|
||||
outdir: /home/lstein/invokeai-main/outputs
|
||||
use_memory_db: false
|
||||
Logging:
|
||||
log_handlers:
|
||||
- console
|
||||
log_format: plain
|
||||
log_level: info
|
||||
Model Cache:
|
||||
ram: 13.5
|
||||
vram: 0.25
|
||||
lazy_offload: true
|
||||
log_memory_usage: false
|
||||
Device:
|
||||
device: auto
|
||||
precision: auto
|
||||
Generation:
|
||||
sequential_guidance: false
|
||||
attention_type: xformers
|
||||
attention_slice_size: auto
|
||||
force_tiled_decode: false
|
||||
|
||||
The default name of the configuration file is `invokeai.yaml`, located
|
||||
in INVOKEAI_ROOT. You can replace supersede this by providing any
|
||||
OmegaConf dictionary object initialization time:
|
||||
|
||||
omegaconf = OmegaConf.load('/tmp/init.yaml')
|
||||
conf = InvokeAIAppConfig()
|
||||
conf.parse_args(conf=omegaconf)
|
||||
|
||||
InvokeAIAppConfig.parse_args() will parse the contents of `sys.argv`
|
||||
at initialization time. You may pass a list of strings in the optional
|
||||
`argv` argument to use instead of the system argv:
|
||||
|
||||
conf.parse_args(argv=['--log_tokenization'])
|
||||
|
||||
It is also possible to set a value at initialization time. However, if
|
||||
you call parse_args() it may be overwritten.
|
||||
|
||||
conf = InvokeAIAppConfig(log_tokenization=True)
|
||||
conf.parse_args(argv=['--no-log_tokenization'])
|
||||
conf.log_tokenization
|
||||
# False
|
||||
|
||||
To avoid this, use `get_config()` to retrieve the application-wide
|
||||
configuration object. This will retain any properties set at object
|
||||
creation time:
|
||||
|
||||
conf = InvokeAIAppConfig.get_config(log_tokenization=True)
|
||||
conf.parse_args(argv=['--no-log_tokenization'])
|
||||
conf.log_tokenization
|
||||
# True
|
||||
|
||||
Any setting can be overwritten by setting an environment variable of
|
||||
form: "INVOKEAI_<setting>", as in:
|
||||
|
||||
export INVOKEAI_port=8080
|
||||
|
||||
Order of precedence (from highest):
|
||||
1) initialization options
|
||||
2) command line options
|
||||
3) environment variable options
|
||||
4) config file options
|
||||
5) pydantic defaults
|
||||
|
||||
Typical usage at the top level file:
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
# get global configuration and print its cache size
|
||||
conf = InvokeAIAppConfig.get_config()
|
||||
conf.parse_args()
|
||||
print(conf.ram_cache_size)
|
||||
|
||||
Typical usage in a backend module:
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
# get global configuration and print its cache size value
|
||||
conf = InvokeAIAppConfig.get_config()
|
||||
print(conf.ram_cache_size)
|
||||
|
||||
Computed properties:
|
||||
|
||||
The InvokeAIAppConfig object has a series of properties that
|
||||
resolve paths relative to the runtime root directory. They each return
|
||||
a Path object:
|
||||
|
||||
root_path - path to InvokeAI root
|
||||
output_path - path to default outputs directory
|
||||
conf - alias for the above
|
||||
embedding_path - path to the embeddings directory
|
||||
lora_path - path to the LoRA directory
|
||||
|
||||
In most cases, you will want to create a single InvokeAIAppConfig
|
||||
object for the entire application. The InvokeAIAppConfig.get_config() function
|
||||
does this:
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args() # read values from the command line/config file
|
||||
print(config.root)
|
||||
|
||||
# Subclassing
|
||||
|
||||
If you wish to create a similar class, please subclass the
|
||||
`InvokeAISettings` class and define a Literal field named "type",
|
||||
which is set to the desired top-level name. For example, to create a
|
||||
"InvokeBatch" configuration, define like this:
|
||||
|
||||
class InvokeBatch(InvokeAISettings):
|
||||
type: Literal["InvokeBatch"] = "InvokeBatch"
|
||||
node_count : int = Field(default=1, description="Number of nodes to run on", json_schema_extra=dict(category='Resources'))
|
||||
cpu_count : int = Field(default=8, description="Number of GPUs to run on per node", json_schema_extra=dict(category='Resources'))
|
||||
|
||||
This will now read and write from the "InvokeBatch" section of the
|
||||
config file, look for environment variables named INVOKEBATCH_*, and
|
||||
accept the command-line arguments `--node_count` and `--cpu_count`. The
|
||||
two configs are kept in separate sections of the config file:
|
||||
|
||||
# invokeai.yaml
|
||||
|
||||
InvokeBatch:
|
||||
Resources:
|
||||
node_count: 1
|
||||
cpu_count: 8
|
||||
|
||||
InvokeAI:
|
||||
Paths:
|
||||
root: /home/lstein/invokeai-main
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
outdir: outputs
|
||||
...
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional
|
||||
|
||||
import psutil
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import Field
|
||||
from pydantic.config import JsonDict
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
import invokeai.configs as model_configs
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
from .config_base import InvokeAISettings
|
||||
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
@ -25,439 +185,308 @@ LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_RAM_CACHE = 10.0
|
||||
DEFAULT_VRAM_CACHE = 0.25
|
||||
DEFAULT_CONVERT_CACHE = 20.0
|
||||
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
|
||||
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
||||
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
||||
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
||||
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||
CONFIG_SCHEMA_VERSION = "4.0.0"
|
||||
|
||||
|
||||
def get_default_ram_cache_size() -> float:
|
||||
"""Run a heuristic for the default RAM cache based on installed RAM."""
|
||||
class Categories(object):
|
||||
"""Category headers for configuration variable groups."""
|
||||
|
||||
# On some machines, psutil.virtual_memory().total gives a value that is slightly less than the actual RAM, so the
|
||||
# limits are set slightly lower than than what we expect the actual RAM to be.
|
||||
|
||||
GB = 1024**3
|
||||
max_ram = psutil.virtual_memory().total / GB
|
||||
|
||||
if max_ram >= 60:
|
||||
return 15.0
|
||||
if max_ram >= 30:
|
||||
return 7.5
|
||||
if max_ram >= 14:
|
||||
return 4.0
|
||||
return 2.1 # 2.1 is just large enough for sd 1.5 ;-)
|
||||
WebServer: JsonDict = {"category": "Web Server"}
|
||||
Features: JsonDict = {"category": "Features"}
|
||||
Paths: JsonDict = {"category": "Paths"}
|
||||
Logging: JsonDict = {"category": "Logging"}
|
||||
Development: JsonDict = {"category": "Development"}
|
||||
Other: JsonDict = {"category": "Other"}
|
||||
ModelCache: JsonDict = {"category": "Model Cache"}
|
||||
Device: JsonDict = {"category": "Device"}
|
||||
Generation: JsonDict = {"category": "Generation"}
|
||||
Queue: JsonDict = {"category": "Queue"}
|
||||
Nodes: JsonDict = {"category": "Nodes"}
|
||||
MemoryPerformance: JsonDict = {"category": "Memory/Performance"}
|
||||
|
||||
|
||||
class URLRegexTokenPair(BaseModel):
|
||||
url_regex: str = Field(description="Regular expression to match against the URL")
|
||||
token: str = Field(description="Token to use when the URL matches the regex")
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""Configuration object for InvokeAI App."""
|
||||
|
||||
@field_validator("url_regex")
|
||||
@classmethod
|
||||
def validate_url_regex(cls, v: str) -> str:
|
||||
"""Validate that the value is a valid regex."""
|
||||
try:
|
||||
re.compile(v)
|
||||
except re.error as e:
|
||||
raise ValueError(f"Invalid regex: {e}")
|
||||
return v
|
||||
|
||||
|
||||
class InvokeAIAppConfig(BaseSettings):
|
||||
"""Invoke's global app configuration.
|
||||
|
||||
Typically, you won't need to interact with this class directly. Instead, use the `get_config` function from `invokeai.app.services.config` to get a singleton config object.
|
||||
|
||||
Attributes:
|
||||
host: IP address to bind to. Use `0.0.0.0` to serve to your local network.
|
||||
port: Port to bind to.
|
||||
allow_origins: Allowed CORS origins.
|
||||
allow_credentials: Allow CORS credentials.
|
||||
allow_methods: Methods allowed for CORS.
|
||||
allow_headers: Headers allowed for CORS.
|
||||
ssl_certfile: SSL certificate file for HTTPS. See https://www.uvicorn.org/settings/#https.
|
||||
ssl_keyfile: SSL key file for HTTPS. See https://www.uvicorn.org/settings/#https.
|
||||
log_tokenization: Enable logging of parsed prompt tokens.
|
||||
patchmatch: Enable patchmatch inpaint code.
|
||||
autoimport_dir: Path to a directory of models files to be imported on startup.
|
||||
models_dir: Path to the models directory.
|
||||
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
|
||||
legacy_conf_dir: Path to directory of legacy checkpoint config files.
|
||||
db_dir: Path to InvokeAI databases directory.
|
||||
outputs_dir: Path to directory for outputs.
|
||||
custom_nodes_dir: Path to directory for custom nodes.
|
||||
log_handlers: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
|
||||
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
|
||||
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
|
||||
log_sql: Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.
|
||||
use_memory_db: Use in-memory database. Useful for development.
|
||||
dev_reload: Automatically reload when Python sources are changed. Does not reload node definitions.
|
||||
profile_graphs: Enable graph profiling using `cProfile`.
|
||||
profile_prefix: An optional prefix for profile output files.
|
||||
profiles_dir: Path to profiles output directory.
|
||||
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
|
||||
vram: Amount of VRAM reserved for model storage (GB).
|
||||
convert_cache: Maximum size of on-disk converted models cache (GB).
|
||||
lazy_offload: Keep models in VRAM until their space is needed.
|
||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
|
||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
||||
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
|
||||
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
|
||||
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
|
||||
max_queue_size: Maximum number of items in the session queue.
|
||||
allow_nodes: List of nodes to allow. Omit to allow all.
|
||||
deny_nodes: List of nodes to deny. Omit to deny none.
|
||||
node_cache_size: How many cached nodes to keep in memory.
|
||||
hashing_algorithm: Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.<br>Valid values: `md5`, `sha1`, `sha224`, `sha256`, `sha384`, `sha512`, `blake2b`, `blake2s`, `sha3_224`, `sha3_256`, `sha3_384`, `sha3_512`, `shake_128`, `shake_256`, `blake3`, `blake3_single`, `random`
|
||||
remote_api_tokens: List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.
|
||||
"""
|
||||
|
||||
_root: Optional[Path] = PrivateAttr(default=None)
|
||||
_config_file: Optional[Path] = PrivateAttr(default=None)
|
||||
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
|
||||
singleton_init: ClassVar[Optional[Dict[str, Any]]] = None
|
||||
|
||||
# fmt: off
|
||||
|
||||
# INTERNAL
|
||||
schema_version: str = Field(default=CONFIG_SCHEMA_VERSION, description="Schema version of the config file. This is not a user-configurable setting.")
|
||||
# This is only used during v3 models.yaml migration
|
||||
legacy_models_yaml_path: Optional[Path] = Field(default=None, description="Path to the legacy models.yaml file. This is not a user-configurable setting.")
|
||||
type: Literal["InvokeAI"] = "InvokeAI"
|
||||
|
||||
# WEB
|
||||
host: str = Field(default="127.0.0.1", description="IP address to bind to. Use `0.0.0.0` to serve to your local network.")
|
||||
port: int = Field(default=9090, description="Port to bind to.")
|
||||
allow_origins: list[str] = Field(default=[], description="Allowed CORS origins.")
|
||||
allow_credentials: bool = Field(default=True, description="Allow CORS credentials.")
|
||||
allow_methods: list[str] = Field(default=["*"], description="Methods allowed for CORS.")
|
||||
allow_headers: list[str] = Field(default=["*"], description="Headers allowed for CORS.")
|
||||
ssl_certfile: Optional[Path] = Field(default=None, description="SSL certificate file for HTTPS. See https://www.uvicorn.org/settings/#https.")
|
||||
ssl_keyfile: Optional[Path] = Field(default=None, description="SSL key file for HTTPS. See https://www.uvicorn.org/settings/#https.")
|
||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", json_schema_extra=Categories.WebServer)
|
||||
port : int = Field(default=9090, description="Port to bind to", json_schema_extra=Categories.WebServer)
|
||||
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", json_schema_extra=Categories.WebServer)
|
||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", json_schema_extra=Categories.WebServer)
|
||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||
# SSL options correspond to https://www.uvicorn.org/settings/#https
|
||||
ssl_certfile : Optional[Path] = Field(default=None, description="SSL certificate file (for HTTPS)", json_schema_extra=Categories.WebServer)
|
||||
ssl_keyfile : Optional[Path] = Field(default=None, description="SSL key file", json_schema_extra=Categories.WebServer)
|
||||
|
||||
# MISC FEATURES
|
||||
log_tokenization: bool = Field(default=False, description="Enable logging of parsed prompt tokens.")
|
||||
patchmatch: bool = Field(default=True, description="Enable patchmatch inpaint code.")
|
||||
# FEATURES
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", json_schema_extra=Categories.Features)
|
||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", json_schema_extra=Categories.Features)
|
||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", json_schema_extra=Categories.Features)
|
||||
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", json_schema_extra=Categories.Features)
|
||||
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', json_schema_extra=Categories.Features)
|
||||
|
||||
# PATHS
|
||||
autoimport_dir: Path = Field(default=Path("autoimport"), description="Path to a directory of models files to be imported on startup.")
|
||||
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
|
||||
convert_cache_dir: Path = Field(default=Path("models/.cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
|
||||
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
|
||||
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
|
||||
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
|
||||
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.")
|
||||
root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths)
|
||||
autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
|
||||
convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths)
|
||||
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
|
||||
db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
|
||||
outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', json_schema_extra=Categories.Paths)
|
||||
custom_nodes_dir : Path = Field(default=Path('nodes'), description='Path to directory for custom nodes', json_schema_extra=Categories.Paths)
|
||||
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only)', json_schema_extra=Categories.Paths)
|
||||
|
||||
# LOGGING
|
||||
log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".')
|
||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', json_schema_extra=Categories.Logging)
|
||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||
log_format: LOG_FORMAT = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.')
|
||||
log_level: LOG_LEVEL = Field(default="info", description="Emit logging messages at this level or higher.")
|
||||
log_sql: bool = Field(default=False, description="Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.")
|
||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', json_schema_extra=Categories.Logging)
|
||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", json_schema_extra=Categories.Logging)
|
||||
log_sql : bool = Field(default=False, description="Log SQL queries", json_schema_extra=Categories.Logging)
|
||||
|
||||
# Development
|
||||
use_memory_db: bool = Field(default=False, description="Use in-memory database. Useful for development.")
|
||||
dev_reload: bool = Field(default=False, description="Automatically reload when Python sources are changed. Does not reload node definitions.")
|
||||
profile_graphs: bool = Field(default=False, description="Enable graph profiling using `cProfile`.")
|
||||
profile_prefix: Optional[str] = Field(default=None, description="An optional prefix for profile output files.")
|
||||
profiles_dir: Path = Field(default=Path("profiles"), description="Path to profiles output directory.")
|
||||
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", json_schema_extra=Categories.Development)
|
||||
profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
|
||||
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
|
||||
profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
|
||||
|
||||
# CACHE
|
||||
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
|
||||
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
|
||||
convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).")
|
||||
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
|
||||
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
|
||||
ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
convert_cache : float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache)
|
||||
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
|
||||
log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache)
|
||||
|
||||
# DEVICE
|
||||
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
|
||||
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
|
||||
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)
|
||||
precision : Literal["auto", "float16", "bfloat16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", json_schema_extra=Categories.Device)
|
||||
|
||||
# GENERATION
|
||||
sequential_guidance: bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.")
|
||||
attention_type: ATTENTION_TYPE = Field(default="auto", description="Attention type.")
|
||||
attention_slice_size: ATTENTION_SLICE_SIZE = Field(default="auto", description='Slice size, valid when attention_type=="sliced".')
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
|
||||
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
|
||||
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", json_schema_extra=Categories.Generation)
|
||||
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", json_schema_extra=Categories.Generation)
|
||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', json_schema_extra=Categories.Generation)
|
||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Generation)
|
||||
png_compress_level : int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", json_schema_extra=Categories.Generation)
|
||||
|
||||
# QUEUE
|
||||
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", json_schema_extra=Categories.Queue)
|
||||
|
||||
# NODES
|
||||
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
|
||||
deny_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.")
|
||||
node_cache_size: int = Field(default=512, description="How many cached nodes to keep in memory.")
|
||||
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", json_schema_extra=Categories.Nodes)
|
||||
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", json_schema_extra=Categories.Nodes)
|
||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
|
||||
|
||||
# MODEL INSTALL
|
||||
hashing_algorithm: HASHING_ALGORITHMS = Field(default="blake3", description="Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.")
|
||||
remote_api_tokens: Optional[list[URLRegexTokenPair]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.")
|
||||
# MODEL IMPORT
|
||||
civitai_api_key : Optional[str] = Field(default=os.environ.get("CIVITAI_API_KEY"), description="API key for CivitAI", json_schema_extra=Categories.Other)
|
||||
|
||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
|
||||
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)
|
||||
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", json_schema_extra=Categories.MemoryPerformance)
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", json_schema_extra=Categories.MemoryPerformance)
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.MemoryPerformance)
|
||||
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
|
||||
|
||||
# this is not referred to in the source code and can be removed entirely
|
||||
#free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
|
||||
|
||||
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
|
||||
# fmt: on
|
||||
|
||||
model_config = SettingsConfigDict(env_prefix="INVOKEAI_", env_ignore_empty=True)
|
||||
model_config = SettingsConfigDict(validate_assignment=True, env_prefix="INVOKEAI")
|
||||
|
||||
def update_config(self, config: dict[str, Any] | InvokeAIAppConfig, clobber: bool = True) -> None:
|
||||
"""Updates the config, overwriting existing values.
|
||||
|
||||
Args:
|
||||
config: A dictionary of config settings, or instance of `InvokeAIAppConfig`. If an instance of \
|
||||
`InvokeAIAppConfig`, only the explicitly set fields will be merged into the singleton config.
|
||||
clobber: If `True`, overwrite existing values. If `False`, only update fields that are not already set.
|
||||
def parse_args(
|
||||
self,
|
||||
argv: Optional[list[str]] = None,
|
||||
conf: Optional[DictConfig] = None,
|
||||
clobber: Optional[bool] = False,
|
||||
) -> None:
|
||||
"""
|
||||
Update settings with contents of init file, environment, and command-line settings.
|
||||
|
||||
if isinstance(config, dict):
|
||||
new_config = self.model_validate(config)
|
||||
:param conf: alternate Omegaconf dictionary object
|
||||
:param argv: aternate sys.argv list
|
||||
:param clobber: ovewrite any initialization parameters passed during initialization
|
||||
"""
|
||||
# Set the runtime root directory. We parse command-line switches here
|
||||
# in order to pick up the --root_dir option.
|
||||
super().parse_args(argv)
|
||||
loaded_conf = None
|
||||
if conf is None:
|
||||
try:
|
||||
loaded_conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||
except Exception:
|
||||
pass
|
||||
if isinstance(loaded_conf, DictConfig):
|
||||
InvokeAISettings.initconf = loaded_conf
|
||||
else:
|
||||
new_config = config
|
||||
InvokeAISettings.initconf = conf
|
||||
|
||||
for field_name in new_config.model_fields_set:
|
||||
new_value = getattr(new_config, field_name)
|
||||
current_value = getattr(self, field_name)
|
||||
# parse args again in order to pick up settings in configuration file
|
||||
super().parse_args(argv)
|
||||
|
||||
if field_name in self.model_fields_set and not clobber:
|
||||
continue
|
||||
if self.singleton_init and not clobber:
|
||||
# When setting values in this way, set validate_assignment to true if you want to validate the value.
|
||||
for k, v in self.singleton_init.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
if new_value != current_value:
|
||||
setattr(self, field_name, new_value)
|
||||
@classmethod
|
||||
def get_config(cls, **kwargs: Any) -> InvokeAIAppConfig:
|
||||
"""Return a singleton InvokeAIAppConfig configuration object."""
|
||||
if (
|
||||
cls.singleton_config is None
|
||||
or type(cls.singleton_config) is not cls
|
||||
or (kwargs and cls.singleton_init != kwargs)
|
||||
):
|
||||
cls.singleton_config = cls(**kwargs)
|
||||
cls.singleton_init = kwargs
|
||||
return cls.singleton_config
|
||||
|
||||
def write_file(self, dest_path: Path, as_example: bool = False) -> None:
|
||||
"""Write the current configuration to file. This will overwrite the existing file.
|
||||
@property
|
||||
def root_path(self) -> Path:
|
||||
"""Path to the runtime root directory."""
|
||||
if self.root:
|
||||
root = Path(self.root).expanduser().absolute()
|
||||
else:
|
||||
root = self.find_root().expanduser().absolute()
|
||||
self.root = root # insulate ourselves from relative paths that may change
|
||||
return root.resolve()
|
||||
|
||||
A `meta` stanza is added to the top of the file, containing metadata about the config file. This is not stored in the config object.
|
||||
|
||||
Args:
|
||||
dest_path: Path to write the config to.
|
||||
"""
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(dest_path, "w") as file:
|
||||
# Meta fields should be written in a separate stanza - skip legacy_models_yaml_path
|
||||
meta_dict = self.model_dump(mode="json", include={"schema_version"})
|
||||
|
||||
# User settings
|
||||
config_dict = self.model_dump(
|
||||
mode="json",
|
||||
exclude_unset=False if as_example else True,
|
||||
exclude_defaults=False if as_example else True,
|
||||
exclude_none=True if as_example else False,
|
||||
exclude={"schema_version", "legacy_models_yaml_path"},
|
||||
)
|
||||
|
||||
if as_example:
|
||||
file.write(
|
||||
"# This is an example file with default and example settings. Use the values here as a baseline.\n\n"
|
||||
)
|
||||
file.write("# Internal metadata - do not edit:\n")
|
||||
file.write(yaml.dump(meta_dict, sort_keys=False))
|
||||
file.write("\n")
|
||||
file.write("# Put user settings here - see https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/:\n")
|
||||
if len(config_dict) > 0:
|
||||
file.write(yaml.dump(config_dict, sort_keys=False))
|
||||
@property
|
||||
def root_dir(self) -> Path:
|
||||
"""Alias for above."""
|
||||
return self.root_path
|
||||
|
||||
def _resolve(self, partial_path: Path) -> Path:
|
||||
return (self.root_path / partial_path).resolve()
|
||||
|
||||
@property
|
||||
def root_path(self) -> Path:
|
||||
"""Path to the runtime root directory, resolved to an absolute path."""
|
||||
if self._root:
|
||||
root = Path(self._root).expanduser().absolute()
|
||||
else:
|
||||
root = self.find_root().expanduser().absolute()
|
||||
self._root = root # insulate ourselves from relative paths that may change
|
||||
return root.resolve()
|
||||
|
||||
@property
|
||||
def config_file_path(self) -> Path:
|
||||
"""Path to invokeai.yaml, resolved to an absolute path.."""
|
||||
resolved_path = self._resolve(self._config_file or INIT_FILE)
|
||||
def init_file_path(self) -> Path:
|
||||
"""Path to invokeai.yaml."""
|
||||
resolved_path = self._resolve(INIT_FILE)
|
||||
assert resolved_path is not None
|
||||
return resolved_path
|
||||
|
||||
@property
|
||||
def autoimport_path(self) -> Path:
|
||||
"""Path to the autoimports directory, resolved to an absolute path.."""
|
||||
return self._resolve(self.autoimport_dir)
|
||||
|
||||
@property
|
||||
def outputs_path(self) -> Optional[Path]:
|
||||
"""Path to the outputs directory, resolved to an absolute path.."""
|
||||
return self._resolve(self.outputs_dir)
|
||||
def output_path(self) -> Optional[Path]:
|
||||
"""Path to defaults outputs directory."""
|
||||
return self._resolve(self.outdir)
|
||||
|
||||
@property
|
||||
def db_path(self) -> Path:
|
||||
"""Path to the invokeai.db file, resolved to an absolute path.."""
|
||||
"""Path to the invokeai.db file."""
|
||||
db_dir = self._resolve(self.db_dir)
|
||||
assert db_dir is not None
|
||||
return db_dir / DB_FILE
|
||||
|
||||
@property
|
||||
def model_conf_path(self) -> Path:
|
||||
"""Path to models configuration file."""
|
||||
return self._resolve(self.conf_path)
|
||||
|
||||
@property
|
||||
def legacy_conf_path(self) -> Path:
|
||||
"""Path to directory of legacy configuration files (e.g. v1-inference.yaml), resolved to an absolute path.."""
|
||||
"""Path to directory of legacy configuration files (e.g. v1-inference.yaml)."""
|
||||
return self._resolve(self.legacy_conf_dir)
|
||||
|
||||
@property
|
||||
def models_path(self) -> Path:
|
||||
"""Path to the models directory, resolved to an absolute path.."""
|
||||
"""Path to the models directory."""
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
@property
|
||||
def convert_cache_path(self) -> Path:
|
||||
"""Path to the converted cache models directory, resolved to an absolute path.."""
|
||||
def models_convert_cache_path(self) -> Path:
|
||||
"""Path to the converted cache models directory."""
|
||||
return self._resolve(self.convert_cache_dir)
|
||||
|
||||
@property
|
||||
def custom_nodes_path(self) -> Path:
|
||||
"""Path to the custom nodes directory, resolved to an absolute path.."""
|
||||
"""Path to the custom nodes directory."""
|
||||
custom_nodes_path = self._resolve(self.custom_nodes_dir)
|
||||
assert custom_nodes_path is not None
|
||||
return custom_nodes_path
|
||||
|
||||
# the following methods support legacy calls leftover from the Globals era
|
||||
@property
|
||||
def full_precision(self) -> bool:
|
||||
"""Return true if precision set to float32."""
|
||||
return self.precision == "float32"
|
||||
|
||||
@property
|
||||
def try_patchmatch(self) -> bool:
|
||||
"""Return true if patchmatch true."""
|
||||
return self.patchmatch
|
||||
|
||||
@property
|
||||
def nsfw_checker(self) -> bool:
|
||||
"""Return value for NSFW checker. The NSFW node is always active and disabled from Web UI."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def invisible_watermark(self) -> bool:
|
||||
"""Return value of invisible watermark. It is always active and disabled from Web UI."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ram_cache_size(self) -> float:
|
||||
"""Return the ram cache size using the legacy or modern setting (GB)."""
|
||||
return self.max_cache_size or self.ram
|
||||
|
||||
@property
|
||||
def vram_cache_size(self) -> float:
|
||||
"""Return the vram cache size using the legacy or modern setting (GB)."""
|
||||
return self.max_vram_cache_size or self.vram
|
||||
|
||||
@property
|
||||
def convert_cache_size(self) -> float:
|
||||
"""Return the convert cache size on disk (GB)."""
|
||||
return self.convert_cache
|
||||
|
||||
@property
|
||||
def use_cpu(self) -> bool:
|
||||
"""Return true if the device is set to CPU or the always_use_cpu flag is set."""
|
||||
return self.always_use_cpu or self.device == "cpu"
|
||||
|
||||
@property
|
||||
def disable_xformers(self) -> bool:
|
||||
"""Return true if enable_xformers is false (reversed logic) and attention type is not set to xformers."""
|
||||
disabled_in_config = not self.xformers_enabled
|
||||
return disabled_in_config and self.attention_type != "xformers"
|
||||
|
||||
@property
|
||||
def profiles_path(self) -> Path:
|
||||
"""Path to the graph profiles directory, resolved to an absolute path.."""
|
||||
"""Path to the graph profiles directory."""
|
||||
return self._resolve(self.profiles_dir)
|
||||
|
||||
@staticmethod
|
||||
def find_root() -> Path:
|
||||
"""Choose the runtime root directory when not specified on command line or init file."""
|
||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ["INVOKEAI_ROOT"])
|
||||
elif any((venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]):
|
||||
root = (venv.parent).resolve()
|
||||
else:
|
||||
root = Path("~/invokeai").expanduser().resolve()
|
||||
return root
|
||||
return _find_root()
|
||||
|
||||
|
||||
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Migrate a v3 config dictionary to a current config object.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v3 config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for _category_name, category_dict in config_dict["InvokeAI"].items():
|
||||
for k, v in category_dict.items():
|
||||
# `outdir` was renamed to `outputs_dir` in v4
|
||||
if k == "outdir":
|
||||
parsed_config_dict["outputs_dir"] = v
|
||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||
if k == "max_cache_size" and "ram" not in category_dict:
|
||||
parsed_config_dict["ram"] = v
|
||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||
parsed_config_dict["vram"] = v
|
||||
if k == "conf_path":
|
||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||
if k == "legacy_conf_dir":
|
||||
# The old default for this was "configs/stable-diffusion". If if the incoming config has that as the value, we won't set it.
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
# Else we do not attempt to migrate this setting
|
||||
if v != "configs/stable-diffusion":
|
||||
parsed_config_dict["legacy_conf_dir"] = v
|
||||
elif Path(v).name == "stable-diffusion":
|
||||
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
# skip unknown fields
|
||||
parsed_config_dict[k] = v
|
||||
config = InvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||
|
||||
return config
|
||||
def get_invokeai_config(**kwargs: Any) -> InvokeAIAppConfig:
|
||||
"""Legacy function which returns InvokeAIAppConfig.get_config()."""
|
||||
return InvokeAIAppConfig.get_config(**kwargs)
|
||||
|
||||
|
||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
"""Load and migrate a config file to the latest version.
|
||||
|
||||
Args:
|
||||
config_path: Path to the config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
|
||||
"""
|
||||
assert config_path.suffix == ".yaml"
|
||||
with open(config_path) as file:
|
||||
loaded_config_dict = yaml.safe_load(file)
|
||||
|
||||
assert isinstance(loaded_config_dict, dict)
|
||||
|
||||
if "InvokeAI" in loaded_config_dict:
|
||||
# This is a v3 config file, attempt to migrate it
|
||||
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
||||
try:
|
||||
# This could be the wrong shape, but we will catch all exceptions below
|
||||
config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
||||
except Exception as e:
|
||||
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
||||
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
||||
# By excluding defaults, we ensure that the new config file only contains the settings that were explicitly set
|
||||
config.write_file(config_path)
|
||||
return config
|
||||
def _find_root() -> Path:
|
||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||
if os.environ.get("INVOKEAI_ROOT"):
|
||||
root = Path(os.environ["INVOKEAI_ROOT"])
|
||||
elif any((venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]):
|
||||
root = (venv.parent).resolve()
|
||||
else:
|
||||
# Attempt to load as a v4 config file
|
||||
try:
|
||||
# Meta is not included in the model fields, so we need to validate it separately
|
||||
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||
assert (
|
||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
||||
return config
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_config() -> InvokeAIAppConfig:
|
||||
"""Get the global singleton app config.
|
||||
|
||||
When first called, this function:
|
||||
- Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file.
|
||||
- Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint.
|
||||
- Sets the root dir, if provided via CLI args.
|
||||
- Logs in to HF if there is no valid token already.
|
||||
- Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers).
|
||||
- Reads and merges in settings from the config file if it exists, else writes out a default config file.
|
||||
|
||||
On subsequent calls, the object is returned from the cache.
|
||||
"""
|
||||
config = InvokeAIAppConfig()
|
||||
|
||||
args = InvokeAIArgs.args
|
||||
|
||||
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
|
||||
# If it is False, we should just return a default config and not set the root, log in to HF, etc.
|
||||
if not InvokeAIArgs.did_parse:
|
||||
return config
|
||||
|
||||
# Set CLI args
|
||||
if root := getattr(args, "root", None):
|
||||
config._root = Path(root)
|
||||
if config_file := getattr(args, "config_file", None):
|
||||
config._config_file = Path(config_file)
|
||||
|
||||
# Create the example file from a deep copy, with some extra values provided
|
||||
example_config = config.model_copy(deep=True)
|
||||
example_config.remote_api_tokens = [
|
||||
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
|
||||
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
|
||||
]
|
||||
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
|
||||
|
||||
# Copy all legacy configs - We know `__path__[0]` is correct here
|
||||
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
||||
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
|
||||
|
||||
if config.config_file_path.exists():
|
||||
incoming_config = load_and_migrate_config(config.config_file_path)
|
||||
# Clobbering here will overwrite any settings that were set via environment variables
|
||||
config.update_config(incoming_config, clobber=False)
|
||||
else:
|
||||
config.write_file(config.config_file_path)
|
||||
|
||||
return config
|
||||
root = Path("~/invokeai").expanduser().resolve()
|
||||
return root
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""Init file for download queue."""
|
||||
|
||||
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
|
||||
from .download_default import DownloadQueueService, TqdmProgress
|
||||
|
||||
|
@ -224,6 +224,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
job.job_started = get_iso_timestamp()
|
||||
self._do_download(job)
|
||||
self._signal_job_complete(job)
|
||||
|
||||
except (OSError, HTTPError) as excp:
|
||||
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
||||
job.error = traceback.format_exc()
|
||||
|
@ -12,7 +12,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
)
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
@ -81,7 +80,7 @@ class EventServiceBase:
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node_id": node_id,
|
||||
"source_node_id": source_node_id,
|
||||
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None,
|
||||
"progress_image": progress_image.model_dump() if progress_image is not None else None,
|
||||
"step": step,
|
||||
"order": order,
|
||||
"total_steps": total_steps,
|
||||
@ -181,7 +180,6 @@ class EventServiceBase:
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Emitted when a model is requested"""
|
||||
self.__emit_queue_event(
|
||||
@ -191,8 +189,7 @@ class EventServiceBase:
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_config": model_config.model_dump(mode="json"),
|
||||
"submodel_type": submodel_type,
|
||||
"model_config": model_config.model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
@ -203,7 +200,6 @@ class EventServiceBase:
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
self.__emit_queue_event(
|
||||
@ -213,8 +209,7 @@ class EventServiceBase:
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_config": model_config.model_dump(mode="json"),
|
||||
"submodel_type": submodel_type,
|
||||
"model_config": model_config.model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
@ -259,8 +254,8 @@ class EventServiceBase:
|
||||
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||
},
|
||||
"batch_status": batch_status.model_dump(mode="json"),
|
||||
"queue_status": queue_status.model_dump(mode="json"),
|
||||
"batch_status": batch_status.model_dump(),
|
||||
"queue_status": queue_status.model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
@ -386,17 +381,6 @@ class EventServiceBase:
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_install_downloads_done(self, source: str) -> None:
|
||||
"""
|
||||
Emit once when all parts are downloaded, but before the probing and registration start.
|
||||
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_downloads_done",
|
||||
payload={"source": source},
|
||||
)
|
||||
|
||||
def emit_model_install_running(self, source: str) -> None:
|
||||
"""
|
||||
Emit once when an install job becomes active.
|
||||
@ -421,7 +405,7 @@ class EventServiceBase:
|
||||
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
|
||||
)
|
||||
|
||||
def emit_model_install_cancelled(self, source: str, id: int) -> None:
|
||||
def emit_model_install_cancelled(self, source: str) -> None:
|
||||
"""
|
||||
Emit when an install job is cancelled.
|
||||
|
||||
@ -429,7 +413,7 @@ class EventServiceBase:
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_cancelled",
|
||||
payload={"source": source, "id": id},
|
||||
payload={"source": source},
|
||||
)
|
||||
|
||||
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:
|
||||
|
@ -82,7 +82,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
image_path,
|
||||
"PNG",
|
||||
pnginfo=pnginfo,
|
||||
compress_level=self.__invoker.services.configuration.pil_compress_level,
|
||||
compress_level=self.__invoker.services.configuration.png_compress_level,
|
||||
)
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
|
@ -41,9 +41,8 @@ class InvocationCacheBase(ABC):
|
||||
"""Clears the cache"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def create_key(invocation: BaseInvocation) -> int:
|
||||
def create_key(self, invocation: BaseInvocation) -> int:
|
||||
"""Gets the key for the invocation's cache item"""
|
||||
pass
|
||||
|
||||
|
@ -61,7 +61,9 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
self._delete_oldest_access(number_to_delete)
|
||||
self._cache[key] = CachedItem(
|
||||
invocation_output,
|
||||
invocation_output.model_dump_json(warnings=False, exclude_defaults=True, exclude_unset=True),
|
||||
invocation_output.model_dump_json(
|
||||
warnings=False, exclude_defaults=True, exclude_unset=True, include={"type"}
|
||||
),
|
||||
)
|
||||
|
||||
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
||||
@ -79,7 +81,7 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
with self._lock:
|
||||
return self._delete(key)
|
||||
|
||||
def clear(self) -> None:
|
||||
def clear(self, *args, **kwargs) -> None:
|
||||
with self._lock:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
|
@ -25,7 +25,6 @@ if TYPE_CHECKING:
|
||||
from .images.images_base import ImageServiceABC
|
||||
from .invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||
from .model_images.model_images_base import ModelImageFileStorageBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
@ -50,7 +49,6 @@ class InvocationServices:
|
||||
image_files: "ImageFileStorageBase",
|
||||
image_records: "ImageRecordStorageBase",
|
||||
logger: "Logger",
|
||||
model_images: "ModelImageFileStorageBase",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
@ -74,7 +72,6 @@ class InvocationServices:
|
||||
self.image_files = image_files
|
||||
self.image_records = image_records
|
||||
self.logger = logger
|
||||
self.model_images = model_images
|
||||
self.model_manager = model_manager
|
||||
self.download_queue = download_queue
|
||||
self.performance_statistics = performance_statistics
|
||||
|
@ -1,33 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
|
||||
class ModelImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, model_key: str) -> PILImageType:
|
||||
"""Retrieves a model image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, model_key: str) -> Path:
|
||||
"""Gets the internal path to a model image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, model_key: str) -> str | None:
|
||||
"""Gets the URL to fetch a model image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, image: PILImageType, model_key: str) -> None:
|
||||
"""Saves a model image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, model_key: str) -> None:
|
||||
"""Deletes a model image."""
|
||||
pass
|
@ -1,20 +0,0 @@
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
class ModelImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message="Model image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message="Model image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Model image file not deleted"):
|
||||
super().__init__(message)
|
@ -1,85 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.app.util.thumbnails import make_thumbnail
|
||||
|
||||
from .model_images_base import ModelImageFileStorageBase
|
||||
from .model_images_common import (
|
||||
ModelImageFileDeleteException,
|
||||
ModelImageFileNotFoundException,
|
||||
ModelImageFileSaveException,
|
||||
)
|
||||
|
||||
|
||||
class ModelImageFileStorageDisk(ModelImageFileStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
def __init__(self, model_images_folder: Path):
|
||||
self._model_images_folder = model_images_folder
|
||||
self._validate_storage_folders()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def get(self, model_key: str) -> PILImageType:
|
||||
try:
|
||||
path = self.get_path(model_key)
|
||||
|
||||
if not self._validate_path(path):
|
||||
raise ModelImageFileNotFoundException
|
||||
|
||||
return Image.open(path)
|
||||
except FileNotFoundError as e:
|
||||
raise ModelImageFileNotFoundException from e
|
||||
|
||||
def save(self, image: PILImageType, model_key: str) -> None:
|
||||
try:
|
||||
self._validate_storage_folders()
|
||||
image_path = self._model_images_folder / (model_key + ".webp")
|
||||
thumbnail = make_thumbnail(image, 256)
|
||||
thumbnail.save(image_path, format="webp")
|
||||
|
||||
except Exception as e:
|
||||
raise ModelImageFileSaveException from e
|
||||
|
||||
def get_path(self, model_key: str) -> Path:
|
||||
path = self._model_images_folder / (model_key + ".webp")
|
||||
|
||||
return path
|
||||
|
||||
def get_url(self, model_key: str) -> str | None:
|
||||
path = self.get_path(model_key)
|
||||
if not self._validate_path(path):
|
||||
return
|
||||
|
||||
url = self._invoker.services.urls.get_model_image_url(model_key)
|
||||
|
||||
# The image URL never changes, so we must add random query string to it to prevent caching
|
||||
url += f"?{uuid_string()}"
|
||||
|
||||
return url
|
||||
|
||||
def delete(self, model_key: str) -> None:
|
||||
try:
|
||||
path = self.get_path(model_key)
|
||||
|
||||
if not self._validate_path(path):
|
||||
raise ModelImageFileNotFoundException
|
||||
|
||||
send2trash(path)
|
||||
|
||||
except Exception as e:
|
||||
raise ModelImageFileDeleteException from e
|
||||
|
||||
def _validate_path(self, path: Path) -> bool:
|
||||
"""Validates the path given for an image."""
|
||||
return path.exists()
|
||||
|
||||
def _validate_storage_folders(self) -> None:
|
||||
"""Checks if the required folders exist and create them if they don't"""
|
||||
self._model_images_folder.mkdir(parents=True, exist_ok=True)
|
@ -1,6 +1,7 @@
|
||||
"""Initialization file for model install service package."""
|
||||
|
||||
from .model_install_base import (
|
||||
CivitaiModelSource,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
@ -22,4 +23,5 @@ __all__ = [
|
||||
"LocalModelSource",
|
||||
"HFModelSource",
|
||||
"URLModelSource",
|
||||
"CivitaiModelSource",
|
||||
]
|
||||
|
@ -18,16 +18,16 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
|
||||
WAITING = "waiting" # waiting to be dequeued
|
||||
DOWNLOADING = "downloading" # downloading of model files in process
|
||||
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
|
||||
RUNNING = "running" # being processed
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
@ -91,6 +91,21 @@ class LocalModelSource(StringLikeSource):
|
||||
return Path(self.path).as_posix()
|
||||
|
||||
|
||||
class CivitaiModelSource(StringLikeSource):
|
||||
"""A Civitai version id, with optional variant and access token."""
|
||||
|
||||
version_id: int
|
||||
variant: Optional[ModelRepoVariant] = None
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["civitai"] = "civitai"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = str(self.version_id)
|
||||
base += f" ({self.variant})" if self.variant else ""
|
||||
return base
|
||||
|
||||
|
||||
class HFModelSource(StringLikeSource):
|
||||
"""
|
||||
A HuggingFace repo_id with optional variant, sub-folder and access token.
|
||||
@ -114,10 +129,8 @@ class HFModelSource(StringLikeSource):
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = self.repo_id
|
||||
if self.variant:
|
||||
base += f":{self.variant or ''}"
|
||||
if self.subfolder:
|
||||
base += f":{self.subfolder}"
|
||||
base += f":{self.variant or ''}"
|
||||
base += f":{self.subfolder}" if self.subfolder else ""
|
||||
return base
|
||||
|
||||
|
||||
@ -133,13 +146,9 @@ class URLModelSource(StringLikeSource):
|
||||
return str(self.url)
|
||||
|
||||
|
||||
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
|
||||
|
||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||
URLModelSource: ModelSourceType.Url,
|
||||
HFModelSource: ModelSourceType.HFRepoID,
|
||||
LocalModelSource: ModelSourceType.Path,
|
||||
}
|
||||
ModelSource = Annotated[
|
||||
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
|
||||
]
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
@ -220,11 +229,6 @@ class ModelInstallJob(BaseModel):
|
||||
"""Return true if job is downloading."""
|
||||
return self.status == InstallStatus.DOWNLOADING
|
||||
|
||||
@property
|
||||
def downloads_done(self) -> bool:
|
||||
"""Return true if job's downloads ae done."""
|
||||
return self.status == InstallStatus.DOWNLOADS_DONE
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""Return true if job is running."""
|
||||
@ -250,6 +254,7 @@ class ModelInstallServiceBase(ABC):
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
metadata_store: ModelMetadataStoreBase,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
):
|
||||
"""
|
||||
@ -336,7 +341,6 @@ class ModelInstallServiceBase(ABC):
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
r"""Install the indicated model using heuristics to interpret user intentions.
|
||||
|
||||
@ -382,7 +386,7 @@ class ModelInstallServiceBase(ABC):
|
||||
will override corresponding autoassigned probe fields in the
|
||||
model's config record. Use it to override
|
||||
`name`, `description`, `base_type`, `model_type`, `format`,
|
||||
`prediction_type`, and/or `image_size`.
|
||||
`prediction_type`, `image_size`, and/or `ztsnr_training`.
|
||||
|
||||
This will download the model located at `source`,
|
||||
probe it, and install it into the models directory.
|
||||
|
@ -7,11 +7,11 @@ import time
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from random import randbytes
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
import yaml
|
||||
from huggingface_hub import HfFolder
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import Session
|
||||
@ -21,30 +21,28 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
HuggingFaceMetadataFetch,
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||
from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
||||
|
||||
from .model_install_base import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
CivitaiModelSource,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
@ -93,6 +91,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._running = False
|
||||
self._session = session
|
||||
self._next_job_id = 0
|
||||
self._metadata_store = record_store.metadata_store # for convenience
|
||||
|
||||
@property
|
||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||
@ -115,7 +114,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
raise Exception("Attempt to start the installer service twice")
|
||||
self._start_installer_thread()
|
||||
self._remove_dangling_install_dirs()
|
||||
self._migrate_yaml()
|
||||
self.sync_to_config()
|
||||
|
||||
def stop(self, invoker: Optional[Invoker] = None) -> None:
|
||||
@ -124,28 +122,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
if not self._running:
|
||||
raise Exception("Attempt to stop the install service before it was started")
|
||||
self._stop_event.set()
|
||||
self._clear_pending_jobs()
|
||||
with self._install_queue.mutex:
|
||||
self._install_queue.queue.clear() # get rid of pending jobs
|
||||
active_jobs = [x for x in self.list_jobs() if x.running]
|
||||
if active_jobs:
|
||||
self._logger.warning("Waiting for active install job to complete")
|
||||
self.wait_for_installs()
|
||||
self._download_cache.clear()
|
||||
self._running = False
|
||||
|
||||
def _clear_pending_jobs(self) -> None:
|
||||
for job in self.list_jobs():
|
||||
if not job.in_terminal_state:
|
||||
self._logger.warning("Cancelling job {job.id}")
|
||||
self.cancel_job(job)
|
||||
while True:
|
||||
try:
|
||||
job = self._install_queue.get(block=False)
|
||||
self._install_queue.task_done()
|
||||
except Empty:
|
||||
break
|
||||
|
||||
def _put_in_queue(self, job: ModelInstallJob) -> None:
|
||||
if self._stop_event.is_set():
|
||||
self.cancel_job(job)
|
||||
else:
|
||||
self._install_queue.put(job)
|
||||
|
||||
def register_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
@ -155,7 +140,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
config["source_type"] = ModelSourceType.Path
|
||||
return self._register(model_path, config)
|
||||
|
||||
def install_path(
|
||||
@ -165,8 +149,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm)
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||
old_hash = info.current_hash
|
||||
|
||||
if preferred_name := config.get("name"):
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
@ -180,6 +167,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
raise DuplicateModelException(
|
||||
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
|
||||
) from excp
|
||||
new_hash = FastModelHash.hash(new_path)
|
||||
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
||||
|
||||
return self._register(
|
||||
new_path,
|
||||
@ -192,14 +181,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
|
||||
source_obj = LocalModelSource(path=Path(source))
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
@ -208,16 +196,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
access_token=access_token,
|
||||
)
|
||||
elif re.match(r"^https?://[^/]+", source):
|
||||
# Pull the token from config if it exists and matches the URL
|
||||
_token = access_token
|
||||
if _token is None:
|
||||
for pair in self.app_config.remote_api_tokens or []:
|
||||
if re.search(pair.url_regex, source):
|
||||
_token = pair.token
|
||||
break
|
||||
source_obj = URLModelSource(
|
||||
url=AnyHttpUrl(source),
|
||||
access_token=_token,
|
||||
access_token=access_token,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model source: '{source}'")
|
||||
@ -231,7 +212,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
if isinstance(source, LocalModelSource):
|
||||
install_job = self._import_local_model(source, config)
|
||||
self._put_in_queue(install_job) # synchronously install
|
||||
self._install_queue.put(install_job) # synchronously install
|
||||
elif isinstance(source, CivitaiModelSource):
|
||||
install_job = self._import_from_civitai(source, config)
|
||||
elif isinstance(source, HFModelSource):
|
||||
install_job = self._import_from_hf(source, config)
|
||||
elif isinstance(source, URLModelSource):
|
||||
@ -266,6 +249,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
raise TimeoutError("Timeout exceeded")
|
||||
return job
|
||||
|
||||
# TODO: Better name? Maybe wait_for_jobs()? Maybe too easily confused with above
|
||||
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
|
||||
"""Block until all installation jobs are done."""
|
||||
start = time.time()
|
||||
@ -291,66 +275,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def sync_to_config(self) -> None:
|
||||
"""Synchronize models on disk to those in the config record store database."""
|
||||
self._scan_models_directory()
|
||||
if self._app_config.autoimport_path:
|
||||
if autoimport := self._app_config.autoimport_dir:
|
||||
self._logger.info("Scanning autoimport directory for new models")
|
||||
installed = self.scan_directory(self._app_config.autoimport_path)
|
||||
installed = self.scan_directory(self._app_config.root_path / autoimport)
|
||||
self._logger.info(f"{len(installed)} new models registered")
|
||||
self._logger.info("Model installer (re)initialized")
|
||||
|
||||
def _migrate_yaml(self) -> None:
|
||||
db_models = self.record_store.all_models()
|
||||
|
||||
legacy_models_yaml_path = (
|
||||
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
|
||||
)
|
||||
|
||||
# The old path may be relative to the root path
|
||||
if not legacy_models_yaml_path.exists():
|
||||
legacy_models_yaml_path = Path(self._app_config.root_path, legacy_models_yaml_path)
|
||||
|
||||
if legacy_models_yaml_path.exists():
|
||||
legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text())
|
||||
|
||||
yaml_metadata = legacy_models_yaml.pop("__metadata__")
|
||||
yaml_version = yaml_metadata.get("version")
|
||||
|
||||
if yaml_version != "3.0.0":
|
||||
raise ValueError(
|
||||
f"Attempted migration of unsupported `models.yaml` v{yaml_version}. Only v3.0.0 is supported. Exiting."
|
||||
)
|
||||
|
||||
self._logger.info(
|
||||
f"Starting one-time migration of {len(legacy_models_yaml.items())} models from {str(legacy_models_yaml_path)}. This may take a few minutes."
|
||||
)
|
||||
|
||||
if len(db_models) == 0 and len(legacy_models_yaml.items()) != 0:
|
||||
for model_key, stanza in legacy_models_yaml.items():
|
||||
_, _, model_name = str(model_key).split("/")
|
||||
model_path = Path(stanza["path"])
|
||||
if not model_path.is_absolute():
|
||||
model_path = self._app_config.models_path / model_path
|
||||
model_path = model_path.resolve()
|
||||
|
||||
config: dict[str, Any] = {}
|
||||
config["name"] = model_name
|
||||
config["description"] = stanza.get("description")
|
||||
config["config_path"] = stanza.get("config")
|
||||
|
||||
try:
|
||||
id = self.register_path(model_path=model_path, config=config)
|
||||
self._logger.info(f"Migrated {model_name} with id {id}")
|
||||
except Exception as e:
|
||||
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
|
||||
|
||||
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
|
||||
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
|
||||
|
||||
# Remove `legacy_models_yaml_path` from the config file - we are done with it either way
|
||||
self._app_config.legacy_models_yaml_path = None
|
||||
self._app_config.write_file(self._app_config.config_file_path)
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}
|
||||
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
|
||||
callback = self._scan_install if install else self._scan_register
|
||||
search = ModelSearch(on_model_found=callback)
|
||||
self._models_installed.clear()
|
||||
@ -364,7 +296,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||
model = self.record_store.get_model(key)
|
||||
models_dir = self.app_config.models_path
|
||||
model_path = models_dir / Path(model.path) # handle legacy relative model paths
|
||||
model_path = models_dir / model.path
|
||||
if model_path.is_relative_to(models_dir):
|
||||
self.unconditionally_delete(key)
|
||||
else:
|
||||
@ -372,11 +304,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
||||
model = self.record_store.get_model(key)
|
||||
model_path = self.app_config.models_path / model.path
|
||||
if model_path.is_dir():
|
||||
rmtree(model_path)
|
||||
path = self.app_config.models_path / model.path
|
||||
if path.is_dir():
|
||||
rmtree(path)
|
||||
else:
|
||||
model_path.unlink()
|
||||
path.unlink()
|
||||
self.unregister(key)
|
||||
|
||||
def download_and_cache(
|
||||
@ -387,7 +319,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) -> Path:
|
||||
"""Download the model file located at source to the models cache and return its Path."""
|
||||
model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32]
|
||||
model_path = self._app_config.convert_cache_path / model_hash
|
||||
model_path = self._app_config.models_convert_cache_path / model_hash
|
||||
|
||||
# We expect the cache directory to contain one and only one downloaded file.
|
||||
# We don't know the file's name in advance, as it is set by the download
|
||||
@ -428,6 +360,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job = self._install_queue.get(timeout=1)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
assert job.local_path is not None
|
||||
try:
|
||||
if job.cancelled:
|
||||
@ -437,22 +370,21 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._signal_job_errored(job)
|
||||
|
||||
elif (
|
||||
job.waiting or job.downloads_done
|
||||
job.waiting or job.downloading
|
||||
): # local jobs will be in waiting state, remote jobs will be downloading state
|
||||
job.total_bytes = self._stat_size(job.local_path)
|
||||
job.bytes = job.total_bytes
|
||||
self._signal_job_running(job)
|
||||
job.config_in["source"] = str(job.source)
|
||||
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
# enter the metadata, if there is any
|
||||
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
|
||||
job.config_in["source_api_response"] = job.source_metadata.api_response
|
||||
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
else:
|
||||
key = self.install_path(job.local_path, job.config_in)
|
||||
job.config_out = self.record_store.get_model(key)
|
||||
|
||||
# enter the metadata, if there is any
|
||||
if job.source_metadata:
|
||||
self._metadata_store.add_metadata(key, job.source_metadata)
|
||||
self._signal_job_completed(job)
|
||||
|
||||
except InvalidModelConfigException as excp:
|
||||
@ -477,6 +409,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._install_completed_event.set()
|
||||
self._install_queue.task_done()
|
||||
|
||||
self._logger.info("Install thread exiting")
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Internal functions that manage the models directory
|
||||
# --------------------------------------------------------------------------------------------
|
||||
@ -510,13 +444,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
|
||||
for cur_base_model in BaseModelType:
|
||||
for cur_model_type in ModelType:
|
||||
models_dir = self._app_config.models_path / Path(cur_base_model.value, cur_model_type.value)
|
||||
if not models_dir.exists():
|
||||
continue
|
||||
models_dir = Path(cur_base_model.value, cur_model_type.value)
|
||||
installed.update(self.scan_directory(models_dir))
|
||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||
|
||||
def _sync_model_path(self, key: str) -> AnyModelConfig:
|
||||
def _sync_model_path(self, key: str, ignore_hash_change: bool = False) -> AnyModelConfig:
|
||||
"""
|
||||
Move model into the location indicated by its basetype, type and name.
|
||||
|
||||
@ -531,21 +463,21 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
old_path = Path(model.path)
|
||||
models_dir = self.app_config.models_path
|
||||
|
||||
try:
|
||||
old_path.relative_to(models_dir)
|
||||
return model
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
new_path = models_dir / model.base.value / model.type.value / old_path.name
|
||||
|
||||
if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
|
||||
if not old_path.is_relative_to(models_dir):
|
||||
return model
|
||||
|
||||
new_path = models_dir / model.base.value / model.type.value / model.name
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
model.path = new_path.as_posix()
|
||||
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
|
||||
new_hash = FastModelHash.hash(new_path)
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
if model.current_hash != new_hash:
|
||||
assert (
|
||||
ignore_hash_change
|
||||
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
|
||||
model.current_hash = new_hash
|
||||
self._logger.info(f"Model has new hash {model.current_hash}, but will continue to be identified by {key}")
|
||||
self.record_store.update_model(key, model)
|
||||
return model
|
||||
|
||||
def _scan_register(self, model: Path) -> bool:
|
||||
@ -597,22 +529,36 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
move(old_path, new_path)
|
||||
return new_path
|
||||
|
||||
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
|
||||
if config: # used to override probe fields
|
||||
for key, value in config.items():
|
||||
setattr(info, key, value)
|
||||
return info
|
||||
|
||||
def _create_key(self) -> str:
|
||||
return sha256(randbytes(100)).hexdigest()[0:32]
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
config = config or {}
|
||||
key = self._create_key()
|
||||
if config and not config.get("key", None):
|
||||
config["key"] = key
|
||||
info = info or ModelProbe.probe(model_path, config)
|
||||
|
||||
info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm)
|
||||
|
||||
model_path = model_path.resolve()
|
||||
model_path = model_path.absolute()
|
||||
if model_path.is_relative_to(self.app_config.models_path):
|
||||
model_path = model_path.relative_to(self.app_config.models_path)
|
||||
|
||||
info.path = model_path.as_posix()
|
||||
|
||||
# Checkpoints have a config file needed for conversion - resolve this to an absolute path
|
||||
if isinstance(info, CheckpointConfigBase):
|
||||
legacy_conf = (self.app_config.legacy_conf_path / info.config_path).resolve()
|
||||
info.config_path = legacy_conf.as_posix()
|
||||
self.record_store.add_model(info)
|
||||
# add 'main' specific fields
|
||||
if hasattr(info, "config"):
|
||||
# make config relative to our root
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
|
||||
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
self.record_store.add_model(info.key, info)
|
||||
return info.key
|
||||
|
||||
def _next_id(self) -> int:
|
||||
@ -633,9 +579,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
local_path=Path(source.path),
|
||||
inplace=source.inplace or False,
|
||||
inplace=source.inplace,
|
||||
)
|
||||
|
||||
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
if not source.access_token:
|
||||
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
|
||||
metadata = CivitaiMetadataFetch(self._session).from_id(str(source.version_id))
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
|
||||
|
||||
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
# Add user's cached access token to HuggingFace requests
|
||||
source.access_token = source.access_token or HfFolder.get_token()
|
||||
@ -658,16 +612,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
)
|
||||
|
||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
# URLs from HuggingFace will be handled specially
|
||||
# URLs from Civitai or HuggingFace will be handled specially
|
||||
url_patterns = {
|
||||
r"^https?://civitai.com/": CivitaiMetadataFetch,
|
||||
r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch,
|
||||
}
|
||||
metadata = None
|
||||
fetcher = None
|
||||
try:
|
||||
fetcher = self.get_fetcher_from_url(str(source.url))
|
||||
except ValueError:
|
||||
pass
|
||||
kwargs: dict[str, Any] = {"session": self._session}
|
||||
if fetcher is not None:
|
||||
metadata = fetcher(**kwargs).from_url(source.url)
|
||||
for pattern, fetcher in url_patterns.items():
|
||||
if re.match(pattern, str(source.url), re.IGNORECASE):
|
||||
metadata = fetcher(self._session).from_url(source.url)
|
||||
break
|
||||
self._logger.debug(f"metadata={metadata}")
|
||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
@ -682,7 +636,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _import_remote_model(
|
||||
self,
|
||||
source: HFModelSource | URLModelSource,
|
||||
source: ModelSource,
|
||||
remote_files: List[RemoteModelFile],
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
@ -710,7 +664,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# In the event that there is a subfolder specified in the source,
|
||||
# we need to remove it from the destination path in order to avoid
|
||||
# creating unwanted subfolders
|
||||
if isinstance(source, HFModelSource) and source.subfolder:
|
||||
if hasattr(source, "subfolder") and source.subfolder:
|
||||
root = Path(remote_files[0].path.parts[0])
|
||||
subfolder = root / source.subfolder
|
||||
else:
|
||||
@ -792,14 +746,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"{download_job.source}: model download complete")
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
self._download_cache.pop(download_job.source, None)
|
||||
|
||||
# are there any more active jobs left in this task?
|
||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||
self._signal_job_downloads_done(install_job)
|
||||
self._put_in_queue(install_job)
|
||||
if all(x.complete for x in install_job.download_parts):
|
||||
# now enqueue job for actual installation into the models directory
|
||||
self._install_queue.put(install_job)
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._download_cache.pop(download_job.source, None)
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
@ -839,7 +793,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
if all(x.in_terminal_state for x in install_job.download_parts):
|
||||
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
||||
self._put_in_queue(install_job)
|
||||
self._install_queue.put(install_job)
|
||||
|
||||
# ------------------------------------------------------------------------------------------------
|
||||
# Internal methods that put events on the event bus
|
||||
@ -872,12 +826,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
id=job.id,
|
||||
)
|
||||
|
||||
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.DOWNLOADS_DONE
|
||||
self._logger.info(f"{job.source}: all parts of this model are downloaded")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_downloads_done(str(job.source))
|
||||
|
||||
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.COMPLETED
|
||||
assert job.config_out
|
||||
@ -902,10 +850,4 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
|
||||
self._logger.info(f"{job.source}: model installation was cancelled")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
|
||||
|
||||
@staticmethod
|
||||
def get_fetcher_from_url(url: str):
|
||||
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||
return HuggingFaceMetadataFetch
|
||||
raise ValueError(f"Unsupported model source: '{url}'")
|
||||
self._event_bus.emit_model_install_cancelled(str(job.source))
|
||||
|
@ -68,7 +68,6 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._emit_load_event(
|
||||
context_data=context_data,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
|
||||
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
|
||||
@ -83,7 +82,6 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._emit_load_event(
|
||||
context_data=context_data,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
loaded=True,
|
||||
)
|
||||
return loaded_model
|
||||
@ -93,7 +91,6 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
context_data: InvocationContextData,
|
||||
model_config: AnyModelConfig,
|
||||
loaded: Optional[bool] = False,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
if not self._invoker:
|
||||
return
|
||||
@ -105,7 +102,6 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
else:
|
||||
self._invoker.services.events.emit_model_load_completed(
|
||||
@ -114,5 +110,4 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
|
@ -1,11 +1,15 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
@ -66,3 +70,32 @@ class ModelManagerServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_config(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
@ -1,10 +1,14 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -14,7 +18,7 @@ from ..download import DownloadQueueServiceBase
|
||||
from ..events.events_base import EventServiceBase
|
||||
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from ..model_records import ModelRecordServiceBase
|
||||
from ..model_records import ModelRecordServiceBase, UnknownModelException
|
||||
from .model_manager_base import ModelManagerServiceBase
|
||||
|
||||
|
||||
@ -60,6 +64,56 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
if hasattr(service, "stop"):
|
||||
service.stop(invoker)
|
||||
|
||||
def load_model_by_config(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
return self.load.load_model(model_config, submodel_type, context_data)
|
||||
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
config = self.store.get_model(key)
|
||||
return self.load.load_model(config, submodel_type, context_data)
|
||||
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||
|
||||
This is provided for API compatability with the get_model() method
|
||||
in the original model manager. However, note that LoadedModel is
|
||||
not the same as the original ModelInfo that ws returned.
|
||||
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: The invocation context.
|
||||
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
ValueError -- more than one model matches this combination
|
||||
"""
|
||||
configs = self.store.search_by_attr(model_name, base_model, model_type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
|
||||
elif len(configs) > 1:
|
||||
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
|
||||
else:
|
||||
return self.load.load_model(configs[0], submodel, context_data)
|
||||
|
||||
@classmethod
|
||||
def build_model_manager(
|
||||
cls,
|
||||
@ -78,12 +132,14 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
logger.setLevel(app_config.log_level.upper())
|
||||
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=app_config.ram,
|
||||
max_vram_cache_size=app_config.vram,
|
||||
max_cache_size=app_config.ram_cache_size,
|
||||
max_vram_cache_size=app_config.vram_cache_size,
|
||||
logger=logger,
|
||||
execution_device=execution_device,
|
||||
)
|
||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size
|
||||
)
|
||||
loader = ModelLoadService(
|
||||
app_config=app_config,
|
||||
ram_cache=ram_cache,
|
||||
|
9
invokeai/app/services/model_metadata/__init__.py
Normal file
9
invokeai/app/services/model_metadata/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
"""Init file for ModelMetadataStoreService module."""
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
from .metadata_store_sql import ModelMetadataStoreSQL
|
||||
|
||||
__all__ = [
|
||||
"ModelMetadataStoreBase",
|
||||
"ModelMetadataStoreSQL",
|
||||
]
|
65
invokeai/app/services/model_metadata/metadata_store_base.py
Normal file
65
invokeai/app/services/model_metadata/metadata_store_base.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Storage for Model Metadata
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class ModelMetadataStoreBase(ABC):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
@abstractmethod
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
|
||||
@abstractmethod
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
222
invokeai/app/services/model_metadata/metadata_store_sql.py
Normal file
222
invokeai/app/services/model_metadata/metadata_store_sql.py
Normal file
@ -0,0 +1,222 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownMetadataException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id,metadata FROM model_metadata;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select tag_text from tags;
|
||||
"""
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.model_id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches if matches else set()
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
@ -1,5 +1,4 @@
|
||||
"""Init file for model record services."""
|
||||
|
||||
from .model_records_base import ( # noqa F401
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
|
@ -6,24 +6,20 @@ Abstract base class for storing and retrieving model configuration records.
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
ControlAdapterDefaultSettings,
|
||||
MainModelDefaultSettings,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
@ -64,34 +60,11 @@ class ModelSummary(BaseModel):
|
||||
tags: Set[str] = Field(description="tags associated with model")
|
||||
|
||||
|
||||
class ModelRecordChanges(BaseModelExcludeNull):
|
||||
"""A set of changes to apply to a model."""
|
||||
|
||||
# Changes applicable to all models
|
||||
name: Optional[str] = Field(description="Name of the model.", default=None)
|
||||
path: Optional[str] = Field(description="Path to the model.", default=None)
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
# Checkpoint-specific changes
|
||||
# TODO(MM2): Should we expose these? Feels footgun-y...
|
||||
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
|
||||
prediction_type: Optional[SchedulerPredictionType] = Field(
|
||||
description="The prediction type of the model.", default=None
|
||||
)
|
||||
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
|
||||
config_path: Optional[str] = Field(description="Path to config file for model", default=None)
|
||||
|
||||
|
||||
class ModelRecordServiceBase(ABC):
|
||||
"""Abstract base class for storage and retrieval of model configs."""
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -115,12 +88,13 @@ class ModelRecordServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
:param key: Unique key for the model to be updated.
|
||||
:param changes: A set of changes to apply to this model. Changes are validated before being written.
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -135,15 +109,38 @@ class ModelRecordServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
"""
|
||||
Retrieve the configuration for the indicated model.
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
pass
|
||||
|
||||
:param hash: Hash of model config to be fetched.
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
@abstractmethod
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||
|
||||
:param key: Model key
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -220,3 +217,21 @@ class ModelRecordServiceBase(ABC):
|
||||
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
||||
)
|
||||
return model_configs[0]
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
key: str,
|
||||
new_name: str,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Rename the indicated model. Just a special case of update_model().
|
||||
|
||||
In some implementations, renaming the model may involve changing where
|
||||
it is stored on the filesystem. So this is broken out.
|
||||
|
||||
:param key: Model key
|
||||
:param new_name: New name for model
|
||||
"""
|
||||
config = self.get_model(key)
|
||||
config.name = new_name
|
||||
return self.update_model(key, config)
|
||||
|
@ -39,11 +39,12 @@ Typical usage:
|
||||
configs = store.search_by_attr(base_model='sd-2', model_type='main')
|
||||
"""
|
||||
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@ -53,11 +54,12 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_records_base import (
|
||||
DuplicateModelException,
|
||||
ModelRecordChanges,
|
||||
ModelRecordOrderBy,
|
||||
ModelRecordServiceBase,
|
||||
ModelSummary,
|
||||
@ -68,7 +70,7 @@ from .model_records_base import (
|
||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
@ -77,13 +79,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = db.conn.cursor()
|
||||
self._metadata_store = metadata_store
|
||||
|
||||
@property
|
||||
def db(self) -> SqliteDatabase:
|
||||
"""Return the underlying database."""
|
||||
return self._db
|
||||
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -93,19 +96,23 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO models (
|
||||
INSERT INTO model_config (
|
||||
id,
|
||||
original_hash,
|
||||
config
|
||||
)
|
||||
VALUES (?,?);
|
||||
VALUES (?,?,?);
|
||||
""",
|
||||
(
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
key,
|
||||
record.original_hash,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._db.conn.commit()
|
||||
@ -113,12 +120,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
if "model_config.path" in str(e):
|
||||
msg = f"A model with path '{record.path}' is already installed"
|
||||
elif "model_config.name" in str(e):
|
||||
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
msg = f"A model with key '{key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
else:
|
||||
raise e
|
||||
@ -126,7 +133,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(config.key)
|
||||
return self.get_model(key)
|
||||
|
||||
def del_model(self, key: str) -> None:
|
||||
"""
|
||||
@ -140,7 +147,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM models
|
||||
DELETE FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -152,20 +159,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
record = self.get_model(key)
|
||||
|
||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
|
||||
json_serialized = record.model_dump_json()
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
UPDATE model_config
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
@ -192,7 +200,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -203,21 +211,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Return True if a model with the indicated key exists in the databse.
|
||||
@ -228,7 +221,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM models
|
||||
select count(*) FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -242,7 +235,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
model_format: Optional[ModelFormat] = None,
|
||||
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
|
||||
) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Return models matching name, base and/or type.
|
||||
@ -251,23 +243,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
:param base_model: Filter by base model (optional)
|
||||
:param model_type: Filter by type of model (optional)
|
||||
:param model_format: Filter by model format (e.g. "diffusers") (optional)
|
||||
:param order_by: Result order
|
||||
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
where_clause: list[str] = []
|
||||
bindings: list[str] = []
|
||||
results = []
|
||||
where_clause = []
|
||||
bindings = []
|
||||
if model_name:
|
||||
where_clause.append("name=?")
|
||||
bindings.append(model_name)
|
||||
@ -284,15 +266,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at)
|
||||
FROM models
|
||||
{where}
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
|
||||
select config, strftime('%s',updated_at) FROM model_config
|
||||
{where};
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
@ -301,7 +282,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
@ -312,13 +293,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated hash."""
|
||||
"""Return models with the indicated original_hash."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE original_hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
@ -327,35 +308,83 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
]
|
||||
return results
|
||||
|
||||
@property
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
return self._metadata_store
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||
|
||||
:param key: Model key
|
||||
"""
|
||||
store = self.metadata_store
|
||||
try:
|
||||
metadata = store.get_metadata(key)
|
||||
return metadata
|
||||
except UnknownMetadataException:
|
||||
return None
|
||||
|
||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
keys = store.search_by_tag(tags)
|
||||
return [self.get_model(x) for x in keys]
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_tags()
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_all_metadata()
|
||||
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Return a paginated summary listing of each model in the database."""
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name",
|
||||
ModelRecordOrderBy.Type: "a.type",
|
||||
ModelRecordOrderBy.Base: "a.base",
|
||||
ModelRecordOrderBy.Name: "a.name",
|
||||
ModelRecordOrderBy.Format: "a.format",
|
||||
}
|
||||
|
||||
def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]:
|
||||
"""Fix up results so that there are no null values."""
|
||||
result: Dict[str, Union[str, int, Set[str]]] = {}
|
||||
for key, item in summary.items():
|
||||
result[key] = item or ""
|
||||
result["tags"] = set(json.loads(summary["tags"] or "[]"))
|
||||
return result
|
||||
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
with self._db.lock:
|
||||
# query1: get the total number of model configs
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from models;
|
||||
select count(*) from model_config;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(self._cursor.fetchone()[0])
|
||||
|
||||
# query2: fetch key fields
|
||||
# query2: fetch key fields from the join of model_config and model_metadata
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config
|
||||
FROM models
|
||||
SELECT a.id as key, a.type, a.base, a.format, a.name,
|
||||
json_extract(a.config, '$.description') as description,
|
||||
json_extract(b.metadata, '$.tags') as tags
|
||||
FROM model_config AS a
|
||||
LEFT JOIN model_metadata AS b on a.id=b.id
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
@ -366,7 +395,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
||||
items = [ModelSummary.model_validate(_fixup(dict(x))) for x in rows]
|
||||
return PaginatedResults(
|
||||
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
|
||||
)
|
||||
|
@ -200,7 +200,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self._invoker.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
|
@ -151,7 +151,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
# TODO: how does this work in a multi-user scenario?
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
max_queue_size = self.__invoker.services.configuration.get_config().max_queue_size
|
||||
max_new_queue_items = max_queue_size - current_queue_size
|
||||
|
||||
priority = 0
|
||||
|
@ -1,7 +1,7 @@
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
@ -13,16 +13,15 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
"""
|
||||
@ -300,27 +299,22 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
|
||||
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
|
||||
def exists(self, key: str) -> bool:
|
||||
"""Checks if a model exists.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
key: The key of the model.
|
||||
|
||||
Returns:
|
||||
True if the model exists, False if not.
|
||||
"""
|
||||
if isinstance(identifier, str):
|
||||
return self._services.model_manager.store.exists(identifier)
|
||||
return self._services.model_manager.store.exists(key)
|
||||
|
||||
return self._services.model_manager.store.exists(identifier.key)
|
||||
|
||||
def load(
|
||||
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""Loads a model.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
key: The key of the model.
|
||||
submodel_type: The submodel of the model to get.
|
||||
|
||||
Returns:
|
||||
@ -330,13 +324,9 @@ class ModelsInterface(InvocationContextInterface):
|
||||
# The model manager emits events as it loads the model. It needs the context data to build
|
||||
# the event payloads.
|
||||
|
||||
if isinstance(identifier, str):
|
||||
model = self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
|
||||
else:
|
||||
_submodel_type = submodel_type or identifier.submodel_type
|
||||
model = self._services.model_manager.store.get_model(identifier.key)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
|
||||
return self._services.model_manager.load_model_by_key(
|
||||
key=key, submodel_type=submodel_type, context_data=self._data
|
||||
)
|
||||
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
@ -353,29 +343,35 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
An object representing the loaded model.
|
||||
"""
|
||||
return self._services.model_manager.load_model_by_attr(
|
||||
model_name=name,
|
||||
base_model=base,
|
||||
model_type=type,
|
||||
submodel=submodel_type,
|
||||
context_data=self._data,
|
||||
)
|
||||
|
||||
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
if len(configs) > 1:
|
||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||
def get_config(self, key: str) -> AnyModelConfig:
|
||||
"""Gets a model's config.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
key: The key of the model.
|
||||
|
||||
Returns:
|
||||
The model's config.
|
||||
"""
|
||||
if isinstance(identifier, str):
|
||||
return self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.store.get_model(key=key)
|
||||
|
||||
return self._services.model_manager.store.get_model(identifier.key)
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Gets a model's metadata, if it has any.
|
||||
|
||||
Args:
|
||||
key: The key of the model.
|
||||
|
||||
Returns:
|
||||
The model's metadata, if it has any.
|
||||
"""
|
||||
return self._services.model_manager.store.get_metadata(key=key)
|
||||
|
||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||
"""Searches for models by path.
|
||||
@ -423,7 +419,7 @@ class ConfigInterface(InvocationContextInterface):
|
||||
The app's config.
|
||||
"""
|
||||
|
||||
return self._services.configuration
|
||||
return self._services.configuration.get_config()
|
||||
|
||||
|
||||
class UtilInterface(InvocationContextInterface):
|
||||
|
@ -9,7 +9,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@ -36,7 +35,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_4())
|
||||
migrator.register_migration(build_migration_5())
|
||||
migrator.register_migration(build_migration_6())
|
||||
migrator.register_migration(build_migration_7())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
@ -4,6 +4,8 @@ from logging import Logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
|
||||
|
||||
|
||||
class Migration3Callback:
|
||||
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
|
||||
@ -13,6 +15,7 @@ class Migration3Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._drop_model_manager_metadata(cursor)
|
||||
self._recreate_model_config(cursor)
|
||||
self._migrate_model_config_records(cursor)
|
||||
|
||||
def _drop_model_manager_metadata(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the `model_manager_metadata` table."""
|
||||
@ -52,6 +55,12 @@ class Migration3Callback:
|
||||
"""
|
||||
)
|
||||
|
||||
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""After updating the model config table, we repopulate it."""
|
||||
self._logger.info("Migrating model config records from models.yaml to database")
|
||||
model_record_migrator = MigrateModelYamlToDb1(self._app_config, self._logger, cursor)
|
||||
model_record_migrator.migrate()
|
||||
|
||||
|
||||
def build_migration_3(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
||||
"""
|
||||
|
@ -1,88 +0,0 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration7Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._create_models_table(cursor)
|
||||
self._drop_old_models_tables(cursor)
|
||||
|
||||
def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the old model_records, model_metadata, model_tags and tags tables."""
|
||||
|
||||
tables = ["model_records", "model_metadata", "model_tags", "tags"]
|
||||
|
||||
for table in tables:
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {table};")
|
||||
|
||||
def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the v4.0.0 models table."""
|
||||
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS models (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
hash TEXT GENERATED ALWAYS as (json_extract(config, '$.hash')) VIRTUAL NOT NULL,
|
||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||
description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL,
|
||||
source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL,
|
||||
source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL,
|
||||
source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL,
|
||||
trigger_phrases TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_phrases')) VIRTUAL,
|
||||
-- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS models_updated_at
|
||||
AFTER UPDATE
|
||||
ON models FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add indexes for searchable fields
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS base_index ON models(base);",
|
||||
"CREATE INDEX IF NOT EXISTS type_index ON models(type);",
|
||||
"CREATE INDEX IF NOT EXISTS name_index ON models(name);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON models(path);",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def build_migration_7() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 6 to 7.
|
||||
|
||||
This migration does the following:
|
||||
- Adds the new models table
|
||||
- Drops the old model_records, model_metadata, model_tags and tags tables.
|
||||
- TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?).
|
||||
"""
|
||||
migration_7 = Migration(
|
||||
from_version=6,
|
||||
to_version=7,
|
||||
callback=Migration7Callback(),
|
||||
)
|
||||
|
||||
return migration_7
|
@ -0,0 +1,151 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from hashlib import sha1
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigFactory,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
|
||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
|
||||
class MigrateModelYamlToDb1:
|
||||
"""
|
||||
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.5.0).
|
||||
|
||||
The class has one externally useful method, migrate(), which scans the
|
||||
currently models.yaml file and imports all its entries into invokeai.db.
|
||||
|
||||
Use this way:
|
||||
|
||||
from invokeai.backend.model_manager/migrate_to_db import MigrateModelYamlToDb
|
||||
MigrateModelYamlToDb().migrate()
|
||||
|
||||
"""
|
||||
|
||||
config: InvokeAIAppConfig
|
||||
logger: Logger
|
||||
cursor: sqlite3.Cursor
|
||||
|
||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger, cursor: sqlite3.Cursor = None) -> None:
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
self.cursor = cursor
|
||||
|
||||
def get_yaml(self) -> DictConfig:
|
||||
"""Fetch the models.yaml DictConfig for this installation."""
|
||||
yaml_path = self.config.model_conf_path
|
||||
omegaconf = OmegaConf.load(yaml_path)
|
||||
assert isinstance(omegaconf, DictConfig)
|
||||
return omegaconf
|
||||
|
||||
def migrate(self) -> None:
|
||||
"""Do the migration from models.yaml to invokeai.db."""
|
||||
try:
|
||||
yaml = self.get_yaml()
|
||||
except OSError:
|
||||
return
|
||||
|
||||
for model_key, stanza in yaml.items():
|
||||
if model_key == "__metadata__":
|
||||
assert (
|
||||
stanza["version"] == "3.0.0"
|
||||
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
|
||||
continue
|
||||
|
||||
base_type, model_type, model_name = str(model_key).split("/")
|
||||
try:
|
||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||
except OSError:
|
||||
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
||||
continue
|
||||
|
||||
assert isinstance(model_key, str)
|
||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||
|
||||
stanza["base"] = BaseModelType(base_type)
|
||||
stanza["type"] = ModelType(model_type)
|
||||
stanza["name"] = model_name
|
||||
stanza["original_hash"] = hash
|
||||
stanza["current_hash"] = hash
|
||||
|
||||
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||
|
||||
try:
|
||||
if original_record := self._search_by_path(stanza.path):
|
||||
key = original_record.key
|
||||
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||
self._update_model(key, new_config)
|
||||
else:
|
||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
||||
self._add_model(new_key, new_config)
|
||||
except DuplicateModelException:
|
||||
self.logger.warning(f"Model {model_name} is already in the database")
|
||||
except UnknownModelException:
|
||||
self.logger.warning(f"Model at {stanza.path} could not be found in database")
|
||||
|
||||
def _search_by_path(self, path: Path) -> Optional[AnyModelConfig]:
|
||||
self.cursor.execute(
|
||||
"""--sql
|
||||
SELECT config FROM model_config
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self.cursor.fetchall()]
|
||||
return results[0] if results else None
|
||||
|
||||
def _update_model(self, key: str, config: AnyModelConfig) -> None:
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
self.cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_config
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, key),
|
||||
)
|
||||
if self.cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
|
||||
def _add_model(self, key: str, config: AnyModelConfig) -> None:
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
try:
|
||||
self.cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_config (
|
||||
id,
|
||||
original_hash,
|
||||
config
|
||||
)
|
||||
VALUES (?,?,?);
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.original_hash,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
|
@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
|
||||
See :class:`Migration` for an example.
|
||||
"""
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
...
|
||||
|
||||
|
||||
class MigrationError(RuntimeError):
|
||||
|
@ -8,8 +8,3 @@ class UrlServiceBase(ABC):
|
||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets the URL for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
"""Gets the URL for a model image"""
|
||||
pass
|
||||
|
@ -4,9 +4,8 @@ from .urls_base import UrlServiceBase
|
||||
|
||||
|
||||
class LocalUrlService(UrlServiceBase):
|
||||
def __init__(self, base_url: str = "api/v1", base_url_v2: str = "api/v2"):
|
||||
def __init__(self, base_url: str = "api/v1"):
|
||||
self._base_url = base_url
|
||||
self._base_url_v2 = base_url_v2
|
||||
|
||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
image_basename = os.path.basename(image_name)
|
||||
@ -16,6 +15,3 @@ class LocalUrlService(UrlServiceBase):
|
||||
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
|
||||
|
||||
return f"{self._base_url}/images/i/{image_basename}/full"
|
||||
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
return f"{self._base_url_v2}/models/i/{model_key}/image"
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user