mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
546 Commits
next-test-
...
speed-up-h
Author | SHA1 | Date | |
---|---|---|---|
cb99a5f0e4 | |||
3bb50926af | |||
bb3f1b9ca6 | |||
119d26e102 | |||
e1c16c33a4 | |||
ad70cdfe87 | |||
549d461107 | |||
cab3748010 | |||
779b3e0e8e | |||
9b48029bc9 | |||
347f1fd0b7 | |||
4af5a09a68 | |||
8df02623f2 | |||
aa88fadc30 | |||
8411029d93 | |||
239b1e8cc7 | |||
8a68355926 | |||
86aef9f31d | |||
2f6964bfa5 | |||
c1cdfd132b | |||
f6bfe5e6f2 | |||
b5a8455b5f | |||
645ef081ea | |||
e68d7fa6d7 | |||
c5ab1c7ad6 | |||
5a561cab78 | |||
132790eebe | |||
c57f6ee885 | |||
d4a2ea68fc | |||
528ac5dd25 | |||
afd9ae7712 | |||
4eefed12f0 | |||
4301a3d6fd | |||
99c0662e3f | |||
cdc0d0c182 | |||
a00369a67a | |||
4f096ac3ba | |||
f5e3341465 | |||
474852ef7e | |||
b1d72d411e | |||
46614ee28f | |||
b019f9bb8b | |||
b857692073 | |||
90fb7a1a59 | |||
56fcf6af78 | |||
c4fe7e697b | |||
2fd483dfc8 | |||
b9a9507422 | |||
f2744fd7d1 | |||
fe6e879d38 | |||
d3ab08fe10 | |||
b0615bdfd4 | |||
bab20467fb | |||
e24624109e | |||
458e7185b8 | |||
a95128f5f2 | |||
46f32c5e3c | |||
e30cb4b52f | |||
ba1f6bf926 | |||
4a9cca6c2d | |||
b0275700b3 | |||
8319aca5f9 | |||
51a604f907 | |||
7515d73628 | |||
2c453aa531 | |||
2cca6e4c76 | |||
ef171e890a | |||
caafbf2f0d | |||
2db5eaf907 | |||
f234bf6256 | |||
cfa78b4052 | |||
ba1dd4b02b | |||
bcf58cac59 | |||
e866d90ab2 | |||
e8797787cf | |||
0082ecb22b | |||
656839fcd1 | |||
99407c899f | |||
48119d9010 | |||
7c9128b253 | |||
4f9bb00275 | |||
78895b3e80 | |||
3030a34b88 | |||
58fa9c2fac | |||
a8b6635050 | |||
6829610a71 | |||
5551cf8ac4 | |||
37b969d339 | |||
c953e61294 | |||
93dd3c848e | |||
02bde7bb75 | |||
3391c19926 | |||
0f60b1ced4 | |||
44c40d7d1a | |||
0b9a212363 | |||
c3aa985c93 | |||
7cb0da1f66 | |||
3534366146 | |||
f2b5f8753f | |||
f13f5984c0 | |||
94e1e64296 | |||
2411bf53c0 | |||
9378e47a06 | |||
4471ea8ad1 | |||
2c835fd550 | |||
61b737bb9f | |||
a8cd3dfc99 | |||
0cce582f2f | |||
4347d1c7f7 | |||
bd4fd9693d | |||
9b40c28144 | |||
16a5d718bf | |||
76cbc745e1 | |||
0a614943f6 | |||
e426096d32 | |||
c561cd751f | |||
af9298f0ef | |||
5b74117836 | |||
38474c9797 | |||
b880a31039 | |||
dd31bc4586 | |||
316573df2d | |||
8b34f5298c | |||
893bcd16fc | |||
f6028a4c61 | |||
264aee3ffa | |||
4deb60f365 | |||
f2d5fb176f | |||
94005b5501 | |||
02dc1a8780 | |||
ef958568ac | |||
48e323d887 | |||
735857479d | |||
2f372d9b18 | |||
554d175792 | |||
ae99428883 | |||
863ce00712 | |||
86982f3059 | |||
ec8ed530a7 | |||
982076d7d7 | |||
2e4672f931 | |||
908e915a71 | |||
a72056e0df | |||
d8d7ddf43a | |||
cc45007dc4 | |||
73bec56c59 | |||
f8b54930f0 | |||
51cc9f9466 | |||
d2ad465e96 | |||
09037b7cd4 | |||
b2a850b5ea | |||
3ba5c2b0b4 | |||
06fc6ccfe5 | |||
0c6b0cfdab | |||
eef3373799 | |||
6935830f99 | |||
7651eeea8d | |||
204e7d383b | |||
9bc4e7a593 | |||
ad96857e0f | |||
8fb297e5f6 | |||
0832e1818e | |||
26d4d93e64 | |||
77f39aa935 | |||
6aae940834 | |||
be8dcad1da | |||
5f2e493244 | |||
c60c9825cb | |||
6f368395df | |||
ea4d071503 | |||
b95e5d0730 | |||
99ee8f9099 | |||
50e58ff323 | |||
b5c12985e7 | |||
a865277667 | |||
b2b65a9012 | |||
9fe579dd99 | |||
a0313ba634 | |||
3a2afe1d15 | |||
813a086cfe | |||
e18533e3b5 | |||
dd9daf8efb | |||
ad86b29798 | |||
8b03af391a | |||
bbbd18f119 | |||
c074beff7c | |||
0b07e2aad4 | |||
753919c6d7 | |||
2f26768d19 | |||
ae19971f65 | |||
e364ce1d4e | |||
0b0128647b | |||
39725e9560 | |||
0305e90287 | |||
ae34bcfbc0 | |||
01898d766f | |||
e7afae0159 | |||
f16e64084b | |||
8992d89817 | |||
0fc2f90824 | |||
c670dacc29 | |||
f475b78734 | |||
ca9b815c89 | |||
8efd4284e9 | |||
5922cee541 | |||
94e3857110 | |||
4b4b940461 | |||
574d6538b9 | |||
3141c6efd5 | |||
9cf2897064 | |||
bcf742ef87 | |||
f6c068afdd | |||
7d2e840590 | |||
f0b3485ce9 | |||
37608cdea2 | |||
aafa464707 | |||
1176c549c0 | |||
d90210fea6 | |||
d99bec8b1a | |||
b661d93bd8 | |||
dc64089c9d | |||
a6f6fe581e | |||
12e859835b | |||
b218282149 | |||
80065858ed | |||
aaeef03593 | |||
97ecd99b9c | |||
202e739404 | |||
10d36b4045 | |||
8f93ae8d7c | |||
506fa55f18 | |||
4c19d5cee4 | |||
afa7043dcd | |||
32b8478974 | |||
d23f2de9d7 | |||
9abfb02bf0 | |||
7b4ef5926d | |||
6c5be9e89c | |||
80697a71de | |||
a253047d8e | |||
7176c5d9d6 | |||
0b54bfb7c5 | |||
24daacecf2 | |||
7326c78ab5 | |||
04545e792c | |||
e6de915c34 | |||
71ceab9094 | |||
ff00ed8e80 | |||
ce3f9037cd | |||
d1f4cde8c7 | |||
90327cb521 | |||
4d5458648b | |||
8d8f1abd50 | |||
e20a506e40 | |||
77b8eed51b | |||
c954cd4c8d | |||
630d3615ca | |||
c80c0f0fb9 | |||
37d66488c5 | |||
371e3cc260 | |||
d22738723d | |||
fbd9ffdc5a | |||
04c060a89d | |||
6f591b324b | |||
82249cc634 | |||
cc82ce820a | |||
8e1fbd6ed1 | |||
68d79c002d | |||
8f6c2a8b92 | |||
ea7b7bcf40 | |||
1456c997fb | |||
7fce234646 | |||
9e02384674 | |||
531d6f40f4 | |||
98d60e7db5 | |||
1436a5f295 | |||
e22c4987bf | |||
4420392241 | |||
1d410e6346 | |||
c98668e7f5 | |||
740dbc0c32 | |||
97181d159f | |||
65b0d3d436 | |||
baf1194cae | |||
9b1f63379a | |||
c3f4e87a6e | |||
26a209a00d | |||
625c86ba9a | |||
53f0090197 | |||
5496699d6c | |||
b5ce28e60b | |||
816fb53a14 | |||
793c7ec832 | |||
62c67d7c4b | |||
7c41b3439a | |||
cdd2f18bbd | |||
e7d7b37896 | |||
57a402053e | |||
9ae09e9a7c | |||
5a12886dbb | |||
5b7633f3c6 | |||
68f24d9f0d | |||
ea364bdf82 | |||
18904f79ef | |||
782d15af13 | |||
86e2b39f0d | |||
20576deae8 | |||
0a69779df9 | |||
6b68971f38 | |||
c46eb72d45 | |||
87ce74e05d | |||
c7d462b222 | |||
9068400433 | |||
55f3c6e721 | |||
c778ab8db4 | |||
65b91356d0 | |||
de9287a3e4 | |||
008716040b | |||
abc569c2dd | |||
3ed2963f43 | |||
79b16596b5 | |||
239ecfaf79 | |||
0d9fbe5e04 | |||
cc41e8912c | |||
1cec0bb179 | |||
65dd4f4abc | |||
5bb3aeaccd | |||
30a374a70f | |||
07dde92664 | |||
06cc57d82a | |||
f7fc20459a | |||
9269bdd233 | |||
97cfcd2eef | |||
571a86a965 | |||
dbd929df05 | |||
b59d23d608 | |||
9d9b417432 | |||
34f3a39cc9 | |||
e3c23baae9 | |||
6a923cce70 | |||
c0f0f2f39e | |||
64908eda55 | |||
a37b60db13 | |||
9e296f6916 | |||
ab94484c6c | |||
5cba55d670 | |||
cbb997e7d0 | |||
98441ad08d | |||
80c67dd6e0 | |||
38af234108 | |||
2291122c2b | |||
bf3b10cb1c | |||
7f8f182a00 | |||
e51867756a | |||
a8d7cf4e97 | |||
037cac8154 | |||
0ab9fe6987 | |||
b5a9ed351d | |||
5f4b406cfe | |||
f15aa562c2 | |||
d0f3571e59 | |||
b5ca1643a6 | |||
39c01a833d | |||
79eb871683 | |||
7544b350f3 | |||
284ba041bd | |||
7d91426d8f | |||
db812133e7 | |||
795fbf0e81 | |||
7114d64b86 | |||
c43ea9f25c | |||
52b0deb179 | |||
7ecc18938b | |||
56d2d220a8 | |||
f1967c3393 | |||
812e24cbd2 | |||
8afe328af0 | |||
e771c5f467 | |||
e7e3045a8a | |||
f870f810d5 | |||
a793103d7a | |||
7e5a85496e | |||
ca7e928710 | |||
5b133ad198 | |||
89fa36a818 | |||
e3f9da29ba | |||
763debdeeb | |||
8bf9fd34ad | |||
0b0cb0ccc6 | |||
fa39523b11 | |||
16676feea8 | |||
0788a27a80 | |||
d53a2a2d4e | |||
ccfe6b6bef | |||
fdac0c3c9b | |||
18adcc1dd2 | |||
86c50f2d5b | |||
3cfac8b843 | |||
0788b6ecee | |||
317d076a1a | |||
725c03cf87 | |||
da9991e361 | |||
67daa127e3 | |||
7e71effa17 | |||
e93bd15392 | |||
0b81703c9f | |||
641d235102 | |||
b79ae3a101 | |||
731860c332 | |||
af2117dc0c | |||
1242cb4f85 | |||
cd070d8be9 | |||
56ac2104e3 | |||
965867151b | |||
2d007ce532 | |||
92394ab751 | |||
43d94c8108 | |||
fc20822595 | |||
5a3195f757 | |||
5d612ec095 | |||
996eb96b4e | |||
f1597bd6da | |||
e50b76571a | |||
db363b5178 | |||
dab939f7d1 | |||
6df3c450e8 | |||
b7ba65fef4 | |||
fc107ed711 | |||
cb804e75ed | |||
7996d43af9 | |||
fab30b5a11 | |||
651ac56b2c | |||
68f53460f0 | |||
c80987eb8a | |||
539570cc7a | |||
88d6de4101 | |||
4c6e34b216 | |||
262cbaacdd | |||
35e8a33dfd | |||
b0835db47d | |||
3e330d7d9d | |||
ff6e94f828 | |||
a2cc4047f9 | |||
4027e845d4 | |||
a23dedd2ee | |||
7956602b19 | |||
8db01ab1b3 | |||
db340bc253 | |||
78ef946e01 | |||
5745ce9c7d | |||
0d3addc69b | |||
67eb715093 | |||
8ba5360269 | |||
b8e875bb73 | |||
010c4eae65 | |||
95453a22b1 | |||
30db708c4f | |||
fe27af461a | |||
f8525837b2 | |||
5fbfed30ac | |||
7a2159beeb | |||
25f64d5b19 | |||
b845e890d1 | |||
6d31bc5326 | |||
0f8af643d1 | |||
e0694a2856 | |||
e5d8921cf2 | |||
fece935438 | |||
11f64dab38 | |||
670f2f75e9 | |||
66d0ec3f6c | |||
6087ace4f1 | |||
a9b1aad3d7 | |||
9edb995647 | |||
091f4cb583 | |||
1655061c96 | |||
220baae793 | |||
e08f16763b | |||
6d25789705 | |||
aff44c0e58 | |||
34d23366f4 | |||
23de78ec9f | |||
507aeac8a5 | |||
9f382419dc | |||
73d871116c | |||
ab58d34f9b | |||
9cda62c2a7 | |||
a50c7c1cd7 | |||
ca09bd63a3 | |||
c96f50cc9a | |||
de63e888d6 | |||
5dd158a2d4 | |||
0710fb3fb0 | |||
31db62ba99 | |||
322a60f48f | |||
b386b1b8af | |||
70034d26e2 | |||
d60f1965d1 | |||
7fbdfbf9e5 | |||
e137071543 | |||
5d2f70b3ef | |||
47d05fdd81 | |||
958b80acdd | |||
5730ae9b96 | |||
60e2eff94d | |||
dcafbb9988 | |||
cc8d713c57 | |||
59c77832d8 | |||
cbf22d8a80 | |||
e11af7de9b | |||
95dd5aad16 | |||
4ce21087d3 | |||
281c334531 | |||
282b483d14 | |||
a466f7a94b | |||
05fb485d33 | |||
6452c706e1 | |||
f612a96afd | |||
9af0553652 | |||
1616974b48 | |||
ef27283569 | |||
a79a450e9d | |||
8637c40661 | |||
9bc2d09889 | |||
3d98446d5d | |||
992b02aa65 | |||
63ab5ff5a2 | |||
9a8a9c5848 | |||
1a3ffb6e94 | |||
3a09bceea4 | |||
2ec6b51d8b | |||
34b0ea20dc | |||
9986fce1a6 | |||
228f1d7f62 | |||
01a6378dc1 | |||
e01769294f | |||
16aa261e28 | |||
1dabf18d14 | |||
115d92b1ae | |||
f0d4c71960 | |||
3e48edda6f | |||
716b584f03 | |||
d43b843c23 | |||
f36b5990ed | |||
5706237ec7 | |||
163b22a7b3 |
33
.github/actions/install-frontend-deps/action.yml
vendored
Normal file
33
.github/actions/install-frontend-deps/action.yml
vendored
Normal file
@ -0,0 +1,33 @@
|
||||
name: install frontend dependencies
|
||||
description: Installs frontend dependencies with pnpm, with caching
|
||||
runs:
|
||||
using: 'composite'
|
||||
steps:
|
||||
- name: setup node 18
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
|
||||
- name: setup pnpm
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: 8
|
||||
run_install: false
|
||||
|
||||
- name: get pnpm store directory
|
||||
shell: bash
|
||||
run: |
|
||||
echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV
|
||||
|
||||
- name: setup cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ env.STORE_PATH }}
|
||||
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pnpm-store-
|
||||
|
||||
- name: install frontend dependencies
|
||||
run: pnpm install --prefer-frozen-lockfile
|
||||
shell: bash
|
||||
working-directory: invokeai/frontend/web
|
28
.github/pr_labels.yml
vendored
28
.github/pr_labels.yml
vendored
@ -1,59 +1,59 @@
|
||||
Root:
|
||||
root:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: '*'
|
||||
|
||||
PythonDeps:
|
||||
python-deps:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'pyproject.toml'
|
||||
|
||||
Python:
|
||||
python:
|
||||
- changed-files:
|
||||
- all-globs-to-any-file:
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
|
||||
PythonTests:
|
||||
python-tests:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'tests/**'
|
||||
|
||||
CICD:
|
||||
ci-cd:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: .github/**
|
||||
|
||||
Docker:
|
||||
docker:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: docker/**
|
||||
|
||||
Installer:
|
||||
installer:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: installer/**
|
||||
|
||||
Documentation:
|
||||
docs:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: docs/**
|
||||
|
||||
Invocations:
|
||||
invocations:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/app/invocations/**'
|
||||
|
||||
Backend:
|
||||
backend:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/backend/**'
|
||||
|
||||
Api:
|
||||
api:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/app/api/**'
|
||||
|
||||
Services:
|
||||
services:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/app/services/**'
|
||||
|
||||
FrontendDeps:
|
||||
frontend-deps:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- '**/*/package.json'
|
||||
- '**/*/pnpm-lock.yaml'
|
||||
|
||||
Frontend:
|
||||
frontend:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/frontend/web/**'
|
||||
|
2
.github/workflows/build-container.yml
vendored
2
.github/workflows/build-container.yml
vendored
@ -11,7 +11,7 @@ on:
|
||||
- 'docker/docker-entrypoint.sh'
|
||||
- 'workflows/build-container.yml'
|
||||
tags:
|
||||
- 'v*'
|
||||
- 'v*.*.*'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
|
45
.github/workflows/build-installer.yml
vendored
Normal file
45
.github/workflows/build-installer.yml
vendored
Normal file
@ -0,0 +1,45 @@
|
||||
# Builds and uploads the installer and python build artifacts.
|
||||
|
||||
name: build installer
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
jobs:
|
||||
build-installer:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <2 min
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: setup python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install pypa/build
|
||||
run: pip install --upgrade build
|
||||
|
||||
- name: setup frontend
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: create installer
|
||||
id: create_installer
|
||||
run: ./create_installer.sh
|
||||
working-directory: installer
|
||||
|
||||
- name: upload python distribution artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: ${{ steps.create_installer.outputs.DIST_PATH }}
|
||||
|
||||
- name: upload installer artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ steps.create_installer.outputs.INSTALLER_FILENAME }}
|
||||
path: ${{ steps.create_installer.outputs.INSTALLER_PATH }}
|
68
.github/workflows/frontend-checks.yml
vendored
Normal file
68
.github/workflows/frontend-checks.yml
vendored
Normal file
@ -0,0 +1,68 @@
|
||||
# Runs frontend code quality checks.
|
||||
#
|
||||
# Checks for changes to frontend files before running the checks.
|
||||
# When manually triggered or when called from another workflow, always runs the checks.
|
||||
|
||||
name: 'frontend checks'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
jobs:
|
||||
frontend-checks:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10 # expected run time: <2 min
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed frontend files
|
||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
frontend:
|
||||
- 'invokeai/frontend/web/**'
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: tsc
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm lint:tsc'
|
||||
shell: bash
|
||||
|
||||
- name: dpdm
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm lint:dpdm'
|
||||
shell: bash
|
||||
|
||||
- name: eslint
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm lint:eslint'
|
||||
shell: bash
|
||||
|
||||
- name: prettier
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm lint:prettier'
|
||||
shell: bash
|
||||
|
||||
- name: knip
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm lint:knip'
|
||||
shell: bash
|
48
.github/workflows/frontend-tests.yml
vendored
Normal file
48
.github/workflows/frontend-tests.yml
vendored
Normal file
@ -0,0 +1,48 @@
|
||||
# Runs frontend tests.
|
||||
#
|
||||
# Checks for changes to frontend files before running the tests.
|
||||
# When manually triggered or called from another workflow, always runs the tests.
|
||||
|
||||
name: 'frontend tests'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
jobs:
|
||||
frontend-tests:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10 # expected run time: <2 min
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed frontend files
|
||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
frontend:
|
||||
- 'invokeai/frontend/web/**'
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
uses: ./.github/actions/install-frontend-deps
|
||||
|
||||
- name: vitest
|
||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: 'pnpm test:no-watch'
|
||||
shell: bash
|
10
.github/workflows/label-pr.yml
vendored
10
.github/workflows/label-pr.yml
vendored
@ -1,6 +1,6 @@
|
||||
name: "Pull Request Labeler"
|
||||
name: 'label PRs'
|
||||
on:
|
||||
- pull_request_target
|
||||
- pull_request_target
|
||||
|
||||
jobs:
|
||||
labeler:
|
||||
@ -9,8 +9,10 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
- uses: actions/labeler@v5
|
||||
|
||||
- name: label PRs
|
||||
uses: actions/labeler@v5
|
||||
with:
|
||||
configuration-path: .github/pr_labels.yml
|
45
.github/workflows/lint-frontend.yml
vendored
45
.github/workflows/lint-frontend.yml
vendored
@ -1,45 +0,0 @@
|
||||
name: Lint frontend
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
jobs:
|
||||
lint-frontend:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Setup Node 18
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: '8.12.1'
|
||||
- name: Install dependencies
|
||||
run: 'pnpm install --prefer-frozen-lockfile'
|
||||
- name: Typescript
|
||||
run: 'pnpm run lint:tsc'
|
||||
- name: Madge
|
||||
run: 'pnpm run lint:dpdm'
|
||||
- name: ESLint
|
||||
run: 'pnpm run lint:eslint'
|
||||
- name: Prettier
|
||||
run: 'pnpm run lint:prettier'
|
||||
- name: Knip
|
||||
run: 'pnpm run lint:knip'
|
52
.github/workflows/mkdocs-material.yml
vendored
52
.github/workflows/mkdocs-material.yml
vendored
@ -1,51 +1,49 @@
|
||||
name: mkdocs-material
|
||||
# This is a mostly a copy-paste from https://github.com/squidfunk/mkdocs-material/blob/master/docs/publishing-your-site.md
|
||||
|
||||
name: mkdocs
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'refs/heads/main'
|
||||
- main
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
mkdocs-material:
|
||||
deploy:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
|
||||
REPO_NAME: '${{ github.repository }}'
|
||||
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
|
||||
|
||||
steps:
|
||||
- name: checkout sources
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: setup python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install requirements
|
||||
env:
|
||||
PIP_USE_PEP517: 1
|
||||
run: |
|
||||
python -m \
|
||||
pip install ".[docs]"
|
||||
- name: set cache id
|
||||
run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
|
||||
|
||||
- name: confirm buildability
|
||||
run: |
|
||||
python -m \
|
||||
mkdocs build \
|
||||
--clean \
|
||||
--verbose
|
||||
- name: use cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
key: mkdocs-material-${{ env.cache_id }}
|
||||
path: .cache
|
||||
restore-keys: |
|
||||
mkdocs-material-
|
||||
|
||||
- name: deploy to gh-pages
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
run: |
|
||||
python -m \
|
||||
mkdocs gh-deploy \
|
||||
--clean \
|
||||
--force
|
||||
- name: install dependencies
|
||||
run: python -m pip install ".[docs]"
|
||||
|
||||
- name: build & deploy
|
||||
run: mkdocs gh-deploy --force
|
||||
|
67
.github/workflows/pypi-release.yml
vendored
67
.github/workflows/pypi-release.yml
vendored
@ -1,67 +0,0 @@
|
||||
name: PyPI Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
publish_package:
|
||||
description: 'Publish build on PyPi? [true/false]'
|
||||
required: true
|
||||
default: 'false'
|
||||
|
||||
jobs:
|
||||
build-and-release:
|
||||
if: github.repository == 'invoke-ai/InvokeAI'
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||
TWINE_NON_INTERACTIVE: 1
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node 18
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: '8.12.1'
|
||||
|
||||
- name: Install frontend dependencies
|
||||
run: pnpm install --prefer-frozen-lockfile
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
- name: Build frontend
|
||||
run: pnpm run build
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
- name: Install python dependencies
|
||||
run: pip install --upgrade build twine
|
||||
|
||||
- name: Build python package
|
||||
run: python3 -m build
|
||||
|
||||
- name: Upload build as workflow artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: dist
|
||||
|
||||
- name: Check distribution
|
||||
run: twine check dist/*
|
||||
|
||||
- name: Check PyPI versions
|
||||
if: github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')
|
||||
run: |
|
||||
pip install --upgrade requests
|
||||
python -c "\
|
||||
import scripts.pypi_helper; \
|
||||
EXISTS=scripts.pypi_helper.local_on_pypi(); \
|
||||
print(f'PACKAGE_EXISTS={EXISTS}')" >> $GITHUB_ENV
|
||||
|
||||
- name: Publish build on PyPi
|
||||
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != '' && github.event.inputs.publish_package == 'true'
|
||||
run: twine upload dist/*
|
64
.github/workflows/python-checks.yml
vendored
Normal file
64
.github/workflows/python-checks.yml
vendored
Normal file
@ -0,0 +1,64 @@
|
||||
# Runs python code quality checks.
|
||||
#
|
||||
# Checks for changes to python files before running the checks.
|
||||
# When manually triggered or called from another workflow, always runs the tests.
|
||||
#
|
||||
# TODO: Add mypy or pyright to the checks.
|
||||
|
||||
name: 'python checks'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
jobs:
|
||||
python-checks:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed python files
|
||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install ruff
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: pip install ruff
|
||||
shell: bash
|
||||
|
||||
- name: ruff check
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: ruff check --output-format=github .
|
||||
shell: bash
|
||||
|
||||
- name: ruff format
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: ruff format --check .
|
||||
shell: bash
|
94
.github/workflows/python-tests.yml
vendored
Normal file
94
.github/workflows/python-tests.yml
vendored
Normal file
@ -0,0 +1,94 @@
|
||||
# Runs python tests on a matrix of python versions and platforms.
|
||||
#
|
||||
# Checks for changes to python files before running the tests.
|
||||
# When manually triggered or called from another workflow, always runs the tests.
|
||||
|
||||
name: 'python tests'
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
matrix:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- '3.10'
|
||||
- '3.11'
|
||||
platform:
|
||||
- linux-cuda-11_7
|
||||
- linux-rocm-5_2
|
||||
- linux-cpu
|
||||
- macos-default
|
||||
- windows-cpu
|
||||
include:
|
||||
- platform: linux-cuda-11_7
|
||||
os: ubuntu-22.04
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: linux-rocm-5_2
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: linux-cpu
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: macos-default
|
||||
os: macOS-12
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: windows-cpu
|
||||
os: windows-2022
|
||||
github-env: $env:GITHUB_ENV
|
||||
name: 'py${{ matrix.python-version }}: ${{ matrix.platform }}'
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 15 # expected run time: 2-6 min, depending on platform
|
||||
env:
|
||||
PIP_USE_PEP517: '1'
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check for changed python files
|
||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v42
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
env:
|
||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||
run: >
|
||||
pip3 install --editable=".[test]"
|
||||
|
||||
- name: run pytest
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
||||
run: pytest
|
96
.github/workflows/release.yml
vendored
Normal file
96
.github/workflows/release.yml
vendored
Normal file
@ -0,0 +1,96 @@
|
||||
# Main release workflow. Triggered on tag push or manual trigger.
|
||||
#
|
||||
# - Runs all code checks and tests
|
||||
# - Verifies the app version matches the tag version.
|
||||
# - Builds the installer and build, uploading them as artifacts.
|
||||
# - Publishes to TestPyPI and PyPI. Both are conditional on the previous steps passing and require a manual approval.
|
||||
#
|
||||
# See docs/RELEASE.md for more information on the release process.
|
||||
|
||||
name: release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
check-version:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: check python version
|
||||
uses: samuelcolvin/check-python-version@v4
|
||||
id: check-python-version
|
||||
with:
|
||||
version_file_path: invokeai/version/invokeai_version.py
|
||||
|
||||
frontend-checks:
|
||||
uses: ./.github/workflows/frontend-checks.yml
|
||||
|
||||
frontend-tests:
|
||||
uses: ./.github/workflows/frontend-tests.yml
|
||||
|
||||
python-checks:
|
||||
uses: ./.github/workflows/python-checks.yml
|
||||
|
||||
python-tests:
|
||||
uses: ./.github/workflows/python-tests.yml
|
||||
|
||||
build:
|
||||
uses: ./.github/workflows/build-installer.yml
|
||||
|
||||
publish-testpypi:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
needs:
|
||||
[
|
||||
check-version,
|
||||
frontend-checks,
|
||||
frontend-tests,
|
||||
python-checks,
|
||||
python-tests,
|
||||
build,
|
||||
]
|
||||
environment:
|
||||
name: testpypi
|
||||
url: https://test.pypi.org/p/invokeai
|
||||
steps:
|
||||
- name: download distribution from build job
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: dist/
|
||||
|
||||
- name: publish distribution to TestPyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
|
||||
publish-pypi:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
needs:
|
||||
[
|
||||
check-version,
|
||||
frontend-checks,
|
||||
frontend-tests,
|
||||
python-checks,
|
||||
python-tests,
|
||||
build,
|
||||
]
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/invokeai
|
||||
steps:
|
||||
- name: download distribution from build job
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: dist/
|
||||
|
||||
- name: publish distribution to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
24
.github/workflows/style-checks.yml
vendored
24
.github/workflows/style-checks.yml
vendored
@ -1,24 +0,0 @@
|
||||
name: style checks
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches: main
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies with pip
|
||||
run: |
|
||||
pip install ruff
|
||||
|
||||
- run: ruff check --output-format=github .
|
||||
- run: ruff format --check .
|
129
.github/workflows/test-invoke-pip.yml
vendored
129
.github/workflows/test-invoke-pip.yml
vendored
@ -1,129 +0,0 @@
|
||||
name: Test invoke.py pip
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types:
|
||||
- 'ready_for_review'
|
||||
- 'opened'
|
||||
- 'synchronize'
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
matrix:
|
||||
if: github.event.pull_request.draft == false
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
# - '3.9'
|
||||
- '3.10'
|
||||
pytorch:
|
||||
- linux-cuda-11_7
|
||||
- linux-rocm-5_2
|
||||
- linux-cpu
|
||||
- macos-default
|
||||
- windows-cpu
|
||||
include:
|
||||
- pytorch: linux-cuda-11_7
|
||||
os: ubuntu-22.04
|
||||
github-env: $GITHUB_ENV
|
||||
- pytorch: linux-rocm-5_2
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
|
||||
github-env: $GITHUB_ENV
|
||||
- pytorch: linux-cpu
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||
github-env: $GITHUB_ENV
|
||||
- pytorch: macos-default
|
||||
os: macOS-12
|
||||
github-env: $GITHUB_ENV
|
||||
- pytorch: windows-cpu
|
||||
os: windows-2022
|
||||
github-env: $env:GITHUB_ENV
|
||||
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
env:
|
||||
PIP_USE_PEP517: '1'
|
||||
steps:
|
||||
- name: Checkout sources
|
||||
id: checkout-sources
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Check for changed python files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v41
|
||||
with:
|
||||
files_yaml: |
|
||||
python:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: set test prompt to main branch validation
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
||||
|
||||
- name: setup python
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install invokeai
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
env:
|
||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||
run: >
|
||||
pip3 install
|
||||
--editable=".[test]"
|
||||
|
||||
- name: run pytest
|
||||
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||
id: run-pytest
|
||||
run: pytest
|
||||
|
||||
# - name: run invokeai-configure
|
||||
# env:
|
||||
# HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
|
||||
# run: >
|
||||
# invokeai-configure
|
||||
# --yes
|
||||
# --default_only
|
||||
# --full-precision
|
||||
# # can't use fp16 weights without a GPU
|
||||
|
||||
# - name: run invokeai
|
||||
# id: run-invokeai
|
||||
# env:
|
||||
# # Set offline mode to make sure configure preloaded successfully.
|
||||
# HF_HUB_OFFLINE: 1
|
||||
# HF_DATASETS_OFFLINE: 1
|
||||
# TRANSFORMERS_OFFLINE: 1
|
||||
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||
# run: >
|
||||
# invokeai
|
||||
# --no-patchmatch
|
||||
# --no-nsfw_checker
|
||||
# --precision=float32
|
||||
# --always_use_cpu
|
||||
# --use_memory_db
|
||||
# --outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
||||
# --from_file ${{ env.TEST_PROMPTS }}
|
||||
|
||||
# - name: Archive results
|
||||
# env:
|
||||
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||
# uses: actions/upload-artifact@v3
|
||||
# with:
|
||||
# name: results
|
||||
# path: ${{ env.INVOKEAI_OUTDIR }}
|
@ -7,7 +7,7 @@ embeddedLanguageFormatting: auto
|
||||
overrides:
|
||||
- files: '*.md'
|
||||
options:
|
||||
proseWrap: always
|
||||
proseWrap: preserve
|
||||
printWidth: 80
|
||||
parser: markdown
|
||||
cursorOffset: -1
|
||||
|
8
Makefile
8
Makefile
@ -10,10 +10,11 @@ help:
|
||||
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
|
||||
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
|
||||
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
|
||||
@echo "test" Run the unit tests.
|
||||
@echo "frontend-install" Install the pnpm modules needed for the front end
|
||||
@echo "test Run the unit tests."
|
||||
@echo "frontend-install Install the pnpm modules needed for the front end"
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||
@echo "installer-zip Build the installer .zip file for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
|
||||
@ -53,6 +54,9 @@ frontend-build:
|
||||
frontend-dev:
|
||||
cd invokeai/frontend/web && pnpm dev
|
||||
|
||||
frontend-typegen:
|
||||
cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen
|
||||
|
||||
# Installer zip file
|
||||
installer-zip:
|
||||
cd installer && ./create_installer.sh
|
||||
|
142
docs/RELEASE.md
Normal file
142
docs/RELEASE.md
Normal file
@ -0,0 +1,142 @@
|
||||
# Release Process
|
||||
|
||||
The app is published in twice, in different build formats.
|
||||
|
||||
- A [PyPI] distribution. This includes both a source distribution and built distribution (a wheel). Users install with `pip install invokeai`. The updater uses this build.
|
||||
- An installer on the [InvokeAI Releases Page]. This is a zip file with install scripts and a wheel. This is only used for new installs.
|
||||
|
||||
## General Prep
|
||||
|
||||
Make a developer call-out for PRs to merge. Merge and test things out.
|
||||
|
||||
While the release workflow does not include end-to-end tests, it does pause before publishing so you can download and test the final build.
|
||||
|
||||
## Release Workflow
|
||||
|
||||
The `release.yml` workflow runs a number of jobs to handle code checks, tests, build and publish on PyPI.
|
||||
|
||||
It is triggered on **tag push**, when the tag matches `v*`. It doesn't matter if you've prepped a release branch like `release/v3.5.0` or are releasing from `main` - it works the same.
|
||||
|
||||
> Because commits are reference-counted, it is safe to create a release branch, tag it, let the workflow run, then delete the branch. So long as the tag exists, that commit will exist.
|
||||
|
||||
### Triggering the Workflow
|
||||
|
||||
Run `make tag-release` to tag the current commit and kick off the workflow.
|
||||
|
||||
The release may also be dispatched [manually].
|
||||
|
||||
### Workflow Jobs and Process
|
||||
|
||||
The workflow consists of a number of concurrently-run jobs, and two final publish jobs.
|
||||
|
||||
The publish jobs require manual approval and are only run if the other jobs succeed.
|
||||
|
||||
#### `check-version` Job
|
||||
|
||||
This job checks that the git ref matches the app version. It matches the ref against the `__version__` variable in `invokeai/version/invokeai_version.py`.
|
||||
|
||||
When the workflow is triggered by tag push, the ref is the tag. If the workflow is run manually, the ref is the target selected from the **Use workflow from** dropdown.
|
||||
|
||||
This job uses [samuelcolvin/check-python-version].
|
||||
|
||||
> Any valid [version specifier] works, so long as the tag matches the version. The release workflow works exactly the same for `RC`, `post`, `dev`, etc.
|
||||
|
||||
#### Check and Test Jobs
|
||||
|
||||
- **`python-tests`**: runs `pytest` on matrix of platforms
|
||||
- **`python-checks`**: runs `ruff` (format and lint)
|
||||
- **`frontend-tests`**: runs `vitest`
|
||||
- **`frontend-checks`**: runs `prettier` (format), `eslint` (lint), `dpdm` (circular refs), `tsc` (static type check) and `knip` (unused imports)
|
||||
|
||||
> **TODO** We should add `mypy` or `pyright` to the **`check-python`** job.
|
||||
|
||||
> **TODO** We should add an end-to-end test job that generates an image.
|
||||
|
||||
#### `build-installer` Job
|
||||
|
||||
This sets up both python and frontend dependencies and builds the python package. Internally, this runs `installer/create_installer.sh` and uploads two artifacts:
|
||||
|
||||
- **`dist`**: the python distribution, to be published on PyPI
|
||||
- **`InvokeAI-installer-${VERSION}.zip`**: the installer to be included in the GitHub release
|
||||
|
||||
#### Sanity Check & Smoke Test
|
||||
|
||||
At this point, the release workflow pauses as the remaining publish jobs require approval.
|
||||
|
||||
A maintainer should go to the **Summary** tab of the workflow, download the installer and test it. Ensure the app loads and generates.
|
||||
|
||||
> The same wheel file is bundled in the installer and in the `dist` artifact, which is uploaded to PyPI. You should end up with the exactly the same installation of the `invokeai` package from any of these methods.
|
||||
|
||||
#### PyPI Publish Jobs
|
||||
|
||||
The publish jobs will run if any of the previous jobs fail.
|
||||
|
||||
They use [GitHub environments], which are configured as [trusted publishers] on PyPI.
|
||||
|
||||
Both jobs require a maintainer to approve them from the workflow's **Summary** tab.
|
||||
|
||||
- Click the **Review deployments** button
|
||||
- Select the environment (either `testpypi` or `pypi`)
|
||||
- Click **Approve and deploy**
|
||||
|
||||
> **If the version already exists on PyPI, the publish jobs will fail.** PyPI only allows a given version to be published once - you cannot change it. If version published on PyPI has a problem, you'll need to "fail forward" by bumping the app version and publishing a followup release.
|
||||
|
||||
#### `publish-testpypi` Job
|
||||
|
||||
Publishes the distribution on the [Test PyPI] index, using the `testpypi` GitHub environment.
|
||||
|
||||
This job is not required for the production PyPI publish, but included just in case you want to test the PyPI release.
|
||||
|
||||
If approved and successful, you could try out the test release like this:
|
||||
|
||||
```sh
|
||||
# Create a new virtual environment
|
||||
python -m venv ~/.test-invokeai-dist --prompt test-invokeai-dist
|
||||
# Install the distribution from Test PyPI
|
||||
pip install --index-url https://test.pypi.org/simple/ invokeai
|
||||
# Run and test the app
|
||||
invokeai-web
|
||||
# Cleanup
|
||||
deactivate
|
||||
rm -rf ~/.test-invokeai-dist
|
||||
```
|
||||
|
||||
#### `publish-pypi` Job
|
||||
|
||||
Publishes the distribution on the production PyPI index, using the `pypi` GitHub environment.
|
||||
|
||||
## Publish the GitHub Release with installer
|
||||
|
||||
Once the release is published to PyPI, it's time to publish the GitHub release.
|
||||
|
||||
1. [Draft a new release] on GitHub, choosing the tag that triggered the release.
|
||||
2. Write the release notes, describing important changes. The **Generate release notes** button automatically inserts the changelog and new contributors, and you can copy/paste the intro from previous releases.
|
||||
3. Upload the zip file created in **`build`** job into the Assets section of the release notes. You can also upload the zip into the body of the release notes, since it can be hard for users to find the Assets section.
|
||||
4. Check the **Set as a pre-release** and **Create a discussion for this release** checkboxes at the bottom of the release page.
|
||||
5. Publish the pre-release.
|
||||
6. Announce the pre-release in Discord.
|
||||
|
||||
> **TODO** Workflows can create a GitHub release from a template and upload release assets. One popular action to handle this is [ncipollo/release-action]. A future enhancement to the release process could set this up.
|
||||
|
||||
## Manual Build
|
||||
|
||||
The `build installer` workflow can be dispatched manually. This is useful to test the installer for a given branch or tag.
|
||||
|
||||
No checks are run, it just builds.
|
||||
|
||||
## Manual Release
|
||||
|
||||
The `release` workflow can be dispatched manually. You must dispatch the workflow from the right tag, else it will fail the version check.
|
||||
|
||||
This functionality is available as a fallback in case something goes wonky. Typically, releases should be triggered via tag push as described above.
|
||||
|
||||
[InvokeAI Releases Page]: https://github.com/invoke-ai/InvokeAI/releases
|
||||
[PyPI]: https://pypi.org/
|
||||
[Draft a new release]: https://github.com/invoke-ai/InvokeAI/releases/new
|
||||
[Test PyPI]: https://test.pypi.org/
|
||||
[version specifier]: https://packaging.python.org/en/latest/specifications/version-specifiers/
|
||||
[ncipollo/release-action]: https://github.com/ncipollo/release-action
|
||||
[GitHub environments]: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment
|
||||
[trusted publishers]: https://docs.pypi.org/trusted-publishers/
|
||||
[samuelcolvin/check-python-version]: https://github.com/samuelcolvin/check-python-version
|
||||
[manually]: #manual-release
|
@ -32,7 +32,6 @@ model. These are the:
|
||||
Responsible for loading a model from disk
|
||||
into RAM and VRAM and getting it ready for inference.
|
||||
|
||||
|
||||
## Location of the Code
|
||||
|
||||
The four main services can be found in
|
||||
@ -67,19 +66,17 @@ provides the following fields:
|
||||
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
||||
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
||||
| `path` | str | Location of model on disk |
|
||||
| `original_hash` | str | Hash of the model when it was first installed |
|
||||
| `current_hash` | str | Most recent hash of the model's contents |
|
||||
| `hash` | str | Hash of the model |
|
||||
| `description` | str | Human-readable description of the model (optional) |
|
||||
| `source` | str | Model's source URL or repo id (optional) |
|
||||
|
||||
The `key` is a unique 32-character random ID which was generated at
|
||||
install time. The `original_hash` field stores a hash of the model's
|
||||
install time. The `hash` field stores a hash of the model's
|
||||
contents at install time obtained by sampling several parts of the
|
||||
model's files using the `imohash` library. Over the course of the
|
||||
model's lifetime it may be transformed in various ways, such as
|
||||
changing its precision or converting it from a .safetensors to a
|
||||
diffusers model. When this happens, `original_hash` is unchanged, but
|
||||
`current_hash` is updated to indicate the current contents.
|
||||
diffusers model.
|
||||
|
||||
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
|
||||
are defined in `invokeai.backend.model_manager.config`. They are also
|
||||
@ -94,7 +91,6 @@ The `path` field can be absolute or relative. If relative, it is taken
|
||||
to be relative to the `models_dir` setting in the user's
|
||||
`invokeai.yaml` file.
|
||||
|
||||
|
||||
### CheckpointConfig
|
||||
|
||||
This adds support for checkpoint configurations, and adds the
|
||||
@ -228,9 +224,9 @@ The way it works is as follows:
|
||||
1. Retrieve the value of the `model_config_db` option from the user's
|
||||
`invokeai.yaml` config file.
|
||||
2. If `model_config_db` is `auto` (the default), then:
|
||||
- Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
|
||||
* Use the values of `conn` and `lock` to return a `ModelRecordServiceSQL` object
|
||||
opened on the passed connection and lock.
|
||||
- Open up a new connection to `databases/invokeai.db` if `conn`
|
||||
* Open up a new connection to `databases/invokeai.db` if `conn`
|
||||
and/or `lock` are missing (see note below).
|
||||
3. If `model_config_db` is a Path, then use `from_db_file`
|
||||
to return the appropriate type of ModelRecordService.
|
||||
@ -255,7 +251,7 @@ store = ModelRecordServiceBase.open(config, db_conn, lock)
|
||||
|
||||
Configurations can be retrieved in several ways.
|
||||
|
||||
#### get_model(key) -> AnyModelConfig:
|
||||
#### get_model(key) -> AnyModelConfig
|
||||
|
||||
The basic functionality is to call the record store object's
|
||||
`get_model()` method with the desired model's unique key. It returns
|
||||
@ -272,28 +268,28 @@ print(model_conf.path)
|
||||
If the key is unrecognized, this call raises an
|
||||
`UnknownModelException`.
|
||||
|
||||
#### exists(key) -> AnyModelConfig:
|
||||
#### exists(key) -> AnyModelConfig
|
||||
|
||||
Returns True if a model with the given key exists in the databsae.
|
||||
|
||||
#### search_by_path(path) -> AnyModelConfig:
|
||||
#### search_by_path(path) -> AnyModelConfig
|
||||
|
||||
Returns the configuration of the model whose path is `path`. The path
|
||||
is matched using a simple string comparison and won't correctly match
|
||||
models referred to by different paths (e.g. using symbolic links).
|
||||
|
||||
#### search_by_name(name, base, type) -> List[AnyModelConfig]:
|
||||
#### search_by_name(name, base, type) -> List[AnyModelConfig]
|
||||
|
||||
This method searches for models that match some combination of `name`,
|
||||
`BaseType` and `ModelType`. Calling without any arguments will return
|
||||
all the models in the database.
|
||||
|
||||
#### all_models() -> List[AnyModelConfig]:
|
||||
#### all_models() -> List[AnyModelConfig]
|
||||
|
||||
Return all the model configs in the database. Exactly equivalent to
|
||||
calling `search_by_name()` with no arguments.
|
||||
|
||||
#### search_by_tag(tags) -> List[AnyModelConfig]:
|
||||
#### search_by_tag(tags) -> List[AnyModelConfig]
|
||||
|
||||
`tags` is a list of strings. This method returns a list of model
|
||||
configs that contain all of the given tags. Examples:
|
||||
@ -312,11 +308,11 @@ commercializable_models = [x for x in store.all_models() \
|
||||
if x.license.contains('allowCommercialUse=Sell')]
|
||||
```
|
||||
|
||||
#### version() -> str:
|
||||
#### version() -> str
|
||||
|
||||
Returns the version of the database, currently at `3.2`
|
||||
|
||||
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase:
|
||||
#### model_info_by_name(name, base_model, model_type) -> ModelConfigBase
|
||||
|
||||
This method exists to ease the transition from the previous version of
|
||||
the model manager, in which `get_model()` took the three arguments
|
||||
@ -337,7 +333,7 @@ model and pass its key to `get_model()`.
|
||||
Several methods allow you to create and update stored model config
|
||||
records.
|
||||
|
||||
#### add_model(key, config) -> AnyModelConfig:
|
||||
#### add_model(key, config) -> AnyModelConfig
|
||||
|
||||
Given a key and a configuration, this will add the model's
|
||||
configuration record to the database. `config` can either be a subclass of
|
||||
@ -352,7 +348,7 @@ model with the same key is already in the database, or an
|
||||
`InvalidModelConfigException` if a dict was passed and Pydantic
|
||||
experienced a parse or validation error.
|
||||
|
||||
### update_model(key, config) -> AnyModelConfig:
|
||||
### update_model(key, config) -> AnyModelConfig
|
||||
|
||||
Given a key and a configuration, this will update the model
|
||||
configuration record in the database. `config` can be either a
|
||||
@ -370,31 +366,31 @@ The `ModelInstallService` class implements the
|
||||
shop for all your model install needs. It provides the following
|
||||
functionality:
|
||||
|
||||
- Registering a model config record for a model already located on the
|
||||
* Registering a model config record for a model already located on the
|
||||
local filesystem, without moving it or changing its path.
|
||||
|
||||
- Installing a model alreadiy located on the local filesystem, by
|
||||
* Installing a model alreadiy located on the local filesystem, by
|
||||
moving it into the InvokeAI root directory under the
|
||||
`models` folder (or wherever config parameter `models_dir`
|
||||
specifies).
|
||||
|
||||
- Probing of models to determine their type, base type and other key
|
||||
* Probing of models to determine their type, base type and other key
|
||||
information.
|
||||
|
||||
- Interface with the InvokeAI event bus to provide status updates on
|
||||
* Interface with the InvokeAI event bus to provide status updates on
|
||||
the download, installation and registration process.
|
||||
|
||||
- Downloading a model from an arbitrary URL and installing it in
|
||||
* Downloading a model from an arbitrary URL and installing it in
|
||||
`models_dir`.
|
||||
|
||||
- Special handling for Civitai model URLs which allow the user to
|
||||
* Special handling for Civitai model URLs which allow the user to
|
||||
paste in a model page's URL or download link
|
||||
|
||||
- Special handling for HuggingFace repo_ids to recursively download
|
||||
* Special handling for HuggingFace repo_ids to recursively download
|
||||
the contents of the repository, paying attention to alternative
|
||||
variants such as fp16.
|
||||
|
||||
- Saving tags and other metadata about the model into the invokeai database
|
||||
* Saving tags and other metadata about the model into the invokeai database
|
||||
when fetching from a repo that provides that type of information,
|
||||
(currently only Civitai and HuggingFace).
|
||||
|
||||
@ -443,7 +439,6 @@ required parameters:
|
||||
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|
||||
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
|
||||
|
||||
|
||||
Once initialized, the installer will provide the following methods:
|
||||
|
||||
#### install_job = installer.heuristic_import(source, [config], [access_token])
|
||||
@ -457,12 +452,12 @@ The `source` is a string that can be any of these forms
|
||||
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
|
||||
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
|
||||
3. A HuggingFace repo_id with any of the following formats:
|
||||
- `model/name` -- entire model
|
||||
- `model/name:fp32` -- entire model, using the fp32 variant
|
||||
- `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||
- `model/name::vae` -- vae submodel, using default precision
|
||||
- `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||
- `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||
* `model/name` -- entire model
|
||||
* `model/name:fp32` -- entire model, using the fp32 variant
|
||||
* `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||
* `model/name::vae` -- vae submodel, using default precision
|
||||
* `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||
* `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||
|
||||
Note that by specifying a relative path to the top of the HuggingFace
|
||||
repo, you can download and install arbitrary models files.
|
||||
@ -566,7 +561,6 @@ details.
|
||||
This is used for a model that is located on a locally-accessible Posix
|
||||
filesystem, such as a local disk or networked fileshare.
|
||||
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `path` | str | Path | None | Path to the model file or directory |
|
||||
@ -625,7 +619,6 @@ HuggingFace has the most complicated `ModelSource` structure:
|
||||
| `subfolder` | Path | None | Look for the model in a subfolder of the repo. |
|
||||
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
|
||||
|
||||
|
||||
The `repo_id` is the repository ID, such as `stabilityai/sdxl-turbo`.
|
||||
|
||||
The `variant` is one of the various diffusers formats that HuggingFace
|
||||
@ -661,7 +654,6 @@ in. To download these files, you must provide an
|
||||
`HfFolder.get_token()` will be called to fill it in with the cached
|
||||
one.
|
||||
|
||||
|
||||
#### Monitoring the install job process
|
||||
|
||||
When you create an install job with `import_model()`, it launches the
|
||||
@ -682,7 +674,6 @@ The `ModelInstallJob` class has the following structure:
|
||||
| `error_type` | `str` | Name of the exception that led to an error status |
|
||||
| `error` | `str` | Traceback of the error |
|
||||
|
||||
|
||||
If the `event_bus` argument was provided, events will also be
|
||||
broadcast to the InvokeAI event bus. The events will appear on the bus
|
||||
as an event of type `EventServiceBase.model_event`, a timestamp and
|
||||
@ -702,14 +693,13 @@ following keys:
|
||||
| `total_bytes` | int | Total size of all the files that make up the model |
|
||||
| `parts` | List[Dict]| Information on the progress of the individual files that make up the model |
|
||||
|
||||
|
||||
The parts is a list of dictionaries that give information on each of
|
||||
the components pieces of the download. The dictionary's keys are
|
||||
`source`, `local_path`, `bytes` and `total_bytes`, and correspond to
|
||||
the like-named keys in the main event.
|
||||
|
||||
Note that downloading events will not be issued for local models, and
|
||||
that downloading events occur *before* the running event.
|
||||
that downloading events occur _before_ the running event.
|
||||
|
||||
##### `model_install_running`
|
||||
|
||||
@ -752,7 +742,6 @@ properties: `waiting`, `downloading`, `running`, `complete`, `errored`
|
||||
and `cancelled`, as well as `in_terminal_state`. The last will return
|
||||
True if the job is in the complete, errored or cancelled states.
|
||||
|
||||
|
||||
#### Model configuration and probing
|
||||
|
||||
The install service uses the `invokeai.backend.model_manager.probe`
|
||||
@ -862,7 +851,6 @@ This method is similar to `unregister()`, but also unconditionally
|
||||
deletes the corresponding model weights file(s), regardless of whether
|
||||
they are inside or outside the InvokeAI models hierarchy.
|
||||
|
||||
|
||||
#### path = installer.download_and_cache(remote_source, [access_token], [timeout])
|
||||
|
||||
This utility routine will download the model file located at source,
|
||||
@ -974,7 +962,7 @@ is in its lifecycle. Values are defined in the string enum
|
||||
`DownloadJobStatus`, a symbol available from
|
||||
`invokeai.app.services.download_manager`. Possible values are:
|
||||
|
||||
| **Value** | **String Value** | ** Description ** |
|
||||
| **Value** | **String Value** | **Description** |
|
||||
|--------------|---------------------|-------------------|
|
||||
| `IDLE` | idle | Job created, but not submitted to the queue |
|
||||
| `ENQUEUED` | enqueued | Job is patiently waiting on the queue |
|
||||
@ -1040,11 +1028,11 @@ While a job is being downloaded, the queue will emit events at
|
||||
periodic intervals. A typical series of events during a successful
|
||||
download session will look like this:
|
||||
|
||||
- enqueued
|
||||
- running
|
||||
- running
|
||||
- running
|
||||
- completed
|
||||
* enqueued
|
||||
* running
|
||||
* running
|
||||
* running
|
||||
* completed
|
||||
|
||||
There will be a single enqueued event, followed by one or more running
|
||||
events, and finally one `completed`, `error` or `cancelled`
|
||||
@ -1053,12 +1041,12 @@ events.
|
||||
It is possible for a caller to pause download temporarily, in which
|
||||
case the events may look something like this:
|
||||
|
||||
- enqueued
|
||||
- running
|
||||
- running
|
||||
- paused
|
||||
- running
|
||||
- completed
|
||||
* enqueued
|
||||
* running
|
||||
* running
|
||||
* paused
|
||||
* running
|
||||
* completed
|
||||
|
||||
The download queue logs when downloads start and end (unless `quiet`
|
||||
is set to True at initialization time) but doesn't log any progress
|
||||
@ -1187,7 +1175,6 @@ and is equivalent to manually specifying a destination of
|
||||
Here is the full list of arguments that can be provided to
|
||||
`create_download_job()`:
|
||||
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `source` | Union[str, Path, AnyHttpUrl] | | Download remote or local source |
|
||||
@ -1275,7 +1262,7 @@ for getting the model to run. For example "author" is metadata, while
|
||||
"type", "base" and "format" are not. The latter fields are part of the
|
||||
model's config, as defined in `invokeai.backend.model_manager.config`.
|
||||
|
||||
### Example Usage:
|
||||
### Example Usage
|
||||
|
||||
```
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
@ -1328,7 +1315,6 @@ This is the common base class for metadata:
|
||||
| `author` | str | Model's author |
|
||||
| `tags` | Set[str] | Model tags |
|
||||
|
||||
|
||||
Note that the model config record also has a `name` field. It is
|
||||
intended that the config record version be locally customizable, while
|
||||
the metadata version is read-only. However, enforcing this is expected
|
||||
@ -1348,7 +1334,6 @@ This descends from `ModelMetadataBase` and adds the following fields:
|
||||
| `last_modified`| datetime | Date of last commit of this model to the repo |
|
||||
| `files` | List[Path] | List of the files in the model repo |
|
||||
|
||||
|
||||
#### `CivitaiMetadata`
|
||||
|
||||
This descends from `ModelMetadataBase` and adds the following fields:
|
||||
@ -1415,7 +1400,6 @@ testing suite to avoid hitting the internet.
|
||||
The HuggingFace and Civitai fetcher subclasses add additional
|
||||
repo-specific fetching methods:
|
||||
|
||||
|
||||
#### HuggingFaceMetadataFetch
|
||||
|
||||
This overrides its base class `from_json()` method to return a
|
||||
@ -1434,7 +1418,6 @@ retrieves its metadata. Functionally equivalent to `from_id()`, the
|
||||
only difference is that it returna a `CivitaiMetadata` object rather
|
||||
than an `AnyModelRepoMetadata`.
|
||||
|
||||
|
||||
### Metadata Storage
|
||||
|
||||
The `ModelMetadataStore` provides a simple facility to store model
|
||||
@ -1567,7 +1550,6 @@ The returned `LoadedModel` object contains a copy of the configuration
|
||||
record returned by the model record `get_model()` method, as well as
|
||||
the in-memory loaded model:
|
||||
|
||||
|
||||
| **Attribute Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
||||
@ -1581,7 +1563,6 @@ return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
|
||||
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
||||
models. The others are obvious.
|
||||
|
||||
|
||||
`LoadedModel` acts as a context manager. The context loads the model
|
||||
into the execution device (e.g. VRAM on CUDA systems), locks the model
|
||||
in the execution device for the duration of the context, and returns
|
||||
@ -1595,9 +1576,9 @@ with model_info as vae:
|
||||
|
||||
`get_model_by_key()` may raise any of the following exceptions:
|
||||
|
||||
- `UnknownModelException` -- key not in database
|
||||
- `ModelNotFoundException` -- key in database but model not found at path
|
||||
- `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
* `UnknownModelException` -- key not in database
|
||||
* `ModelNotFoundException` -- key in database but model not found at path
|
||||
* `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
|
||||
### Emitting model loading events
|
||||
|
||||
@ -1724,6 +1705,7 @@ object, or in `context.services.model_manager` from within an
|
||||
invocation.
|
||||
|
||||
In the examples below, we have retrieved the manager using:
|
||||
|
||||
```
|
||||
mm = ApiDependencies.invoker.services.model_manager
|
||||
```
|
||||
|
45
docs/nodes/INVOCATION_API.md
Normal file
45
docs/nodes/INVOCATION_API.md
Normal file
@ -0,0 +1,45 @@
|
||||
# Invocation API
|
||||
|
||||
Each invocation's `invoke` method is provided a single arg - the Invocation
|
||||
Context.
|
||||
|
||||
This object provides access to various methods, used to interact with the
|
||||
application. Loading and saving images, logging messages, etc.
|
||||
|
||||
!!! warning ""
|
||||
|
||||
This API may shift slightly until the release of v4.0.0 as we work through a few final updates to the Model Manager.
|
||||
|
||||
```py
|
||||
class MyInvocation(BaseInvocation):
|
||||
...
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image_pil = context.images.get_pil(image_name)
|
||||
# Do something to the image
|
||||
image_dto = context.images.save(image_pil)
|
||||
# Log a message
|
||||
context.logger.info(f"Did something cool, image saved!")
|
||||
...
|
||||
```
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
::: invokeai.app.services.shared.invocation_context.InvocationContext
|
||||
options:
|
||||
members: false
|
||||
|
||||
::: invokeai.app.services.shared.invocation_context.ImagesInterface
|
||||
|
||||
::: invokeai.app.services.shared.invocation_context.TensorsInterface
|
||||
|
||||
::: invokeai.app.services.shared.invocation_context.ConditioningInterface
|
||||
|
||||
::: invokeai.app.services.shared.invocation_context.ModelsInterface
|
||||
|
||||
::: invokeai.app.services.shared.invocation_context.LoggerInterface
|
||||
|
||||
::: invokeai.app.services.shared.invocation_context.ConfigInterface
|
||||
|
||||
::: invokeai.app.services.shared.invocation_context.UtilInterface
|
||||
|
||||
::: invokeai.app.services.shared.invocation_context.BoardsInterface
|
||||
<!-- prettier-ignore-end -->
|
148
docs/nodes/NODES_MIGRATION_V3_V4.md
Normal file
148
docs/nodes/NODES_MIGRATION_V3_V4.md
Normal file
@ -0,0 +1,148 @@
|
||||
# Invoke v4.0.0 Nodes API Migration guide
|
||||
|
||||
Invoke v4.0.0 is versioned as such due to breaking changes to the API utilized
|
||||
by nodes, both core and custom.
|
||||
|
||||
## Motivation
|
||||
|
||||
Prior to v4.0.0, the `invokeai` python package has not be set up to be utilized
|
||||
as a library. That is to say, it didn't have any explicitly public API, and node
|
||||
authors had to work with the unstable internal application API.
|
||||
|
||||
v4.0.0 introduces a stable public API for nodes.
|
||||
|
||||
## Changes
|
||||
|
||||
There are two node-author-facing changes:
|
||||
|
||||
1. Import Paths
|
||||
1. Invocation Context API
|
||||
|
||||
### Import Paths
|
||||
|
||||
All public objects are now exported from `invokeai.invocation_api`:
|
||||
|
||||
```py
|
||||
# Old
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
|
||||
# New
|
||||
from invokeai.invocation_api import (
|
||||
BaseInvocation,
|
||||
ImageField,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
)
|
||||
```
|
||||
|
||||
It's possible that we've missed some classes you need in your node. Please let
|
||||
us know if that's the case.
|
||||
|
||||
### Invocation Context API
|
||||
|
||||
Most nodes utilize the Invocation Context, an object that is passed to the
|
||||
`invoke` that provides access to data and services a node may need.
|
||||
|
||||
Until now, that object and the services it exposed were internal. Exposing them
|
||||
to nodes means that changes to our internal implementation could break nodes.
|
||||
The methods on the services are also often fairly complicated and allowed nodes
|
||||
to footgun.
|
||||
|
||||
In v4.0.0, this object has been refactored to be much simpler.
|
||||
|
||||
See [INVOCATION_API](./INVOCATION_API.md) for full details of the API.
|
||||
|
||||
!!! warning ""
|
||||
|
||||
This API may shift slightly until the release of v4.0.0 as we work through a few final updates to the Model Manager.
|
||||
|
||||
#### Improved Service Methods
|
||||
|
||||
The biggest offender was the image save method:
|
||||
|
||||
```py
|
||||
# Old
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
# New
|
||||
image_dto = context.images.save(image=image)
|
||||
```
|
||||
|
||||
Other methods are simplified, or enhanced with additional functionality:
|
||||
|
||||
```py
|
||||
# Old
|
||||
image = context.services.images.get_pil_image(image_name)
|
||||
|
||||
# New
|
||||
image = context.images.get_pil(image_name)
|
||||
image_cmyk = context.images.get_pil(image_name, "CMYK")
|
||||
```
|
||||
|
||||
We also had some typing issues around tensors:
|
||||
|
||||
```py
|
||||
# Old
|
||||
# `latents` typed as `torch.Tensor`, but could be `ConditioningFieldData`
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
# `data` typed as `torch.Tenssor,` but could be `ConditioningFieldData`
|
||||
context.services.latents.save(latents_name, data)
|
||||
|
||||
# New - separate methods for tensors and conditioning data w/ correct typing
|
||||
# Also, the service generates the names
|
||||
tensor_name = context.tensors.save(tensor)
|
||||
tensor = context.tensors.load(tensor_name)
|
||||
# For conditioning
|
||||
cond_name = context.conditioning.save(cond_data)
|
||||
cond_data = context.conditioning.load(cond_name)
|
||||
```
|
||||
|
||||
#### Output Construction
|
||||
|
||||
Core Outputs have builder functions right on them - no need to manually
|
||||
construct these objects, or use an extra utility:
|
||||
|
||||
```py
|
||||
# Old
|
||||
image_output = ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
latents_output = build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||
noise_output = NoiseOutput(
|
||||
noise=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
cond_output = ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
# New
|
||||
image_output = ImageOutput.build(image_dto)
|
||||
latents_output = LatentsOutput.build(latents_name=name, latents=noise, seed=self.seed)
|
||||
noise_output = NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed)
|
||||
cond_output = ConditioningOutput.build(conditioning_name)
|
||||
```
|
||||
|
||||
You can still create the objects using constructors if you want, but we suggest
|
||||
using the builder methods.
|
@ -32,6 +32,7 @@ To use a community workflow, download the the `.json` node graph file and load i
|
||||
+ [Image to Character Art Image Nodes](#image-to-character-art-image-nodes)
|
||||
+ [Image Picker](#image-picker)
|
||||
+ [Image Resize Plus](#image-resize-plus)
|
||||
+ [Latent Upscale](#latent-upscale)
|
||||
+ [Load Video Frame](#load-video-frame)
|
||||
+ [Make 3D](#make-3d)
|
||||
+ [Mask Operations](#mask-operations)
|
||||
@ -290,6 +291,13 @@ View:
|
||||
</br><img src="https://raw.githubusercontent.com/VeyDlin/image-resize-plus-node/master/.readme/node.png" width="500" />
|
||||
|
||||
|
||||
--------------------------------
|
||||
### Latent Upscale
|
||||
|
||||
**Description:** This node uses a small (~2.4mb) model to upscale the latents used in a Stable Diffusion 1.5 or Stable Diffusion XL image generation, rather than the typical interpolation method, avoiding the traditional downsides of the latent upscale technique.
|
||||
|
||||
**Node Link:** [https://github.com/gogurtenjoyer/latent-upscale](https://github.com/gogurtenjoyer/latent-upscale)
|
||||
|
||||
--------------------------------
|
||||
### Load Video Frame
|
||||
|
||||
@ -346,12 +354,21 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
|
||||
|
||||
**Description:** A set of nodes for Metadata. Collect Metadata from within an `iterate` node & extract metadata from an image.
|
||||
|
||||
- `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node.
|
||||
- `Metadata From Image` - Provides Metadata from an image.
|
||||
- `Metadata To String` - Extracts a String value of a label from metadata.
|
||||
- `Metadata To Integer` - Extracts an Integer value of a label from metadata.
|
||||
- `Metadata To Float` - Extracts a Float value of a label from metadata.
|
||||
- `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata.
|
||||
- `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node
|
||||
- `Metadata From Image` - Provides Metadata from an image
|
||||
- `Metadata To String` - Extracts a String value of a label from metadata
|
||||
- `Metadata To Integer` - Extracts an Integer value of a label from metadata
|
||||
- `Metadata To Float` - Extracts a Float value of a label from metadata
|
||||
- `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata
|
||||
- `Metadata To Bool` - Extracts Bool types from metadata
|
||||
- `Metadata To Model` - Extracts model types from metadata
|
||||
- `Metadata To SDXL Model` - Extracts SDXL model types from metadata
|
||||
- `Metadata To LoRAs` - Extracts Loras from metadata.
|
||||
- `Metadata To SDXL LoRAs` - Extracts SDXL Loras from metadata
|
||||
- `Metadata To ControlNets` - Extracts ControNets from metadata
|
||||
- `Metadata To IP-Adapters` - Extracts IP-Adapters from metadata
|
||||
- `Metadata To T2I-Adapters` - Extracts T2I-Adapters from metadata
|
||||
- `Denoise Latents + Metadata` - This is an inherited version of the existing `Denoise Latents` node but with a metadata input and output.
|
||||
|
||||
**Node Link:** https://github.com/skunkworxdark/metadata-linked-nodes
|
||||
|
||||
|
@ -19,6 +19,8 @@ their descriptions.
|
||||
| Conditioning Primitive | A conditioning tensor primitive value |
|
||||
| Content Shuffle Processor | Applies content shuffle processing to image |
|
||||
| ControlNet | Collects ControlNet info to pass to other nodes |
|
||||
| Create Denoise Mask | Converts a greyscale or transparency image into a mask for denoising. |
|
||||
| Create Gradient Mask | Creates a mask for Gradient ("soft", "differential") inpainting that gradually expands during denoising. Improves edge coherence. |
|
||||
| Denoise Latents | Denoises noisy latents to decodable images |
|
||||
| Divide Integers | Divides two numbers |
|
||||
| Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator |
|
||||
|
@ -1,5 +0,0 @@
|
||||
mkdocs
|
||||
mkdocs-material>=8, <9
|
||||
mkdocs-git-revision-date-localized-plugin
|
||||
mkdocs-redirects==1.2.0
|
||||
|
@ -1,5 +0,0 @@
|
||||
:root {
|
||||
--md-primary-fg-color: #35A4DB;
|
||||
--md-primary-fg-color--light: #35A4DB;
|
||||
--md-primary-fg-color--dark: #35A4DB;
|
||||
}
|
@ -2,22 +2,18 @@
|
||||
|
||||
set -e
|
||||
|
||||
BCYAN="\e[1;36m"
|
||||
BYELLOW="\e[1;33m"
|
||||
BGREEN="\e[1;32m"
|
||||
BRED="\e[1;31m"
|
||||
RED="\e[31m"
|
||||
RESET="\e[0m"
|
||||
|
||||
function is_bin_in_path {
|
||||
builtin type -P "$1" &>/dev/null
|
||||
}
|
||||
BCYAN="\033[1;36m"
|
||||
BYELLOW="\033[1;33m"
|
||||
BGREEN="\033[1;32m"
|
||||
BRED="\033[1;31m"
|
||||
RED="\033[31m"
|
||||
RESET="\033[0m"
|
||||
|
||||
function git_show {
|
||||
git show -s --format=oneline --abbrev-commit "$1" | cat
|
||||
}
|
||||
|
||||
if [[ -v "VIRTUAL_ENV" ]]; then
|
||||
if [[ ! -z "${VIRTUAL_ENV}" ]]; then
|
||||
# we can't just call 'deactivate' because this function is not exported
|
||||
# to the environment of this script from the bash process that runs the script
|
||||
echo -e "${BRED}A virtual environment is activated. Please deactivate it before proceeding.${RESET}"
|
||||
@ -26,31 +22,63 @@ fi
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
echo
|
||||
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
|
||||
echo "The current working directory is $(pwd)"
|
||||
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
|
||||
echo
|
||||
|
||||
# Some machines only have `python3` in PATH, others have `python` - make an alias.
|
||||
# We can use a function to approximate an alias within a non-interactive shell.
|
||||
if ! is_bin_in_path python && is_bin_in_path python3; then
|
||||
function python {
|
||||
python3 "$@"
|
||||
}
|
||||
fi
|
||||
|
||||
VERSION=$(
|
||||
cd ..
|
||||
python -c "from invokeai.version import __version__ as version; print(version)"
|
||||
python3 -c "from invokeai.version import __version__ as version; print(version)"
|
||||
)
|
||||
PATCH=""
|
||||
VERSION="v${VERSION}${PATCH}"
|
||||
VERSION="v${VERSION}"
|
||||
|
||||
if [[ ! -z ${CI} ]]; then
|
||||
echo
|
||||
echo -e "${BCYAN}CI environment detected${RESET}"
|
||||
echo
|
||||
else
|
||||
echo
|
||||
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
|
||||
echo "The current working directory is $(pwd)"
|
||||
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
|
||||
echo
|
||||
fi
|
||||
|
||||
echo -e "${BGREEN}HEAD${RESET}:"
|
||||
git_show HEAD
|
||||
echo
|
||||
|
||||
# ---------------------- FRONTEND ----------------------
|
||||
|
||||
pushd ../invokeai/frontend/web >/dev/null
|
||||
echo "Installing frontend dependencies..."
|
||||
echo
|
||||
pnpm i --frozen-lockfile
|
||||
echo
|
||||
if [[ ! -z ${CI} ]]; then
|
||||
echo "Building frontend without checks..."
|
||||
# In CI, we have already done the frontend checks and can just build
|
||||
pnpm vite build
|
||||
else
|
||||
echo "Running checks and building frontend..."
|
||||
# This runs all the frontend checks and builds
|
||||
pnpm build
|
||||
fi
|
||||
echo
|
||||
popd
|
||||
|
||||
# ---------------------- BACKEND ----------------------
|
||||
|
||||
echo
|
||||
echo "Building wheel..."
|
||||
echo
|
||||
|
||||
# install the 'build' package in the user site packages, if needed
|
||||
# could be improved by using a temporary venv, but it's tiny and harmless
|
||||
if [[ $(python3 -c 'from importlib.util import find_spec; print(find_spec("build") is None)') == "True" ]]; then
|
||||
pip install --user build
|
||||
fi
|
||||
|
||||
rm -rf ../build
|
||||
|
||||
python3 -m build --outdir dist/ ../.
|
||||
|
||||
# ----------------------
|
||||
|
||||
echo
|
||||
@ -78,10 +106,28 @@ chmod a+x InvokeAI-Installer/install.sh
|
||||
cp install.bat.in InvokeAI-Installer/install.bat
|
||||
cp WinLongPathsEnabled.reg InvokeAI-Installer/
|
||||
|
||||
# Zip everything up
|
||||
zip -r InvokeAI-installer-$VERSION.zip InvokeAI-Installer
|
||||
FILENAME=InvokeAI-installer-$VERSION.zip
|
||||
|
||||
# clean up
|
||||
rm -rf InvokeAI-Installer tmp dist ../invokeai/frontend/web/dist/
|
||||
# Zip everything up
|
||||
zip -r ${FILENAME} InvokeAI-Installer
|
||||
|
||||
echo
|
||||
echo -e "${BGREEN}Built installer: ./${FILENAME}${RESET}"
|
||||
echo -e "${BGREEN}Built PyPi distribution: ./dist${RESET}"
|
||||
|
||||
# clean up, but only if we are not in a github action
|
||||
if [[ -z ${CI} ]]; then
|
||||
echo
|
||||
echo "Cleaning up intermediate build files..."
|
||||
rm -rf InvokeAI-Installer tmp ../invokeai/frontend/web/dist/
|
||||
fi
|
||||
|
||||
if [[ ! -z ${CI} ]]; then
|
||||
echo
|
||||
echo "Setting GitHub action outputs..."
|
||||
echo "INSTALLER_FILENAME=${FILENAME}" >>$GITHUB_OUTPUT
|
||||
echo "INSTALLER_PATH=installer/${FILENAME}" >>$GITHUB_OUTPUT
|
||||
echo "DIST_PATH=installer/dist/" >>$GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
exit 0
|
||||
|
@ -2,12 +2,12 @@
|
||||
|
||||
set -e
|
||||
|
||||
BCYAN="\e[1;36m"
|
||||
BYELLOW="\e[1;33m"
|
||||
BGREEN="\e[1;32m"
|
||||
BRED="\e[1;31m"
|
||||
RED="\e[31m"
|
||||
RESET="\e[0m"
|
||||
BCYAN="\033[1;36m"
|
||||
BYELLOW="\033[1;33m"
|
||||
BGREEN="\033[1;32m"
|
||||
BRED="\033[1;31m"
|
||||
RED="\033[31m"
|
||||
RESET="\033[0m"
|
||||
|
||||
function does_tag_exist {
|
||||
git rev-parse --quiet --verify "refs/tags/$1" >/dev/null
|
||||
@ -23,49 +23,40 @@ function git_show {
|
||||
|
||||
VERSION=$(
|
||||
cd ..
|
||||
python -c "from invokeai.version import __version__ as version; print(version)"
|
||||
python3 -c "from invokeai.version import __version__ as version; print(version)"
|
||||
)
|
||||
PATCH=""
|
||||
MAJOR_VERSION=$(echo $VERSION | sed 's/\..*$//')
|
||||
VERSION="v${VERSION}${PATCH}"
|
||||
LATEST_TAG="v${MAJOR_VERSION}-latest"
|
||||
|
||||
if does_tag_exist $VERSION; then
|
||||
echo -e "${BCYAN}${VERSION}${RESET} already exists:"
|
||||
git_show_ref tags/$VERSION
|
||||
echo
|
||||
fi
|
||||
if does_tag_exist $LATEST_TAG; then
|
||||
echo -e "${BCYAN}${LATEST_TAG}${RESET} already exists:"
|
||||
git_show_ref tags/$LATEST_TAG
|
||||
echo
|
||||
fi
|
||||
|
||||
echo -e "${BGREEN}HEAD${RESET}:"
|
||||
git_show
|
||||
echo
|
||||
|
||||
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} and ${BCYAN}${LATEST_TAG}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on remote${RESET}? "
|
||||
echo -e "${BGREEN}git remote -v${RESET}:"
|
||||
git remote -v
|
||||
echo
|
||||
|
||||
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on origin remote${RESET}? "
|
||||
read -e -p 'y/n [n]: ' input
|
||||
RESPONSE=${input:='n'}
|
||||
if [ "$RESPONSE" == 'y' ]; then
|
||||
echo
|
||||
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on remote..."
|
||||
git push --delete origin $VERSION
|
||||
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on origin remote..."
|
||||
git push origin :refs/tags/$VERSION
|
||||
|
||||
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} locally..."
|
||||
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} on locally..."
|
||||
if ! git tag -fa $VERSION; then
|
||||
echo "Existing/invalid tag"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
echo -e "Deleting ${BCYAN}${LATEST_TAG}${RESET} tag on remote..."
|
||||
git push --delete origin $LATEST_TAG
|
||||
|
||||
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${LATEST_TAG}${RESET} locally..."
|
||||
git tag -fa $LATEST_TAG
|
||||
|
||||
echo -e "Pushing updated tags to remote..."
|
||||
echo -e "Pushing updated tags to origin remote..."
|
||||
git push origin --tags
|
||||
fi
|
||||
exit 0
|
||||
|
0
invokeai/app/__init__.py
Normal file
0
invokeai/app/__init__.py
Normal file
@ -25,8 +25,8 @@ from ..services.invocation_cache.invocation_cache_memory import MemoryInvocation
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.model_images.model_images_default import ModelImageFileStorageDisk
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_metadata import ModelMetadataStoreSQL
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
@ -72,6 +72,8 @@ class ApiDependencies:
|
||||
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
model_images_folder = config.models_path
|
||||
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
|
||||
configuration = config
|
||||
@ -93,10 +95,10 @@ class ApiDependencies:
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
model_metadata_service = ModelMetadataStoreSQL(db=db)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
|
||||
model_record_service=ModelRecordServiceSQL(db=db),
|
||||
download_queue=download_queue_service,
|
||||
events=events,
|
||||
)
|
||||
@ -120,6 +122,7 @@ class ApiDependencies:
|
||||
images=images,
|
||||
invocation_cache=invocation_cache,
|
||||
logger=logger,
|
||||
model_images=model_images_service,
|
||||
model_manager=model_manager,
|
||||
download_queue=download_queue_service,
|
||||
names=names,
|
||||
|
@ -1,27 +1,26 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""FastAPI route for model configuration records."""
|
||||
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelRecordOrderBy,
|
||||
ModelSummary,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@ -30,14 +29,15 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
|
||||
|
||||
# images are immutable; set a high max-age
|
||||
IMAGE_MAX_AGE = 31536000
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
"""Return list of configs."""
|
||||
@ -47,15 +47,6 @@ class ModelsList(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class ModelTagSet(BaseModel):
|
||||
"""Return tags for a set of models."""
|
||||
|
||||
key: str
|
||||
name: str
|
||||
author: str
|
||||
tags: Set[str]
|
||||
|
||||
|
||||
##############################################################################
|
||||
# These are example inputs and outputs that are used in places where Swagger
|
||||
# is unable to generate a correct example.
|
||||
@ -66,19 +57,16 @@ example_model_config = {
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config": "string",
|
||||
"config_path": "string",
|
||||
"key": "string",
|
||||
"original_hash": "string",
|
||||
"current_hash": "string",
|
||||
"hash": "string",
|
||||
"description": "string",
|
||||
"source": "string",
|
||||
"last_modified": 0,
|
||||
"vae": "string",
|
||||
"converted_at": 0,
|
||||
"variant": "normal",
|
||||
"prediction_type": "epsilon",
|
||||
"repo_variant": "fp16",
|
||||
"upcast_attention": False,
|
||||
"ztsnr_training": False,
|
||||
}
|
||||
|
||||
example_model_input = {
|
||||
@ -87,50 +75,12 @@ example_model_input = {
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config": "configs/stable-diffusion/v1-inference.yaml",
|
||||
"config_path": "configs/stable-diffusion/v1-inference.yaml",
|
||||
"description": "Model description",
|
||||
"vae": None,
|
||||
"variant": "normal",
|
||||
}
|
||||
|
||||
example_model_metadata = {
|
||||
"name": "ip_adapter_sd_image_encoder",
|
||||
"author": "InvokeAI",
|
||||
"tags": [
|
||||
"transformers",
|
||||
"safetensors",
|
||||
"clip_vision_model",
|
||||
"endpoints_compatible",
|
||||
"region:us",
|
||||
"has_space",
|
||||
"license:apache-2.0",
|
||||
],
|
||||
"files": [
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
|
||||
"path": "ip_adapter_sd_image_encoder/README.md",
|
||||
"size": 628,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
|
||||
"path": "ip_adapter_sd_image_encoder/config.json",
|
||||
"size": 560,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
|
||||
"path": "ip_adapter_sd_image_encoder/model.safetensors",
|
||||
"size": 2528373448,
|
||||
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
|
||||
},
|
||||
],
|
||||
"type": "huggingface",
|
||||
"id": "InvokeAI/ip_adapter_sd_image_encoder",
|
||||
"tag_dict": {"license": "apache-2.0"},
|
||||
"last_modified": "2023-09-23T17:33:25Z",
|
||||
}
|
||||
|
||||
##############################################################################
|
||||
# ROUTES
|
||||
##############################################################################
|
||||
@ -162,6 +112,9 @@ async def list_model_records(
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||
)
|
||||
for model in found_models:
|
||||
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key)
|
||||
model.cover_image = cover_image
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@ -205,53 +158,23 @@ async def get_model_record(
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
config: AnyModelConfig = record_store.get_model(key)
|
||||
cover_image = ApiDependencies.invoker.services.model_images.get_url(key)
|
||||
config.cover_image = cover_image
|
||||
return config
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
async def list_model_summary(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of models per page"),
|
||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Gets a page of model summary data."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
return results
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/i/{key}/metadata",
|
||||
operation_id="get_model_metadata",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model metadata was retrieved successfully",
|
||||
"content": {"application/json": {"example": example_model_metadata}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def get_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Get a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
)
|
||||
async def list_tags() -> Set[str]:
|
||||
"""Get a unique set of all the model tags."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Set[str] = record_store.list_tags()
|
||||
return result
|
||||
# @model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
# async def list_model_summary(
|
||||
# page: int = Query(default=0, description="The page to get"),
|
||||
# per_page: int = Query(default=10, description="The number of models per page"),
|
||||
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
# ) -> PaginatedResults[ModelSummary]:
|
||||
# """Gets a page of model summary data."""
|
||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
# return results
|
||||
|
||||
|
||||
class FoundModel(BaseModel):
|
||||
@ -323,19 +246,6 @@ async def scan_for_models(
|
||||
return scan_results
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags/search",
|
||||
operation_id="search_by_metadata_tags",
|
||||
)
|
||||
async def search_by_metadata_tags(
|
||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results = record_store.search_by_metadata_tag(tags)
|
||||
return ModelsList(models=results)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
@ -352,15 +262,13 @@ async def search_by_metadata_tags(
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
info: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
|
||||
) -> AnyModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
"""Update a model's config."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
model_response: AnyModelConfig = record_store.update_model(key, config=info)
|
||||
model_response: AnyModelConfig = record_store.update_model(key, changes=changes)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -370,16 +278,85 @@ async def update_model_record(
|
||||
return model_response
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/i/{key}/image",
|
||||
operation_id="get_model_image",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model image was fetched successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model image could not be found"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def get_model_image(
|
||||
key: str = Path(description="The name of model image file to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets an image file that previews the model"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.model_images.get_path(key)
|
||||
|
||||
response = FileResponse(
|
||||
path,
|
||||
media_type="image/png",
|
||||
filename=key + ".png",
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}/image",
|
||||
operation_id="update_model_image",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model image was updated successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def update_model_image(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
image: UploadFile,
|
||||
) -> None:
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await image.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_images = ApiDependencies.invoker.services.model_images
|
||||
try:
|
||||
model_images.save(pil_image, key)
|
||||
logger.info(f"Updated image for model: {key}")
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="del_model_record",
|
||||
operation_id="delete_model",
|
||||
responses={
|
||||
204: {"description": "Model deleted successfully"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=204,
|
||||
)
|
||||
async def del_model_record(
|
||||
async def delete_model(
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
) -> Response:
|
||||
"""
|
||||
@ -400,42 +377,62 @@ async def del_model_record(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/i/",
|
||||
operation_id="add_model_record",
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}/image",
|
||||
operation_id="delete_model_image",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The model added successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
204: {"description": "Model image deleted successfully"},
|
||||
404: {"description": "Model image not found"},
|
||||
},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=201,
|
||||
status_code=204,
|
||||
)
|
||||
async def add_model_record(
|
||||
config: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
) -> AnyModelConfig:
|
||||
"""Add a model using the configuration information appropriate for its type."""
|
||||
async def delete_model_image(
|
||||
key: str = Path(description="Unique key of model image to remove from model_images directory."),
|
||||
) -> None:
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
if config.key == "<NOKEY>":
|
||||
config.key = sha1(randbytes(100)).hexdigest()
|
||||
logger.info(f"Created model {config.key} for {config.name}")
|
||||
model_images = ApiDependencies.invoker.services.model_images
|
||||
try:
|
||||
record_store.add_model(config.key, config)
|
||||
except DuplicateModelException as e:
|
||||
model_images.delete(key)
|
||||
logger.info(f"Deleted model image: {key}")
|
||||
return
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
# now fetch it out
|
||||
result: AnyModelConfig = record_store.get_model(config.key)
|
||||
return result
|
||||
|
||||
# @model_manager_router.post(
|
||||
# "/i/",
|
||||
# operation_id="add_model_record",
|
||||
# responses={
|
||||
# 201: {
|
||||
# "description": "The model added successfully",
|
||||
# "content": {"application/json": {"example": example_model_config}},
|
||||
# },
|
||||
# 409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
# 415: {"description": "Unrecognized file/folder format"},
|
||||
# },
|
||||
# status_code=201,
|
||||
# )
|
||||
# async def add_model_record(
|
||||
# config: Annotated[
|
||||
# AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
# ],
|
||||
# ) -> AnyModelConfig:
|
||||
# """Add a model using the configuration information appropriate for its type."""
|
||||
# logger = ApiDependencies.invoker.services.logger
|
||||
# record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
# try:
|
||||
# record_store.add_model(config)
|
||||
# except DuplicateModelException as e:
|
||||
# logger.error(str(e))
|
||||
# raise HTTPException(status_code=409, detail=str(e))
|
||||
# except InvalidModelException as e:
|
||||
# logger.error(str(e))
|
||||
# raise HTTPException(status_code=415)
|
||||
|
||||
# # now fetch it out
|
||||
# result: AnyModelConfig = record_store.get_model(config.key)
|
||||
# return result
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
@ -451,6 +448,7 @@ async def add_model_record(
|
||||
)
|
||||
async def install_model(
|
||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
||||
# TODO(MM2): Can we type this?
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
@ -493,6 +491,7 @@ async def install_model(
|
||||
source=source,
|
||||
config=config,
|
||||
access_token=access_token,
|
||||
inplace=bool(inplace),
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
@ -508,10 +507,10 @@ async def install_model(
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/import",
|
||||
operation_id="list_model_install_jobs",
|
||||
"/install",
|
||||
operation_id="list_model_installs",
|
||||
)
|
||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
async def list_model_installs() -> List[ModelInstallJob]:
|
||||
"""Return the list of model install jobs.
|
||||
|
||||
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
||||
@ -525,9 +524,8 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
* "cancelled" -- Job was cancelled before completion.
|
||||
|
||||
Once completed, information about the model such as its size, base
|
||||
model, type, and metadata can be retrieved from the `config_out`
|
||||
field. For multi-file models such as diffusers, information on individual files
|
||||
can be retrieved from `download_parts`.
|
||||
model and type can be retrieved from the `config_out` field. For multi-file models such as diffusers,
|
||||
information on individual files can be retrieved from `download_parts`.
|
||||
|
||||
See the example and schema below for more information.
|
||||
"""
|
||||
@ -536,7 +534,7 @@ async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/import/{id}",
|
||||
"/install/{id}",
|
||||
operation_id="get_model_install_job",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
@ -556,7 +554,7 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/import/{id}",
|
||||
"/install/{id}",
|
||||
operation_id="cancel_model_install_job",
|
||||
responses={
|
||||
201: {"description": "The job was cancelled successfully"},
|
||||
@ -574,8 +572,8 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
||||
installer.cancel_job(job)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/import",
|
||||
@model_manager_router.delete(
|
||||
"/install",
|
||||
operation_id="prune_model_install_jobs",
|
||||
responses={
|
||||
204: {"description": "All completed and errored jobs have been pruned"},
|
||||
@ -645,7 +643,7 @@ async def convert_model(
|
||||
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||
|
||||
# loading the model will convert it into a cached diffusers file
|
||||
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
|
||||
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
|
||||
|
||||
# Get the path of the converted model from the loader
|
||||
cache_path = loader.convert_cache.cache_path(key)
|
||||
@ -654,7 +652,8 @@ async def convert_model(
|
||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||
original_name = model_config.name
|
||||
model_config.name = f"{original_name}.DELETE"
|
||||
store.update_model(key, config=model_config)
|
||||
changes = ModelRecordChanges(name=model_config.name)
|
||||
store.update_model(key, changes=changes)
|
||||
|
||||
# install the diffusers
|
||||
try:
|
||||
@ -663,7 +662,7 @@ async def convert_model(
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"original_hash": model_config.original_hash,
|
||||
"hash": model_config.hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
)
|
||||
@ -671,10 +670,6 @@ async def convert_model(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
# get the original metadata
|
||||
if orig_metadata := store.get_metadata(key):
|
||||
store.metadata_store.add_metadata(new_key, orig_metadata)
|
||||
|
||||
# delete the original safetensors file
|
||||
installer.delete(key)
|
||||
|
||||
@ -686,66 +681,66 @@ async def convert_model(
|
||||
return new_config
|
||||
|
||||
|
||||
@model_manager_router.put(
|
||||
"/merge",
|
||||
operation_id="merge",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Model converted successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "Model not found"},
|
||||
409: {"description": "There is already a model registered at this location"},
|
||||
},
|
||||
)
|
||||
async def merge(
|
||||
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
force: bool = Body(
|
||||
description="Force merging of models created with different versions of diffusers",
|
||||
default=False,
|
||||
),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||
merge_dest_directory: Optional[str] = Body(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
),
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||
```
|
||||
Argument Description [default]
|
||||
-------- ----------------------
|
||||
keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||
merged_model_name Name for the merged model [Concat model names]
|
||||
alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||
force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||
```
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
merger = ModelMerger(installer)
|
||||
model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
response = merger.merge_diffusion_models_and_save(
|
||||
model_keys=keys,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
except UnknownModelException:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"One or more of the models '{keys}' not found",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
# @model_manager_router.put(
|
||||
# "/merge",
|
||||
# operation_id="merge",
|
||||
# responses={
|
||||
# 200: {
|
||||
# "description": "Model converted successfully",
|
||||
# "content": {"application/json": {"example": example_model_config}},
|
||||
# },
|
||||
# 400: {"description": "Bad request"},
|
||||
# 404: {"description": "Model not found"},
|
||||
# 409: {"description": "There is already a model registered at this location"},
|
||||
# },
|
||||
# )
|
||||
# async def merge(
|
||||
# keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
# merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
# alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
# force: bool = Body(
|
||||
# description="Force merging of models created with different versions of diffusers",
|
||||
# default=False,
|
||||
# ),
|
||||
# interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||
# merge_dest_directory: Optional[str] = Body(
|
||||
# description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
# default=None,
|
||||
# ),
|
||||
# ) -> AnyModelConfig:
|
||||
# """
|
||||
# Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||
# ```
|
||||
# Argument Description [default]
|
||||
# -------- ----------------------
|
||||
# keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||
# merged_model_name Name for the merged model [Concat model names]
|
||||
# alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||
# force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||
# interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
# merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||
# ```
|
||||
# """
|
||||
# logger = ApiDependencies.invoker.services.logger
|
||||
# try:
|
||||
# logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
# dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
# installer = ApiDependencies.invoker.services.model_manager.install
|
||||
# merger = ModelMerger(installer)
|
||||
# model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
# response = merger.merge_diffusion_models_and_save(
|
||||
# model_keys=keys,
|
||||
# merged_model_name=merged_model_name or "+".join(model_names),
|
||||
# alpha=alpha,
|
||||
# interp=interp,
|
||||
# force=force,
|
||||
# merge_dest_directory=dest,
|
||||
# )
|
||||
# except UnknownModelException:
|
||||
# raise HTTPException(
|
||||
# status_code=404,
|
||||
# detail=f"One or more of the models '{keys}' not found",
|
||||
# )
|
||||
# except ValueError as e:
|
||||
# raise HTTPException(status_code=400, detail=str(e))
|
||||
# return response
|
||||
|
@ -3,10 +3,8 @@
|
||||
# values from the command line or config file.
|
||||
import sys
|
||||
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
@ -19,6 +17,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import socket
|
||||
from contextlib import asynccontextmanager
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@ -39,6 +38,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
from .api.dependencies import ApiDependencies
|
||||
@ -58,6 +58,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
BaseInvocation,
|
||||
UIConfigBase,
|
||||
)
|
||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
|
||||
if is_mps_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
@ -71,9 +72,25 @@ logger = InvokeAILogger.get_logger(config=app_config)
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Add startup event to load dependencies
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
yield
|
||||
# Shut down threads
|
||||
ApiDependencies.shutdown()
|
||||
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
app = FastAPI(title="Invoke - Community Edition", docs_url=None, redoc_url=None, separate_input_output_schemas=False)
|
||||
app = FastAPI(
|
||||
title="Invoke - Community Edition",
|
||||
docs_url=None,
|
||||
redoc_url=None,
|
||||
separate_input_output_schemas=False,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add event handler
|
||||
event_handler_id: int = id(app)
|
||||
@ -96,18 +113,6 @@ app.add_middleware(
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
|
||||
|
||||
# Shut down threads
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event() -> None:
|
||||
ApiDependencies.shutdown()
|
||||
|
||||
|
||||
# Include all routers
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
app.include_router(model_manager.model_manager_router, prefix="/api")
|
||||
|
@ -1,17 +1,11 @@
|
||||
from typing import Iterator, List, Optional, Tuple, Union
|
||||
from typing import Iterator, List, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
@ -25,13 +19,8 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
)
|
||||
from invokeai.backend.util.devices import torch_dtype
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .model import ClipField
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from .model import CLIPField
|
||||
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
|
||||
@ -57,7 +46,7 @@ class CompelInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.compel_prompt,
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
clip: ClipField = InputField(
|
||||
clip: CLIPField = InputField(
|
||||
title="CLIP",
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
@ -65,16 +54,16 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||
tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||
text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
@ -138,18 +127,18 @@ class SDXLPromptInvocationBase:
|
||||
def run_clip_compel(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
clip_field: ClipField,
|
||||
clip_field: CLIPField,
|
||||
prompt: str,
|
||||
get_pooled: bool,
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||
text_encoder_info = context.models.load(clip_field.text_encoder)
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
|
||||
# return zero on empty
|
||||
if prompt == "" and zero_on_empty:
|
||||
@ -174,7 +163,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_model = lora_info.model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
yield (lora_model, lora.weight)
|
||||
@ -196,7 +185,8 @@ class SDXLPromptInvocationBase:
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
||||
):
|
||||
assert isinstance(text_encoder, CLIPTextModel)
|
||||
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
text_encoder = cast(CLIPTextModel, text_encoder)
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
@ -263,8 +253,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
target_width: int = InputField(default=1024, description="")
|
||||
target_height: int = InputField(default=1024, description="")
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
@ -350,7 +340,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
crop_top: int = InputField(default=0, description="")
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
@ -380,10 +370,10 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
|
||||
|
||||
@invocation_output("clip_skip_output")
|
||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""Clip skip node output"""
|
||||
class CLIPSkipInvocationOutput(BaseInvocationOutput):
|
||||
"""CLIP skip node output"""
|
||||
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -393,15 +383,15 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ClipSkipInvocation(BaseInvocation):
|
||||
class CLIPSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> CLIPSkipInvocationOutput:
|
||||
self.clip.skipped_layers += self.skipped_layers
|
||||
return ClipSkipInvocationOutput(
|
||||
return CLIPSkipInvocationOutput(
|
||||
clip=self.clip,
|
||||
)
|
||||
|
||||
|
@ -34,6 +34,7 @@ from invokeai.app.invocations.fields import (
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@ -51,15 +52,9 @@ CONTROLNET_RESIZE_VALUES = Literal[
|
||||
]
|
||||
|
||||
|
||||
class ControlNetModelField(BaseModel):
|
||||
"""ControlNet model field"""
|
||||
|
||||
key: str = Field(description="Model config record key for the ControlNet model")
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
|
||||
control_model: ModelField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
@ -95,7 +90,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||
control_model: ModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
|
@ -228,7 +228,7 @@ class ConditioningField(BaseModel):
|
||||
# endregion
|
||||
|
||||
|
||||
class MetadataField(RootModel):
|
||||
class MetadataField(RootModel[dict[str, Any]]):
|
||||
"""
|
||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||
Metadata is stored without a strict schema.
|
||||
|
@ -11,25 +11,17 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
|
||||
|
||||
# LS: Consider moving these two classes into model.py
|
||||
class IPAdapterModelField(BaseModel):
|
||||
key: str = Field(description="Key to the IP-Adapter model")
|
||||
|
||||
|
||||
class CLIPVisionModelField(BaseModel):
|
||||
key: str = Field(description="Key to the CLIP Vision image encoder model")
|
||||
from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig, ModelType
|
||||
|
||||
|
||||
class IPAdapterField(BaseModel):
|
||||
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
|
||||
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
|
||||
ip_adapter_model: ModelField = Field(description="The IP-Adapter model to use.")
|
||||
image_encoder_model: ModelField = Field(description="The name of the CLIP image encoder model.")
|
||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||
@ -62,7 +54,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
|
||||
ip_adapter_model: IPAdapterModelField = InputField(
|
||||
ip_adapter_model: ModelField = InputField(
|
||||
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
|
||||
)
|
||||
|
||||
@ -90,18 +82,18 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||
assert isinstance(ip_adapter_info, IPAdapterConfig)
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
image_encoder_models = context.models.search_by_attrs(
|
||||
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
|
||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||
)
|
||||
assert len(image_encoder_models) == 1
|
||||
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
|
||||
return IPAdapterOutput(
|
||||
ip_adapter=IPAdapterField(
|
||||
image=self.image,
|
||||
ip_adapter_model=self.ip_adapter_model,
|
||||
image_encoder_model=image_encoder_model,
|
||||
image_encoder_model=ModelField(key=image_encoder_models[0].key),
|
||||
weight=self.weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
|
@ -26,6 +26,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from PIL import Image, ImageFilter
|
||||
from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.invocations.fields import (
|
||||
@ -75,7 +76,7 @@ from .baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
from .model import ModelField, UNetField, VAEField
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
@ -118,7 +119,7 @@ class SchedulerInvocation(BaseInvocation):
|
||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
|
||||
vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
||||
@ -153,7 +154,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
if image_tensor is not None:
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
@ -173,6 +174,16 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("gradient_mask_output")
|
||||
class GradientMaskOutput(BaseInvocationOutput):
|
||||
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
|
||||
|
||||
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||
expanded_mask_area: ImageField = OutputField(
|
||||
description="Image representing the total gradient area of the mask. For paste-back purposes."
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"create_gradient_mask",
|
||||
title="Create Gradient Mask",
|
||||
@ -193,49 +204,53 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
if self.edge_radius > 0:
|
||||
if self.coherence_mode == "Box Blur":
|
||||
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||
else: # Gaussian Blur OR Staged
|
||||
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||
|
||||
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||
|
||||
# redistribute blur so that the edges are 0 and blur out to 1
|
||||
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
|
||||
threshold = 1 - self.minimum_denoise
|
||||
|
||||
if self.coherence_mode == "Staged":
|
||||
# wherever the blur_tensor is masked to any degree, convert it to threshold
|
||||
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
|
||||
# wherever the blur_tensor is less than fully masked, convert it to threshold
|
||||
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
|
||||
else:
|
||||
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||
|
||||
# multiply original mask to force actually masked regions to 0
|
||||
blur_tensor = mask_tensor * blur_tensor
|
||||
else:
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
|
||||
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
||||
|
||||
return DenoiseMaskOutput.build(
|
||||
mask_name=mask_name,
|
||||
masked_latents_name=None,
|
||||
gradient=True,
|
||||
# compute a [0, 1] mask from the blur_tensor
|
||||
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
|
||||
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
||||
expanded_image_dto = context.images.save(expanded_mask_image)
|
||||
|
||||
return GradientMaskOutput(
|
||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
|
||||
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
||||
)
|
||||
|
||||
|
||||
def get_scheduler(
|
||||
context: InvocationContext,
|
||||
scheduler_info: ModelInfo,
|
||||
scheduler_info: ModelField,
|
||||
scheduler_name: str,
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
|
||||
orig_scheduler_info = context.models.load(scheduler_info)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
@ -360,7 +375,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = c.extra_conditioning
|
||||
|
||||
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
@ -370,7 +384,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
text_embeddings=c,
|
||||
guidance_scale=self.cfg_scale,
|
||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
extra=extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0, # threshold,
|
||||
warmup=0.2, # warmup,
|
||||
@ -449,7 +462,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# and if weight is None, populate with default 1.0?
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
|
||||
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||
|
||||
# control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
@ -511,11 +524,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
|
||||
context.models.load(single_ip_adapter.ip_adapter_model)
|
||||
)
|
||||
|
||||
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
|
||||
|
||||
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
@ -526,6 +538,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||
single_ipa_images, image_encoder_model
|
||||
@ -565,8 +578,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
@ -719,12 +732,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.models.load(**self.unet.unet.model_dump())
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
@ -777,10 +791,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
denoising_end=self.denoising_end,
|
||||
)
|
||||
|
||||
(
|
||||
result_latents,
|
||||
result_attention_map_saver,
|
||||
) = pipeline.latents_from_embeddings(
|
||||
result_latents = pipeline.latents_from_embeddings(
|
||||
latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
@ -821,7 +832,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VaeField = InputField(
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
@ -832,8 +843,8 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL))
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
latents = latents.to(vae.device)
|
||||
@ -999,7 +1010,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode",
|
||||
)
|
||||
vae: VaeField = InputField(
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
@ -1055,7 +1066,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
|
@ -8,7 +8,10 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.controlnet_image_processors import (
|
||||
CONTROLNET_MODE_VALUES,
|
||||
CONTROLNET_RESIZE_VALUES,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
@ -17,10 +20,8 @@ from invokeai.app.invocations.fields import (
|
||||
OutputField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
|
||||
from ...version import __version__
|
||||
|
||||
@ -30,10 +31,20 @@ class MetadataItemField(BaseModel):
|
||||
value: Any = Field(description=FieldDescriptions.metadata_item_value)
|
||||
|
||||
|
||||
class ModelMetadataField(BaseModel):
|
||||
"""Model Metadata Field"""
|
||||
|
||||
key: str
|
||||
hash: str
|
||||
name: str
|
||||
base: BaseModelType
|
||||
type: ModelType
|
||||
|
||||
|
||||
class LoRAMetadataField(BaseModel):
|
||||
"""LoRA Metadata Field"""
|
||||
|
||||
model: LoRAModelField = Field(description=FieldDescriptions.lora_model)
|
||||
model: ModelMetadataField = Field(description=FieldDescriptions.lora_model)
|
||||
weight: float = Field(description=FieldDescriptions.lora_weight)
|
||||
|
||||
|
||||
@ -41,7 +52,7 @@ class IPAdapterMetadataField(BaseModel):
|
||||
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
|
||||
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: IPAdapterModelField = Field(
|
||||
ip_adapter_model: ModelMetadataField = Field(
|
||||
description="The IP-Adapter model.",
|
||||
)
|
||||
weight: Union[float, list[float]] = Field(
|
||||
@ -51,6 +62,33 @@ class IPAdapterMetadataField(BaseModel):
|
||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||
|
||||
|
||||
class T2IAdapterMetadataField(BaseModel):
|
||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||
t2i_adapter_model: ModelMetadataField = Field(description="The T2I-Adapter model to use.")
|
||||
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
|
||||
)
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
|
||||
class ControlNetMetadataField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelMetadataField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
|
||||
@invocation_output("metadata_item_output")
|
||||
class MetadataItemOutput(BaseInvocationOutput):
|
||||
"""Metadata Item Output"""
|
||||
@ -140,14 +178,14 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
default=None,
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: Optional[MainModelField] = InputField(default=None, description="The main model used for inference")
|
||||
controlnets: Optional[list[ControlField]] = InputField(
|
||||
model: Optional[ModelMetadataField] = InputField(default=None, description="The main model used for inference")
|
||||
controlnets: Optional[list[ControlNetMetadataField]] = InputField(
|
||||
default=None, description="The ControlNets used for inference"
|
||||
)
|
||||
ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField(
|
||||
default=None, description="The IP Adapters used for inference"
|
||||
)
|
||||
t2iAdapters: Optional[list[T2IAdapterField]] = InputField(
|
||||
t2iAdapters: Optional[list[T2IAdapterMetadataField]] = InputField(
|
||||
default=None, description="The IP Adapters used for inference"
|
||||
)
|
||||
loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference")
|
||||
@ -159,7 +197,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
default=None,
|
||||
description="The name of the initial image",
|
||||
)
|
||||
vae: Optional[VAEModelField] = InputField(
|
||||
vae: Optional[ModelMetadataField] = InputField(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
@ -190,7 +228,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Optional[MainModelField] = InputField(
|
||||
refiner_model: Optional[ModelMetadataField] = InputField(
|
||||
default=None,
|
||||
description="The SDXL Refiner model used",
|
||||
)
|
||||
@ -222,10 +260,9 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
return MetadataOutput(
|
||||
metadata=MetadataField.model_validate(
|
||||
self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
||||
)
|
||||
)
|
||||
as_dict = self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
||||
as_dict["app_version"] = __version__
|
||||
|
||||
return MetadataOutput(metadata=MetadataField.model_validate(as_dict))
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
@ -6,8 +6,8 @@ from pydantic import BaseModel, Field
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
|
||||
from ...backend.model_manager import SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -16,33 +16,33 @@ from .baseinvocation import (
|
||||
)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
class ModelField(BaseModel):
|
||||
key: str = Field(description="Key of the model")
|
||||
submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None)
|
||||
|
||||
|
||||
class LoraInfo(ModelInfo):
|
||||
weight: float = Field(description="Lora's weight which to use when apply to model")
|
||||
class LoRAField(BaseModel):
|
||||
lora: ModelField = Field(description="Info to load lora model")
|
||||
weight: float = Field(description="Weight to apply to lora model")
|
||||
|
||||
|
||||
class UNetField(BaseModel):
|
||||
unet: ModelInfo = Field(description="Info to load unet submodel")
|
||||
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
unet: ModelField = Field(description="Info to load unet submodel")
|
||||
scheduler: ModelField = Field(description="Info to load scheduler submodel")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
||||
|
||||
|
||||
class ClipField(BaseModel):
|
||||
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
|
||||
class CLIPField(BaseModel):
|
||||
tokenizer: ModelField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelField = Field(description="Info to load text_encoder submodel")
|
||||
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class VaeField(BaseModel):
|
||||
# TODO: better naming?
|
||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||
class VAEField(BaseModel):
|
||||
vae: ModelField = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
|
||||
|
||||
@ -57,14 +57,14 @@ class UNetOutput(BaseInvocationOutput):
|
||||
class VAEOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a VAE field"""
|
||||
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation_output("clip_output")
|
||||
class CLIPOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a CLIP field"""
|
||||
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation_output("model_loader_output")
|
||||
@ -74,18 +74,6 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
||||
pass
|
||||
|
||||
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
key: str = Field(description="Model key")
|
||||
|
||||
|
||||
class LoRAModelField(BaseModel):
|
||||
"""LoRA model field"""
|
||||
|
||||
key: str = Field(description="LoRA model key")
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
title="Main Model",
|
||||
@ -96,62 +84,40 @@ class LoRAModelField(BaseModel):
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||
model: ModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.models.exists(key):
|
||||
raise Exception(f"Unknown model {key}")
|
||||
if not context.models.exists(self.model.key):
|
||||
raise Exception(f"Unknown model {self.model.key}")
|
||||
|
||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("lora_loader_output")
|
||||
class LoraLoaderOutput(BaseInvocationOutput):
|
||||
class LoRALoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
class LoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None,
|
||||
@ -159,46 +125,41 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[ClipField] = InputField(
|
||||
clip: Optional[CLIPField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise Exception(f"Unkown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to clip')
|
||||
|
||||
output = LoraLoaderOutput()
|
||||
output = LoRALoaderOutput()
|
||||
|
||||
if self.unet is not None:
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet = self.unet.model_copy(deep=True)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@ -207,12 +168,12 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation_output("sdxl_lora_loader_output")
|
||||
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL LoRA Loader Output"""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -222,10 +183,10 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
category="model",
|
||||
version="1.0.1",
|
||||
)
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None,
|
||||
@ -233,65 +194,59 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
clip: Optional[ClipField] = InputField(
|
||||
clip: Optional[CLIPField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP 1",
|
||||
)
|
||||
clip2: Optional[ClipField] = InputField(
|
||||
clip2: Optional[CLIPField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP 2",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise Exception(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to clip')
|
||||
|
||||
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip2')
|
||||
if self.clip2 is not None and any(lora.lora.key == lora_key for lora in self.clip2.loras):
|
||||
raise Exception(f'LoRA "{lora_key}" already applied to clip2')
|
||||
|
||||
output = SDXLLoraLoaderOutput()
|
||||
output = SDXLLoRALoaderOutput()
|
||||
|
||||
if self.unet is not None:
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet = self.unet.model_copy(deep=True)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip is not None:
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip = self.clip.model_copy(deep=True)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
if self.clip2 is not None:
|
||||
output.clip2 = copy.deepcopy(self.clip2)
|
||||
output.clip2 = self.clip2.model_copy(deep=True)
|
||||
output.clip2.loras.append(
|
||||
LoraInfo(
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@ -299,17 +254,11 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
return output
|
||||
|
||||
|
||||
class VAEModelField(BaseModel):
|
||||
"""Vae model field"""
|
||||
|
||||
key: str = Field(description="Model's key")
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
class VAELoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
vae_model: VAEModelField = InputField(
|
||||
vae_model: ModelField = InputField(
|
||||
description=FieldDescriptions.vae_model,
|
||||
input=Input.Direct,
|
||||
title="VAE",
|
||||
@ -321,7 +270,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
if not context.models.exists(key):
|
||||
raise Exception(f"Unkown vae: {key}!")
|
||||
|
||||
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
|
||||
return VAEOutput(vae=VAEField(vae=self.vae_model))
|
||||
|
||||
|
||||
@invocation_output("seamless_output")
|
||||
@ -329,7 +278,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
|
||||
"""Modified Seamless Model output"""
|
||||
|
||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
vae: Optional[VaeField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
|
||||
vae: Optional[VAEField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -348,7 +297,7 @@ class SeamlessModeInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
vae: Optional[VaeField] = InputField(
|
||||
vae: Optional[VAEField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.vae_model,
|
||||
input=Input.Connection,
|
||||
|
@ -8,7 +8,7 @@ from .baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
|
||||
from .model import CLIPField, ModelField, UNetField, VAEField
|
||||
|
||||
|
||||
@invocation_output("sdxl_model_loader_output")
|
||||
@ -16,9 +16,9 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL base model loader output"""
|
||||
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation_output("sdxl_refiner_model_loader_output")
|
||||
@ -26,15 +26,15 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
"""SDXL refiner model loader output"""
|
||||
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1")
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
model: MainModelField = InputField(
|
||||
model: ModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
||||
)
|
||||
# TODO: precision?
|
||||
@ -46,48 +46,19 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
if not context.models.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
return SDXLModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
||||
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
|
||||
|
||||
@ -101,10 +72,8 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
||||
model: MainModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model,
|
||||
input=Input.Direct,
|
||||
ui_type=UIType.SDXLRefinerModel,
|
||||
model: ModelField = InputField(
|
||||
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
@ -115,34 +84,14 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
if not context.models.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
return SDXLRefinerModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
|
@ -10,17 +10,14 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
class T2IAdapterModelField(BaseModel):
|
||||
key: str = Field(description="Model record key for the T2I-Adapter model")
|
||||
|
||||
|
||||
class T2IAdapterField(BaseModel):
|
||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
|
||||
t2i_adapter_model: ModelField = Field(description="The T2I-Adapter model to use.")
|
||||
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||
@ -55,7 +52,7 @@ class T2IAdapterInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||
t2i_adapter_model: T2IAdapterModelField = InputField(
|
||||
t2i_adapter_model: ModelField = InputField(
|
||||
description="The T2I-Adapter model.",
|
||||
title="T2I-Adapter Model",
|
||||
input=Input.Direct,
|
||||
|
@ -166,6 +166,7 @@ two configs are kept in separate sections of the config file:
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
@ -255,6 +256,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
|
||||
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
|
||||
profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
|
||||
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce startup time.", json_schema_extra=Categories.Development)
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Init file for download queue."""
|
||||
|
||||
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
|
||||
from .download_default import DownloadQueueService, TqdmProgress
|
||||
|
||||
|
@ -224,20 +224,13 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
job.job_started = get_iso_timestamp()
|
||||
self._do_download(job)
|
||||
self._signal_job_complete(job)
|
||||
|
||||
except (OSError, HTTPError) as excp:
|
||||
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
||||
job.error = traceback.format_exc()
|
||||
try:
|
||||
self._signal_job_error(job, excp)
|
||||
except:
|
||||
pass
|
||||
except DownloadJobCancelledException:
|
||||
try:
|
||||
self._signal_job_cancelled(job)
|
||||
self._cleanup_cancelled_job(job)
|
||||
except:
|
||||
pass
|
||||
|
||||
finally:
|
||||
job.job_ended = get_iso_timestamp()
|
||||
|
@ -41,8 +41,9 @@ class InvocationCacheBase(ABC):
|
||||
"""Clears the cache"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def create_key(self, invocation: BaseInvocation) -> int:
|
||||
def create_key(invocation: BaseInvocation) -> int:
|
||||
"""Gets the key for the invocation's cache item"""
|
||||
pass
|
||||
|
||||
|
@ -61,9 +61,7 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
self._delete_oldest_access(number_to_delete)
|
||||
self._cache[key] = CachedItem(
|
||||
invocation_output,
|
||||
invocation_output.model_dump_json(
|
||||
warnings=False, exclude_defaults=True, exclude_unset=True, include={"type"}
|
||||
),
|
||||
invocation_output.model_dump_json(warnings=False, exclude_defaults=True, exclude_unset=True),
|
||||
)
|
||||
|
||||
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
||||
@ -81,7 +79,7 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
with self._lock:
|
||||
return self._delete(key)
|
||||
|
||||
def clear(self, *args, **kwargs) -> None:
|
||||
def clear(self) -> None:
|
||||
with self._lock:
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
|
@ -25,6 +25,7 @@ if TYPE_CHECKING:
|
||||
from .images.images_base import ImageServiceABC
|
||||
from .invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||
from .model_images.model_images_base import ModelImageFileStorageBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
@ -49,6 +50,7 @@ class InvocationServices:
|
||||
image_files: "ImageFileStorageBase",
|
||||
image_records: "ImageRecordStorageBase",
|
||||
logger: "Logger",
|
||||
model_images: "ModelImageFileStorageBase",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
@ -72,6 +74,7 @@ class InvocationServices:
|
||||
self.image_files = image_files
|
||||
self.image_records = image_records
|
||||
self.logger = logger
|
||||
self.model_images = model_images
|
||||
self.model_manager = model_manager
|
||||
self.download_queue = download_queue
|
||||
self.performance_statistics = performance_statistics
|
||||
|
33
invokeai/app/services/model_images/model_images_base.py
Normal file
33
invokeai/app/services/model_images/model_images_base.py
Normal file
@ -0,0 +1,33 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
|
||||
class ModelImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, model_key: str) -> PILImageType:
|
||||
"""Retrieves a model image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, model_key: str) -> Path:
|
||||
"""Gets the internal path to a model image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, model_key: str) -> str | None:
|
||||
"""Gets the URL to fetch a model image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, image: PILImageType, model_key: str) -> None:
|
||||
"""Saves a model image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, model_key: str) -> None:
|
||||
"""Deletes a model image."""
|
||||
pass
|
20
invokeai/app/services/model_images/model_images_common.py
Normal file
20
invokeai/app/services/model_images/model_images_common.py
Normal file
@ -0,0 +1,20 @@
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
class ModelImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message="Model image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message="Model image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Model image file not deleted"):
|
||||
super().__init__(message)
|
79
invokeai/app/services/model_images/model_images_default.py
Normal file
79
invokeai/app/services/model_images/model_images_default.py
Normal file
@ -0,0 +1,79 @@
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.util.thumbnails import make_thumbnail
|
||||
|
||||
from .model_images_base import ModelImageFileStorageBase
|
||||
from .model_images_common import (
|
||||
ModelImageFileDeleteException,
|
||||
ModelImageFileNotFoundException,
|
||||
ModelImageFileSaveException,
|
||||
)
|
||||
|
||||
|
||||
class ModelImageFileStorageDisk(ModelImageFileStorageBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
def __init__(self, model_images_folder: Path):
|
||||
self._model_images_folder = model_images_folder
|
||||
self._validate_storage_folders()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def get(self, model_key: str) -> PILImageType:
|
||||
try:
|
||||
path = self.get_path(model_key)
|
||||
|
||||
if not self._validate_path(path):
|
||||
raise ModelImageFileNotFoundException
|
||||
|
||||
return Image.open(path)
|
||||
except FileNotFoundError as e:
|
||||
raise ModelImageFileNotFoundException from e
|
||||
|
||||
def save(self, image: PILImageType, model_key: str) -> None:
|
||||
try:
|
||||
self._validate_storage_folders()
|
||||
image_path = self._model_images_folder / (model_key + ".webp")
|
||||
thumbnail = make_thumbnail(image, 256)
|
||||
thumbnail.save(image_path, format="webp")
|
||||
|
||||
except Exception as e:
|
||||
raise ModelImageFileSaveException from e
|
||||
|
||||
def get_path(self, model_key: str) -> Path:
|
||||
path = self._model_images_folder / (model_key + ".webp")
|
||||
|
||||
return path
|
||||
|
||||
def get_url(self, model_key: str) -> str | None:
|
||||
path = self.get_path(model_key)
|
||||
if not self._validate_path(path):
|
||||
return
|
||||
|
||||
return self._invoker.services.urls.get_model_image_url(model_key)
|
||||
|
||||
def delete(self, model_key: str) -> None:
|
||||
try:
|
||||
path = self.get_path(model_key)
|
||||
|
||||
if not self._validate_path(path):
|
||||
raise ModelImageFileNotFoundException
|
||||
|
||||
send2trash(path)
|
||||
|
||||
except Exception as e:
|
||||
raise ModelImageFileDeleteException from e
|
||||
|
||||
def _validate_path(self, path: Path) -> bool:
|
||||
"""Validates the path given for an image."""
|
||||
return path.exists()
|
||||
|
||||
def _validate_storage_folders(self) -> None:
|
||||
"""Checks if the required folders exist and create them if they don't"""
|
||||
self._model_images_folder.mkdir(parents=True, exist_ok=True)
|
@ -18,16 +18,16 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
|
||||
WAITING = "waiting" # waiting to be dequeued
|
||||
DOWNLOADING = "downloading" # downloading of model files in process
|
||||
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
|
||||
RUNNING = "running" # being processed
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
@ -150,6 +150,13 @@ ModelSource = Annotated[
|
||||
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
|
||||
]
|
||||
|
||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||
URLModelSource: ModelSourceType.Url,
|
||||
HFModelSource: ModelSourceType.HFRepoID,
|
||||
CivitaiModelSource: ModelSourceType.CivitAI,
|
||||
LocalModelSource: ModelSourceType.Path,
|
||||
}
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
"""Object that tracks the current status of an install request."""
|
||||
@ -229,6 +236,11 @@ class ModelInstallJob(BaseModel):
|
||||
"""Return true if job is downloading."""
|
||||
return self.status == InstallStatus.DOWNLOADING
|
||||
|
||||
@property
|
||||
def downloads_done(self) -> bool:
|
||||
"""Return true if job's downloads ae done."""
|
||||
return self.status == InstallStatus.DOWNLOADS_DONE
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""Return true if job is running."""
|
||||
@ -254,7 +266,6 @@ class ModelInstallServiceBase(ABC):
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
metadata_store: ModelMetadataStoreBase,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
):
|
||||
"""
|
||||
@ -341,6 +352,7 @@ class ModelInstallServiceBase(ABC):
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
r"""Install the indicated model using heuristics to interpret user intentions.
|
||||
|
||||
@ -386,7 +398,7 @@ class ModelInstallServiceBase(ABC):
|
||||
will override corresponding autoassigned probe fields in the
|
||||
model's config record. Use it to override
|
||||
`name`, `description`, `base_type`, `model_type`, `format`,
|
||||
`prediction_type`, `image_size`, and/or `ztsnr_training`.
|
||||
`prediction_type`, and/or `image_size`.
|
||||
|
||||
This will download the model located at `source`,
|
||||
probe it, and install it into the models directory.
|
||||
|
@ -4,10 +4,10 @@ import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from random import randbytes
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
@ -21,14 +21,17 @@ from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
CheckpointConfigBase,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
@ -36,12 +39,14 @@ from invokeai.backend.model_manager.metadata import (
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||
from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
||||
|
||||
from .model_install_base import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
CivitaiModelSource,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
@ -91,7 +96,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._running = False
|
||||
self._session = session
|
||||
self._next_job_id = 0
|
||||
self._metadata_store = record_store.metadata_store # for convenience
|
||||
|
||||
@property
|
||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||
@ -140,6 +144,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
config["source_type"] = ModelSourceType.Path
|
||||
return self._register(model_path, config)
|
||||
|
||||
def install_path(
|
||||
@ -149,11 +154,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||
old_hash = info.current_hash
|
||||
if self._app_config.skip_model_hash:
|
||||
config["hash"] = uuid_string()
|
||||
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
|
||||
|
||||
if preferred_name := config.get("name"):
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
@ -167,8 +172,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
raise DuplicateModelException(
|
||||
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
|
||||
) from excp
|
||||
new_hash = FastModelHash.hash(new_path)
|
||||
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
||||
|
||||
return self._register(
|
||||
new_path,
|
||||
@ -181,13 +184,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source))
|
||||
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
@ -277,14 +281,20 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._scan_models_directory()
|
||||
if autoimport := self._app_config.autoimport_dir:
|
||||
self._logger.info("Scanning autoimport directory for new models")
|
||||
installed = self.scan_directory(self._app_config.root_path / autoimport)
|
||||
installed: List[str] = []
|
||||
# Use ThreadPoolExecutor to scan dirs in parallel
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future_models = [executor.submit(self.scan_directory, self._app_config.root_path / autoimport / cur_model_type.value) for cur_model_type in ModelType]
|
||||
[installed.extend(models.result()) for models in as_completed(future_models)]
|
||||
self._logger.info(f"{len(installed)} new models registered")
|
||||
self._logger.info("Model installer (re)initialized")
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
|
||||
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
|
||||
if len([entry for entry in os.scandir(scan_dir) if not entry.name.startswith(".")]) == 0:
|
||||
return []
|
||||
callback = self._scan_install if install else self._scan_register
|
||||
search = ModelSearch(on_model_found=callback)
|
||||
search = ModelSearch(on_model_found=callback, config=self._app_config)
|
||||
self._models_installed.clear()
|
||||
search.search(scan_dir)
|
||||
return list(self._models_installed)
|
||||
@ -370,21 +380,24 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._signal_job_errored(job)
|
||||
|
||||
elif (
|
||||
job.waiting or job.downloading
|
||||
job.waiting or job.downloads_done
|
||||
): # local jobs will be in waiting state, remote jobs will be downloading state
|
||||
job.total_bytes = self._stat_size(job.local_path)
|
||||
job.bytes = job.total_bytes
|
||||
self._signal_job_running(job)
|
||||
job.config_in["source"] = str(job.source)
|
||||
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
# enter the metadata, if there is any
|
||||
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
|
||||
job.config_in["source_api_response"] = job.source_metadata.api_response
|
||||
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_phrases:
|
||||
job.config_in["trigger_phrases"] = job.source_metadata.trigger_phrases
|
||||
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
else:
|
||||
key = self.install_path(job.local_path, job.config_in)
|
||||
job.config_out = self.record_store.get_model(key)
|
||||
|
||||
# enter the metadata, if there is any
|
||||
if job.source_metadata:
|
||||
self._metadata_store.add_metadata(key, job.source_metadata)
|
||||
self._signal_job_completed(job)
|
||||
|
||||
except InvalidModelConfigException as excp:
|
||||
@ -442,13 +455,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self.unregister(key)
|
||||
|
||||
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
|
||||
for cur_base_model in BaseModelType:
|
||||
for cur_model_type in ModelType:
|
||||
models_dir = Path(cur_base_model.value, cur_model_type.value)
|
||||
installed.update(self.scan_directory(models_dir))
|
||||
# Use ThreadPoolExecutor to scan dirs in parallel
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future_models = [executor.submit(self.scan_directory, Path(cur_base_model.value, cur_model_type.value)) for cur_base_model in BaseModelType for cur_model_type in ModelType]
|
||||
[installed.update(models.result()) for models in as_completed(future_models)]
|
||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||
|
||||
def _sync_model_path(self, key: str, ignore_hash_change: bool = False) -> AnyModelConfig:
|
||||
def _sync_model_path(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Move model into the location indicated by its basetype, type and name.
|
||||
|
||||
@ -469,15 +482,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
new_path = models_dir / model.base.value / model.type.value / model.name
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
new_hash = FastModelHash.hash(new_path)
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
if model.current_hash != new_hash:
|
||||
assert (
|
||||
ignore_hash_change
|
||||
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
|
||||
model.current_hash = new_hash
|
||||
self._logger.info(f"Model has new hash {model.current_hash}, but will continue to be identified by {key}")
|
||||
self.record_store.update_model(key, model)
|
||||
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
|
||||
return model
|
||||
|
||||
def _scan_register(self, model: Path) -> bool:
|
||||
@ -529,22 +535,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
move(old_path, new_path)
|
||||
return new_path
|
||||
|
||||
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
|
||||
if config: # used to override probe fields
|
||||
for key, value in config.items():
|
||||
setattr(info, key, value)
|
||||
return info
|
||||
|
||||
def _create_key(self) -> str:
|
||||
return sha256(randbytes(100)).hexdigest()[0:32]
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
key = self._create_key()
|
||||
if config and not config.get("key", None):
|
||||
config["key"] = key
|
||||
config = config or {}
|
||||
|
||||
if self._app_config.skip_model_hash:
|
||||
config["hash"] = uuid_string()
|
||||
|
||||
info = info or ModelProbe.probe(model_path, config)
|
||||
|
||||
model_path = model_path.absolute()
|
||||
@ -554,11 +552,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
info.path = model_path.as_posix()
|
||||
|
||||
# add 'main' specific fields
|
||||
if hasattr(info, "config"):
|
||||
if isinstance(info, CheckpointConfigBase):
|
||||
# make config relative to our root
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
|
||||
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
self.record_store.add_model(info.key, info)
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve()
|
||||
info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
self.record_store.add_model(info)
|
||||
return info.key
|
||||
|
||||
def _next_id(self) -> int:
|
||||
@ -579,13 +577,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
local_path=Path(source.path),
|
||||
inplace=source.inplace,
|
||||
inplace=source.inplace or False,
|
||||
)
|
||||
|
||||
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
if not source.access_token:
|
||||
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
|
||||
metadata = CivitaiMetadataFetch(self._session).from_id(str(source.version_id))
|
||||
metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id(
|
||||
str(source.version_id)
|
||||
)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
|
||||
@ -613,15 +613,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
# URLs from Civitai or HuggingFace will be handled specially
|
||||
url_patterns = {
|
||||
r"^https?://civitai.com/": CivitaiMetadataFetch,
|
||||
r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch,
|
||||
}
|
||||
metadata = None
|
||||
for pattern, fetcher in url_patterns.items():
|
||||
if re.match(pattern, str(source.url), re.IGNORECASE):
|
||||
metadata = fetcher(self._session).from_url(source.url)
|
||||
break
|
||||
fetcher = None
|
||||
try:
|
||||
fetcher = self.get_fetcher_from_url(str(source.url))
|
||||
except ValueError:
|
||||
pass
|
||||
kwargs: dict[str, Any] = {"session": self._session}
|
||||
if fetcher is CivitaiMetadataFetch:
|
||||
kwargs["api_key"] = self._app_config.get_config().civitai_api_key
|
||||
if fetcher is not None:
|
||||
metadata = fetcher(**kwargs).from_url(source.url)
|
||||
self._logger.debug(f"metadata={metadata}")
|
||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
@ -636,7 +638,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _import_remote_model(
|
||||
self,
|
||||
source: ModelSource,
|
||||
source: HFModelSource | CivitaiModelSource | URLModelSource,
|
||||
remote_files: List[RemoteModelFile],
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
@ -664,7 +666,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# In the event that there is a subfolder specified in the source,
|
||||
# we need to remove it from the destination path in order to avoid
|
||||
# creating unwanted subfolders
|
||||
if hasattr(source, "subfolder") and source.subfolder:
|
||||
if isinstance(source, HFModelSource) and source.subfolder:
|
||||
root = Path(remote_files[0].path.parts[0])
|
||||
subfolder = root / source.subfolder
|
||||
else:
|
||||
@ -749,8 +751,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._download_cache.pop(download_job.source, None)
|
||||
|
||||
# are there any more active jobs left in this task?
|
||||
if all(x.complete for x in install_job.download_parts):
|
||||
# now enqueue job for actual installation into the models directory
|
||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||
install_job.status = InstallStatus.DOWNLOADS_DONE
|
||||
self._install_queue.put(install_job)
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
@ -851,3 +853,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"{job.source}: model installation was cancelled")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_cancelled(str(job.source))
|
||||
|
||||
@staticmethod
|
||||
def get_fetcher_from_url(url: str):
|
||||
if re.match(r"^https?://civitai.com/", url.lower()):
|
||||
return CivitaiMetadataFetch
|
||||
elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||
return HuggingFaceMetadataFetch
|
||||
raise ValueError(f"Unsupported model source: '{url}'")
|
||||
|
@ -1,15 +1,11 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
@ -70,32 +66,3 @@ class ModelManagerServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_config(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
@ -1,14 +1,10 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -18,7 +14,7 @@ from ..download import DownloadQueueServiceBase
|
||||
from ..events.events_base import EventServiceBase
|
||||
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from ..model_records import ModelRecordServiceBase, UnknownModelException
|
||||
from ..model_records import ModelRecordServiceBase
|
||||
from .model_manager_base import ModelManagerServiceBase
|
||||
|
||||
|
||||
@ -64,56 +60,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
if hasattr(service, "stop"):
|
||||
service.stop(invoker)
|
||||
|
||||
def load_model_by_config(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
return self.load.load_model(model_config, submodel_type, context_data)
|
||||
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
config = self.store.get_model(key)
|
||||
return self.load.load_model(config, submodel_type, context_data)
|
||||
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||
|
||||
This is provided for API compatability with the get_model() method
|
||||
in the original model manager. However, note that LoadedModel is
|
||||
not the same as the original ModelInfo that ws returned.
|
||||
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: The invocation context.
|
||||
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
ValueError -- more than one model matches this combination
|
||||
"""
|
||||
configs = self.store.search_by_attr(model_name, base_model, model_type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
|
||||
elif len(configs) > 1:
|
||||
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
|
||||
else:
|
||||
return self.load.load_model(configs[0], submodel, context_data)
|
||||
|
||||
@classmethod
|
||||
def build_model_manager(
|
||||
cls,
|
||||
|
@ -1,9 +0,0 @@
|
||||
"""Init file for ModelMetadataStoreService module."""
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
from .metadata_store_sql import ModelMetadataStoreSQL
|
||||
|
||||
__all__ = [
|
||||
"ModelMetadataStoreBase",
|
||||
"ModelMetadataStoreSQL",
|
||||
]
|
@ -1,65 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Storage for Model Metadata
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class ModelMetadataStoreBase(ABC):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
@abstractmethod
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
|
||||
@abstractmethod
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
@ -1,222 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownMetadataException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id,metadata FROM model_metadata;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select tag_text from tags;
|
||||
"""
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.model_id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches if matches else set()
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
@ -1,4 +1,5 @@
|
||||
"""Init file for model record services."""
|
||||
|
||||
from .model_records_base import ( # noqa F401
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
|
@ -6,20 +6,19 @@ Abstract base class for storing and retrieving model configuration records.
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
@ -60,11 +59,34 @@ class ModelSummary(BaseModel):
|
||||
tags: Set[str] = Field(description="tags associated with model")
|
||||
|
||||
|
||||
class ModelRecordChanges(BaseModelExcludeNull):
|
||||
"""A set of changes to apply to a model."""
|
||||
|
||||
# Changes applicable to all models
|
||||
name: Optional[str] = Field(description="Name of the model.", default=None)
|
||||
path: Optional[str] = Field(description="Path to the model.", default=None)
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
# Checkpoint-specific changes
|
||||
# TODO(MM2): Should we expose these? Feels footgun-y...
|
||||
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
|
||||
prediction_type: Optional[SchedulerPredictionType] = Field(
|
||||
description="The prediction type of the model.", default=None
|
||||
)
|
||||
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
|
||||
config_path: Optional[str] = Field(description="Path to config file for model", default=None)
|
||||
|
||||
|
||||
class ModelRecordServiceBase(ABC):
|
||||
"""Abstract base class for storage and retrieval of model configs."""
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -88,13 +110,12 @@ class ModelRecordServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
:param key: Unique key for the model to be updated.
|
||||
:param changes: A set of changes to apply to this model. Changes are validated before being written.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -109,40 +130,17 @@ class ModelRecordServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
"""
|
||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||
Retrieve the configuration for the indicated model.
|
||||
|
||||
:param key: Model key
|
||||
:param hash: Hash of model config to be fetched.
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
@ -217,21 +215,3 @@ class ModelRecordServiceBase(ABC):
|
||||
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
|
||||
)
|
||||
return model_configs[0]
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
key: str,
|
||||
new_name: str,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Rename the indicated model. Just a special case of update_model().
|
||||
|
||||
In some implementations, renaming the model may involve changing where
|
||||
it is stored on the filesystem. So this is broken out.
|
||||
|
||||
:param key: Model key
|
||||
:param new_name: New name for model
|
||||
"""
|
||||
config = self.get_model(key)
|
||||
config.name = new_name
|
||||
return self.update_model(key, config)
|
||||
|
@ -39,12 +39,11 @@ Typical usage:
|
||||
configs = store.search_by_attr(base_model='sd-2', model_type='main')
|
||||
"""
|
||||
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@ -54,12 +53,11 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_records_base import (
|
||||
DuplicateModelException,
|
||||
ModelRecordChanges,
|
||||
ModelRecordOrderBy,
|
||||
ModelRecordServiceBase,
|
||||
ModelSummary,
|
||||
@ -70,7 +68,7 @@ from .model_records_base import (
|
||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
@ -79,14 +77,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = db.conn.cursor()
|
||||
self._metadata_store = metadata_store
|
||||
|
||||
@property
|
||||
def db(self) -> SqliteDatabase:
|
||||
"""Return the underlying database."""
|
||||
return self._db
|
||||
|
||||
def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -96,23 +93,19 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_config (
|
||||
INSERT INTO models (
|
||||
id,
|
||||
original_hash,
|
||||
config
|
||||
)
|
||||
VALUES (?,?,?);
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.original_hash,
|
||||
json_serialized,
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
),
|
||||
)
|
||||
self._db.conn.commit()
|
||||
@ -120,12 +113,12 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "model_config.path" in str(e):
|
||||
msg = f"A model with path '{record.path}' is already installed"
|
||||
elif "model_config.name" in str(e):
|
||||
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{key}' is already installed"
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
else:
|
||||
raise e
|
||||
@ -133,7 +126,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(key)
|
||||
return self.get_model(config.key)
|
||||
|
||||
def del_model(self, key: str) -> None:
|
||||
"""
|
||||
@ -147,7 +140,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_config
|
||||
DELETE FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -159,21 +152,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
record = self.get_model(key)
|
||||
|
||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
|
||||
json_serialized = record.model_dump_json()
|
||||
|
||||
:param key: Unique key for the model to be updated
|
||||
:param config: Model configuration record. Either a dict with the
|
||||
required fields, or a ModelConfigBase instance.
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_config
|
||||
UPDATE models
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
@ -200,7 +192,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -211,6 +203,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Return True if a model with the indicated key exists in the databse.
|
||||
@ -221,7 +228,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM model_config
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -247,9 +254,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
results = []
|
||||
where_clause = []
|
||||
bindings = []
|
||||
where_clause: list[str] = []
|
||||
bindings: list[str] = []
|
||||
if model_name:
|
||||
where_clause.append("name=?")
|
||||
bindings.append(model_name)
|
||||
@ -266,14 +272,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
select config, strftime('%s',updated_at) FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
{where};
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
result = self._cursor.fetchall()
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
@ -282,7 +287,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
@ -293,13 +298,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated original_hash."""
|
||||
"""Return models with the indicated hash."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE original_hash=?;
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
@ -308,83 +313,35 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
]
|
||||
return results
|
||||
|
||||
@property
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
return self._metadata_store
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Retrieve metadata (if any) from when model was downloaded from a repo.
|
||||
|
||||
:param key: Model key
|
||||
"""
|
||||
store = self.metadata_store
|
||||
try:
|
||||
metadata = store.get_metadata(key)
|
||||
return metadata
|
||||
except UnknownMetadataException:
|
||||
return None
|
||||
|
||||
def search_by_metadata_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Search model metadata for ones with all listed tags and return their corresponding configs.
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
keys = store.search_by_tag(tags)
|
||||
return [self.get_model(x) for x in keys]
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_tags()
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_all_metadata()
|
||||
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Return a paginated summary listing of each model in the database."""
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "a.type, a.base, a.format, a.name",
|
||||
ModelRecordOrderBy.Type: "a.type",
|
||||
ModelRecordOrderBy.Base: "a.base",
|
||||
ModelRecordOrderBy.Name: "a.name",
|
||||
ModelRecordOrderBy.Format: "a.format",
|
||||
ModelRecordOrderBy.Default: "type, base, format, name",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
def _fixup(summary: Dict[str, str]) -> Dict[str, Union[str, int, Set[str]]]:
|
||||
"""Fix up results so that there are no null values."""
|
||||
result: Dict[str, Union[str, int, Set[str]]] = {}
|
||||
for key, item in summary.items():
|
||||
result[key] = item or ""
|
||||
result["tags"] = set(json.loads(summary["tags"] or "[]"))
|
||||
return result
|
||||
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
with self._db.lock:
|
||||
# query1: get the total number of model configs
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from model_config;
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(self._cursor.fetchone()[0])
|
||||
|
||||
# query2: fetch key fields from the join of model_config and model_metadata
|
||||
# query2: fetch key fields
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT a.id as key, a.type, a.base, a.format, a.name,
|
||||
json_extract(a.config, '$.description') as description,
|
||||
json_extract(b.metadata, '$.tags') as tags
|
||||
FROM model_config AS a
|
||||
LEFT JOIN model_metadata AS b on a.id=b.id
|
||||
SELECT config
|
||||
FROM models
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
@ -395,7 +352,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
items = [ModelSummary.model_validate(_fixup(dict(x))) for x in rows]
|
||||
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
||||
return PaginatedResults(
|
||||
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
|
||||
)
|
||||
|
@ -200,6 +200,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self._invoker.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
|
@ -1,7 +1,7 @@
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
@ -13,15 +13,16 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.model import ModelField
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
"""
|
||||
@ -65,75 +66,86 @@ class InvocationContextInterface:
|
||||
|
||||
class BoardsInterface(InvocationContextInterface):
|
||||
def create(self, board_name: str) -> BoardDTO:
|
||||
"""
|
||||
Creates a board.
|
||||
"""Creates a board.
|
||||
|
||||
:param board_name: The name of the board to create.
|
||||
Args:
|
||||
board_name: The name of the board to create.
|
||||
|
||||
Returns:
|
||||
The created board DTO.
|
||||
"""
|
||||
return self._services.boards.create(board_name)
|
||||
|
||||
def get_dto(self, board_id: str) -> BoardDTO:
|
||||
"""
|
||||
Gets a board DTO.
|
||||
"""Gets a board DTO.
|
||||
|
||||
:param board_id: The ID of the board to get.
|
||||
Args:
|
||||
board_id: The ID of the board to get.
|
||||
|
||||
Returns:
|
||||
The board DTO.
|
||||
"""
|
||||
return self._services.boards.get_dto(board_id)
|
||||
|
||||
def get_all(self) -> list[BoardDTO]:
|
||||
"""
|
||||
Gets all boards.
|
||||
"""Gets all boards.
|
||||
|
||||
Returns:
|
||||
A list of all boards.
|
||||
"""
|
||||
return self._services.boards.get_all()
|
||||
|
||||
def add_image_to_board(self, board_id: str, image_name: str) -> None:
|
||||
"""
|
||||
Adds an image to a board.
|
||||
"""Adds an image to a board.
|
||||
|
||||
:param board_id: The ID of the board to add the image to.
|
||||
:param image_name: The name of the image to add to the board.
|
||||
Args:
|
||||
board_id: The ID of the board to add the image to.
|
||||
image_name: The name of the image to add to the board.
|
||||
"""
|
||||
return self._services.board_images.add_image_to_board(board_id, image_name)
|
||||
|
||||
def get_all_image_names_for_board(self, board_id: str) -> list[str]:
|
||||
"""
|
||||
Gets all image names for a board.
|
||||
"""Gets all image names for a board.
|
||||
|
||||
:param board_id: The ID of the board to get the image names for.
|
||||
Args:
|
||||
board_id: The ID of the board to get the image names for.
|
||||
|
||||
Returns:
|
||||
A list of all image names for the board.
|
||||
"""
|
||||
return self._services.board_images.get_all_board_image_names_for_board(board_id)
|
||||
|
||||
|
||||
class LoggerInterface(InvocationContextInterface):
|
||||
def debug(self, message: str) -> None:
|
||||
"""
|
||||
Logs a debug message.
|
||||
"""Logs a debug message.
|
||||
|
||||
:param message: The message to log.
|
||||
Args:
|
||||
message: The message to log.
|
||||
"""
|
||||
self._services.logger.debug(message)
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""
|
||||
Logs an info message.
|
||||
"""Logs an info message.
|
||||
|
||||
:param message: The message to log.
|
||||
Args:
|
||||
message: The message to log.
|
||||
"""
|
||||
self._services.logger.info(message)
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""
|
||||
Logs a warning message.
|
||||
"""Logs a warning message.
|
||||
|
||||
:param message: The message to log.
|
||||
Args:
|
||||
message: The message to log.
|
||||
"""
|
||||
self._services.logger.warning(message)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""
|
||||
Logs an error message.
|
||||
"""Logs an error message.
|
||||
|
||||
:param message: The message to log.
|
||||
Args:
|
||||
message: The message to log.
|
||||
"""
|
||||
self._services.logger.error(message)
|
||||
|
||||
@ -146,20 +158,23 @@ class ImagesInterface(InvocationContextInterface):
|
||||
image_category: ImageCategory = ImageCategory.GENERAL,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
) -> ImageDTO:
|
||||
"""
|
||||
Saves an image, returning its DTO.
|
||||
"""Saves an image, returning its DTO.
|
||||
|
||||
If the current queue item has a workflow or metadata, it is automatically saved with the image.
|
||||
|
||||
:param image: The image to save, as a PIL image.
|
||||
:param board_id: The board ID to add the image to, if it should be added. It the invocation \
|
||||
Args:
|
||||
image: The image to save, as a PIL image.
|
||||
board_id: The board ID to add the image to, if it should be added. It the invocation \
|
||||
inherits from `WithBoard`, that board will be used automatically. **Use this only if \
|
||||
you want to override or provide a board manually!**
|
||||
:param image_category: The category of the image. Only the GENERAL category is added \
|
||||
image_category: The category of the image. Only the GENERAL category is added \
|
||||
to the gallery.
|
||||
:param metadata: The metadata to save with the image, if it should have any. If the \
|
||||
metadata: The metadata to save with the image, if it should have any. If the \
|
||||
invocation inherits from `WithMetadata`, that metadata will be used automatically. \
|
||||
**Use this only if you want to override or provide metadata manually!**
|
||||
|
||||
Returns:
|
||||
The saved image DTO.
|
||||
"""
|
||||
|
||||
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
|
||||
@ -189,11 +204,14 @@ class ImagesInterface(InvocationContextInterface):
|
||||
)
|
||||
|
||||
def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image:
|
||||
"""
|
||||
Gets an image as a PIL Image object.
|
||||
"""Gets an image as a PIL Image object.
|
||||
|
||||
:param image_name: The name of the image to get.
|
||||
:param mode: The color mode to convert the image to. If None, the original mode is used.
|
||||
Args:
|
||||
image_name: The name of the image to get.
|
||||
mode: The color mode to convert the image to. If None, the original mode is used.
|
||||
|
||||
Returns:
|
||||
The image as a PIL Image object.
|
||||
"""
|
||||
image = self._services.images.get_pil_image(image_name)
|
||||
if mode and mode != image.mode:
|
||||
@ -206,158 +224,202 @@ class ImagesInterface(InvocationContextInterface):
|
||||
return image
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
"""
|
||||
Gets an image's metadata, if it has any.
|
||||
"""Gets an image's metadata, if it has any.
|
||||
|
||||
:param image_name: The name of the image to get the metadata for.
|
||||
Args:
|
||||
image_name: The name of the image to get the metadata for.
|
||||
|
||||
Returns:
|
||||
The image's metadata, if it has any.
|
||||
"""
|
||||
return self._services.images.get_metadata(image_name)
|
||||
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
"""
|
||||
Gets an image as an ImageDTO object.
|
||||
"""Gets an image as an ImageDTO object.
|
||||
|
||||
:param image_name: The name of the image to get.
|
||||
Args:
|
||||
image_name: The name of the image to get.
|
||||
|
||||
Returns:
|
||||
The image as an ImageDTO object.
|
||||
"""
|
||||
return self._services.images.get_dto(image_name)
|
||||
|
||||
|
||||
class TensorsInterface(InvocationContextInterface):
|
||||
def save(self, tensor: Tensor) -> str:
|
||||
"""
|
||||
Saves a tensor, returning its name.
|
||||
"""Saves a tensor, returning its name.
|
||||
|
||||
:param tensor: The tensor to save.
|
||||
Args:
|
||||
tensor: The tensor to save.
|
||||
|
||||
Returns:
|
||||
The name of the saved tensor.
|
||||
"""
|
||||
|
||||
name = self._services.tensors.save(obj=tensor)
|
||||
return name
|
||||
|
||||
def load(self, name: str) -> Tensor:
|
||||
"""
|
||||
Loads a tensor by name.
|
||||
"""Loads a tensor by name.
|
||||
|
||||
:param name: The name of the tensor to load.
|
||||
Args:
|
||||
name: The name of the tensor to load.
|
||||
|
||||
Returns:
|
||||
The loaded tensor.
|
||||
"""
|
||||
return self._services.tensors.load(name)
|
||||
|
||||
|
||||
class ConditioningInterface(InvocationContextInterface):
|
||||
def save(self, conditioning_data: ConditioningFieldData) -> str:
|
||||
"""
|
||||
Saves a conditioning data object, returning its name.
|
||||
"""Saves a conditioning data object, returning its name.
|
||||
|
||||
:param conditioning_data: The conditioning data to save.
|
||||
Args:
|
||||
conditioning_data: The conditioning data to save.
|
||||
|
||||
Returns:
|
||||
The name of the saved conditioning data.
|
||||
"""
|
||||
|
||||
name = self._services.conditioning.save(obj=conditioning_data)
|
||||
return name
|
||||
|
||||
def load(self, name: str) -> ConditioningFieldData:
|
||||
"""
|
||||
Loads conditioning data by name.
|
||||
"""Loads conditioning data by name.
|
||||
|
||||
:param name: The name of the conditioning data to load.
|
||||
Args:
|
||||
name: The name of the conditioning data to load.
|
||||
|
||||
Returns:
|
||||
The loaded conditioning data.
|
||||
"""
|
||||
|
||||
return self._services.conditioning.load(name)
|
||||
|
||||
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Checks if a model exists.
|
||||
def exists(self, identifier: Union[str, "ModelField"]) -> bool:
|
||||
"""Checks if a model exists.
|
||||
|
||||
:param key: The key of the model.
|
||||
"""
|
||||
return self._services.model_manager.store.exists(key)
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
|
||||
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
Returns:
|
||||
True if the model exists, False if not.
|
||||
"""
|
||||
Loads a model.
|
||||
if isinstance(identifier, str):
|
||||
return self._services.model_manager.store.exists(identifier)
|
||||
|
||||
:param key: The key of the model.
|
||||
:param submodel_type: The submodel of the model to get.
|
||||
:returns: An object representing the loaded model.
|
||||
return self._services.model_manager.store.exists(identifier.key)
|
||||
|
||||
def load(self, identifier: Union[str, "ModelField"], submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""Loads a model.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
submodel_type: The submodel of the model to get.
|
||||
|
||||
Returns:
|
||||
An object representing the loaded model.
|
||||
"""
|
||||
|
||||
# The model manager emits events as it loads the model. It needs the context data to build
|
||||
# the event payloads.
|
||||
|
||||
return self._services.model_manager.load_model_by_key(
|
||||
key=key, submodel_type=submodel_type, context_data=self._data
|
||||
)
|
||||
if isinstance(identifier, str):
|
||||
model = self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
|
||||
else:
|
||||
_submodel_type = submodel_type or identifier.submodel_type
|
||||
model = self._services.model_manager.store.get_model(identifier.key)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
|
||||
|
||||
def load_by_attrs(
|
||||
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Loads a model by its attributes.
|
||||
"""Loads a model by its attributes.
|
||||
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
"""
|
||||
return self._services.model_manager.load_model_by_attr(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
context_data=self._data,
|
||||
)
|
||||
Args:
|
||||
name: Name of the model.
|
||||
base: The models' base type, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
|
||||
type: Type of the model, e.g. `ModelType.Main`, `ModelType.Vae`, etc.
|
||||
submodel_type: The type of submodel to load, e.g. `SubModelType.UNet`, `SubModelType.TextEncoder`, etc. Only main
|
||||
models have submodels.
|
||||
|
||||
def get_config(self, key: str) -> AnyModelConfig:
|
||||
Returns:
|
||||
An object representing the loaded model.
|
||||
"""
|
||||
Gets a model's info, an dict-like object.
|
||||
|
||||
:param key: The key of the model.
|
||||
"""
|
||||
return self._services.model_manager.store.get_model(key=key)
|
||||
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Gets a model's metadata, if it has any.
|
||||
if len(configs) > 1:
|
||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
:param key: The key of the model.
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelField"]) -> AnyModelConfig:
|
||||
"""Gets a model's config.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
|
||||
Returns:
|
||||
The model's config.
|
||||
"""
|
||||
return self._services.model_manager.store.get_metadata(key=key)
|
||||
if isinstance(identifier, str):
|
||||
return self._services.model_manager.store.get_model(identifier)
|
||||
|
||||
return self._services.model_manager.store.get_model(identifier.key)
|
||||
|
||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||
"""
|
||||
Searches for models by path.
|
||||
"""Searches for models by path.
|
||||
|
||||
:param path: The path to search for.
|
||||
Args:
|
||||
path: The path to search for.
|
||||
|
||||
Returns:
|
||||
A list of models that match the path.
|
||||
"""
|
||||
return self._services.model_manager.store.search_by_path(path)
|
||||
|
||||
def search_by_attrs(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
model_format: Optional[ModelFormat] = None,
|
||||
name: Optional[str] = None,
|
||||
base: Optional[BaseModelType] = None,
|
||||
type: Optional[ModelType] = None,
|
||||
format: Optional[ModelFormat] = None,
|
||||
) -> list[AnyModelConfig]:
|
||||
"""
|
||||
Searches for models by attributes.
|
||||
"""Searches for models by attributes.
|
||||
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
Args:
|
||||
name: The name to search for (exact match).
|
||||
base: The base to search for, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
|
||||
type: Type type of model to search for, e.g. `ModelType.Main`, `ModelType.Vae`, etc.
|
||||
format: The format of model to search for, e.g. `ModelFormat.Checkpoint`, `ModelFormat.Diffusers`, etc.
|
||||
|
||||
Returns:
|
||||
A list of models that match the attributes.
|
||||
"""
|
||||
|
||||
return self._services.model_manager.store.search_by_attr(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_format=model_format,
|
||||
model_name=name,
|
||||
base_model=base,
|
||||
model_type=type,
|
||||
model_format=format,
|
||||
)
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
def get(self) -> InvokeAIAppConfig:
|
||||
"""Gets the app's config."""
|
||||
"""Gets the app's config.
|
||||
|
||||
Returns:
|
||||
The app's config.
|
||||
"""
|
||||
|
||||
return self._services.configuration.get_config()
|
||||
|
||||
@ -370,7 +432,11 @@ class UtilInterface(InvocationContextInterface):
|
||||
self._cancel_event = cancel_event
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
"""Checks if the current invocation has been canceled."""
|
||||
"""Checks if the current session has been canceled.
|
||||
|
||||
Returns:
|
||||
True if the current session has been canceled, False if not.
|
||||
"""
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
|
||||
@ -380,8 +446,9 @@ class UtilInterface(InvocationContextInterface):
|
||||
|
||||
This should be called after each denoising step.
|
||||
|
||||
:param intermediate_state: The intermediate state of the diffusion pipeline.
|
||||
:param base_model: The base model for the current denoising step.
|
||||
Args:
|
||||
intermediate_state: The intermediate state of the diffusion pipeline.
|
||||
base_model: The base model for the current denoising step.
|
||||
"""
|
||||
|
||||
stable_diffusion_step_callback(
|
||||
@ -394,8 +461,17 @@ class UtilInterface(InvocationContextInterface):
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""
|
||||
The `InvocationContext` provides access to various services and data for the current invocation.
|
||||
"""Provides access to various services and data for the current invocation.
|
||||
|
||||
Attributes:
|
||||
images (ImagesInterface): Methods to save, get and update images and their metadata.
|
||||
tensors (TensorsInterface): Methods to save and get tensors, including image, noise, masks, and masked images.
|
||||
conditioning (ConditioningInterface): Methods to save and get conditioning data.
|
||||
models (ModelsInterface): Methods to check if a model exists, get a model, and get a model's info.
|
||||
logger (LoggerInterface): The app logger.
|
||||
config (ConfigInterface): The app config.
|
||||
util (UtilInterface): Utility methods, including a method to check if an invocation was canceled and step callbacks.
|
||||
boards (BoardsInterface): Methods to interact with boards.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -438,11 +514,14 @@ def build_invocation_context(
|
||||
data: InvocationContextData,
|
||||
cancel_event: threading.Event,
|
||||
) -> InvocationContext:
|
||||
"""
|
||||
Builds the invocation context for a specific invocation execution.
|
||||
"""Builds the invocation context for a specific invocation execution.
|
||||
|
||||
:param services: The invocation services to wrap.
|
||||
:param data: The invocation context data.
|
||||
Args:
|
||||
services: The invocation services to wrap.
|
||||
data: The invocation context data.
|
||||
|
||||
Returns:
|
||||
The invocation context.
|
||||
"""
|
||||
|
||||
logger = LoggerInterface(services=services, data=data)
|
||||
|
@ -9,6 +9,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@ -35,6 +36,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_4())
|
||||
migrator.register_migration(build_migration_5())
|
||||
migrator.register_migration(build_migration_6())
|
||||
migrator.register_migration(build_migration_7())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
@ -0,0 +1,88 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration7Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._create_models_table(cursor)
|
||||
self._drop_old_models_tables(cursor)
|
||||
|
||||
def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the old model_records, model_metadata, model_tags and tags tables."""
|
||||
|
||||
tables = ["model_records", "model_metadata", "model_tags", "tags"]
|
||||
|
||||
for table in tables:
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {table};")
|
||||
|
||||
def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the v4.0.0 models table."""
|
||||
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS models (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
hash TEXT GENERATED ALWAYS as (json_extract(config, '$.hash')) VIRTUAL NOT NULL,
|
||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||
description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL,
|
||||
source TEXT GENERATED ALWAYS as (json_extract(config, '$.source')) VIRTUAL NOT NULL,
|
||||
source_type TEXT GENERATED ALWAYS as (json_extract(config, '$.source_type')) VIRTUAL NOT NULL,
|
||||
source_api_response TEXT GENERATED ALWAYS as (json_extract(config, '$.source_api_response')) VIRTUAL,
|
||||
trigger_phrases TEXT GENERATED ALWAYS as (json_extract(config, '$.trigger_phrases')) VIRTUAL,
|
||||
-- Serialized JSON representation of the whole config object, which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS models_updated_at
|
||||
AFTER UPDATE
|
||||
ON models FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add indexes for searchable fields
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS base_index ON models(base);",
|
||||
"CREATE INDEX IF NOT EXISTS type_index ON models(type);",
|
||||
"CREATE INDEX IF NOT EXISTS name_index ON models(name);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON models(path);",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def build_migration_7() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 6 to 7.
|
||||
|
||||
This migration does the following:
|
||||
- Adds the new models table
|
||||
- Drops the old model_records, model_metadata, model_tags and tags tables.
|
||||
- TODO(MM2): Migrates model names and descriptions from `models.yaml` to the new table (?).
|
||||
"""
|
||||
migration_7 = Migration(
|
||||
from_version=6,
|
||||
to_version=7,
|
||||
callback=Migration7Callback(),
|
||||
)
|
||||
|
||||
return migration_7
|
@ -3,7 +3,6 @@
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from hashlib import sha1
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@ -22,7 +21,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelConfigFactory,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.model_manager.hash import ModelHash
|
||||
|
||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
@ -73,19 +72,27 @@ class MigrateModelYamlToDb1:
|
||||
|
||||
base_type, model_type, model_name = str(model_key).split("/")
|
||||
try:
|
||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||
hash = ModelHash().hash(self.config.models_path / stanza.path)
|
||||
except OSError:
|
||||
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
||||
continue
|
||||
|
||||
assert isinstance(model_key, str)
|
||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||
|
||||
stanza["base"] = BaseModelType(base_type)
|
||||
stanza["type"] = ModelType(model_type)
|
||||
stanza["name"] = model_name
|
||||
stanza["original_hash"] = hash
|
||||
stanza["current_hash"] = hash
|
||||
new_key = hash # deterministic key assignment
|
||||
|
||||
# special case for ip adapters, which need the new `image_encoder_model_id` field
|
||||
if stanza["type"] == ModelType.IPAdapter:
|
||||
try:
|
||||
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
|
||||
self.config.models_path / stanza.path
|
||||
)
|
||||
except OSError:
|
||||
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
|
||||
continue
|
||||
|
||||
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||
|
||||
@ -95,7 +102,7 @@ class MigrateModelYamlToDb1:
|
||||
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||
self._update_model(key, new_config)
|
||||
else:
|
||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
||||
self.logger.info(f"Adding model {model_name} with key {new_key}")
|
||||
self._add_model(new_key, new_config)
|
||||
except DuplicateModelException:
|
||||
self.logger.warning(f"Model {model_name} is already in the database")
|
||||
@ -143,9 +150,14 @@ class MigrateModelYamlToDb1:
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.original_hash,
|
||||
record.hash,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
|
||||
|
||||
def _get_image_encoder_model_id(self, model_path: Path) -> str:
|
||||
with open(model_path / "image_encoder.txt") as f:
|
||||
encoder = f.read()
|
||||
return encoder.strip()
|
||||
|
@ -17,8 +17,7 @@ class MigrateCallback(Protocol):
|
||||
See :class:`Migration` for an example.
|
||||
"""
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
...
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
|
||||
|
||||
|
||||
class MigrationError(RuntimeError):
|
||||
|
@ -8,3 +8,8 @@ class UrlServiceBase(ABC):
|
||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets the URL for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
"""Gets the URL for a model image"""
|
||||
pass
|
||||
|
@ -4,8 +4,9 @@ from .urls_base import UrlServiceBase
|
||||
|
||||
|
||||
class LocalUrlService(UrlServiceBase):
|
||||
def __init__(self, base_url: str = "api/v1"):
|
||||
def __init__(self, base_url: str = "api/v1", base_url_v2: str = "api/v2"):
|
||||
self._base_url = base_url
|
||||
self._base_url_v2 = base_url_v2
|
||||
|
||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
image_basename = os.path.basename(image_name)
|
||||
@ -15,3 +16,6 @@ class LocalUrlService(UrlServiceBase):
|
||||
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
|
||||
|
||||
return f"{self._base_url}/images/i/{image_basename}/full"
|
||||
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
return f"{self._base_url_v2}/models/i/{model_key}/image"
|
||||
|
@ -1,55 +0,0 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from invokeai.app.services.shared.graph import Edge
|
||||
|
||||
|
||||
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
||||
"""
|
||||
Parses raw session string, returning a dict of the graph.
|
||||
|
||||
Only the general graph shape is validated; none of the fields are validated.
|
||||
|
||||
Any `metadata_accumulator` nodes and edges are removed.
|
||||
|
||||
Any validation failure will return None.
|
||||
"""
|
||||
|
||||
graph = json.loads(session_raw).get("graph", None)
|
||||
|
||||
# sanity check make sure the graph is at least reasonably shaped
|
||||
if (
|
||||
not isinstance(graph, dict)
|
||||
or "nodes" not in graph
|
||||
or not isinstance(graph["nodes"], dict)
|
||||
or "edges" not in graph
|
||||
or not isinstance(graph["edges"], list)
|
||||
):
|
||||
# something has gone terribly awry, return an empty dict
|
||||
return None
|
||||
|
||||
try:
|
||||
# delete the `metadata_accumulator` node
|
||||
del graph["nodes"]["metadata_accumulator"]
|
||||
except KeyError:
|
||||
# no accumulator node, all good
|
||||
pass
|
||||
|
||||
# delete any edges to or from it
|
||||
for i, edge in enumerate(graph["edges"]):
|
||||
try:
|
||||
# try to parse the edge
|
||||
Edge(**edge)
|
||||
except ValidationError:
|
||||
# something has gone terribly awry, return an empty dict
|
||||
return None
|
||||
|
||||
if (
|
||||
edge["source"]["node_id"] == "metadata_accumulator"
|
||||
or edge["destination"]["node_id"] == "metadata_accumulator"
|
||||
):
|
||||
del graph["edges"][i]
|
||||
|
||||
return graph
|
@ -22,7 +22,7 @@ def generate_ti_list(
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
name_or_key = trigger[1:-1]
|
||||
try:
|
||||
loaded_model = context.models.load(key=name_or_key)
|
||||
loaded_model = context.models.load(name_or_key)
|
||||
model = loaded_model.model
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
assert loaded_model.config.base == base
|
||||
@ -30,7 +30,7 @@ def generate_ti_list(
|
||||
except UnknownModelException:
|
||||
try:
|
||||
loaded_model = context.models.load_by_attrs(
|
||||
model_name=name_or_key, base_model=base, model_type=ModelType.TextualInversion
|
||||
name=name_or_key, base=base, type=ModelType.TextualInversion
|
||||
)
|
||||
model = loaded_model.model
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.image_util methods.
|
||||
"""
|
||||
|
||||
from .patchmatch import PatchMatch # noqa: F401
|
||||
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
|
||||
from .seamless import configure_model_padding # noqa: F401
|
||||
|
@ -3,6 +3,7 @@ This module defines a singleton object, "invisible_watermark" that
|
||||
wraps the invisible watermark model. It respects the global "invisible_watermark"
|
||||
configuration variable, that allows the watermarking to be supressed.
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from imwatermark import WatermarkEncoder
|
||||
|
@ -4,6 +4,7 @@ wraps the actual patchmatch object. It respects the global
|
||||
"try_patchmatch" attribute, so that patchmatch loading can
|
||||
be suppressed or deferred
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
@ -6,6 +6,7 @@ PngWriter -- Converts Images generated by T2I into PNGs, finds
|
||||
|
||||
Exports function retrieve_metadata(path)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
@ -3,6 +3,7 @@ This module defines a singleton object, "safety_checker" that
|
||||
wraps the safety_checker model. It respects the global "nsfw_checker"
|
||||
configuration variable, that allows the checker to be supressed.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""
|
||||
Check that the invokeai_root is correctly configured and exit if not.
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Utility (backend) functions used by model_install.py"""
|
||||
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
@ -18,7 +19,6 @@ from invokeai.app.services.model_install import (
|
||||
ModelInstallService,
|
||||
ModelInstallServiceBase,
|
||||
)
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager import (
|
||||
@ -38,7 +38,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
|
||||
db = init_db(config=app_config, logger=logger, image_files=image_files)
|
||||
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
||||
return obj
|
||||
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Re-export frequently-used symbols from the Model Manager backend."""
|
||||
|
||||
from .config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
|
@ -19,15 +19,19 @@ Typical usage:
|
||||
Validation errors will raise an InvalidModelConfigException error.
|
||||
|
||||
"""
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from diffusers import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
from ..raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
@ -55,8 +59,8 @@ class ModelType(str, Enum):
|
||||
|
||||
ONNX = "onnx"
|
||||
Main = "main"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
VAE = "vae"
|
||||
LoRA = "lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
@ -72,9 +76,9 @@ class SubModelType(str, Enum):
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Vae = "vae"
|
||||
VaeDecoder = "vae_decoder"
|
||||
VaeEncoder = "vae_encoder"
|
||||
VAE = "vae"
|
||||
VAEDecoder = "vae_decoder"
|
||||
VAEEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
|
||||
@ -92,8 +96,8 @@ class ModelFormat(str, Enum):
|
||||
|
||||
Diffusers = "diffusers"
|
||||
Checkpoint = "checkpoint"
|
||||
Lycoris = "lycoris"
|
||||
Onnx = "onnx"
|
||||
LyCORIS = "lycoris"
|
||||
ONNX = "onnx"
|
||||
Olive = "olive"
|
||||
EmbeddingFile = "embedding_file"
|
||||
EmbeddingFolder = "embedding_folder"
|
||||
@ -111,128 +115,188 @@ class SchedulerPredictionType(str, Enum):
|
||||
class ModelRepoVariant(str, Enum):
|
||||
"""Various hugging face variants on the diffusers format."""
|
||||
|
||||
DEFAULT = "" # model files without "fp16" or other qualifier - empty str
|
||||
Default = "" # model files without "fp16" or other qualifier - empty str
|
||||
FP16 = "fp16"
|
||||
FP32 = "fp32"
|
||||
ONNX = "onnx"
|
||||
OPENVINO = "openvino"
|
||||
FLAX = "flax"
|
||||
OpenVINO = "openvino"
|
||||
Flax = "flax"
|
||||
|
||||
|
||||
class ModelSourceType(str, Enum):
|
||||
"""Model source type."""
|
||||
|
||||
Path = "path"
|
||||
Url = "url"
|
||||
HFRepoID = "hf_repo_id"
|
||||
CivitAI = "civitai"
|
||||
|
||||
|
||||
class ModelDefaultSettings(BaseModel):
|
||||
vae: str | None
|
||||
vae_precision: str | None
|
||||
scheduler: SCHEDULER_NAME_VALUES | None
|
||||
steps: int | None
|
||||
cfg_scale: float | None
|
||||
cfg_rescale_multiplier: float | None
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
"""Base class for model configuration information."""
|
||||
|
||||
path: str = Field(description="filesystem path to the model file or directory")
|
||||
name: str = Field(description="model name")
|
||||
base: BaseModelType = Field(description="base model")
|
||||
type: ModelType = Field(description="type of the model")
|
||||
format: ModelFormat = Field(description="model format")
|
||||
key: str = Field(description="unique key for model", default="<NOKEY>")
|
||||
original_hash: Optional[str] = Field(
|
||||
description="original fasthash of model contents", default=None
|
||||
) # this is assigned at install time and will not change
|
||||
current_hash: Optional[str] = Field(
|
||||
description="current fasthash of model contents", default=None
|
||||
) # if model is converted or otherwise modified, this will hold updated hash
|
||||
description: Optional[str] = Field(description="human readable description of the model", default=None)
|
||||
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
|
||||
last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
|
||||
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
|
||||
hash: str = Field(description="The hash of the model file(s).")
|
||||
path: str = Field(
|
||||
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
|
||||
)
|
||||
name: str = Field(description="Name of the model.")
|
||||
base: BaseModelType = Field(description="The base model.")
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
source: str = Field(description="The original source of the model (path, URL or repo_id).")
|
||||
source_type: ModelSourceType = Field(description="The type of source")
|
||||
source_api_response: Optional[str] = Field(
|
||||
description="The original API response from the source, as stringified JSON.", default=None
|
||||
)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
schema["required"].extend(
|
||||
["key", "base", "type", "format", "original_hash", "current_hash", "source", "last_modified"]
|
||||
)
|
||||
schema["required"].extend(["key", "type", "format"])
|
||||
|
||||
model_config = ConfigDict(
|
||||
use_enum_values=False,
|
||||
validate_assignment=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
)
|
||||
|
||||
def update(self, attributes: Dict[str, Any]) -> None:
|
||||
"""Update the object with fields in dict."""
|
||||
for key, value in attributes.items():
|
||||
setattr(self, key, value) # may raise a validation error
|
||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
||||
|
||||
|
||||
class _CheckpointConfig(ModelConfigBase):
|
||||
class CheckpointConfigBase(ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
config: str = Field(description="path to the checkpoint model config file")
|
||||
config_path: str = Field(description="path to the checkpoint model config file")
|
||||
converted_at: Optional[float] = Field(
|
||||
description="When this model was last converted to diffusers", default_factory=time.time
|
||||
)
|
||||
|
||||
|
||||
class _DiffusersConfig(ModelConfigBase):
|
||||
class DiffusersConfigBase(ModelConfigBase):
|
||||
"""Model config for diffusers-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
|
||||
|
||||
|
||||
class LoRAConfig(ModelConfigBase):
|
||||
class LoRALyCORISConfig(ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
type: Literal[ModelType.Lora] = ModelType.Lora
|
||||
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class VaeCheckpointConfig(ModelConfigBase):
|
||||
class LoRADiffusersConfig(ModelConfigBase):
|
||||
"""Model config for LoRA/Diffusers models."""
|
||||
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class VAECheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
class VaeDiffusersConfig(ModelConfigBase):
|
||||
|
||||
class VAEDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for standalone VAE models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
class ControlNetDiffusersConfig(_DiffusersConfig):
|
||||
|
||||
class ControlNetDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
class ControlNetCheckpointConfig(_CheckpointConfig):
|
||||
|
||||
class ControlNetCheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
class TextualInversionConfig(ModelConfigBase):
|
||||
|
||||
class TextualInversionFileConfig(ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
|
||||
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
|
||||
|
||||
|
||||
class _MainConfig(ModelConfigBase):
|
||||
"""Model config for main models."""
|
||||
class TextualInversionFolderConfig(ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
vae: Optional[str] = Field(default=None)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
ztsnr_training: bool = False
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
|
||||
|
||||
|
||||
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
||||
class MainCheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||
class MainDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class IPAdapterConfig(ModelConfigBase):
|
||||
"""Model config for IP Adaptor format models."""
|
||||
@ -241,65 +305,76 @@ class IPAdapterConfig(ModelConfigBase):
|
||||
image_encoder_model_id: str
|
||||
format: Literal[ModelFormat.InvokeAI]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for ClipVision."""
|
||||
"""Model config for CLIPVision."""
|
||||
|
||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
class T2IConfig(ModelConfigBase):
|
||||
|
||||
class T2IAdapterConfig(ModelConfigBase):
|
||||
"""Model config for T2I."""
|
||||
|
||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
_ControlNetConfig = Annotated[
|
||||
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
|
||||
Field(discriminator="format"),
|
||||
]
|
||||
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
|
||||
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
|
||||
|
||||
AnyModelConfig = Union[
|
||||
_MainModelConfig,
|
||||
_VaeConfig,
|
||||
_ControlNetConfig,
|
||||
# ModelConfigBase,
|
||||
LoRAConfig,
|
||||
TextualInversionConfig,
|
||||
IPAdapterConfig,
|
||||
CLIPVisionDiffusersConfig,
|
||||
T2IConfig,
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
Computes the discriminator value for a model config.
|
||||
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
|
||||
"""
|
||||
format_ = None
|
||||
type_ = None
|
||||
if isinstance(v, dict):
|
||||
format_ = v.get("format")
|
||||
if isinstance(format_, Enum):
|
||||
format_ = format_.value
|
||||
type_ = v.get("type")
|
||||
if isinstance(type_, Enum):
|
||||
type_ = type_.value
|
||||
else:
|
||||
format_ = v.format.value
|
||||
type_ = v.type.value
|
||||
v = f"{type_}.{format_}"
|
||||
return v
|
||||
|
||||
|
||||
AnyModelConfig = Annotated[
|
||||
Union[
|
||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
||||
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
|
||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
]
|
||||
|
||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
|
||||
# IMPLEMENTATION NOTE:
|
||||
# The preferred alternative to the above is a discriminated Union as shown
|
||||
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
|
||||
# This is a known issue. Please see:
|
||||
# https://github.com/tiangolo/fastapi/discussions/9761 and
|
||||
# https://github.com/tiangolo/fastapi/discussions/9287
|
||||
# AnyModelConfig = Annotated[
|
||||
# Union[
|
||||
# _MainModelConfig,
|
||||
# _ONNXConfig,
|
||||
# _VaeConfig,
|
||||
# _ControlNetConfig,
|
||||
# LoRAConfig,
|
||||
# TextualInversionConfig,
|
||||
# IPAdapterConfig,
|
||||
# CLIPVisionDiffusersConfig,
|
||||
# T2IConfig,
|
||||
# ],
|
||||
# Field(discriminator="type"),
|
||||
# ]
|
||||
|
||||
|
||||
class ModelConfigFactory(object):
|
||||
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
||||
|
||||
@ -331,6 +406,6 @@ class ModelConfigFactory(object):
|
||||
assert model is not None
|
||||
if key:
|
||||
model.key = key
|
||||
if timestamp:
|
||||
model.last_modified = timestamp
|
||||
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
|
||||
model.converted_at = timestamp
|
||||
return model # type: ignore
|
||||
|
@ -15,7 +15,7 @@
|
||||
#
|
||||
# Adapted for use in InvokeAI by Lincoln Stein, July 2023
|
||||
#
|
||||
""" Conversion script for the Stable Diffusion checkpoints."""
|
||||
"""Conversion script for the Stable Diffusion checkpoints."""
|
||||
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
|
@ -11,56 +11,178 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
from imohash import hashfile
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
|
||||
class FastModelHash(object):
|
||||
"""FastModelHash obect provides one public class method, hash()."""
|
||||
from blake3 import blake3
|
||||
|
||||
@classmethod
|
||||
def hash(cls, model_location: Union[str, Path]) -> str:
|
||||
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
||||
|
||||
ALGORITHM = Literal[
|
||||
"md5",
|
||||
"sha1",
|
||||
"sha224",
|
||||
"sha256",
|
||||
"sha384",
|
||||
"sha512",
|
||||
"blake2b",
|
||||
"blake2s",
|
||||
"sha3_224",
|
||||
"sha3_256",
|
||||
"sha3_384",
|
||||
"sha3_512",
|
||||
"shake_128",
|
||||
"shake_256",
|
||||
"blake3",
|
||||
]
|
||||
|
||||
|
||||
class ModelHash:
|
||||
"""
|
||||
Return hexdigest string for model located at model_location.
|
||||
Creates a hash of a model using a specified algorithm.
|
||||
|
||||
:param model_location: Path to the model
|
||||
Args:
|
||||
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
||||
file_filter: A function that takes a file name and returns True if the file should be included in the hash.
|
||||
|
||||
If the model is a single file, it is hashed directly using the provided algorithm.
|
||||
|
||||
If the model is a directory, each model weights file in the directory is hashed using the provided algorithm.
|
||||
|
||||
Only files with the following extensions are hashed: .ckpt, .safetensors, .bin, .pt, .pth
|
||||
|
||||
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
|
||||
that directory hashes are never weaker than the file hashes.
|
||||
|
||||
Usage:
|
||||
```py
|
||||
# BLAKE3 hash
|
||||
ModelHash().hash("path/to/some/model.safetensors")
|
||||
# MD5
|
||||
ModelHash("md5").hash("path/to/model/dir/")
|
||||
```
|
||||
"""
|
||||
model_location = Path(model_location)
|
||||
if model_location.is_file():
|
||||
return cls._hash_file(model_location)
|
||||
elif model_location.is_dir():
|
||||
return cls._hash_dir(model_location)
|
||||
|
||||
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
|
||||
if algorithm == "blake3":
|
||||
self._hash_file = self._blake3
|
||||
elif algorithm in hashlib.algorithms_available:
|
||||
self._hash_file = self._get_hashlib(algorithm)
|
||||
else:
|
||||
raise OSError(f"Not a valid file or directory: {model_location}")
|
||||
raise ValueError(f"Algorithm {algorithm} not available")
|
||||
|
||||
@classmethod
|
||||
def _hash_file(cls, model_location: Union[str, Path]) -> str:
|
||||
self._file_filter = file_filter or self._default_file_filter
|
||||
|
||||
def hash(self, model_path: Union[str, Path]) -> str:
|
||||
"""
|
||||
Fasthash a single file and return its hexdigest.
|
||||
Return hexdigest of hash of model located at model_path using the algorithm provided at class instantiation.
|
||||
|
||||
:param model_location: Path to the model file
|
||||
If model_path is a directory, the hash is computed by hashing the hashes of all model files in the
|
||||
directory. The final composite hash is always computed using BLAKE3.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
|
||||
Returns:
|
||||
str: Hexdigest of the hash of the model
|
||||
"""
|
||||
# we return md5 hash of the filehash to make it shorter
|
||||
# cryptographic security not needed here
|
||||
return hashlib.md5(hashfile(model_location)).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
||||
components: Dict[str, str] = {}
|
||||
model_path = Path(model_path)
|
||||
if model_path.is_file():
|
||||
return self._hash_file(model_path)
|
||||
elif model_path.is_dir():
|
||||
return self._hash_dir(model_path)
|
||||
else:
|
||||
raise OSError(f"Not a valid file or directory: {model_path}")
|
||||
|
||||
for root, _dirs, files in os.walk(model_location):
|
||||
for file in files:
|
||||
# only tally tensor files because diffusers config files change slightly
|
||||
# depending on how the model was downloaded/converted.
|
||||
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
||||
continue
|
||||
path = (Path(root) / file).as_posix()
|
||||
fast_hash = cls._hash_file(path)
|
||||
components.update({path: fast_hash})
|
||||
def _hash_dir(self, dir: Path) -> str:
|
||||
"""Compute the hash for all files in a directory and return a hexdigest.
|
||||
|
||||
# hash all the model hashes together, using alphabetic file order
|
||||
md5 = hashlib.md5()
|
||||
for _path, fast_hash in sorted(components.items()):
|
||||
md5.update(fast_hash.encode("utf-8"))
|
||||
return md5.hexdigest()
|
||||
Args:
|
||||
dir: Path to the directory
|
||||
|
||||
Returns:
|
||||
str: Hexdigest of the hash of the directory
|
||||
"""
|
||||
model_component_paths = self._get_file_paths(dir, self._file_filter)
|
||||
|
||||
# Use ThreadPoolExecutor to hash files in parallel
|
||||
with ThreadPoolExecutor(min(((os.cpu_count() or 1) + 4), len(model_component_paths))) as executor:
|
||||
future_to_component = {executor.submit(self._hash_file, component): component for component in sorted(model_component_paths)}
|
||||
component_hashes = [future.result() for future in as_completed(future_to_component)]
|
||||
|
||||
# BLAKE3 to hash the hashes
|
||||
composite_hasher = blake3()
|
||||
component_hashes.sort()
|
||||
for h in component_hashes:
|
||||
composite_hasher.update(h.encode("utf-8"))
|
||||
return composite_hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _get_file_paths(model_path: Path, file_filter: Callable[[str], bool]) -> list[Path]:
|
||||
"""Return a list of all model files in the directory.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model
|
||||
file_filter: Function that takes a file name and returns True if the file should be included in the list.
|
||||
|
||||
Returns:
|
||||
List of all model files in the directory
|
||||
"""
|
||||
|
||||
files: list[Path] = []
|
||||
entries = [entry for entry in os.scandir(model_path.as_posix()) if not entry.name.startswith(".")]
|
||||
dirs = [entry for entry in entries if entry.is_dir()]
|
||||
file_paths = [entry.path for entry in entries if entry.is_file() and file_filter(entry.path)]
|
||||
files.extend([Path(file) for file in file_paths])
|
||||
for dir in dirs:
|
||||
files.extend(ModelHash._get_file_paths(Path(dir.path), file_filter))
|
||||
return files
|
||||
|
||||
@staticmethod
|
||||
def _blake3(file_path: Path) -> str:
|
||||
"""Hashes a file using BLAKE3
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to hash
|
||||
|
||||
Returns:
|
||||
Hexdigest of the hash of the file
|
||||
"""
|
||||
file_hasher = blake3(max_threads=blake3.AUTO)
|
||||
file_hasher.update_mmap(file_path)
|
||||
return file_hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
|
||||
"""Factory function that returns a function to hash a file with the given algorithm.
|
||||
|
||||
Args:
|
||||
algorithm: Hashing algorithm to use
|
||||
|
||||
Returns:
|
||||
A function that hashes a file using the given algorithm
|
||||
"""
|
||||
|
||||
def hashlib_hasher(file_path: Path) -> str:
|
||||
"""Hashes a file using a hashlib algorithm."""
|
||||
hasher = hashlib.new(algorithm)
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8 * 1024), b""):
|
||||
hasher.update(chunk)
|
||||
return hasher.hexdigest()
|
||||
|
||||
return hashlib_hasher
|
||||
|
||||
@staticmethod
|
||||
def _default_file_filter(file_path: str) -> bool:
|
||||
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
True if the file matches the given extensions, otherwise False
|
||||
"""
|
||||
return file_path.endswith(MODEL_FILE_EXTENSIONS)
|
||||
|
@ -2,6 +2,7 @@
|
||||
"""
|
||||
Init file for the model loader.
|
||||
"""
|
||||
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""
|
||||
Disk-based converted model cache.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -13,6 +13,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelRepoVariant,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
@ -50,7 +51,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
if model_config.type == "main" and not submodel_type:
|
||||
if model_config.type is ModelType.Main and not submodel_type:
|
||||
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
||||
|
||||
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
|
||||
@ -80,7 +81,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
return self._convert_model(config, model_path, cache_path)
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
return False
|
||||
|
||||
def _load_if_needed(
|
||||
@ -119,7 +120,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
return calc_model_size_by_fs(
|
||||
model_path=model_path,
|
||||
subfolder=submodel_type.value if submodel_type else None,
|
||||
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
|
||||
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
|
||||
)
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
|
@ -14,10 +14,9 @@ Use like this:
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional, Tuple, Type
|
||||
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
|
||||
|
||||
from ..config import (
|
||||
AnyModelConfig,
|
||||
@ -26,8 +25,6 @@ from ..config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
VaeCheckpointConfig,
|
||||
VaeDiffusersConfig,
|
||||
)
|
||||
from . import ModelLoaderBase
|
||||
|
||||
@ -60,6 +57,9 @@ class ModelLoaderRegistryBase(ABC):
|
||||
"""
|
||||
|
||||
|
||||
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
|
||||
|
||||
|
||||
class ModelLoaderRegistry:
|
||||
"""
|
||||
This class allows model loaders to register their type, base and format.
|
||||
@ -70,10 +70,10 @@ class ModelLoaderRegistry:
|
||||
@classmethod
|
||||
def register(
|
||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
||||
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
|
||||
"""Define a decorator which registers the subclass of loader."""
|
||||
|
||||
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
||||
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
|
||||
key = cls._to_registry_key(base, type, format)
|
||||
if key in cls._registry:
|
||||
raise Exception(
|
||||
@ -89,33 +89,15 @@ class ModelLoaderRegistry:
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
||||
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
||||
|
||||
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
|
||||
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
|
||||
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
|
||||
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
|
||||
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
||||
if not implementation:
|
||||
raise NotImplementedError(
|
||||
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
||||
)
|
||||
return implementation, conf2, submodel_type
|
||||
|
||||
@classmethod
|
||||
def _handle_subtype_overrides(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
|
||||
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
||||
model_path = Path(config.vae)
|
||||
config_class = (
|
||||
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
|
||||
)
|
||||
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
|
||||
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
|
||||
submodel_type = None
|
||||
else:
|
||||
new_conf = config
|
||||
return new_conf, submodel_type
|
||||
return implementation, config, submodel_type
|
||||
|
||||
@staticmethod
|
||||
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
|
||||
|
@ -3,8 +3,8 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
@ -12,6 +12,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@ -20,15 +21,15 @@ from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||
class ControlnetLoader(GenericDiffusersLoader):
|
||||
class ControlNetLoader(GenericDiffusersLoader):
|
||||
"""Class to load ControlNet models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
@ -37,13 +38,13 @@ class ControlnetLoader(GenericDiffusersLoader):
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
raise Exception(f"ControlNet conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert hasattr(config, "config")
|
||||
config_file = config.config
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = config.config_path
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
|
@ -3,9 +3,10 @@
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
@ -41,6 +42,7 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
# TO DO: Add exception handling
|
||||
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
||||
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
|
||||
result = None
|
||||
if submodel_type:
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||
@ -64,6 +66,7 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||
assert result is not None
|
||||
return result
|
||||
|
||||
# TO DO: Add exception handling
|
||||
@ -75,7 +78,7 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
result: ModelMixin = getattr(res_type, class_name)
|
||||
return result
|
||||
|
||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
|
||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> dict[str, Any]:
|
||||
return ConfigLoader.load_config(model_path, config_name=config_name)
|
||||
|
||||
|
||||
@ -83,8 +86,8 @@ class ConfigLoader(ConfigMixin):
|
||||
"""Subclass of ConfigMixin for loading diffusers configuration files."""
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
def load_config(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # pyright: ignore [reportIncompatibleMethodOverride]
|
||||
"""Load a diffusrs ConfigMixin configuration."""
|
||||
cls.config_name = kwargs.pop("config_name")
|
||||
# Diffusers doesn't provide typing info
|
||||
# TODO(psyche): the types on this diffusers method are not correct
|
||||
return super().load_config(*args, **kwargs) # type: ignore
|
||||
|
@ -31,7 +31,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||
model = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=model_path / "ip_adapter.bin",
|
||||
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
|
||||
device=torch.device("cpu"),
|
||||
dtype=self._torch_dtype,
|
||||
)
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for LoRA model loading in InvokeAI."""
|
||||
|
||||
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
@ -23,9 +22,9 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
|
||||
class LoraLoader(ModelLoader):
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||
class LoRALoader(ModelLoader):
|
||||
"""Class to load LoRA models."""
|
||||
|
||||
# We cheat a little bit to get access to the model base
|
||||
|
@ -18,7 +18,7 @@ from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.ONNX)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
|
||||
class OnnyxDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load onnx models."""
|
||||
|
@ -1,11 +1,11 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for StableDiffusion model loading in InvokeAI."""
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
@ -17,7 +17,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelVariantType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import MainCheckpointConfig
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
@ -55,11 +55,11 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
return result
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
@ -74,7 +74,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
|
||||
)
|
||||
|
||||
config_file = config.config
|
||||
config_file = config.config_path
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
convert_ckpt_to_diffusers(
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for TI model loading in InvokeAI."""
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
@ -3,9 +3,9 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
@ -13,24 +13,25 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
class VaeLoader(GenericDiffusersLoader):
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
class VAELoader(GenericDiffusersLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
if config.format != ModelFormat.Checkpoint:
|
||||
if not isinstance(config, CheckpointConfigBase):
|
||||
return False
|
||||
elif (
|
||||
dest_path.exists()
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
|
||||
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
|
||||
):
|
||||
return False
|
||||
@ -38,16 +39,15 @@ class VaeLoader(GenericDiffusersLoader):
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
# TO DO: check whether sdxl VAE models convert.
|
||||
# TODO(MM2): check whether sdxl VAE models convert.
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
raise Exception(f"VAE conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
config_file = (
|
||||
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
|
||||
)
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = config.config_path
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors.torch.load_file(model_path, device="cpu")
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(model_path, map_location="cpu")
|
||||
|
||||
@ -55,7 +55,7 @@ class VaeLoader(GenericDiffusersLoader):
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
|
||||
ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
|
||||
assert isinstance(ckpt_config, DictConfig)
|
||||
|
||||
vae_model = convert_ldm_vae_to_diffusers(
|
||||
|
@ -16,6 +16,7 @@ from diffusers import AutoPipelineForText2Image
|
||||
from diffusers.utils import logging as dlogging
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
from . import (
|
||||
@ -117,7 +118,6 @@ class ModelMerger(object):
|
||||
config = self._installer.app_config
|
||||
store = self._installer.record_store
|
||||
base_models: Set[BaseModelType] = set()
|
||||
vae = None
|
||||
variant = None if self._installer.app_config.full_precision else "fp16"
|
||||
|
||||
assert (
|
||||
@ -134,10 +134,6 @@ class ModelMerger(object):
|
||||
"normal"
|
||||
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
|
||||
|
||||
# pick up the first model's vae
|
||||
if key == model_keys[0]:
|
||||
vae = info.vae
|
||||
|
||||
# tally base models used
|
||||
base_models.add(info.base)
|
||||
model_paths.extend([config.models_path / info.path])
|
||||
@ -163,12 +159,10 @@ class ModelMerger(object):
|
||||
|
||||
# update model's config
|
||||
model_config = self._installer.record_store.get_model(key)
|
||||
model_config.update(
|
||||
{
|
||||
"name": merged_model_name,
|
||||
"description": f"Merge of models {', '.join(model_names)}",
|
||||
"vae": vae,
|
||||
}
|
||||
model_config.name = merged_model_name
|
||||
model_config.description = f"Merge of models {', '.join(model_names)}"
|
||||
|
||||
self._installer.record_store.update_model(
|
||||
key, ModelRecordChanges(name=model_config.name, description=model_config.description)
|
||||
)
|
||||
self._installer.record_store.update_model(key, model_config)
|
||||
return model_config
|
||||
|
@ -18,15 +18,14 @@ assert isinstance(data, CivitaiMetadata)
|
||||
if data.allow_commercial_use:
|
||||
print("Commercial use of this model is allowed")
|
||||
"""
|
||||
|
||||
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch, ModelMetadataFetchBase
|
||||
from .metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
AnyModelRepoMetadataValidator,
|
||||
BaseMetadata,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
LicenseRestrictions,
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
@ -37,10 +36,8 @@ __all__ = [
|
||||
"AnyModelRepoMetadataValidator",
|
||||
"CivitaiMetadata",
|
||||
"CivitaiMetadataFetch",
|
||||
"CommercialUsage",
|
||||
"HuggingFaceMetadata",
|
||||
"HuggingFaceMetadataFetch",
|
||||
"LicenseRestrictions",
|
||||
"ModelMetadataFetchBase",
|
||||
"BaseMetadata",
|
||||
"ModelMetadataWithFiles",
|
||||
|
@ -23,22 +23,21 @@ metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||
print(metadata.trained_words)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
|
||||
from ..metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
LicenseRestrictions,
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
@ -52,10 +51,13 @@ CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||
|
||||
|
||||
StringSetAdapter = TypeAdapter(set[str])
|
||||
|
||||
|
||||
class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""Fetch model metadata from Civitai."""
|
||||
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
@ -63,6 +65,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
this module without an internet connection.
|
||||
"""
|
||||
self._requests = session or requests.Session()
|
||||
self._api_key = api_key
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
@ -102,22 +105,21 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
May raise an `UnknownMetadataException`.
|
||||
"""
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(model_url).json()
|
||||
return self._from_model_json(model_json)
|
||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||
return self._from_api_response(model_json)
|
||||
|
||||
def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
try:
|
||||
version_id = version_id or model_json["modelVersions"][0]["id"]
|
||||
version_id = version_id or api_response["modelVersions"][0]["id"]
|
||||
except TypeError as excp:
|
||||
raise UnknownMetadataException from excp
|
||||
|
||||
# loop till we find the section containing the version requested
|
||||
version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id]
|
||||
version_sections = [x for x in api_response["modelVersions"] if x["id"] == version_id]
|
||||
if not version_sections:
|
||||
raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
|
||||
|
||||
version_json = version_sections[0]
|
||||
safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"]
|
||||
|
||||
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
|
||||
primary = [x for x in version_json["files"] if x.get("primary")]
|
||||
@ -134,36 +136,23 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
url = url + f"?type={primary_file['type']}{metadata_string}"
|
||||
model_files = [
|
||||
RemoteModelFile(
|
||||
url=url,
|
||||
url=self._get_url_with_api_key(url),
|
||||
path=Path(primary_file["name"]),
|
||||
size=int(primary_file["sizeKB"] * 1024),
|
||||
sha256=primary_file["hashes"]["SHA256"],
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
trigger_phrases = StringSetAdapter.validate_python(version_json.get("trainedWords"))
|
||||
except ValidationError:
|
||||
trigger_phrases: set[str] = set()
|
||||
|
||||
return CivitaiMetadata(
|
||||
id=model_json["id"],
|
||||
name=version_json["name"],
|
||||
version_id=version_json["id"],
|
||||
version_name=version_json["name"],
|
||||
created=datetime.fromisoformat(_fix_timezone(version_json["createdAt"])),
|
||||
updated=datetime.fromisoformat(_fix_timezone(version_json["updatedAt"])),
|
||||
published=datetime.fromisoformat(_fix_timezone(version_json["publishedAt"])),
|
||||
base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType
|
||||
files=model_files,
|
||||
download_url=version_json["downloadUrl"],
|
||||
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
|
||||
author=model_json["creator"]["username"],
|
||||
description=model_json["description"],
|
||||
version_description=version_json["description"] or "",
|
||||
tags=model_json["tags"],
|
||||
trained_words=version_json["trainedWords"],
|
||||
nsfw=model_json["nsfw"],
|
||||
restrictions=LicenseRestrictions(
|
||||
AllowNoCredit=model_json["allowNoCredit"],
|
||||
AllowCommercialUse={CommercialUsage(x) for x in model_json["allowCommercialUse"]},
|
||||
AllowDerivatives=model_json["allowDerivatives"],
|
||||
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
||||
),
|
||||
trigger_phrases=trigger_phrases,
|
||||
api_response=json.dumps(version_json),
|
||||
)
|
||||
|
||||
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
@ -174,14 +163,14 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""
|
||||
if model_id is None:
|
||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||
version = self._requests.get(version_url).json()
|
||||
version = self._requests.get(self._get_url_with_api_key(version_url)).json()
|
||||
if error := version.get("error"):
|
||||
raise UnknownMetadataException(error)
|
||||
model_id = version["modelId"]
|
||||
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(model_url).json()
|
||||
return self._from_model_json(model_json, version_id)
|
||||
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||
return self._from_api_response(model_json, version_id)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> CivitaiMetadata:
|
||||
@ -189,6 +178,11 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
metadata = CivitaiMetadata.model_validate_json(json)
|
||||
return metadata
|
||||
|
||||
def _get_url_with_api_key(self, url: str) -> str:
|
||||
if not self._api_key:
|
||||
return url
|
||||
|
||||
def _fix_timezone(date: str) -> str:
|
||||
return re.sub(r"Z$", "+00:00", date)
|
||||
if "?" in url:
|
||||
return f"{url}&token={self._api_key}"
|
||||
|
||||
return f"{url}?token={self._api_key}"
|
||||
|
@ -13,6 +13,7 @@ metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
|
||||
print(metadata.tags)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@ -23,7 +24,7 @@ from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFo
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
|
||||
from ..metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
@ -60,6 +61,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
# Little loop which tries fetching a revision corresponding to the selected variant.
|
||||
# If not available, then set variant to None and get the default.
|
||||
# If this too fails, raise exception.
|
||||
|
||||
model_info = None
|
||||
while not model_info:
|
||||
try:
|
||||
@ -72,23 +74,24 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
else:
|
||||
variant = None
|
||||
|
||||
files: list[RemoteModelFile] = []
|
||||
|
||||
_, name = id.split("/")
|
||||
return HuggingFaceMetadata(
|
||||
id=model_info.id,
|
||||
author=model_info.author,
|
||||
name=name,
|
||||
last_modified=model_info.last_modified,
|
||||
tag_dict=model_info.card_data.to_dict() if model_info.card_data else {},
|
||||
tags=model_info.tags,
|
||||
files=[
|
||||
|
||||
for s in model_info.siblings or []:
|
||||
assert s.rfilename is not None
|
||||
assert s.size is not None
|
||||
files.append(
|
||||
RemoteModelFile(
|
||||
url=hf_hub_url(id, x.rfilename, revision=variant),
|
||||
path=Path(name, x.rfilename),
|
||||
size=x.size,
|
||||
sha256=x.lfs.get("sha256") if x.lfs else None,
|
||||
url=hf_hub_url(id, s.rfilename, revision=variant),
|
||||
path=Path(name, s.rfilename),
|
||||
size=s.size,
|
||||
sha256=s.lfs.get("sha256") if s.lfs else None,
|
||||
)
|
||||
for x in model_info.siblings
|
||||
],
|
||||
)
|
||||
|
||||
return HuggingFaceMetadata(
|
||||
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
|
||||
)
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
|
@ -14,10 +14,8 @@ versions of these fields are intended to be kept in sync with the
|
||||
remote repo.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from huggingface_hub import configure_http_backend, hf_hub_url
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
@ -34,31 +32,6 @@ class UnknownMetadataException(Exception):
|
||||
"""Raised when no metadata is available for a model."""
|
||||
|
||||
|
||||
class CommercialUsage(str, Enum):
|
||||
"""Type of commercial usage allowed."""
|
||||
|
||||
No = "None"
|
||||
Image = "Image"
|
||||
Rent = "Rent"
|
||||
RentCivit = "RentCivit"
|
||||
Sell = "Sell"
|
||||
|
||||
|
||||
class LicenseRestrictions(BaseModel):
|
||||
"""Broad categories of licensing restrictions."""
|
||||
|
||||
AllowNoCredit: bool = Field(
|
||||
description="if true, model can be redistributed without crediting author", default=False
|
||||
)
|
||||
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
|
||||
AllowDifferentLicense: bool = Field(
|
||||
description="if true, derivatives of this model be redistributed under a different license", default=False
|
||||
)
|
||||
AllowCommercialUse: Optional[Set[CommercialUsage] | CommercialUsage] = Field(
|
||||
description="Type of commercial use allowed if no commercial use is allowed.", default=None
|
||||
)
|
||||
|
||||
|
||||
class RemoteModelFile(BaseModel):
|
||||
"""Information about a downloadable file that forms part of a model."""
|
||||
|
||||
@ -72,8 +45,6 @@ class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
|
||||
name: str = Field(description="model's name")
|
||||
author: str = Field(description="model's author")
|
||||
tags: Set[str] = Field(description="tags provided by model source")
|
||||
|
||||
|
||||
class BaseMetadata(ModelMetadataBase):
|
||||
@ -111,60 +82,16 @@ class CivitaiMetadata(ModelMetadataWithFiles):
|
||||
"""Extended metadata fields provided by Civitai."""
|
||||
|
||||
type: Literal["civitai"] = "civitai"
|
||||
id: int = Field(description="Civitai version identifier")
|
||||
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
|
||||
version_id: int = Field(description="Civitai model version identifier")
|
||||
created: datetime = Field(description="date the model was created")
|
||||
updated: datetime = Field(description="date the model was last modified")
|
||||
published: datetime = Field(description="date the model was published to Civitai")
|
||||
description: str = Field(description="text description of model; may contain HTML")
|
||||
version_description: str = Field(
|
||||
description="text description of the model's reversion; usually change history; may contain HTML"
|
||||
)
|
||||
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
|
||||
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
|
||||
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
|
||||
download_url: AnyHttpUrl = Field(description="download URL for this model")
|
||||
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
|
||||
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
|
||||
weight_minmax: Tuple[float, float] = Field(
|
||||
description="minimum and maximum slider values for a LoRA or other secondary model", default=(-1.0, +2.0)
|
||||
) # note: For future use
|
||||
|
||||
@property
|
||||
def credit_required(self) -> bool:
|
||||
"""Return True if you must give credit for derivatives of this model and images generated from it."""
|
||||
return not self.restrictions.AllowNoCredit
|
||||
|
||||
@property
|
||||
def allow_commercial_use(self) -> bool:
|
||||
"""Return True if commercial use is allowed."""
|
||||
if self.restrictions.AllowCommercialUse is None:
|
||||
return False
|
||||
else:
|
||||
# accommodate schema change
|
||||
acu = self.restrictions.AllowCommercialUse
|
||||
commercial_usage = acu if isinstance(acu, set) else {acu}
|
||||
return CommercialUsage.No not in commercial_usage
|
||||
|
||||
@property
|
||||
def allow_derivatives(self) -> bool:
|
||||
"""Return True if derivatives of this model can be redistributed."""
|
||||
return self.restrictions.AllowDerivatives
|
||||
|
||||
@property
|
||||
def allow_different_license(self) -> bool:
|
||||
"""Return true if derivatives of this model can use a different license."""
|
||||
return self.restrictions.AllowDifferentLicense
|
||||
trigger_phrases: set[str] = Field(description="Trigger phrases extracted from the API response")
|
||||
api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None)
|
||||
|
||||
|
||||
class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||
"""Extended metadata fields provided by HuggingFace."""
|
||||
|
||||
type: Literal["huggingface"] = "huggingface"
|
||||
id: str = Field(description="huggingface model id")
|
||||
tag_dict: Dict[str, Any]
|
||||
last_modified: datetime = Field(description="date of last commit to repo")
|
||||
id: str = Field(description="The HF model id")
|
||||
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
|
||||
|
||||
def download_urls(
|
||||
self,
|
||||
@ -193,7 +120,7 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||
# the next step reads model_index.json to determine which subdirectories belong
|
||||
# to the model
|
||||
if Path(f"{prefix}model_index.json") in paths:
|
||||
url = hf_hub_url(self.id, filename="model_index.json", subfolder=subfolder)
|
||||
url = hf_hub_url(self.id, filename="model_index.json", subfolder=str(subfolder) if subfolder else None)
|
||||
resp = session.get(url)
|
||||
resp.raise_for_status()
|
||||
submodels = resp.json()
|
||||
|
@ -1,221 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .fetch import ModelMetadataFetchBase
|
||||
from .metadata_base import AnyModelRepoMetadata, UnknownMetadataException
|
||||
|
||||
|
||||
class ModelMetadataStore:
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownMetadataException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id,metadata FROM model_metadata;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select tag_text from tags;
|
||||
"""
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.model_id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches if matches else set()
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
@ -8,6 +8,7 @@ import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.util.util import SilenceWarnings
|
||||
|
||||
from .config import (
|
||||
@ -17,11 +18,12 @@ from .config import (
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from .hash import FastModelHash
|
||||
from .hash import ModelHash
|
||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
|
||||
CkptType = Dict[str, Any]
|
||||
@ -82,6 +84,9 @@ class ProbeBase(object):
|
||||
|
||||
|
||||
class ModelProbe(object):
|
||||
|
||||
hasher = ModelHash()
|
||||
|
||||
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
|
||||
"diffusers": {},
|
||||
"checkpoint": {},
|
||||
@ -95,8 +100,8 @@ class ModelProbe(object):
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"AutoencoderTiny": ModelType.Vae,
|
||||
"AutoencoderKL": ModelType.VAE,
|
||||
"AutoencoderTiny": ModelType.VAE,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||
"T2IAdapter": ModelType.T2IAdapter,
|
||||
@ -108,14 +113,6 @@ class ModelProbe(object):
|
||||
) -> None:
|
||||
cls.PROBES[format][model_type] = probe_class
|
||||
|
||||
@classmethod
|
||||
def heuristic_probe(
|
||||
cls,
|
||||
model_path: Path,
|
||||
fields: Optional[Dict[str, Any]] = None,
|
||||
) -> AnyModelConfig:
|
||||
return cls.probe(model_path, fields)
|
||||
|
||||
@classmethod
|
||||
def probe(
|
||||
cls,
|
||||
@ -137,19 +134,21 @@ class ModelProbe(object):
|
||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||
model_info = None
|
||||
model_type = None
|
||||
if format_type == "diffusers":
|
||||
if format_type is ModelFormat.Diffusers:
|
||||
model_type = cls.get_model_type_from_folder(model_path)
|
||||
else:
|
||||
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
|
||||
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
|
||||
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
if not probe_class:
|
||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||
|
||||
hash = FastModelHash.hash(model_path)
|
||||
probe = probe_class(model_path)
|
||||
|
||||
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
|
||||
fields["source"] = fields.get("source") or model_path.as_posix()
|
||||
fields["key"] = fields.get("key", uuid_string())
|
||||
fields["path"] = model_path.as_posix()
|
||||
fields["type"] = fields.get("type") or model_type
|
||||
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||
@ -161,15 +160,17 @@ class ModelProbe(object):
|
||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["original_hash"] = fields.get("original_hash") or hash
|
||||
fields["current_hash"] = fields.get("current_hash") or hash
|
||||
fields["hash"] = fields.get("hash") or cls.hasher.hash(model_path)
|
||||
|
||||
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
|
||||
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
||||
# additional fields needed for main and controlnet models
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
|
||||
fields["config"] = cls._get_checkpoint_config_path(
|
||||
if (
|
||||
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
|
||||
and fields["format"] is ModelFormat.Checkpoint
|
||||
):
|
||||
fields["config_path"] = cls._get_checkpoint_config_path(
|
||||
model_path,
|
||||
model_type=fields["type"],
|
||||
base_type=fields["base"],
|
||||
@ -179,7 +180,7 @@ class ModelProbe(object):
|
||||
|
||||
# additional fields needed for main non-checkpoint models
|
||||
elif fields["type"] == ModelType.Main and fields["format"] in [
|
||||
ModelFormat.Onnx,
|
||||
ModelFormat.ONNX,
|
||||
ModelFormat.Olive,
|
||||
ModelFormat.Diffusers,
|
||||
]:
|
||||
@ -188,7 +189,7 @@ class ModelProbe(object):
|
||||
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
|
||||
model_info = ModelConfigFactory.make_config(fields, key=fields.get("key", None))
|
||||
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
|
||||
return model_info
|
||||
|
||||
@classmethod
|
||||
@ -213,11 +214,11 @@ class ModelProbe(object):
|
||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||
return ModelType.Main
|
||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||
return ModelType.Vae
|
||||
return ModelType.VAE
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
return ModelType.Lora
|
||||
return ModelType.LoRA
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
return ModelType.Lora
|
||||
return ModelType.LoRA
|
||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||
return ModelType.ControlNet
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
@ -239,7 +240,7 @@ class ModelProbe(object):
|
||||
if (folder_path / f"learned_embeds.{suffix}").exists():
|
||||
return ModelType.TextualInversion
|
||||
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
|
||||
return ModelType.Lora
|
||||
return ModelType.LoRA
|
||||
if (folder_path / "unet/model.onnx").exists():
|
||||
return ModelType.ONNX
|
||||
if (folder_path / "image_encoder.txt").exists():
|
||||
@ -285,13 +286,21 @@ class ModelProbe(object):
|
||||
if possible_conf.exists():
|
||||
return possible_conf.absolute()
|
||||
|
||||
if model_type == ModelType.Main:
|
||||
if model_type is ModelType.Main:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
elif model_type == ModelType.ControlNet:
|
||||
elif model_type is ModelType.ControlNet:
|
||||
config_file = (
|
||||
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
|
||||
"../controlnet/cldm_v15.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "../controlnet/cldm_v21.yaml"
|
||||
)
|
||||
elif model_type is ModelType.VAE:
|
||||
config_file = (
|
||||
"../stable-diffusion/v1-inference.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "../stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
else:
|
||||
raise InvalidModelConfigException(
|
||||
@ -497,12 +506,12 @@ class FolderProbeBase(ProbeBase):
|
||||
if ".fp16" in x.suffixes:
|
||||
return ModelRepoVariant.FP16
|
||||
if "openvino_model" in x.name:
|
||||
return ModelRepoVariant.OPENVINO
|
||||
return ModelRepoVariant.OpenVINO
|
||||
if "flax_model" in x.name:
|
||||
return ModelRepoVariant.FLAX
|
||||
return ModelRepoVariant.Flax
|
||||
if x.suffix == ".onnx":
|
||||
return ModelRepoVariant.ONNX
|
||||
return ModelRepoVariant.DEFAULT
|
||||
return ModelRepoVariant.Default
|
||||
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
@ -708,8 +717,8 @@ class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
|
||||
############## register probe classes ######
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
@ -717,8 +726,8 @@ ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderPro
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
|
@ -28,6 +28,7 @@ from typing import Callable, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
default_logger: Logger = InvokeAILogger.get_logger()
|
||||
@ -117,13 +118,10 @@ class ModelSearch(ModelSearchBase):
|
||||
"""
|
||||
|
||||
models_found: Set[Path] = Field(default_factory=set)
|
||||
scanned_dirs: Set[Path] = Field(default_factory=set)
|
||||
pruned_paths: Set[Path] = Field(default_factory=set)
|
||||
config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
|
||||
|
||||
def search_started(self) -> None:
|
||||
self.models_found = set()
|
||||
self.scanned_dirs = set()
|
||||
self.pruned_paths = set()
|
||||
if self.on_search_started:
|
||||
self.on_search_started(self._directory)
|
||||
|
||||
@ -139,29 +137,28 @@ class ModelSearch(ModelSearchBase):
|
||||
|
||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||
self._directory = Path(directory)
|
||||
if not self._directory.is_absolute():
|
||||
self._directory = self.config.models_path / self._directory
|
||||
self.stats = SearchStats() # zero out
|
||||
self.search_started() # This will initialize _models_found to empty
|
||||
self._walk_directory(directory)
|
||||
self._walk_directory(self._directory)
|
||||
self.search_completed()
|
||||
return self.models_found
|
||||
|
||||
def _walk_directory(self, path: Union[Path, str]) -> None:
|
||||
for root, dirs, files in os.walk(path, followlinks=True):
|
||||
# don't descend into directories that start with a "."
|
||||
# to avoid the Mac .DS_STORE issue.
|
||||
if str(Path(root).name).startswith("."):
|
||||
self.pruned_paths.add(Path(root))
|
||||
if any(Path(root).is_relative_to(x) for x in self.pruned_paths):
|
||||
continue
|
||||
|
||||
self.stats.items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path.parent in self.scanned_dirs:
|
||||
self.scanned_dirs.add(path)
|
||||
continue
|
||||
def _walk_directory(self, path: Union[Path, str], max_depth: int = 20) -> None:
|
||||
absolute_path = Path(path)
|
||||
if (
|
||||
len(absolute_path.parts) - len(self._directory.parts) > max_depth
|
||||
or not absolute_path.exists()
|
||||
or absolute_path.parent in self.models_found
|
||||
):
|
||||
return
|
||||
entries = os.scandir(absolute_path.as_posix())
|
||||
entries = [entry for entry in entries if not entry.name.startswith(".")]
|
||||
dirs = [entry for entry in entries if entry.is_dir()]
|
||||
file_names = [entry.name for entry in entries if entry.is_file()]
|
||||
if any(
|
||||
(path / x).exists()
|
||||
x in file_names
|
||||
for x in [
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
@ -170,22 +167,23 @@ class ModelSearch(ModelSearchBase):
|
||||
"image_encoder.txt",
|
||||
]
|
||||
):
|
||||
self.scanned_dirs.add(path)
|
||||
try:
|
||||
self.model_found(path)
|
||||
self.model_found(absolute_path)
|
||||
return
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
return
|
||||
|
||||
for n in file_names:
|
||||
if n.endswith((".ckpt", ".bin", ".pth", ".safetensors", ".pt")):
|
||||
try:
|
||||
self.model_found(absolute_path / n)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path.parent in self.scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
||||
try:
|
||||
self.model_found(path)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
for d in dirs:
|
||||
self._walk_directory(absolute_path / d)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user