mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
682 Commits
lstein/fea
...
v4.2.9.dev
Author | SHA1 | Date | |
---|---|---|---|
5db7b48cd8 | |||
ea014c66ac | |||
25918c28aa | |||
0c60469401 | |||
f1aa50f447 | |||
a413b261f0 | |||
4a1a6639f6 | |||
201c370ca1 | |||
d070c7c726 | |||
e38e20a992 | |||
39a94ec70e | |||
c7bfae2d1e | |||
e7944c427d | |||
48ed4e120d | |||
a5b038a1b1 | |||
dc752c98b0 | |||
85a47cc6fe | |||
6450f42cfa | |||
3876f71ff4 | |||
cf819e8eab | |||
2217fb8485 | |||
43652e830a | |||
a3417bf81d | |||
06637161e3 | |||
c4f4b16a36 | |||
3001718f9f | |||
ac16fa65a3 | |||
bc0b5335ff | |||
e91c7c5a30 | |||
74791cc490 | |||
68409c6a0f | |||
85613b220c | |||
80085ad854 | |||
b6bfa65104 | |||
0bfa033089 | |||
6f0974b5bc | |||
c3b53fc4f6 | |||
8f59a32d81 | |||
3b4f20f433 | |||
73da6e9628 | |||
2fc482141d | |||
a41ec5f3fc | |||
c906225d03 | |||
2985ea3716 | |||
4f151c6c6f | |||
c4f5252c1a | |||
77c13f2cf3 | |||
3270d36fca | |||
b6b30ff01f | |||
aa9bfdff35 | |||
80308cc3b8 | |||
f6db73bf1f | |||
ef9f61a39f | |||
a1a0881133 | |||
9956919ab6 | |||
abc07f57d6 | |||
1a1cae79f1 | |||
bcfafe7b06 | |||
34e8ced592 | |||
1fdada65b6 | |||
433f3e1971 | |||
a60e23f825 | |||
f69de3148e | |||
cbcd36ef54 | |||
aa76134340 | |||
55758acae8 | |||
196e43b5e5 | |||
38b9828441 | |||
0048a7077e | |||
527a39a3ad | |||
30ce4c55c7 | |||
ca082d4288 | |||
5e59a4f43a | |||
9f86605049 | |||
79058a7894 | |||
bb3ad8c2f1 | |||
799688514b | |||
b7344b0df2 | |||
7e382c5f3f | |||
9cf357e184 | |||
95b6c773d4 | |||
89d8c5ba00 | |||
59580cf6ed | |||
2b0c084f5b | |||
4d896073ff | |||
9f69503a80 | |||
0311e852a0 | |||
7003a3d546 | |||
dc73072e27 | |||
e549c44ad7 | |||
45a4231cbe | |||
81f046ebac | |||
6ef6c593c4 | |||
5b53eefef7 | |||
9a9919c0af | |||
10661b33d4 | |||
52193d604d | |||
2568441e6a | |||
1a14860b3b | |||
9ff7647ec5 | |||
b49106e8fe | |||
906d0902a3 | |||
fbde6f5a7f | |||
b388268987 | |||
3b4164bd62 | |||
b7fc6fe573 | |||
2954a19d27 | |||
aa45ce7fbd | |||
77e5078e4a | |||
603cc7bf2e | |||
cd517a102d | |||
9a442918b5 | |||
f9c03d85a5 | |||
10d07c71c4 | |||
cd05a78219 | |||
f8ee572abc | |||
d918654509 | |||
582e30c542 | |||
34a6555301 | |||
fff860090b | |||
f4971197c1 | |||
621d5e0462 | |||
0b68a69a6c | |||
9a599ce595 | |||
1467ba276f | |||
708facf707 | |||
9c6c6adb1f | |||
c335b8581c | |||
f1348e45bd | |||
ce6cf9b079 | |||
13ec80736a | |||
c9690a4b21 | |||
489e875a6e | |||
8651396048 | |||
2bab5a6179 | |||
006f06b615 | |||
d603923d1b | |||
86878e855b | |||
35de60a8fa | |||
2c444a1941 | |||
3dfef01889 | |||
a845a2daa5 | |||
df41f4fbce | |||
76482da6f5 | |||
8205abbbbf | |||
926873de26 | |||
00cb1903ba | |||
58ba38b9c7 | |||
2f6a5617f9 | |||
e0d84743be | |||
ee7c62acc4 | |||
daf3e58bd9 | |||
c5b9209057 | |||
2a4d6d98e2 | |||
cfdf59d906 | |||
f91ce1a47c | |||
4af2888168 | |||
8471c6fe86 | |||
fe65a5a2db | |||
0df26e967c | |||
d4822b305e | |||
8df5447563 | |||
7b5a43df9b | |||
61ef630175 | |||
4eda2ef555 | |||
57f4489520 | |||
fb6cf9e3da | |||
f776326cff | |||
5be32d5733 | |||
1f73435241 | |||
3251a00631 | |||
49c4ad1dd7 | |||
5857e95c4a | |||
85be2532c6 | |||
8b81a00def | |||
8544595c27 | |||
a6a5d1470c | |||
febcc12ec9 | |||
ab64078b76 | |||
0ff3459b07 | |||
2abd7c9bfe | |||
8e5330bdc9 | |||
1ecec4ea3a | |||
700dbe69f3 | |||
af7ba3b7e4 | |||
1f7144d62e | |||
4e389e415b | |||
2db29fb6ab | |||
79653fcff5 | |||
9e39180fbc | |||
dc0f832d8f | |||
0dcd6aa5d9 | |||
9f4a8f11f8 | |||
6f9579d6ec | |||
2b2aabb234 | |||
a79b9633ab | |||
7b628c908b | |||
181703a709 | |||
c439e3c204 | |||
11e81eb456 | |||
3e24bf640e | |||
a17664fb75 | |||
5aed23dc91 | |||
7f389716d0 | |||
eca13b674a | |||
7c982a1bdf | |||
bdfe6870fd | |||
7eebbc0dd9 | |||
6ae46d7c8b | |||
8aae30566e | |||
4b7c3e221c | |||
4015795b7f | |||
132dd61d8d | |||
6f2b548dd1 | |||
49dd316f17 | |||
e0a8bb149d | |||
020b6db34b | |||
f6d2f0bf8c | |||
1280cce803 | |||
82463d12e2 | |||
ba6c1b84e4 | |||
92b6d3198a | |||
693ae1af50 | |||
3873a3096c | |||
4a9f6ab5ef | |||
2fc29c5125 | |||
5fbc876cfd | |||
89b0673ac9 | |||
3179a16189 | |||
3ce216b391 | |||
8a3a94e21a | |||
f61af188f9 | |||
73fc52bfed | |||
4eeff4eef8 | |||
cff2c43030 | |||
504b1f2425 | |||
eff5b56990 | |||
0e673f1a18 | |||
4fea22aea4 | |||
c2478c9ac3 | |||
ed8243825e | |||
c7ac2b5278 | |||
a783003556 | |||
44ba1c6113 | |||
a02d67fcc6 | |||
14c8f7c4f5 | |||
a487ecb50f | |||
fc55862823 | |||
e5400601d6 | |||
735f9f1483 | |||
d139db0a0f | |||
a35bb450b1 | |||
26c01dfa48 | |||
2b1839374a | |||
59f5f18e1d | |||
e41fcb081c | |||
3032042b35 | |||
544db61044 | |||
67a3aa6dff | |||
34f4468b20 | |||
d99ae58001 | |||
d60ec53762 | |||
0ece9361d5 | |||
69987a2f00 | |||
7346bfccb9 | |||
b705083ce2 | |||
4b3c82df6f | |||
46290205d5 | |||
7b15585b80 | |||
cbfeeb9079 | |||
fdac20b43e | |||
ae2312104e | |||
74b6674af6 | |||
81adce3238 | |||
48b7f460a8 | |||
38a8232341 | |||
39a51c4f4c | |||
220bfeb37d | |||
e766279950 | |||
db82406525 | |||
2d44f332d9 | |||
ea1526689a | |||
200338ed72 | |||
88003a61bd | |||
b82089b30b | |||
efd780d395 | |||
3f90f783de | |||
4f5b755117 | |||
a455f12581 | |||
de516db383 | |||
6781575293 | |||
65b0e40fc8 | |||
19e78a07b7 | |||
e3b60dda07 | |||
a9696d3193 | |||
336d72873f | |||
111a380bce | |||
012a8351af | |||
deeb80ea9b | |||
1e97a917d6 | |||
a15ba925db | |||
6e5a968aad | |||
915edaa02f | |||
f6b0fa7c18 | |||
3f1fba0f35 | |||
d9fa85a4c6 | |||
022bb8649c | |||
673bc33a01 | |||
579a64928d | |||
5316df7d7d | |||
62fe61dc30 | |||
ea819a4a2f | |||
868a25dae2 | |||
3a25d00cb6 | |||
6301b74d87 | |||
b43b90c299 | |||
0a43444ab3 | |||
75ea4d2155 | |||
9d281388e0 | |||
178e1cc50b | |||
6eafe53b2d | |||
7514b2b7d4 | |||
51dee2dba2 | |||
a81c3d841c | |||
154487f966 | |||
588daafcf5 | |||
ab00097aed | |||
5aefae71a2 | |||
0edd598970 | |||
0bb485031f | |||
01ffd86367 | |||
e2e02f31b6 | |||
0008617348 | |||
f8f21c0edd | |||
010916158b | |||
df0ba004ca | |||
038b29e15b | |||
763ab73923 | |||
9a060b4437 | |||
2307a30892 | |||
2006f84f6e | |||
23b15fef6a | |||
41e72f929d | |||
29fc49bb3b | |||
6f65b6a40f | |||
1e9b22e3a4 | |||
3dc2c723c3 | |||
6f007cbd48 | |||
9a402dd10e | |||
4249d0e13b | |||
f6050bad67 | |||
d3a4b7b51b | |||
3f5f9ac764 | |||
e7c1299a7f | |||
56a3918a1e | |||
dabf7718cf | |||
ef22c29288 | |||
74f06074f7 | |||
b97bf52faa | |||
0a77f5cec8 | |||
f8e92f7b73 | |||
530c6e3a59 | |||
11b95cfaf4 | |||
707c005a26 | |||
19fa8e7e33 | |||
b252ded366 | |||
84305d4e73 | |||
1150b41e14 | |||
88d8ccb34b | |||
4111b3f1aa | |||
36862be2aa | |||
425665e0d9 | |||
9abd604f69 | |||
59bdc288b5 | |||
eb37d2958e | |||
2a9738a341 | |||
6aac1cf33a | |||
9ca4d072ab | |||
7aaf14c26b | |||
cf598ca175 | |||
a722790afc | |||
320151a040 | |||
c090f511c3 | |||
86dd1475b3 | |||
0b71ac258c | |||
54e1eae509 | |||
bf57b2dc77 | |||
de3c27b44f | |||
05717fea93 | |||
191584d229 | |||
6069169e6b | |||
07438587f3 | |||
913e36d6fd | |||
139004b976 | |||
ef4269d585 | |||
954cb129a4 | |||
02c4b28de5 | |||
febea88b58 | |||
40ccfac514 | |||
831fb814cc | |||
bf166fdd61 | |||
384bde3539 | |||
6f1d238d0a | |||
ac524153a7 | |||
2cad2b15cf | |||
fd63e202fe | |||
a0250e47e3 | |||
7d8ece45bb | |||
ffb8f053da | |||
fb46f457f9 | |||
6d4f4152a7 | |||
d3d0ac7327 | |||
f57df46995 | |||
a747171745 | |||
e9ae9e80d4 | |||
3b9a59b98d | |||
8a381e7f74 | |||
61513fc800 | |||
1d26c49e92 | |||
77be9836d2 | |||
4427960acb | |||
84aa4fb7bc | |||
01df96cbe0 | |||
1ac0634f57 | |||
8e7d3634b1 | |||
fadafe5c77 | |||
b2ea1f6690 | |||
0a03c1f882 | |||
ce8b490ed8 | |||
86eccba80d | |||
97453e7c6c | |||
9e1084b701 | |||
41aec81f3f | |||
a5741a0551 | |||
72f73c231a | |||
b8fcaa274e | |||
a33bbf48bb | |||
98d9490fb9 | |||
51c643c4f8 | |||
2d7370ca6c | |||
0d68141387 | |||
e88a8c6639 | |||
2d04bb286e | |||
6d9ba24c32 | |||
7645a1c86e | |||
dc284b9a48 | |||
8a38332d44 | |||
2407d7d148 | |||
23275cf607 | |||
8a0e02d335 | |||
913873623c | |||
5d81e7dd4d | |||
418650fdf3 | |||
7c24e56d9f | |||
f6faed46c3 | |||
1d393eecf1 | |||
03a72240c0 | |||
acf62450fb | |||
d0269310cf | |||
43618a74e7 | |||
56fed637ec | |||
944ae4a604 | |||
f3da609102 | |||
11c8a8cf72 | |||
656978676f | |||
d216e6519f | |||
0a2bcae0e3 | |||
31cf244420 | |||
d761effac1 | |||
5efaf2c661 | |||
b79161fbec | |||
89a2f2134b | |||
343c3b19b1 | |||
e6ec646b2c | |||
b557fe4e40 | |||
45908dfbd2 | |||
1cf0673a22 | |||
189847f8a5 | |||
b27929f702 | |||
b1a6a9835d | |||
5564a16d4b | |||
b2aa447d50 | |||
8666f9a848 | |||
e7dc7c4917 | |||
513d0f1e5c | |||
c6ce7618cf | |||
9b78b8dc91 | |||
3bae233f40 | |||
0a305c4291 | |||
54fe4ddf3e | |||
aa877b981c | |||
b8ec4348a5 | |||
0e3e27668c | |||
e8c8025119 | |||
6a72eda5d2 | |||
35c8c54466 | |||
2eda45ca5b | |||
ecd6e7960c | |||
0d3a61cbdb | |||
6737c275d7 | |||
201a3e5838 | |||
85cb239219 | |||
002f45e383 | |||
470e5ba290 | |||
b6722b3a10 | |||
6a624916ca | |||
65653e0932 | |||
a7e59d6697 | |||
f1356105c1 | |||
8acc6379fb | |||
d13014c5d9 | |||
db11d8ba90 | |||
46a4ee2360 | |||
a19c053b88 | |||
5b47a32d31 | |||
32ae9efb6a | |||
b81bee551a | |||
8c9dffd082 | |||
cdfae643e4 | |||
cb93108206 | |||
c11343dc1c | |||
3733c6f89d | |||
3c23a0eac0 | |||
7fd69ab1f1 | |||
3f511774de | |||
3600100879 | |||
8fae372103 | |||
2fce1fe04c | |||
22f3e975c7 | |||
c6ced5a210 | |||
5d66e85205 | |||
649f163bf7 | |||
4c821bd930 | |||
6359de4e25 | |||
2c610c8cd4 | |||
7efe8a249b | |||
bd421d184e | |||
78f5844ba0 | |||
f6bb4d5051 | |||
56642c3e87 | |||
9fd8678d3d | |||
8b89518fd6 | |||
50085b40bb | |||
cff382715a | |||
54d54d1bf2 | |||
e84ea68282 | |||
160dd36782 | |||
65bb46bcca | |||
2d185fb766 | |||
2ba9b02932 | |||
849da67cc7 | |||
3ea6c9666e | |||
cf633e4ef2 | |||
bbf934d980 | |||
620f733110 | |||
67928609a3 | |||
5f15afb7db | |||
635d2f480d | |||
70c278c810 | |||
56b9906e2e | |||
a808ce81fd | |||
83f82c5ddf | |||
101de8c25d | |||
3339a4baf0 | |||
dff4a88baa | |||
a21f6c4964 | |||
97562504b7 | |||
75d8ac378c | |||
b9dd354e2b | |||
33c2fbd201 | |||
5063be92bf | |||
1047584b3e | |||
6764dcfdaa | |||
012864ceb1 | |||
a0bf20bcee | |||
14ab339b33 | |||
25c91efbb6 | |||
1c1f2c6664 | |||
d7c22b3bf7 | |||
185f2a395f | |||
0c5649491e | |||
94aba5892a | |||
ef093dde29 | |||
34451e5f27 | |||
1f9bdd1a9a | |||
c27d59baf7 | |||
f130ddec7c | |||
a0a259eef1 | |||
b66f19d4d1 | |||
4105a78b83 | |||
19a68afb3a | |||
fd68a2475b | |||
28ff7ba830 | |||
5d0b248fdb | |||
01a4e0f6ef | |||
91e0731506 | |||
d1f904d41f | |||
269388c9f4 | |||
b8486379ce | |||
400eb94d3b | |||
e210c96485 | |||
5f567f41f4 | |||
5fed573a29 | |||
cfac7c8189 | |||
1787de6836 | |||
ac96f187bd | |||
72398350b4 | |||
df9445c351 | |||
87b7a2e39b | |||
f7e46622a1 | |||
71f18353a9 | |||
4228de707b | |||
b6a05629ef | |||
fbaa820643 | |||
db2a2d5e38 | |||
8ba6e6b1f8 | |||
57168d719b | |||
dee6d2c98e | |||
e49105ece5 | |||
0c5e11f521 | |||
a63f842a13 | |||
4bd7fda694 | |||
81f0886d6f | |||
2eb87f3306 | |||
723f3ab0a9 | |||
1bd90e0fd4 | |||
436f18ff55 | |||
cde9696214 | |||
2d9042fb93 | |||
9ed53af520 | |||
56fda669fd | |||
1d8545a76c | |||
5f59a828f9 | |||
1fa6bddc89 | |||
d3a5ca5247 | |||
f01f56a98e | |||
99b0f79784 | |||
e1eb104345 | |||
5c2f95ef50 | |||
b63df9bab9 | |||
a52c899c6d | |||
eeabb7ebe5 | |||
8b1cef978c | |||
152da482cd | |||
3cf0365a35 | |||
5870742bb9 | |||
01d8c62c57 | |||
55a242b2d6 | |||
45263b339f | |||
3319491861 | |||
e687afac90 | |||
b39031ea53 | |||
0b77511271 | |||
c99cd989c1 | |||
317fdadb21 | |||
4e294f9e3e | |||
526e0f30a0 | |||
231e5ec94a | |||
e5bb6f9693 | |||
da7dee44c6 | |||
83144f4fe3 | |||
c451f52ea3 | |||
8a2c78f2e1 | |||
bcc78bde9b | |||
054bb6fe0a | |||
4f4aa6d92e | |||
eac51ac6f5 | |||
9f349a7c0a | |||
918afa5b15 | |||
eb1113f95c | |||
4f4ba7b462 | |||
2298be0e6b | |||
63494dfca7 | |||
36a1d39454 | |||
a6f6d5c400 | |||
e85f221aca | |||
d4797e37dc | |||
3e7923d072 | |||
a85d69ce3d | |||
96db006c99 | |||
8ca57d03d8 | |||
6c404ce5f8 |
2
.github/workflows/python-tests.yml
vendored
2
.github/workflows/python-tests.yml
vendored
@ -60,7 +60,7 @@ jobs:
|
|||||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||||
github-env: $GITHUB_ENV
|
github-env: $GITHUB_ENV
|
||||||
- platform: macos-default
|
- platform: macos-default
|
||||||
os: macOS-12
|
os: macOS-14
|
||||||
github-env: $GITHUB_ENV
|
github-env: $GITHUB_ENV
|
||||||
- platform: windows-cpu
|
- platform: windows-cpu
|
||||||
os: windows-2022
|
os: windows-2022
|
||||||
|
@ -1,20 +1,22 @@
|
|||||||
# Invoke in Docker
|
# Invoke in Docker
|
||||||
|
|
||||||
- Ensure that Docker can use the GPU on your system
|
First things first:
|
||||||
- This documentation assumes Linux, but should work similarly under Windows with WSL2
|
|
||||||
|
- Ensure that Docker can use your [NVIDIA][nvidia docker docs] or [AMD][amd docker docs] GPU.
|
||||||
|
- This document assumes a Linux system, but should work similarly under Windows with WSL2.
|
||||||
- We don't recommend running Invoke in Docker on macOS at this time. It works, but very slowly.
|
- We don't recommend running Invoke in Docker on macOS at this time. It works, but very slowly.
|
||||||
|
|
||||||
## Quickstart :lightning:
|
## Quickstart
|
||||||
|
|
||||||
No `docker compose`, no persistence, just a simple one-liner using the official images:
|
No `docker compose`, no persistence, single command, using the official images:
|
||||||
|
|
||||||
**CUDA:**
|
**CUDA (NVIDIA GPU):**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
|
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
|
||||||
```
|
```
|
||||||
|
|
||||||
**ROCm:**
|
**ROCm (AMD GPU):**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invoke-ai/invokeai:main-rocm
|
docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invoke-ai/invokeai:main-rocm
|
||||||
@ -22,12 +24,20 @@ docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invok
|
|||||||
|
|
||||||
Open `http://localhost:9090` in your browser once the container finishes booting, install some models, and generate away!
|
Open `http://localhost:9090` in your browser once the container finishes booting, install some models, and generate away!
|
||||||
|
|
||||||
> [!TIP]
|
### Data persistence
|
||||||
> To persist your data (including downloaded models) outside of the container, add a `--volume/-v` flag to the above command, e.g.: `docker run --volume /some/local/path:/invokeai <...the rest of the command>`
|
|
||||||
|
To persist your generated images and downloaded models outside of the container, add a `--volume/-v` flag to the above command, e.g.:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run --volume /some/local/path:/invokeai {...etc...}
|
||||||
|
```
|
||||||
|
|
||||||
|
`/some/local/path/invokeai` will contain all your data.
|
||||||
|
It can *usually* be reused between different installs of Invoke. Tread with caution and read the release notes!
|
||||||
|
|
||||||
## Customize the container
|
## Customize the container
|
||||||
|
|
||||||
We ship the `run.sh` script, which is a convenient wrapper around `docker compose` for cases where custom image build args are needed. Alternatively, the familiar `docker compose` commands work just as well.
|
The included `run.sh` script is a convenience wrapper around `docker compose`. It can be helpful for passing additional build arguments to `docker compose`. Alternatively, the familiar `docker compose` commands work just as well.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd docker
|
cd docker
|
||||||
@ -38,11 +48,14 @@ cp .env.sample .env
|
|||||||
|
|
||||||
It will take a few minutes to build the image the first time. Once the application starts up, open `http://localhost:9090` in your browser to invoke!
|
It will take a few minutes to build the image the first time. Once the application starts up, open `http://localhost:9090` in your browser to invoke!
|
||||||
|
|
||||||
|
>[!TIP]
|
||||||
|
>When using the `run.sh` script, the container will continue running after Ctrl+C. To shut it down, use the `docker compose down` command.
|
||||||
|
|
||||||
## Docker setup in detail
|
## Docker setup in detail
|
||||||
|
|
||||||
#### Linux
|
#### Linux
|
||||||
|
|
||||||
1. Ensure builkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
|
1. Ensure buildkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
|
||||||
2. Install the `docker compose` plugin using your package manager, or follow a [tutorial](https://docs.docker.com/compose/install/linux/#install-using-the-repository).
|
2. Install the `docker compose` plugin using your package manager, or follow a [tutorial](https://docs.docker.com/compose/install/linux/#install-using-the-repository).
|
||||||
- The deprecated `docker-compose` (hyphenated) CLI probably won't work. Update to a recent version.
|
- The deprecated `docker-compose` (hyphenated) CLI probably won't work. Update to a recent version.
|
||||||
3. Ensure docker daemon is able to access the GPU.
|
3. Ensure docker daemon is able to access the GPU.
|
||||||
@ -98,25 +111,7 @@ GPU_DRIVER=cuda
|
|||||||
|
|
||||||
Any environment variables supported by InvokeAI can be set here. See the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
|
Any environment variables supported by InvokeAI can be set here. See the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
|
||||||
|
|
||||||
## Even More Customizing!
|
---
|
||||||
|
|
||||||
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.
|
[nvidia docker docs]: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
|
||||||
|
[amd docker docs]: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html
|
||||||
### Reconfigure the runtime directory
|
|
||||||
|
|
||||||
Can be used to download additional models from the supported model list
|
|
||||||
|
|
||||||
In conjunction with `INVOKEAI_ROOT` can be also used to initialize a runtime directory
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
command:
|
|
||||||
- invokeai-configure
|
|
||||||
- --yes
|
|
||||||
```
|
|
||||||
|
|
||||||
Or install models:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
command:
|
|
||||||
- invokeai-model-install
|
|
||||||
```
|
|
||||||
|
@ -11,6 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
Batch,
|
Batch,
|
||||||
BatchStatus,
|
BatchStatus,
|
||||||
CancelByBatchIDsResult,
|
CancelByBatchIDsResult,
|
||||||
|
CancelByOriginResult,
|
||||||
ClearResult,
|
ClearResult,
|
||||||
EnqueueBatchResult,
|
EnqueueBatchResult,
|
||||||
PruneResult,
|
PruneResult,
|
||||||
@ -105,6 +106,19 @@ async def cancel_by_batch_ids(
|
|||||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
|
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@session_queue_router.put(
|
||||||
|
"/{queue_id}/cancel_by_origin",
|
||||||
|
operation_id="cancel_by_origin",
|
||||||
|
responses={200: {"model": CancelByBatchIDsResult}},
|
||||||
|
)
|
||||||
|
async def cancel_by_origin(
|
||||||
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
origin: str = Query(description="The origin to cancel all queue items for"),
|
||||||
|
) -> CancelByOriginResult:
|
||||||
|
"""Immediately cancels all queue items with the given origin"""
|
||||||
|
return ApiDependencies.invoker.services.session_queue.cancel_by_origin(queue_id=queue_id, origin=origin)
|
||||||
|
|
||||||
|
|
||||||
@session_queue_router.put(
|
@session_queue_router.put(
|
||||||
"/{queue_id}/clear",
|
"/{queue_id}/clear",
|
||||||
operation_id="clear",
|
operation_id="clear",
|
||||||
|
@ -26,13 +26,10 @@ from invokeai.app.services.style_preset_records.style_preset_records_common impo
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class StylePresetUpdateFormData(BaseModel):
|
class StylePresetFormData(BaseModel):
|
||||||
name: str = Field(description="Preset name")
|
name: str = Field(description="Preset name")
|
||||||
positive_prompt: str = Field(description="Positive prompt")
|
positive_prompt: str = Field(description="Positive prompt")
|
||||||
negative_prompt: str = Field(description="Negative prompt")
|
negative_prompt: str = Field(description="Negative prompt")
|
||||||
|
|
||||||
|
|
||||||
class StylePresetCreateFormData(StylePresetUpdateFormData):
|
|
||||||
type: PresetType = Field(description="Preset type")
|
type: PresetType = Field(description="Preset type")
|
||||||
|
|
||||||
|
|
||||||
@ -95,9 +92,10 @@ async def update_style_preset(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
parsed_data = json.loads(data)
|
parsed_data = json.loads(data)
|
||||||
validated_data = StylePresetUpdateFormData(**parsed_data)
|
validated_data = StylePresetFormData(**parsed_data)
|
||||||
|
|
||||||
name = validated_data.name
|
name = validated_data.name
|
||||||
|
type = validated_data.type
|
||||||
positive_prompt = validated_data.positive_prompt
|
positive_prompt = validated_data.positive_prompt
|
||||||
negative_prompt = validated_data.negative_prompt
|
negative_prompt = validated_data.negative_prompt
|
||||||
|
|
||||||
@ -105,7 +103,7 @@ async def update_style_preset(
|
|||||||
raise HTTPException(status_code=400, detail="Invalid preset data")
|
raise HTTPException(status_code=400, detail="Invalid preset data")
|
||||||
|
|
||||||
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
|
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
|
||||||
changes = StylePresetChanges(name=name, preset_data=preset_data)
|
changes = StylePresetChanges(name=name, preset_data=preset_data, type=type)
|
||||||
|
|
||||||
style_preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
|
style_preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
|
||||||
style_preset = ApiDependencies.invoker.services.style_preset_records.update(
|
style_preset = ApiDependencies.invoker.services.style_preset_records.update(
|
||||||
@ -145,7 +143,7 @@ async def create_style_preset(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
parsed_data = json.loads(data)
|
parsed_data = json.loads(data)
|
||||||
validated_data = StylePresetCreateFormData(**parsed_data)
|
validated_data = StylePresetFormData(**parsed_data)
|
||||||
|
|
||||||
name = validated_data.name
|
name = validated_data.name
|
||||||
type = validated_data.type
|
type = validated_data.type
|
||||||
|
@ -20,7 +20,6 @@ from typing import (
|
|||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import semver
|
import semver
|
||||||
@ -80,7 +79,7 @@ class UIConfigBase(BaseModel):
|
|||||||
version: str = Field(
|
version: str = Field(
|
||||||
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
||||||
)
|
)
|
||||||
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
|
node_pack: str = Field(description="The node pack that this node belongs to, will be 'invokeai' for built-in nodes")
|
||||||
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
|
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
@ -230,18 +229,16 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
||||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||||
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
if title := model_class.UIConfig.title:
|
||||||
if uiconfig is not None:
|
schema["title"] = title
|
||||||
if uiconfig.title is not None:
|
if tags := model_class.UIConfig.tags:
|
||||||
schema["title"] = uiconfig.title
|
schema["tags"] = tags
|
||||||
if uiconfig.tags is not None:
|
if category := model_class.UIConfig.category:
|
||||||
schema["tags"] = uiconfig.tags
|
schema["category"] = category
|
||||||
if uiconfig.category is not None:
|
if node_pack := model_class.UIConfig.node_pack:
|
||||||
schema["category"] = uiconfig.category
|
schema["node_pack"] = node_pack
|
||||||
if uiconfig.node_pack is not None:
|
schema["classification"] = model_class.UIConfig.classification
|
||||||
schema["node_pack"] = uiconfig.node_pack
|
schema["version"] = model_class.UIConfig.version
|
||||||
schema["classification"] = uiconfig.classification
|
|
||||||
schema["version"] = uiconfig.version
|
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = []
|
schema["required"] = []
|
||||||
schema["class"] = "invocation"
|
schema["class"] = "invocation"
|
||||||
@ -312,7 +309,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
||||||
)
|
)
|
||||||
|
|
||||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
UIConfig: ClassVar[UIConfigBase]
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
protected_namespaces=(),
|
protected_namespaces=(),
|
||||||
@ -441,30 +438,25 @@ def invocation(
|
|||||||
validate_fields(cls.model_fields, invocation_type)
|
validate_fields(cls.model_fields, invocation_type)
|
||||||
|
|
||||||
# Add OpenAPI schema extras
|
# Add OpenAPI schema extras
|
||||||
uiconfig_name = cls.__qualname__ + ".UIConfig"
|
uiconfig: dict[str, Any] = {}
|
||||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name:
|
uiconfig["title"] = title
|
||||||
cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {})
|
uiconfig["tags"] = tags
|
||||||
cls.UIConfig.title = title
|
uiconfig["category"] = category
|
||||||
cls.UIConfig.tags = tags
|
uiconfig["classification"] = classification
|
||||||
cls.UIConfig.category = category
|
# The node pack is the module name - will be "invokeai" for built-in nodes
|
||||||
cls.UIConfig.classification = classification
|
uiconfig["node_pack"] = cls.__module__.split(".")[0]
|
||||||
|
|
||||||
# Grab the node pack's name from the module name, if it's a custom node
|
|
||||||
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"
|
|
||||||
if is_custom_node:
|
|
||||||
cls.UIConfig.node_pack = cls.__module__.split(".")[0]
|
|
||||||
else:
|
|
||||||
cls.UIConfig.node_pack = None
|
|
||||||
|
|
||||||
if version is not None:
|
if version is not None:
|
||||||
try:
|
try:
|
||||||
semver.Version.parse(version)
|
semver.Version.parse(version)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||||
cls.UIConfig.version = version
|
uiconfig["version"] = version
|
||||||
else:
|
else:
|
||||||
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
|
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
|
||||||
cls.UIConfig.version = "1.0.0"
|
uiconfig["version"] = "1.0.0"
|
||||||
|
|
||||||
|
cls.UIConfig = UIConfigBase(**uiconfig)
|
||||||
|
|
||||||
if use_cache is not None:
|
if use_cache is not None:
|
||||||
cls.model_fields["use_cache"].default = use_cache
|
cls.model_fields["use_cache"].default = use_cache
|
||||||
|
@ -40,6 +40,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
|||||||
|
|
||||||
# region Model Field Types
|
# region Model Field Types
|
||||||
MainModel = "MainModelField"
|
MainModel = "MainModelField"
|
||||||
|
FluxMainModel = "FluxMainModelField"
|
||||||
SDXLMainModel = "SDXLMainModelField"
|
SDXLMainModel = "SDXLMainModelField"
|
||||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||||
ONNXModel = "ONNXModelField"
|
ONNXModel = "ONNXModelField"
|
||||||
@ -48,6 +49,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
|||||||
ControlNetModel = "ControlNetModelField"
|
ControlNetModel = "ControlNetModelField"
|
||||||
IPAdapterModel = "IPAdapterModelField"
|
IPAdapterModel = "IPAdapterModelField"
|
||||||
T2IAdapterModel = "T2IAdapterModelField"
|
T2IAdapterModel = "T2IAdapterModelField"
|
||||||
|
T5EncoderModel = "T5EncoderModelField"
|
||||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
@ -125,13 +127,16 @@ class FieldDescriptions:
|
|||||||
negative_cond = "Negative conditioning tensor"
|
negative_cond = "Negative conditioning tensor"
|
||||||
noise = "Noise tensor"
|
noise = "Noise tensor"
|
||||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||||
|
t5_encoder = "T5 tokenizer and text encoder"
|
||||||
unet = "UNet (scheduler, LoRAs)"
|
unet = "UNet (scheduler, LoRAs)"
|
||||||
|
transformer = "Transformer"
|
||||||
vae = "VAE"
|
vae = "VAE"
|
||||||
cond = "Conditioning tensor"
|
cond = "Conditioning tensor"
|
||||||
controlnet_model = "ControlNet model to load"
|
controlnet_model = "ControlNet model to load"
|
||||||
vae_model = "VAE model to load"
|
vae_model = "VAE model to load"
|
||||||
lora_model = "LoRA model to load"
|
lora_model = "LoRA model to load"
|
||||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||||
|
flux_model = "Flux model (Transformer) to load"
|
||||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||||
@ -231,6 +236,12 @@ class ColorField(BaseModel):
|
|||||||
return (self.r, self.g, self.b, self.a)
|
return (self.r, self.g, self.b, self.a)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxConditioningField(BaseModel):
|
||||||
|
"""A conditioning tensor primitive value"""
|
||||||
|
|
||||||
|
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||||
|
|
||||||
|
|
||||||
class ConditioningField(BaseModel):
|
class ConditioningField(BaseModel):
|
||||||
"""A conditioning tensor primitive value"""
|
"""A conditioning tensor primitive value"""
|
||||||
|
|
||||||
|
86
invokeai/app/invocations/flux_text_encoder.py
Normal file
86
invokeai/app/invocations/flux_text_encoder.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||||
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||||
|
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||||
|
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"flux_text_encoder",
|
||||||
|
title="FLUX Text Encoding",
|
||||||
|
tags=["prompt", "conditioning", "flux"],
|
||||||
|
category="conditioning",
|
||||||
|
version="1.0.0",
|
||||||
|
classification=Classification.Prototype,
|
||||||
|
)
|
||||||
|
class FluxTextEncoderInvocation(BaseInvocation):
|
||||||
|
"""Encodes and preps a prompt for a flux image."""
|
||||||
|
|
||||||
|
clip: CLIPField = InputField(
|
||||||
|
title="CLIP",
|
||||||
|
description=FieldDescriptions.clip,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
t5_encoder: T5EncoderField = InputField(
|
||||||
|
title="T5Encoder",
|
||||||
|
description=FieldDescriptions.t5_encoder,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
t5_max_seq_len: Literal[256, 512] = InputField(
|
||||||
|
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
|
||||||
|
)
|
||||||
|
prompt: str = InputField(description="Text prompt to encode.")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||||
|
t5_embeddings, clip_embeddings = self._encode_prompt(context)
|
||||||
|
conditioning_data = ConditioningFieldData(
|
||||||
|
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||||
|
)
|
||||||
|
|
||||||
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
|
return FluxConditioningOutput.build(conditioning_name)
|
||||||
|
|
||||||
|
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Load CLIP.
|
||||||
|
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||||
|
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||||
|
|
||||||
|
# Load T5.
|
||||||
|
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||||
|
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
|
||||||
|
|
||||||
|
prompt = [self.prompt]
|
||||||
|
|
||||||
|
with (
|
||||||
|
t5_text_encoder_info as t5_text_encoder,
|
||||||
|
t5_tokenizer_info as t5_tokenizer,
|
||||||
|
):
|
||||||
|
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||||
|
assert isinstance(t5_tokenizer, T5Tokenizer)
|
||||||
|
|
||||||
|
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||||
|
|
||||||
|
prompt_embeds = t5_encoder(prompt)
|
||||||
|
|
||||||
|
with (
|
||||||
|
clip_text_encoder_info as clip_text_encoder,
|
||||||
|
clip_tokenizer_info as clip_tokenizer,
|
||||||
|
):
|
||||||
|
assert isinstance(clip_text_encoder, CLIPTextModel)
|
||||||
|
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||||
|
|
||||||
|
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
||||||
|
|
||||||
|
pooled_prompt_embeds = clip_encoder(prompt)
|
||||||
|
|
||||||
|
assert isinstance(prompt_embeds, torch.Tensor)
|
||||||
|
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||||
|
return prompt_embeds, pooled_prompt_embeds
|
172
invokeai/app/invocations/flux_text_to_image.py
Normal file
172
invokeai/app/invocations/flux_text_to_image.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
FieldDescriptions,
|
||||||
|
FluxConditioningField,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
WithBoard,
|
||||||
|
WithMetadata,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||||
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
|
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.flux.model import Flux
|
||||||
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||||
|
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"flux_text_to_image",
|
||||||
|
title="FLUX Text to Image",
|
||||||
|
tags=["image", "flux"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
classification=Classification.Prototype,
|
||||||
|
)
|
||||||
|
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
|
"""Text-to-image generation using a FLUX model."""
|
||||||
|
|
||||||
|
transformer: TransformerField = InputField(
|
||||||
|
description=FieldDescriptions.flux_model,
|
||||||
|
input=Input.Connection,
|
||||||
|
title="Transformer",
|
||||||
|
)
|
||||||
|
vae: VAEField = InputField(
|
||||||
|
description=FieldDescriptions.vae,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
positive_text_conditioning: FluxConditioningField = InputField(
|
||||||
|
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||||
|
)
|
||||||
|
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||||
|
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||||
|
num_steps: int = InputField(
|
||||||
|
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50."
|
||||||
|
)
|
||||||
|
guidance: float = InputField(
|
||||||
|
default=4.0,
|
||||||
|
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||||
|
)
|
||||||
|
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
# Load the conditioning data.
|
||||||
|
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||||
|
assert len(cond_data.conditionings) == 1
|
||||||
|
flux_conditioning = cond_data.conditionings[0]
|
||||||
|
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||||
|
|
||||||
|
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
|
||||||
|
image = self._run_vae_decoding(context, latents)
|
||||||
|
image_dto = context.images.save(image=image)
|
||||||
|
return ImageOutput.build(image_dto)
|
||||||
|
|
||||||
|
def _run_diffusion(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
clip_embeddings: torch.Tensor,
|
||||||
|
t5_embeddings: torch.Tensor,
|
||||||
|
):
|
||||||
|
transformer_info = context.models.load(self.transformer.transformer)
|
||||||
|
inference_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
# Prepare input noise.
|
||||||
|
x = get_noise(
|
||||||
|
num_samples=1,
|
||||||
|
height=self.height,
|
||||||
|
width=self.width,
|
||||||
|
device=TorchDevice.choose_torch_device(),
|
||||||
|
dtype=inference_dtype,
|
||||||
|
seed=self.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
img, img_ids = prepare_latent_img_patches(x)
|
||||||
|
|
||||||
|
is_schnell = "schnell" in transformer_info.config.config_path
|
||||||
|
|
||||||
|
timesteps = get_schedule(
|
||||||
|
num_steps=self.num_steps,
|
||||||
|
image_seq_len=img.shape[1],
|
||||||
|
shift=not is_schnell,
|
||||||
|
)
|
||||||
|
|
||||||
|
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||||
|
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||||
|
|
||||||
|
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
||||||
|
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
||||||
|
# if the cache is not empty.
|
||||||
|
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||||
|
|
||||||
|
with transformer_info as transformer:
|
||||||
|
assert isinstance(transformer, Flux)
|
||||||
|
|
||||||
|
def step_callback() -> None:
|
||||||
|
if context.util.is_canceled():
|
||||||
|
raise CanceledException
|
||||||
|
|
||||||
|
# TODO: Make this look like the image before re-enabling
|
||||||
|
# latent_image = unpack(img.float(), self.height, self.width)
|
||||||
|
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
|
||||||
|
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
|
||||||
|
|
||||||
|
# # Create a new tensor of the required shape [255, 255, 3]
|
||||||
|
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
|
||||||
|
|
||||||
|
# # Convert to a NumPy array and then to a PIL Image
|
||||||
|
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
|
||||||
|
|
||||||
|
# (width, height) = image.size
|
||||||
|
# width *= 8
|
||||||
|
# height *= 8
|
||||||
|
|
||||||
|
# dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||||
|
|
||||||
|
# # TODO: move this whole function to invocation context to properly reference these variables
|
||||||
|
# context._services.events.emit_invocation_denoise_progress(
|
||||||
|
# context._data.queue_item,
|
||||||
|
# context._data.invocation,
|
||||||
|
# state,
|
||||||
|
# ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||||
|
# )
|
||||||
|
|
||||||
|
x = denoise(
|
||||||
|
model=transformer,
|
||||||
|
img=img,
|
||||||
|
img_ids=img_ids,
|
||||||
|
txt=t5_embeddings,
|
||||||
|
txt_ids=txt_ids,
|
||||||
|
vec=clip_embeddings,
|
||||||
|
timesteps=timesteps,
|
||||||
|
step_callback=step_callback,
|
||||||
|
guidance=self.guidance,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = unpack(x.float(), self.height, self.width)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _run_vae_decoding(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
) -> Image.Image:
|
||||||
|
vae_info = context.models.load(self.vae.vae)
|
||||||
|
with vae_info as vae:
|
||||||
|
assert isinstance(vae, AutoEncoder)
|
||||||
|
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
|
||||||
|
img = vae.decode(latents)
|
||||||
|
|
||||||
|
img = img.clamp(-1, 1)
|
||||||
|
img = rearrange(img[0], "c h w -> h w c")
|
||||||
|
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||||
|
|
||||||
|
return img_pil
|
@ -6,13 +6,19 @@ import cv2
|
|||||||
import numpy
|
import numpy
|
||||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
Classification,
|
||||||
|
invocation,
|
||||||
|
invocation_output,
|
||||||
|
)
|
||||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
ColorField,
|
ColorField,
|
||||||
FieldDescriptions,
|
FieldDescriptions,
|
||||||
ImageField,
|
ImageField,
|
||||||
InputField,
|
InputField,
|
||||||
|
OutputField,
|
||||||
WithBoard,
|
WithBoard,
|
||||||
WithMetadata,
|
WithMetadata,
|
||||||
)
|
)
|
||||||
@ -1007,3 +1013,62 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
|
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
|
||||||
|
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("canvas_v2_mask_and_crop_output")
|
||||||
|
class CanvasV2MaskAndCropOutput(ImageOutput):
|
||||||
|
offset_x: int = OutputField(description="The x offset of the image, after cropping")
|
||||||
|
offset_y: int = OutputField(description="The y offset of the image, after cropping")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"canvas_v2_mask_and_crop",
|
||||||
|
title="Canvas V2 Mask and Crop",
|
||||||
|
tags=["image", "mask", "id"],
|
||||||
|
category="image",
|
||||||
|
version="1.0.0",
|
||||||
|
classification=Classification.Prototype,
|
||||||
|
)
|
||||||
|
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
|
"""Handles Canvas V2 image output masking and cropping"""
|
||||||
|
|
||||||
|
source_image: ImageField | None = InputField(
|
||||||
|
default=None,
|
||||||
|
description="The source image onto which the masked generated image is pasted. If omitted, the masked generated image is returned with transparency.",
|
||||||
|
)
|
||||||
|
generated_image: ImageField = InputField(description="The image to apply the mask to")
|
||||||
|
mask: ImageField = InputField(description="The mask to apply")
|
||||||
|
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")
|
||||||
|
|
||||||
|
def _prepare_mask(self, mask: Image.Image) -> Image.Image:
|
||||||
|
mask_array = numpy.array(mask)
|
||||||
|
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
|
||||||
|
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
|
||||||
|
dilated_mask = Image.fromarray(dilated_mask_array)
|
||||||
|
if self.mask_blur > 0:
|
||||||
|
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
||||||
|
return ImageOps.invert(mask.convert("L"))
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> CanvasV2MaskAndCropOutput:
|
||||||
|
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
|
||||||
|
|
||||||
|
if self.source_image:
|
||||||
|
generated_image = context.images.get_pil(self.generated_image.image_name)
|
||||||
|
source_image = context.images.get_pil(self.source_image.image_name)
|
||||||
|
source_image.paste(generated_image, (0, 0), mask)
|
||||||
|
image_dto = context.images.save(image=source_image)
|
||||||
|
else:
|
||||||
|
generated_image = context.images.get_pil(self.generated_image.image_name)
|
||||||
|
generated_image.putalpha(mask)
|
||||||
|
image_dto = context.images.save(image=generated_image)
|
||||||
|
|
||||||
|
# bbox = image.getbbox()
|
||||||
|
# image = image.crop(bbox)
|
||||||
|
|
||||||
|
return CanvasV2MaskAndCropOutput(
|
||||||
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
|
offset_x=0,
|
||||||
|
offset_y=0,
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
from typing import List, Optional
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -13,7 +13,14 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
from invokeai.backend.flux.util import max_seq_lengths
|
||||||
|
from invokeai.backend.model_manager.config import (
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
CheckpointConfigBase,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelIdentifierField(BaseModel):
|
class ModelIdentifierField(BaseModel):
|
||||||
@ -60,6 +67,15 @@ class CLIPField(BaseModel):
|
|||||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerField(BaseModel):
|
||||||
|
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
|
||||||
|
|
||||||
|
|
||||||
|
class T5EncoderField(BaseModel):
|
||||||
|
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||||
|
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||||
|
|
||||||
|
|
||||||
class VAEField(BaseModel):
|
class VAEField(BaseModel):
|
||||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||||
@ -122,6 +138,112 @@ class ModelIdentifierInvocation(BaseInvocation):
|
|||||||
return ModelIdentifierOutput(model=self.model)
|
return ModelIdentifierOutput(model=self.model)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("flux_model_loader_output")
|
||||||
|
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||||
|
"""Flux base model loader output"""
|
||||||
|
|
||||||
|
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||||
|
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||||
|
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
max_seq_len: Literal[256, 512] = OutputField(
|
||||||
|
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
|
||||||
|
title="Max Seq Length",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"flux_model_loader",
|
||||||
|
title="Flux Main Model",
|
||||||
|
tags=["model", "flux"],
|
||||||
|
category="model",
|
||||||
|
version="1.0.3",
|
||||||
|
classification=Classification.Prototype,
|
||||||
|
)
|
||||||
|
class FluxModelLoaderInvocation(BaseInvocation):
|
||||||
|
"""Loads a flux base model, outputting its submodels."""
|
||||||
|
|
||||||
|
model: ModelIdentifierField = InputField(
|
||||||
|
description=FieldDescriptions.flux_model,
|
||||||
|
ui_type=UIType.FluxMainModel,
|
||||||
|
input=Input.Direct,
|
||||||
|
)
|
||||||
|
|
||||||
|
t5_encoder: ModelIdentifierField = InputField(
|
||||||
|
description=FieldDescriptions.t5_encoder,
|
||||||
|
ui_type=UIType.T5EncoderModel,
|
||||||
|
input=Input.Direct,
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||||
|
model_key = self.model.key
|
||||||
|
|
||||||
|
if not context.models.exists(model_key):
|
||||||
|
raise ValueError(f"Unknown model: {model_key}")
|
||||||
|
transformer = self._get_model(context, SubModelType.Transformer)
|
||||||
|
tokenizer = self._get_model(context, SubModelType.Tokenizer)
|
||||||
|
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
|
||||||
|
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
|
||||||
|
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
|
||||||
|
vae = self._get_model(context, SubModelType.VAE)
|
||||||
|
transformer_config = context.models.get_config(transformer)
|
||||||
|
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||||
|
|
||||||
|
return FluxModelLoaderOutput(
|
||||||
|
transformer=TransformerField(transformer=transformer),
|
||||||
|
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||||
|
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||||
|
vae=VAEField(vae=vae),
|
||||||
|
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
|
||||||
|
match submodel:
|
||||||
|
case SubModelType.Transformer:
|
||||||
|
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||||
|
case SubModelType.VAE:
|
||||||
|
return self._pull_model_from_mm(
|
||||||
|
context,
|
||||||
|
SubModelType.VAE,
|
||||||
|
"FLUX.1-schnell_ae",
|
||||||
|
ModelType.VAE,
|
||||||
|
BaseModelType.Flux,
|
||||||
|
)
|
||||||
|
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
|
||||||
|
return self._pull_model_from_mm(
|
||||||
|
context,
|
||||||
|
submodel,
|
||||||
|
"clip-vit-large-patch14",
|
||||||
|
ModelType.CLIPEmbed,
|
||||||
|
BaseModelType.Any,
|
||||||
|
)
|
||||||
|
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
|
||||||
|
return self._pull_model_from_mm(
|
||||||
|
context,
|
||||||
|
submodel,
|
||||||
|
self.t5_encoder.name,
|
||||||
|
ModelType.T5Encoder,
|
||||||
|
BaseModelType.Any,
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
|
||||||
|
|
||||||
|
def _pull_model_from_mm(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
submodel: SubModelType,
|
||||||
|
name: str,
|
||||||
|
type: ModelType,
|
||||||
|
base: BaseModelType,
|
||||||
|
):
|
||||||
|
if models := context.models.search_by_attrs(name=name, base=base, type=type):
|
||||||
|
if len(models) != 1:
|
||||||
|
raise Exception(f"Multiple models detected for selected model with name {name}")
|
||||||
|
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Please install the {base}:{type} model named {name} via starter models")
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"main_model_loader",
|
"main_model_loader",
|
||||||
title="Main Model",
|
title="Main Model",
|
||||||
|
@ -12,6 +12,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
ConditioningField,
|
ConditioningField,
|
||||||
DenoiseMaskField,
|
DenoiseMaskField,
|
||||||
FieldDescriptions,
|
FieldDescriptions,
|
||||||
|
FluxConditioningField,
|
||||||
ImageField,
|
ImageField,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
@ -414,6 +415,17 @@ class MaskOutput(BaseInvocationOutput):
|
|||||||
height: int = OutputField(description="The height of the mask in pixels.")
|
height: int = OutputField(description="The height of the mask in pixels.")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("flux_conditioning_output")
|
||||||
|
class FluxConditioningOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a single conditioning tensor"""
|
||||||
|
|
||||||
|
conditioning: FluxConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls, conditioning_name: str) -> "FluxConditioningOutput":
|
||||||
|
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("conditioning_output")
|
@invocation_output("conditioning_output")
|
||||||
class ConditioningOutput(BaseInvocationOutput):
|
class ConditioningOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single conditioning tensor"""
|
"""Base class for nodes that output a single conditioning tensor"""
|
||||||
|
@ -88,6 +88,7 @@ class QueueItemEventBase(QueueEventBase):
|
|||||||
|
|
||||||
item_id: int = Field(description="The ID of the queue item")
|
item_id: int = Field(description="The ID of the queue item")
|
||||||
batch_id: str = Field(description="The ID of the queue batch")
|
batch_id: str = Field(description="The ID of the queue batch")
|
||||||
|
origin: str | None = Field(default=None, description="The origin of the batch")
|
||||||
|
|
||||||
|
|
||||||
class InvocationEventBase(QueueItemEventBase):
|
class InvocationEventBase(QueueItemEventBase):
|
||||||
@ -95,8 +96,6 @@ class InvocationEventBase(QueueItemEventBase):
|
|||||||
|
|
||||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||||
queue_id: str = Field(description="The ID of the queue")
|
queue_id: str = Field(description="The ID of the queue")
|
||||||
item_id: int = Field(description="The ID of the queue item")
|
|
||||||
batch_id: str = Field(description="The ID of the queue batch")
|
|
||||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||||
invocation: AnyInvocation = Field(description="The ID of the invocation")
|
invocation: AnyInvocation = Field(description="The ID of the invocation")
|
||||||
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
||||||
@ -114,6 +113,7 @@ class InvocationStartedEvent(InvocationEventBase):
|
|||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
batch_id=queue_item.batch_id,
|
batch_id=queue_item.batch_id,
|
||||||
|
origin=queue_item.origin,
|
||||||
session_id=queue_item.session_id,
|
session_id=queue_item.session_id,
|
||||||
invocation=invocation,
|
invocation=invocation,
|
||||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
@ -147,6 +147,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
|
|||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
batch_id=queue_item.batch_id,
|
batch_id=queue_item.batch_id,
|
||||||
|
origin=queue_item.origin,
|
||||||
session_id=queue_item.session_id,
|
session_id=queue_item.session_id,
|
||||||
invocation=invocation,
|
invocation=invocation,
|
||||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
@ -184,6 +185,7 @@ class InvocationCompleteEvent(InvocationEventBase):
|
|||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
batch_id=queue_item.batch_id,
|
batch_id=queue_item.batch_id,
|
||||||
|
origin=queue_item.origin,
|
||||||
session_id=queue_item.session_id,
|
session_id=queue_item.session_id,
|
||||||
invocation=invocation,
|
invocation=invocation,
|
||||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
@ -216,6 +218,7 @@ class InvocationErrorEvent(InvocationEventBase):
|
|||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
batch_id=queue_item.batch_id,
|
batch_id=queue_item.batch_id,
|
||||||
|
origin=queue_item.origin,
|
||||||
session_id=queue_item.session_id,
|
session_id=queue_item.session_id,
|
||||||
invocation=invocation,
|
invocation=invocation,
|
||||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
@ -253,6 +256,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
|
|||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
batch_id=queue_item.batch_id,
|
batch_id=queue_item.batch_id,
|
||||||
|
origin=queue_item.origin,
|
||||||
session_id=queue_item.session_id,
|
session_id=queue_item.session_id,
|
||||||
status=queue_item.status,
|
status=queue_item.status,
|
||||||
error_type=queue_item.error_type,
|
error_type=queue_item.error_type,
|
||||||
@ -279,12 +283,14 @@ class BatchEnqueuedEvent(QueueEventBase):
|
|||||||
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
|
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
|
||||||
)
|
)
|
||||||
priority: int = Field(description="The priority of the batch")
|
priority: int = Field(description="The priority of the batch")
|
||||||
|
origin: str | None = Field(default=None, description="The origin of the batch")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
|
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
|
||||||
return cls(
|
return cls(
|
||||||
queue_id=enqueue_result.queue_id,
|
queue_id=enqueue_result.queue_id,
|
||||||
batch_id=enqueue_result.batch.batch_id,
|
batch_id=enqueue_result.batch.batch_id,
|
||||||
|
origin=enqueue_result.batch.origin,
|
||||||
enqueued=enqueue_result.enqueued,
|
enqueued=enqueue_result.enqueued,
|
||||||
requested=enqueue_result.requested,
|
requested=enqueue_result.requested,
|
||||||
priority=enqueue_result.priority,
|
priority=enqueue_result.priority,
|
||||||
|
@ -783,8 +783,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
||||||
if subfolder:
|
if subfolder:
|
||||||
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
||||||
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
|
path_to_remove = top / subfolder # sdxl-turbo/vae/
|
||||||
path_to_add = Path(f"{top}_{subfolder}")
|
subfolder_rename = subfolder.name.replace("/", "_").replace("\\", "_")
|
||||||
|
path_to_add = Path(f"{top}_{subfolder_rename}")
|
||||||
else:
|
else:
|
||||||
path_to_remove = Path(".")
|
path_to_remove = Path(".")
|
||||||
path_to_add = Path(".")
|
path_to_add = Path(".")
|
||||||
|
@ -77,6 +77,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
|||||||
type: Optional[ModelType] = Field(description="Type of model", default=None)
|
type: Optional[ModelType] = Field(description="Type of model", default=None)
|
||||||
key: Optional[str] = Field(description="Database ID for this model", default=None)
|
key: Optional[str] = Field(description="Database ID for this model", default=None)
|
||||||
hash: Optional[str] = Field(description="hash of model file", default=None)
|
hash: Optional[str] = Field(description="hash of model file", default=None)
|
||||||
|
format: Optional[str] = Field(description="format of model file", default=None)
|
||||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||||
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
|
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
|
||||||
description="Default settings for this model", default=None
|
description="Default settings for this model", default=None
|
||||||
|
@ -6,6 +6,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
Batch,
|
Batch,
|
||||||
BatchStatus,
|
BatchStatus,
|
||||||
CancelByBatchIDsResult,
|
CancelByBatchIDsResult,
|
||||||
|
CancelByOriginResult,
|
||||||
CancelByQueueIDResult,
|
CancelByQueueIDResult,
|
||||||
ClearResult,
|
ClearResult,
|
||||||
EnqueueBatchResult,
|
EnqueueBatchResult,
|
||||||
@ -95,6 +96,11 @@ class SessionQueueBase(ABC):
|
|||||||
"""Cancels all queue items with matching batch IDs"""
|
"""Cancels all queue items with matching batch IDs"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
|
||||||
|
"""Cancels all queue items with the given batch origin"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||||
"""Cancels all queue items with matching queue ID"""
|
"""Cancels all queue items with matching queue ID"""
|
||||||
|
@ -77,6 +77,7 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]]
|
|||||||
|
|
||||||
class Batch(BaseModel):
|
class Batch(BaseModel):
|
||||||
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
||||||
|
origin: str | None = Field(default=None, description="The origin of this batch.")
|
||||||
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
||||||
graph: Graph = Field(description="The graph to initialize the session with")
|
graph: Graph = Field(description="The graph to initialize the session with")
|
||||||
workflow: Optional[WorkflowWithoutID] = Field(
|
workflow: Optional[WorkflowWithoutID] = Field(
|
||||||
@ -195,6 +196,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
|||||||
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
||||||
priority: int = Field(default=0, description="The priority of this queue item")
|
priority: int = Field(default=0, description="The priority of this queue item")
|
||||||
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
||||||
|
origin: str | None = Field(default=None, description="The origin of this queue item. ")
|
||||||
session_id: str = Field(
|
session_id: str = Field(
|
||||||
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
||||||
)
|
)
|
||||||
@ -294,6 +296,7 @@ class SessionQueueStatus(BaseModel):
|
|||||||
class BatchStatus(BaseModel):
|
class BatchStatus(BaseModel):
|
||||||
queue_id: str = Field(..., description="The ID of the queue")
|
queue_id: str = Field(..., description="The ID of the queue")
|
||||||
batch_id: str = Field(..., description="The ID of the batch")
|
batch_id: str = Field(..., description="The ID of the batch")
|
||||||
|
origin: str | None = Field(..., description="The origin of the batch")
|
||||||
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
||||||
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
||||||
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
||||||
@ -328,6 +331,12 @@ class CancelByBatchIDsResult(BaseModel):
|
|||||||
canceled: int = Field(..., description="Number of queue items canceled")
|
canceled: int = Field(..., description="Number of queue items canceled")
|
||||||
|
|
||||||
|
|
||||||
|
class CancelByOriginResult(BaseModel):
|
||||||
|
"""Result of canceling by list of batch ids"""
|
||||||
|
|
||||||
|
canceled: int = Field(..., description="Number of queue items canceled")
|
||||||
|
|
||||||
|
|
||||||
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
||||||
"""Result of canceling by queue id"""
|
"""Result of canceling by queue id"""
|
||||||
|
|
||||||
@ -433,6 +442,7 @@ class SessionQueueValueToInsert(NamedTuple):
|
|||||||
field_values: Optional[str] # field_values json
|
field_values: Optional[str] # field_values json
|
||||||
priority: int # priority
|
priority: int # priority
|
||||||
workflow: Optional[str] # workflow json
|
workflow: Optional[str] # workflow json
|
||||||
|
origin: str | None
|
||||||
|
|
||||||
|
|
||||||
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||||
@ -453,6 +463,7 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
|
|||||||
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
|
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
|
||||||
priority, # priority
|
priority, # priority
|
||||||
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
|
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
|
||||||
|
batch.origin, # origin
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return values_to_insert
|
return values_to_insert
|
||||||
|
@ -10,6 +10,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
Batch,
|
Batch,
|
||||||
BatchStatus,
|
BatchStatus,
|
||||||
CancelByBatchIDsResult,
|
CancelByBatchIDsResult,
|
||||||
|
CancelByOriginResult,
|
||||||
CancelByQueueIDResult,
|
CancelByQueueIDResult,
|
||||||
ClearResult,
|
ClearResult,
|
||||||
EnqueueBatchResult,
|
EnqueueBatchResult,
|
||||||
@ -127,8 +128,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
|
|
||||||
self.__cursor.executemany(
|
self.__cursor.executemany(
|
||||||
"""--sql
|
"""--sql
|
||||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow)
|
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""",
|
""",
|
||||||
values_to_insert,
|
values_to_insert,
|
||||||
)
|
)
|
||||||
@ -417,11 +418,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
)
|
)
|
||||||
self.__conn.commit()
|
self.__conn.commit()
|
||||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
|
||||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
|
||||||
current_queue_item, batch_status, queue_status
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
self.__conn.rollback()
|
self.__conn.rollback()
|
||||||
raise
|
raise
|
||||||
@ -429,6 +426,46 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
self.__lock.release()
|
self.__lock.release()
|
||||||
return CancelByBatchIDsResult(canceled=count)
|
return CancelByBatchIDsResult(canceled=count)
|
||||||
|
|
||||||
|
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
|
||||||
|
try:
|
||||||
|
current_queue_item = self.get_current(queue_id)
|
||||||
|
self.__lock.acquire()
|
||||||
|
where = """--sql
|
||||||
|
WHERE
|
||||||
|
queue_id == ?
|
||||||
|
AND origin == ?
|
||||||
|
AND status != 'canceled'
|
||||||
|
AND status != 'completed'
|
||||||
|
AND status != 'failed'
|
||||||
|
"""
|
||||||
|
params = (queue_id, origin)
|
||||||
|
self.__cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM session_queue
|
||||||
|
{where};
|
||||||
|
""",
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
count = self.__cursor.fetchone()[0]
|
||||||
|
self.__cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
UPDATE session_queue
|
||||||
|
SET status = 'canceled'
|
||||||
|
{where};
|
||||||
|
""",
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
self.__conn.commit()
|
||||||
|
if current_queue_item is not None and current_queue_item.origin == origin:
|
||||||
|
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
return CancelByOriginResult(canceled=count)
|
||||||
|
|
||||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||||
try:
|
try:
|
||||||
current_queue_item = self.get_current(queue_id)
|
current_queue_item = self.get_current(queue_id)
|
||||||
@ -541,7 +578,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
started_at,
|
started_at,
|
||||||
session_id,
|
session_id,
|
||||||
batch_id,
|
batch_id,
|
||||||
queue_id
|
queue_id,
|
||||||
|
origin
|
||||||
FROM session_queue
|
FROM session_queue
|
||||||
WHERE queue_id = ?
|
WHERE queue_id = ?
|
||||||
"""
|
"""
|
||||||
@ -621,7 +659,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
self.__lock.acquire()
|
self.__lock.acquire()
|
||||||
self.__cursor.execute(
|
self.__cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT status, count(*)
|
SELECT status, count(*), origin
|
||||||
FROM session_queue
|
FROM session_queue
|
||||||
WHERE
|
WHERE
|
||||||
queue_id = ?
|
queue_id = ?
|
||||||
@ -633,6 +671,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||||
total = sum(row[1] for row in result)
|
total = sum(row[1] for row in result)
|
||||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||||
|
origin = result[0]["origin"] if result else None
|
||||||
except Exception:
|
except Exception:
|
||||||
self.__conn.rollback()
|
self.__conn.rollback()
|
||||||
raise
|
raise
|
||||||
@ -641,6 +680,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
|
|
||||||
return BatchStatus(
|
return BatchStatus(
|
||||||
batch_id=batch_id,
|
batch_id=batch_id,
|
||||||
|
origin=origin,
|
||||||
queue_id=queue_id,
|
queue_id=queue_id,
|
||||||
pending=counts.get("pending", 0),
|
pending=counts.get("pending", 0),
|
||||||
in_progress=counts.get("in_progress", 0),
|
in_progress=counts.get("in_progress", 0),
|
||||||
|
@ -17,6 +17,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import
|
|||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
|
||||||
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||||
|
|
||||||
|
|
||||||
@ -51,6 +52,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
|||||||
migrator.register_migration(build_migration_12(app_config=config))
|
migrator.register_migration(build_migration_12(app_config=config))
|
||||||
migrator.register_migration(build_migration_13())
|
migrator.register_migration(build_migration_13())
|
||||||
migrator.register_migration(build_migration_14())
|
migrator.register_migration(build_migration_14())
|
||||||
|
migrator.register_migration(build_migration_15())
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
|
|
||||||
return db
|
return db
|
||||||
|
@ -0,0 +1,31 @@
|
|||||||
|
import sqlite3
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||||
|
|
||||||
|
|
||||||
|
class Migration15Callback:
|
||||||
|
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
self._add_origin_col(cursor)
|
||||||
|
|
||||||
|
def _add_origin_col(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""
|
||||||
|
- Adds `origin` column to the session queue table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;")
|
||||||
|
|
||||||
|
|
||||||
|
def build_migration_15() -> Migration:
|
||||||
|
"""
|
||||||
|
Build the migration from database version 14 to 15.
|
||||||
|
|
||||||
|
This migration does the following:
|
||||||
|
- Adds `origin` column to the session queue table.
|
||||||
|
"""
|
||||||
|
migration_15 = Migration(
|
||||||
|
from_version=14,
|
||||||
|
to_version=15,
|
||||||
|
callback=Migration15Callback(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return migration_15
|
@ -32,6 +32,7 @@ class PresetType(str, Enum, metaclass=MetaEnum):
|
|||||||
class StylePresetChanges(BaseModel, extra="forbid"):
|
class StylePresetChanges(BaseModel, extra="forbid"):
|
||||||
name: Optional[str] = Field(default=None, description="The style preset's new name.")
|
name: Optional[str] = Field(default=None, description="The style preset's new name.")
|
||||||
preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.")
|
preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.")
|
||||||
|
type: Optional[PresetType] = Field(description="The updated type of the style preset")
|
||||||
|
|
||||||
|
|
||||||
class StylePresetWithoutId(BaseModel):
|
class StylePresetWithoutId(BaseModel):
|
||||||
|
@ -0,0 +1,266 @@
|
|||||||
|
{
|
||||||
|
"name": "FLUX Text to Image",
|
||||||
|
"author": "InvokeAI",
|
||||||
|
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"contact": "",
|
||||||
|
"tags": "text2image, flux",
|
||||||
|
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
|
||||||
|
"exposedFields": [
|
||||||
|
{
|
||||||
|
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||||
|
"fieldName": "model"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
|
"fieldName": "prompt"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"nodeId": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||||
|
"fieldName": "num_steps"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||||
|
"fieldName": "t5_encoder"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"meta": {
|
||||||
|
"version": "3.0.0",
|
||||||
|
"category": "default"
|
||||||
|
},
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||||
|
"type": "invocation",
|
||||||
|
"data": {
|
||||||
|
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||||
|
"type": "flux_model_loader",
|
||||||
|
"version": "1.0.3",
|
||||||
|
"label": "",
|
||||||
|
"notes": "",
|
||||||
|
"isOpen": true,
|
||||||
|
"isIntermediate": true,
|
||||||
|
"useCache": false,
|
||||||
|
"inputs": {
|
||||||
|
"model": {
|
||||||
|
"name": "model",
|
||||||
|
"label": "Model (Starter Models can be found in Model Manager)",
|
||||||
|
"value": {
|
||||||
|
"key": "f04a7a2f-c74d-4538-8d5e-879a53501662",
|
||||||
|
"hash": "random:4875da7a9508444ffa706f61961c260d0c6729f6181a86b31fad06df1277b850",
|
||||||
|
"name": "FLUX Dev (Quantized)",
|
||||||
|
"base": "flux",
|
||||||
|
"type": "main"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"t5_encoder": {
|
||||||
|
"name": "t5_encoder",
|
||||||
|
"label": "T 5 Encoder (Starter Models can be found in Model Manager)",
|
||||||
|
"value": {
|
||||||
|
"key": "20dcd9ec-5fbb-4012-8401-049e707da5e5",
|
||||||
|
"hash": "random:f986be43ff3502169e4adbdcee158afb0e0a65a1edc4cab16ae59963630cfd8f",
|
||||||
|
"name": "t5_bnb_int8_quantized_encoder",
|
||||||
|
"base": "any",
|
||||||
|
"type": "t5_encoder"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"x": 337.09365228062825,
|
||||||
|
"y": 40.63469521079861
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
|
"type": "invocation",
|
||||||
|
"data": {
|
||||||
|
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
|
"type": "flux_text_encoder",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"label": "",
|
||||||
|
"notes": "",
|
||||||
|
"isOpen": true,
|
||||||
|
"isIntermediate": true,
|
||||||
|
"useCache": true,
|
||||||
|
"inputs": {
|
||||||
|
"clip": {
|
||||||
|
"name": "clip",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"t5_encoder": {
|
||||||
|
"name": "t5_encoder",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"t5_max_seq_len": {
|
||||||
|
"name": "t5_max_seq_len",
|
||||||
|
"label": "T5 Max Seq Len",
|
||||||
|
"value": 256
|
||||||
|
},
|
||||||
|
"prompt": {
|
||||||
|
"name": "prompt",
|
||||||
|
"label": "",
|
||||||
|
"value": "a cat"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"x": 824.1970602278849,
|
||||||
|
"y": 146.98251001061735
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||||
|
"type": "invocation",
|
||||||
|
"data": {
|
||||||
|
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||||
|
"type": "rand_int",
|
||||||
|
"version": "1.0.1",
|
||||||
|
"label": "",
|
||||||
|
"notes": "",
|
||||||
|
"isOpen": true,
|
||||||
|
"isIntermediate": true,
|
||||||
|
"useCache": false,
|
||||||
|
"inputs": {
|
||||||
|
"low": {
|
||||||
|
"name": "low",
|
||||||
|
"label": "",
|
||||||
|
"value": 0
|
||||||
|
},
|
||||||
|
"high": {
|
||||||
|
"name": "high",
|
||||||
|
"label": "",
|
||||||
|
"value": 2147483647
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"x": 822.9899179655476,
|
||||||
|
"y": 360.9657214885052
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||||
|
"type": "invocation",
|
||||||
|
"data": {
|
||||||
|
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||||
|
"type": "flux_text_to_image",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"label": "",
|
||||||
|
"notes": "",
|
||||||
|
"isOpen": true,
|
||||||
|
"isIntermediate": false,
|
||||||
|
"useCache": true,
|
||||||
|
"inputs": {
|
||||||
|
"board": {
|
||||||
|
"name": "board",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"name": "metadata",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"transformer": {
|
||||||
|
"name": "transformer",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"vae": {
|
||||||
|
"name": "vae",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"positive_text_conditioning": {
|
||||||
|
"name": "positive_text_conditioning",
|
||||||
|
"label": ""
|
||||||
|
},
|
||||||
|
"width": {
|
||||||
|
"name": "width",
|
||||||
|
"label": "",
|
||||||
|
"value": 1024
|
||||||
|
},
|
||||||
|
"height": {
|
||||||
|
"name": "height",
|
||||||
|
"label": "",
|
||||||
|
"value": 1024
|
||||||
|
},
|
||||||
|
"num_steps": {
|
||||||
|
"name": "num_steps",
|
||||||
|
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
|
||||||
|
"value": 30
|
||||||
|
},
|
||||||
|
"guidance": {
|
||||||
|
"name": "guidance",
|
||||||
|
"label": "",
|
||||||
|
"value": 4
|
||||||
|
},
|
||||||
|
"seed": {
|
||||||
|
"name": "seed",
|
||||||
|
"label": "",
|
||||||
|
"value": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"x": 1216.3900791301849,
|
||||||
|
"y": 5.500841807102248
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33amax_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
|
||||||
|
"type": "default",
|
||||||
|
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||||
|
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
|
"sourceHandle": "max_seq_len",
|
||||||
|
"targetHandle": "t5_max_seq_len"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33avae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
|
||||||
|
"type": "default",
|
||||||
|
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||||
|
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||||
|
"sourceHandle": "vae",
|
||||||
|
"targetHandle": "vae"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33atransformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
|
||||||
|
"type": "default",
|
||||||
|
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||||
|
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||||
|
"sourceHandle": "transformer",
|
||||||
|
"targetHandle": "transformer"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33at5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
|
||||||
|
"type": "default",
|
||||||
|
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||||
|
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
|
"sourceHandle": "t5_encoder",
|
||||||
|
"targetHandle": "t5_encoder"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33aclip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
|
||||||
|
"type": "default",
|
||||||
|
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
|
||||||
|
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
|
"sourceHandle": "clip",
|
||||||
|
"targetHandle": "clip"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
|
||||||
|
"type": "default",
|
||||||
|
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
|
||||||
|
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||||
|
"sourceHandle": "conditioning",
|
||||||
|
"targetHandle": "positive_text_conditioning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-159bdf1b-79e7-4174-b86e-d40e646964c8seed",
|
||||||
|
"type": "default",
|
||||||
|
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
|
||||||
|
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
|
||||||
|
"sourceHandle": "value",
|
||||||
|
"targetHandle": "seed"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
32
invokeai/backend/flux/math.py
Normal file
32
invokeai/backend/flux/math.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||||
|
q, k = apply_rope(q, k, pe)
|
||||||
|
|
||||||
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||||
|
x = rearrange(x, "B H L D -> B L (H D)")
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||||
|
assert dim % 2 == 0
|
||||||
|
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||||
|
omega = 1.0 / (theta**scale)
|
||||||
|
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||||
|
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||||
|
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||||
|
return out.float()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
|
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||||
|
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||||
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
117
invokeai/backend/flux/model.py
Normal file
117
invokeai/backend/flux/model.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from invokeai.backend.flux.modules.layers import (
|
||||||
|
DoubleStreamBlock,
|
||||||
|
EmbedND,
|
||||||
|
LastLayer,
|
||||||
|
MLPEmbedder,
|
||||||
|
SingleStreamBlock,
|
||||||
|
timestep_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FluxParams:
|
||||||
|
in_channels: int
|
||||||
|
vec_in_dim: int
|
||||||
|
context_in_dim: int
|
||||||
|
hidden_size: int
|
||||||
|
mlp_ratio: float
|
||||||
|
num_heads: int
|
||||||
|
depth: int
|
||||||
|
depth_single_blocks: int
|
||||||
|
axes_dim: list[int]
|
||||||
|
theta: int
|
||||||
|
qkv_bias: bool
|
||||||
|
guidance_embed: bool
|
||||||
|
|
||||||
|
|
||||||
|
class Flux(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer model for flow matching on sequences.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params: FluxParams):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.params = params
|
||||||
|
self.in_channels = params.in_channels
|
||||||
|
self.out_channels = self.in_channels
|
||||||
|
if params.hidden_size % params.num_heads != 0:
|
||||||
|
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||||
|
pe_dim = params.hidden_size // params.num_heads
|
||||||
|
if sum(params.axes_dim) != pe_dim:
|
||||||
|
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||||
|
self.hidden_size = params.hidden_size
|
||||||
|
self.num_heads = params.num_heads
|
||||||
|
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||||
|
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||||
|
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||||
|
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||||
|
self.guidance_in = (
|
||||||
|
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||||
|
)
|
||||||
|
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||||
|
|
||||||
|
self.double_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DoubleStreamBlock(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_heads,
|
||||||
|
mlp_ratio=params.mlp_ratio,
|
||||||
|
qkv_bias=params.qkv_bias,
|
||||||
|
)
|
||||||
|
for _ in range(params.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.single_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||||
|
for _ in range(params.depth_single_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
y: Tensor,
|
||||||
|
guidance: Tensor | None = None,
|
||||||
|
) -> Tensor:
|
||||||
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
|
# running on sequences img
|
||||||
|
img = self.img_in(img)
|
||||||
|
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||||
|
if self.params.guidance_embed:
|
||||||
|
if guidance is None:
|
||||||
|
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||||
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||||
|
vec = vec + self.vector_in(y)
|
||||||
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
|
for block in self.double_blocks:
|
||||||
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||||
|
|
||||||
|
img = torch.cat((txt, img), 1)
|
||||||
|
for block in self.single_blocks:
|
||||||
|
img = block(img, vec=vec, pe=pe)
|
||||||
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
|
||||||
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
|
return img
|
310
invokeai/backend/flux/modules/autoencoder.py
Normal file
310
invokeai/backend/flux/modules/autoencoder.py
Normal file
@ -0,0 +1,310 @@
|
|||||||
|
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AutoEncoderParams:
|
||||||
|
resolution: int
|
||||||
|
in_channels: int
|
||||||
|
ch: int
|
||||||
|
out_ch: int
|
||||||
|
ch_mult: list[int]
|
||||||
|
num_res_blocks: int
|
||||||
|
z_channels: int
|
||||||
|
scale_factor: float
|
||||||
|
shift_factor: float
|
||||||
|
|
||||||
|
|
||||||
|
class AttnBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||||
|
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||||
|
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||||
|
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||||
|
|
||||||
|
def attention(self, h_: Tensor) -> Tensor:
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q = self.q(h_)
|
||||||
|
k = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
|
||||||
|
b, c, h, w = q.shape
|
||||||
|
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
||||||
|
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
||||||
|
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
||||||
|
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
||||||
|
|
||||||
|
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x + self.proj_out(self.attention(x))
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels: int, out_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||||
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = x
|
||||||
|
h = self.norm1(h)
|
||||||
|
h = torch.nn.functional.silu(h)
|
||||||
|
h = self.conv1(h)
|
||||||
|
|
||||||
|
h = self.norm2(h)
|
||||||
|
h = torch.nn.functional.silu(h)
|
||||||
|
h = self.conv2(h)
|
||||||
|
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
|
return x + h
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
def __init__(self, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
|
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
pad = (0, 1, 0, 1)
|
||||||
|
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
def __init__(self, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
resolution: int,
|
||||||
|
in_channels: int,
|
||||||
|
ch: int,
|
||||||
|
ch_mult: list[int],
|
||||||
|
num_res_blocks: int,
|
||||||
|
z_channels: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ch = ch
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
# downsampling
|
||||||
|
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
curr_res = resolution
|
||||||
|
in_ch_mult = (1,) + tuple(ch_mult)
|
||||||
|
self.in_ch_mult = in_ch_mult
|
||||||
|
self.down = nn.ModuleList()
|
||||||
|
block_in = self.ch
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_in = ch * in_ch_mult[i_level]
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for _ in range(self.num_res_blocks):
|
||||||
|
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||||
|
block_in = block_out
|
||||||
|
down = nn.Module()
|
||||||
|
down.block = block
|
||||||
|
down.attn = attn
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
down.downsample = Downsample(block_in)
|
||||||
|
curr_res = curr_res // 2
|
||||||
|
self.down.append(down)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||||
|
self.mid.attn_1 = AttnBlock(block_in)
|
||||||
|
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||||
|
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
# downsampling
|
||||||
|
hs = [self.conv_in(x)]
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
for i_block in range(self.num_res_blocks):
|
||||||
|
h = self.down[i_level].block[i_block](hs[-1])
|
||||||
|
if len(self.down[i_level].attn) > 0:
|
||||||
|
h = self.down[i_level].attn[i_block](h)
|
||||||
|
hs.append(h)
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||||
|
|
||||||
|
# middle
|
||||||
|
h = hs[-1]
|
||||||
|
h = self.mid.block_1(h)
|
||||||
|
h = self.mid.attn_1(h)
|
||||||
|
h = self.mid.block_2(h)
|
||||||
|
# end
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = torch.nn.functional.silu(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ch: int,
|
||||||
|
out_ch: int,
|
||||||
|
ch_mult: list[int],
|
||||||
|
num_res_blocks: int,
|
||||||
|
in_channels: int,
|
||||||
|
resolution: int,
|
||||||
|
z_channels: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ch = ch
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.ffactor = 2 ** (self.num_resolutions - 1)
|
||||||
|
|
||||||
|
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||||
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
|
|
||||||
|
# z to block_in
|
||||||
|
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||||
|
self.mid.attn_1 = AttnBlock(block_in)
|
||||||
|
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
self.up = nn.ModuleList()
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for _ in range(self.num_res_blocks + 1):
|
||||||
|
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||||
|
block_in = block_out
|
||||||
|
up = nn.Module()
|
||||||
|
up.block = block
|
||||||
|
up.attn = attn
|
||||||
|
if i_level != 0:
|
||||||
|
up.upsample = Upsample(block_in)
|
||||||
|
curr_res = curr_res * 2
|
||||||
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||||
|
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
def forward(self, z: Tensor) -> Tensor:
|
||||||
|
# z to block_in
|
||||||
|
h = self.conv_in(z)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
h = self.mid.block_1(h)
|
||||||
|
h = self.mid.attn_1(h)
|
||||||
|
h = self.mid.block_2(h)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
h = self.up[i_level].block[i_block](h)
|
||||||
|
if len(self.up[i_level].attn) > 0:
|
||||||
|
h = self.up[i_level].attn[i_block](h)
|
||||||
|
if i_level != 0:
|
||||||
|
h = self.up[i_level].upsample(h)
|
||||||
|
|
||||||
|
# end
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = torch.nn.functional.silu(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class DiagonalGaussian(nn.Module):
|
||||||
|
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.sample = sample
|
||||||
|
self.chunk_dim = chunk_dim
|
||||||
|
|
||||||
|
def forward(self, z: Tensor) -> Tensor:
|
||||||
|
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||||
|
if self.sample:
|
||||||
|
std = torch.exp(0.5 * logvar)
|
||||||
|
return mean + std * torch.randn_like(mean)
|
||||||
|
else:
|
||||||
|
return mean
|
||||||
|
|
||||||
|
|
||||||
|
class AutoEncoder(nn.Module):
|
||||||
|
def __init__(self, params: AutoEncoderParams):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = Encoder(
|
||||||
|
resolution=params.resolution,
|
||||||
|
in_channels=params.in_channels,
|
||||||
|
ch=params.ch,
|
||||||
|
ch_mult=params.ch_mult,
|
||||||
|
num_res_blocks=params.num_res_blocks,
|
||||||
|
z_channels=params.z_channels,
|
||||||
|
)
|
||||||
|
self.decoder = Decoder(
|
||||||
|
resolution=params.resolution,
|
||||||
|
in_channels=params.in_channels,
|
||||||
|
ch=params.ch,
|
||||||
|
out_ch=params.out_ch,
|
||||||
|
ch_mult=params.ch_mult,
|
||||||
|
num_res_blocks=params.num_res_blocks,
|
||||||
|
z_channels=params.z_channels,
|
||||||
|
)
|
||||||
|
self.reg = DiagonalGaussian()
|
||||||
|
|
||||||
|
self.scale_factor = params.scale_factor
|
||||||
|
self.shift_factor = params.shift_factor
|
||||||
|
|
||||||
|
def encode(self, x: Tensor) -> Tensor:
|
||||||
|
z = self.reg(self.encoder(x))
|
||||||
|
z = self.scale_factor * (z - self.shift_factor)
|
||||||
|
return z
|
||||||
|
|
||||||
|
def decode(self, z: Tensor) -> Tensor:
|
||||||
|
z = z / self.scale_factor + self.shift_factor
|
||||||
|
return self.decoder(z)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self.decode(self.encode(x))
|
33
invokeai/backend/flux/modules/conditioner.py
Normal file
33
invokeai/backend/flux/modules/conditioner.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||||
|
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class HFEncoder(nn.Module):
|
||||||
|
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
|
||||||
|
super().__init__()
|
||||||
|
self.max_length = max_length
|
||||||
|
self.is_clip = is_clip
|
||||||
|
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.hf_module = encoder
|
||||||
|
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
||||||
|
|
||||||
|
def forward(self, text: list[str]) -> Tensor:
|
||||||
|
batch_encoding = self.tokenizer(
|
||||||
|
text,
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_length,
|
||||||
|
return_length=False,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
padding="max_length",
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self.hf_module(
|
||||||
|
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
||||||
|
attention_mask=None,
|
||||||
|
output_hidden_states=False,
|
||||||
|
)
|
||||||
|
return outputs[self.output_key]
|
253
invokeai/backend/flux/modules/layers.py
Normal file
253
invokeai/backend/flux/modules/layers.py
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from invokeai.backend.flux.math import attention, rope
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedND(nn.Module):
|
||||||
|
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = axes_dim
|
||||||
|
|
||||||
|
def forward(self, ids: Tensor) -> Tensor:
|
||||||
|
n_axes = ids.shape[-1]
|
||||||
|
emb = torch.cat(
|
||||||
|
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||||
|
dim=-3,
|
||||||
|
)
|
||||||
|
|
||||||
|
return emb.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
These may be fractional.
|
||||||
|
:param dim: the dimension of the output.
|
||||||
|
:param max_period: controls the minimum frequency of the embeddings.
|
||||||
|
:return: an (N, D) Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
t = time_factor * t
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
||||||
|
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
if torch.is_floating_point(t):
|
||||||
|
embedding = embedding.to(t)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
class MLPEmbedder(nn.Module):
|
||||||
|
def __init__(self, in_dim: int, hidden_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self.out_layer(self.silu(self.in_layer(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
x_dtype = x.dtype
|
||||||
|
x = x.float()
|
||||||
|
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||||
|
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||||
|
|
||||||
|
|
||||||
|
class QKNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.query_norm = RMSNorm(dim)
|
||||||
|
self.key_norm = RMSNorm(dim)
|
||||||
|
|
||||||
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
|
q = self.query_norm(q)
|
||||||
|
k = self.key_norm(k)
|
||||||
|
return q.to(v), k.to(v)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.norm = QKNorm(head_dim)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
q, k = self.norm(q, k, v)
|
||||||
|
x = attention(q, k, v, pe=pe)
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModulationOut:
|
||||||
|
shift: Tensor
|
||||||
|
scale: Tensor
|
||||||
|
gate: Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Modulation(nn.Module):
|
||||||
|
def __init__(self, dim: int, double: bool):
|
||||||
|
super().__init__()
|
||||||
|
self.is_double = double
|
||||||
|
self.multiplier = 6 if double else 3
|
||||||
|
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||||
|
|
||||||
|
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
||||||
|
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||||
|
|
||||||
|
return (
|
||||||
|
ModulationOut(*out[:3]),
|
||||||
|
ModulationOut(*out[3:]) if self.is_double else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DoubleStreamBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.img_mod = Modulation(hidden_size, double=True)
|
||||||
|
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||||
|
|
||||||
|
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.img_mlp = nn.Sequential(
|
||||||
|
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.txt_mod = Modulation(hidden_size, double=True)
|
||||||
|
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||||
|
|
||||||
|
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.txt_mlp = nn.Sequential(
|
||||||
|
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||||
|
|
||||||
|
# prepare image for attention
|
||||||
|
img_modulated = self.img_norm1(img)
|
||||||
|
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||||
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
|
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
|
# prepare txt for attention
|
||||||
|
txt_modulated = self.txt_norm1(txt)
|
||||||
|
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||||
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
|
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
|
# run actual attention
|
||||||
|
q = torch.cat((txt_q, img_q), dim=2)
|
||||||
|
k = torch.cat((txt_k, img_k), dim=2)
|
||||||
|
v = torch.cat((txt_v, img_v), dim=2)
|
||||||
|
|
||||||
|
attn = attention(q, k, v, pe=pe)
|
||||||
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||||
|
|
||||||
|
# calculate the img bloks
|
||||||
|
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||||
|
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||||
|
|
||||||
|
# calculate the txt bloks
|
||||||
|
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||||
|
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||||
|
return img, txt
|
||||||
|
|
||||||
|
|
||||||
|
class SingleStreamBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A DiT block with parallel linear layers as described in
|
||||||
|
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
qk_scale: float | None = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_dim = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = hidden_size // num_heads
|
||||||
|
self.scale = qk_scale or head_dim**-0.5
|
||||||
|
|
||||||
|
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
# qkv and mlp_in
|
||||||
|
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||||
|
# proj and mlp_out
|
||||||
|
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||||
|
|
||||||
|
self.norm = QKNorm(head_dim)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
|
self.modulation = Modulation(hidden_size, double=False)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||||
|
mod, _ = self.modulation(vec)
|
||||||
|
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||||
|
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
attn = attention(q, k, v, pe=pe)
|
||||||
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
|
return x + mod.gate * output
|
||||||
|
|
||||||
|
|
||||||
|
class LastLayer(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||||
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||||
|
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||||
|
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
176
invokeai/backend/flux/sampling.py
Normal file
176
invokeai/backend/flux/sampling.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from torch import Tensor
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from invokeai.backend.flux.model import Flux
|
||||||
|
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||||
|
|
||||||
|
|
||||||
|
def get_noise(
|
||||||
|
num_samples: int,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
):
|
||||||
|
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
|
||||||
|
rand_device = "cpu"
|
||||||
|
rand_dtype = torch.float16
|
||||||
|
return torch.randn(
|
||||||
|
num_samples,
|
||||||
|
16,
|
||||||
|
# allow for packing
|
||||||
|
2 * math.ceil(height / 16),
|
||||||
|
2 * math.ceil(width / 16),
|
||||||
|
device=rand_device,
|
||||||
|
dtype=rand_dtype,
|
||||||
|
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||||
|
).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
||||||
|
bs, c, h, w = img.shape
|
||||||
|
if bs == 1 and not isinstance(prompt, str):
|
||||||
|
bs = len(prompt)
|
||||||
|
|
||||||
|
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
|
if img.shape[0] == 1 and bs > 1:
|
||||||
|
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||||
|
|
||||||
|
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompt = [prompt]
|
||||||
|
txt = t5(prompt)
|
||||||
|
if txt.shape[0] == 1 and bs > 1:
|
||||||
|
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
||||||
|
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
||||||
|
|
||||||
|
vec = clip(prompt)
|
||||||
|
if vec.shape[0] == 1 and bs > 1:
|
||||||
|
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"img": img,
|
||||||
|
"img_ids": img_ids.to(img.device),
|
||||||
|
"txt": txt.to(img.device),
|
||||||
|
"txt_ids": txt_ids.to(img.device),
|
||||||
|
"vec": vec.to(img.device),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def time_shift(mu: float, sigma: float, t: Tensor):
|
||||||
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||||
|
|
||||||
|
|
||||||
|
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||||
|
m = (y2 - y1) / (x2 - x1)
|
||||||
|
b = y1 - m * x1
|
||||||
|
return lambda x: m * x + b
|
||||||
|
|
||||||
|
|
||||||
|
def get_schedule(
|
||||||
|
num_steps: int,
|
||||||
|
image_seq_len: int,
|
||||||
|
base_shift: float = 0.5,
|
||||||
|
max_shift: float = 1.15,
|
||||||
|
shift: bool = True,
|
||||||
|
) -> list[float]:
|
||||||
|
# extra step for zero
|
||||||
|
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||||
|
|
||||||
|
# shifting the schedule to favor high timesteps for higher signal images
|
||||||
|
if shift:
|
||||||
|
# eastimate mu based on linear estimation between two points
|
||||||
|
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||||
|
timesteps = time_shift(mu, 1.0, timesteps)
|
||||||
|
|
||||||
|
return timesteps.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def denoise(
|
||||||
|
model: Flux,
|
||||||
|
# model input
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
vec: Tensor,
|
||||||
|
# sampling parameters
|
||||||
|
timesteps: list[float],
|
||||||
|
step_callback: Callable[[], None],
|
||||||
|
guidance: float = 4.0,
|
||||||
|
):
|
||||||
|
dtype = model.txt_in.bias.dtype
|
||||||
|
|
||||||
|
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
|
||||||
|
img = img.to(dtype=dtype)
|
||||||
|
img_ids = img_ids.to(dtype=dtype)
|
||||||
|
txt = txt.to(dtype=dtype)
|
||||||
|
txt_ids = txt_ids.to(dtype=dtype)
|
||||||
|
vec = vec.to(dtype=dtype)
|
||||||
|
|
||||||
|
# this is ignored for schnell
|
||||||
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
|
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||||
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
|
pred = model(
|
||||||
|
img=img,
|
||||||
|
img_ids=img_ids,
|
||||||
|
txt=txt,
|
||||||
|
txt_ids=txt_ids,
|
||||||
|
y=vec,
|
||||||
|
timesteps=t_vec,
|
||||||
|
guidance=guidance_vec,
|
||||||
|
)
|
||||||
|
|
||||||
|
img = img + (t_prev - t_curr) * pred
|
||||||
|
step_callback()
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||||
|
h=math.ceil(height / 16),
|
||||||
|
w=math.ceil(width / 16),
|
||||||
|
ph=2,
|
||||||
|
pw=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Convert an input image in latent space to patches for diffusion.
|
||||||
|
|
||||||
|
This implementation was extracted from:
|
||||||
|
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
|
||||||
|
"""
|
||||||
|
bs, c, h, w = latent_img.shape
|
||||||
|
|
||||||
|
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||||
|
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
|
if img.shape[0] == 1 and bs > 1:
|
||||||
|
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||||
|
|
||||||
|
# Generate patch position ids.
|
||||||
|
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
return img, img_ids
|
71
invokeai/backend/flux/util.py
Normal file
71
invokeai/backend/flux/util.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# Initially pulled from https://github.com/black-forest-labs/flux
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Literal
|
||||||
|
|
||||||
|
from invokeai.backend.flux.model import FluxParams
|
||||||
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelSpec:
|
||||||
|
params: FluxParams
|
||||||
|
ae_params: AutoEncoderParams
|
||||||
|
ckpt_path: str | None
|
||||||
|
ae_path: str | None
|
||||||
|
repo_id: str | None
|
||||||
|
repo_flow: str | None
|
||||||
|
repo_ae: str | None
|
||||||
|
|
||||||
|
|
||||||
|
max_seq_lengths: Dict[str, Literal[256, 512]] = {
|
||||||
|
"flux-dev": 512,
|
||||||
|
"flux-schnell": 256,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ae_params = {
|
||||||
|
"flux": AutoEncoderParams(
|
||||||
|
resolution=256,
|
||||||
|
in_channels=3,
|
||||||
|
ch=128,
|
||||||
|
out_ch=3,
|
||||||
|
ch_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
z_channels=16,
|
||||||
|
scale_factor=0.3611,
|
||||||
|
shift_factor=0.1159,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"flux-dev": FluxParams(
|
||||||
|
in_channels=64,
|
||||||
|
vec_in_dim=768,
|
||||||
|
context_in_dim=4096,
|
||||||
|
hidden_size=3072,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
num_heads=24,
|
||||||
|
depth=19,
|
||||||
|
depth_single_blocks=38,
|
||||||
|
axes_dim=[16, 56, 56],
|
||||||
|
theta=10_000,
|
||||||
|
qkv_bias=True,
|
||||||
|
guidance_embed=True,
|
||||||
|
),
|
||||||
|
"flux-schnell": FluxParams(
|
||||||
|
in_channels=64,
|
||||||
|
vec_in_dim=768,
|
||||||
|
context_in_dim=4096,
|
||||||
|
hidden_size=3072,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
num_heads=24,
|
||||||
|
depth=19,
|
||||||
|
depth_single_blocks=38,
|
||||||
|
axes_dim=[16, 56, 56],
|
||||||
|
theta=10_000,
|
||||||
|
qkv_bias=True,
|
||||||
|
guidance_embed=False,
|
||||||
|
),
|
||||||
|
}
|
@ -52,6 +52,7 @@ class BaseModelType(str, Enum):
|
|||||||
StableDiffusion2 = "sd-2"
|
StableDiffusion2 = "sd-2"
|
||||||
StableDiffusionXL = "sdxl"
|
StableDiffusionXL = "sdxl"
|
||||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||||
|
Flux = "flux"
|
||||||
# Kandinsky2_1 = "kandinsky-2.1"
|
# Kandinsky2_1 = "kandinsky-2.1"
|
||||||
|
|
||||||
|
|
||||||
@ -66,7 +67,9 @@ class ModelType(str, Enum):
|
|||||||
TextualInversion = "embedding"
|
TextualInversion = "embedding"
|
||||||
IPAdapter = "ip_adapter"
|
IPAdapter = "ip_adapter"
|
||||||
CLIPVision = "clip_vision"
|
CLIPVision = "clip_vision"
|
||||||
|
CLIPEmbed = "clip_embed"
|
||||||
T2IAdapter = "t2i_adapter"
|
T2IAdapter = "t2i_adapter"
|
||||||
|
T5Encoder = "t5_encoder"
|
||||||
SpandrelImageToImage = "spandrel_image_to_image"
|
SpandrelImageToImage = "spandrel_image_to_image"
|
||||||
|
|
||||||
|
|
||||||
@ -74,6 +77,7 @@ class SubModelType(str, Enum):
|
|||||||
"""Submodel type."""
|
"""Submodel type."""
|
||||||
|
|
||||||
UNet = "unet"
|
UNet = "unet"
|
||||||
|
Transformer = "transformer"
|
||||||
TextEncoder = "text_encoder"
|
TextEncoder = "text_encoder"
|
||||||
TextEncoder2 = "text_encoder_2"
|
TextEncoder2 = "text_encoder_2"
|
||||||
Tokenizer = "tokenizer"
|
Tokenizer = "tokenizer"
|
||||||
@ -104,6 +108,9 @@ class ModelFormat(str, Enum):
|
|||||||
EmbeddingFile = "embedding_file"
|
EmbeddingFile = "embedding_file"
|
||||||
EmbeddingFolder = "embedding_folder"
|
EmbeddingFolder = "embedding_folder"
|
||||||
InvokeAI = "invokeai"
|
InvokeAI = "invokeai"
|
||||||
|
T5Encoder = "t5_encoder"
|
||||||
|
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
|
||||||
|
BnbQuantizednf4b = "bnb_quantized_nf4b"
|
||||||
|
|
||||||
|
|
||||||
class SchedulerPredictionType(str, Enum):
|
class SchedulerPredictionType(str, Enum):
|
||||||
@ -186,7 +193,9 @@ class ModelConfigBase(BaseModel):
|
|||||||
class CheckpointConfigBase(ModelConfigBase):
|
class CheckpointConfigBase(ModelConfigBase):
|
||||||
"""Model config for checkpoint-style models."""
|
"""Model config for checkpoint-style models."""
|
||||||
|
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(
|
||||||
|
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
|
||||||
|
)
|
||||||
config_path: str = Field(description="path to the checkpoint model config file")
|
config_path: str = Field(description="path to the checkpoint model config file")
|
||||||
converted_at: Optional[float] = Field(
|
converted_at: Optional[float] = Field(
|
||||||
description="When this model was last converted to diffusers", default_factory=time.time
|
description="When this model was last converted to diffusers", default_factory=time.time
|
||||||
@ -205,6 +214,26 @@ class LoRAConfigBase(ModelConfigBase):
|
|||||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class T5EncoderConfigBase(ModelConfigBase):
|
||||||
|
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
|
||||||
|
|
||||||
|
|
||||||
|
class T5EncoderConfig(T5EncoderConfigBase):
|
||||||
|
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
|
||||||
|
|
||||||
|
|
||||||
|
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase):
|
||||||
|
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.BnbQuantizedLlmInt8b.value}")
|
||||||
|
|
||||||
|
|
||||||
class LoRALyCORISConfig(LoRAConfigBase):
|
class LoRALyCORISConfig(LoRAConfigBase):
|
||||||
"""Model config for LoRA/Lycoris models."""
|
"""Model config for LoRA/Lycoris models."""
|
||||||
|
|
||||||
@ -229,7 +258,6 @@ class VAECheckpointConfig(CheckpointConfigBase):
|
|||||||
"""Model config for standalone VAE models."""
|
"""Model config for standalone VAE models."""
|
||||||
|
|
||||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tag() -> Tag:
|
def get_tag() -> Tag:
|
||||||
@ -268,7 +296,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase)
|
|||||||
"""Model config for ControlNet models (diffusers version)."""
|
"""Model config for ControlNet models (diffusers version)."""
|
||||||
|
|
||||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tag() -> Tag:
|
def get_tag() -> Tag:
|
||||||
@ -317,6 +344,21 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
|||||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
||||||
|
|
||||||
|
|
||||||
|
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||||
|
"""Model config for main checkpoint models."""
|
||||||
|
|
||||||
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||||
|
upcast_attention: bool = False
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.format = ModelFormat.BnbQuantizednf4b
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
|
||||||
|
|
||||||
|
|
||||||
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
||||||
"""Model config for main diffusers models."""
|
"""Model config for main diffusers models."""
|
||||||
|
|
||||||
@ -350,6 +392,17 @@ class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
|
|||||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
|
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||||
|
"""Model config for Clip Embeddings."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
|
||||||
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||||
"""Model config for CLIPVision."""
|
"""Model config for CLIPVision."""
|
||||||
|
|
||||||
@ -408,12 +461,15 @@ AnyModelConfig = Annotated[
|
|||||||
Union[
|
Union[
|
||||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||||
|
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
|
||||||
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
||||||
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
||||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||||
|
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
||||||
|
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
|
||||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||||
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
||||||
@ -421,6 +477,7 @@ AnyModelConfig = Annotated[
|
|||||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||||
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
||||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||||
|
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
|
||||||
],
|
],
|
||||||
Discriminator(get_model_discriminator_value),
|
Discriminator(get_model_discriminator_value),
|
||||||
]
|
]
|
||||||
|
234
invokeai/backend/model_manager/load/model_loaders/flux.py
Normal file
234
invokeai/backend/model_manager/load/model_loaders/flux.py
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
||||||
|
"""Class for Flux model loading in InvokeAI."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import accelerate
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||||
|
|
||||||
|
from invokeai.app.services.config.config_default import get_config
|
||||||
|
from invokeai.backend.flux.model import Flux
|
||||||
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||||
|
from invokeai.backend.flux.util import ae_params, params
|
||||||
|
from invokeai.backend.model_manager import (
|
||||||
|
AnyModel,
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.config import (
|
||||||
|
CheckpointConfigBase,
|
||||||
|
CLIPEmbedDiffusersConfig,
|
||||||
|
MainBnbQuantized4bCheckpointConfig,
|
||||||
|
MainCheckpointConfig,
|
||||||
|
T5EncoderBnbQuantizedLlmInt8bConfig,
|
||||||
|
T5EncoderConfig,
|
||||||
|
VAECheckpointConfig,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||||
|
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||||
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
|
try:
|
||||||
|
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||||
|
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||||
|
|
||||||
|
bnb_available = True
|
||||||
|
except ImportError:
|
||||||
|
bnb_available = False
|
||||||
|
|
||||||
|
app_config = get_config()
|
||||||
|
|
||||||
|
|
||||||
|
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||||
|
class FluxVAELoader(ModelLoader):
|
||||||
|
"""Class to load VAE models."""
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
config: AnyModelConfig,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> AnyModel:
|
||||||
|
if not isinstance(config, VAECheckpointConfig):
|
||||||
|
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
|
||||||
|
model_path = Path(config.path)
|
||||||
|
|
||||||
|
with SilenceWarnings():
|
||||||
|
model = AutoEncoder(ae_params[config.config_path])
|
||||||
|
sd = load_file(model_path)
|
||||||
|
model.load_state_dict(sd, assign=True)
|
||||||
|
model.to(dtype=self._torch_dtype)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
|
||||||
|
class ClipCheckpointModel(ModelLoader):
|
||||||
|
"""Class to load main models."""
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
config: AnyModelConfig,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> AnyModel:
|
||||||
|
if not isinstance(config, CLIPEmbedDiffusersConfig):
|
||||||
|
raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.")
|
||||||
|
|
||||||
|
match submodel_type:
|
||||||
|
case SubModelType.Tokenizer:
|
||||||
|
return CLIPTokenizer.from_pretrained(Path(config.path) / "tokenizer")
|
||||||
|
case SubModelType.TextEncoder:
|
||||||
|
return CLIPTextModel.from_pretrained(Path(config.path) / "text_encoder")
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b)
|
||||||
|
class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
|
||||||
|
"""Class to load main models."""
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
config: AnyModelConfig,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> AnyModel:
|
||||||
|
if not isinstance(config, T5EncoderBnbQuantizedLlmInt8bConfig):
|
||||||
|
raise ValueError("Only T5EncoderBnbQuantizedLlmInt8bConfig models are currently supported here.")
|
||||||
|
if not bnb_available:
|
||||||
|
raise ImportError(
|
||||||
|
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
|
||||||
|
)
|
||||||
|
match submodel_type:
|
||||||
|
case SubModelType.Tokenizer2:
|
||||||
|
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||||
|
case SubModelType.TextEncoder2:
|
||||||
|
te2_model_path = Path(config.path) / "text_encoder_2"
|
||||||
|
model_config = AutoConfig.from_pretrained(te2_model_path)
|
||||||
|
with accelerate.init_empty_weights():
|
||||||
|
model = AutoModelForTextEncoding.from_config(model_config)
|
||||||
|
model = quantize_model_llm_int8(model, modules_to_not_convert=set())
|
||||||
|
|
||||||
|
state_dict_path = te2_model_path / "bnb_llm_int8_model.safetensors"
|
||||||
|
state_dict = load_file(state_dict_path)
|
||||||
|
self._load_state_dict_into_t5(model, state_dict)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_state_dict_into_t5(cls, model: T5EncoderModel, state_dict: dict[str, torch.Tensor]):
|
||||||
|
# There is a shared reference to a single weight tensor in the model.
|
||||||
|
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
|
||||||
|
# be present in the state_dict.
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
|
||||||
|
assert len(unexpected_keys) == 0
|
||||||
|
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
|
||||||
|
# Assert that the layers we expect to be shared are actually shared.
|
||||||
|
assert model.encoder.embed_tokens.weight is model.shared.weight
|
||||||
|
|
||||||
|
|
||||||
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
||||||
|
class T5EncoderCheckpointModel(ModelLoader):
|
||||||
|
"""Class to load main models."""
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
config: AnyModelConfig,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> AnyModel:
|
||||||
|
if not isinstance(config, T5EncoderConfig):
|
||||||
|
raise ValueError("Only T5EncoderConfig models are currently supported here.")
|
||||||
|
|
||||||
|
match submodel_type:
|
||||||
|
case SubModelType.Tokenizer2:
|
||||||
|
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
|
||||||
|
case SubModelType.TextEncoder2:
|
||||||
|
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2")
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||||
|
class FluxCheckpointModel(ModelLoader):
|
||||||
|
"""Class to load main models."""
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
config: AnyModelConfig,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> AnyModel:
|
||||||
|
if not isinstance(config, CheckpointConfigBase):
|
||||||
|
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||||
|
|
||||||
|
match submodel_type:
|
||||||
|
case SubModelType.Transformer:
|
||||||
|
return self._load_from_singlefile(config)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_from_singlefile(
|
||||||
|
self,
|
||||||
|
config: AnyModelConfig,
|
||||||
|
) -> AnyModel:
|
||||||
|
assert isinstance(config, MainCheckpointConfig)
|
||||||
|
model_path = Path(config.path)
|
||||||
|
|
||||||
|
with SilenceWarnings():
|
||||||
|
model = Flux(params[config.config_path])
|
||||||
|
sd = load_file(model_path)
|
||||||
|
model.load_state_dict(sd, assign=True)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
|
||||||
|
class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||||
|
"""Class to load main models."""
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
config: AnyModelConfig,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> AnyModel:
|
||||||
|
if not isinstance(config, CheckpointConfigBase):
|
||||||
|
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||||
|
|
||||||
|
match submodel_type:
|
||||||
|
case SubModelType.Transformer:
|
||||||
|
return self._load_from_singlefile(config)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_from_singlefile(
|
||||||
|
self,
|
||||||
|
config: AnyModelConfig,
|
||||||
|
) -> AnyModel:
|
||||||
|
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
|
||||||
|
if not bnb_available:
|
||||||
|
raise ImportError(
|
||||||
|
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
|
||||||
|
)
|
||||||
|
model_path = Path(config.path)
|
||||||
|
|
||||||
|
with SilenceWarnings():
|
||||||
|
with accelerate.init_empty_weights():
|
||||||
|
model = Flux(params[config.config_path])
|
||||||
|
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||||
|
sd = load_file(model_path)
|
||||||
|
model.load_state_dict(sd, assign=True)
|
||||||
|
return model
|
@ -78,7 +78,12 @@ class GenericDiffusersLoader(ModelLoader):
|
|||||||
|
|
||||||
# TO DO: Add exception handling
|
# TO DO: Add exception handling
|
||||||
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
|
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
|
||||||
if module in ["diffusers", "transformers"]:
|
if module in [
|
||||||
|
"diffusers",
|
||||||
|
"transformers",
|
||||||
|
"invokeai.backend.quantization.fast_quantized_transformers_model",
|
||||||
|
"invokeai.backend.quantization.fast_quantized_diffusion_model",
|
||||||
|
]:
|
||||||
res_type = sys.modules[module]
|
res_type = sys.modules[module]
|
||||||
else:
|
else:
|
||||||
res_type = sys.modules["diffusers"].pipelines
|
res_type = sys.modules["diffusers"].pipelines
|
||||||
|
@ -36,8 +36,18 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||||
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||||
|
@ModelLoaderRegistry.register(
|
||||||
|
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
|
||||||
|
)
|
||||||
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||||
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||||
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||||
|
@ModelLoaderRegistry.register(
|
||||||
|
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint
|
||||||
|
)
|
||||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||||
"""Class to load main models."""
|
"""Class to load main models."""
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
|
||||||
|
|
||||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||||
@ -50,6 +50,17 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
return model.calc_size()
|
return model.calc_size()
|
||||||
|
elif isinstance(
|
||||||
|
model,
|
||||||
|
(
|
||||||
|
T5TokenizerFast,
|
||||||
|
T5Tokenizer,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
# HACK(ryand): len(model) just returns the vocabulary size, so this is blatantly wrong. It should be small
|
||||||
|
# relative to the text encoder that it's used with, so shouldn't matter too much, but we should fix this at some
|
||||||
|
# point.
|
||||||
|
return len(model)
|
||||||
else:
|
else:
|
||||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||||
# supported model types.
|
# supported model types.
|
||||||
|
@ -95,6 +95,7 @@ class ModelProbe(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
CLASS2TYPE = {
|
CLASS2TYPE = {
|
||||||
|
"FluxPipeline": ModelType.Main,
|
||||||
"StableDiffusionPipeline": ModelType.Main,
|
"StableDiffusionPipeline": ModelType.Main,
|
||||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||||
"StableDiffusionXLPipeline": ModelType.Main,
|
"StableDiffusionXLPipeline": ModelType.Main,
|
||||||
@ -106,6 +107,7 @@ class ModelProbe(object):
|
|||||||
"ControlNetModel": ModelType.ControlNet,
|
"ControlNetModel": ModelType.ControlNet,
|
||||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||||
"T2IAdapter": ModelType.T2IAdapter,
|
"T2IAdapter": ModelType.T2IAdapter,
|
||||||
|
"CLIPModel": ModelType.CLIPEmbed,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -161,7 +163,7 @@ class ModelProbe(object):
|
|||||||
fields["description"] = (
|
fields["description"] = (
|
||||||
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
|
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
|
||||||
)
|
)
|
||||||
fields["format"] = fields.get("format") or probe.get_format()
|
fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
|
||||||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||||
|
|
||||||
fields["default_settings"] = fields.get("default_settings")
|
fields["default_settings"] = fields.get("default_settings")
|
||||||
@ -176,10 +178,10 @@ class ModelProbe(object):
|
|||||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||||
|
|
||||||
# additional fields needed for main and controlnet models
|
# additional fields needed for main and controlnet models
|
||||||
if (
|
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
|
||||||
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
|
ModelFormat.Checkpoint,
|
||||||
and fields["format"] is ModelFormat.Checkpoint
|
ModelFormat.BnbQuantizednf4b,
|
||||||
):
|
]:
|
||||||
ckpt_config_path = cls._get_checkpoint_config_path(
|
ckpt_config_path = cls._get_checkpoint_config_path(
|
||||||
model_path,
|
model_path,
|
||||||
model_type=fields["type"],
|
model_type=fields["type"],
|
||||||
@ -222,7 +224,8 @@ class ModelProbe(object):
|
|||||||
ckpt = ckpt.get("state_dict", ckpt)
|
ckpt = ckpt.get("state_dict", ckpt)
|
||||||
|
|
||||||
for key in [str(k) for k in ckpt.keys()]:
|
for key in [str(k) for k in ckpt.keys()]:
|
||||||
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")):
|
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.", "double_blocks.")):
|
||||||
|
# Keys starting with double_blocks are associated with Flux models
|
||||||
return ModelType.Main
|
return ModelType.Main
|
||||||
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
||||||
return ModelType.VAE
|
return ModelType.VAE
|
||||||
@ -321,6 +324,23 @@ class ModelProbe(object):
|
|||||||
return possible_conf.absolute()
|
return possible_conf.absolute()
|
||||||
|
|
||||||
if model_type is ModelType.Main:
|
if model_type is ModelType.Main:
|
||||||
|
if base_type == BaseModelType.Flux:
|
||||||
|
# TODO: Decide between dev/schnell
|
||||||
|
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||||
|
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||||
|
if "guidance_in.out_layer.weight" in state_dict:
|
||||||
|
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||||
|
# Due to model type and format being the descriminator for model configs this
|
||||||
|
# is used rather than attempting to support flux with separate model types and format
|
||||||
|
# If changed in the future, please fix me
|
||||||
|
config_file = "flux-dev"
|
||||||
|
else:
|
||||||
|
# For flux, this is a key in invokeai.backend.flux.util.params
|
||||||
|
# Due to model type and format being the descriminator for model configs this
|
||||||
|
# is used rather than attempting to support flux with separate model types and format
|
||||||
|
# If changed in the future, please fix me
|
||||||
|
config_file = "flux-schnell"
|
||||||
|
else:
|
||||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||||
config_file = config_file[prediction_type]
|
config_file = config_file[prediction_type]
|
||||||
@ -333,7 +353,13 @@ class ModelProbe(object):
|
|||||||
)
|
)
|
||||||
elif model_type is ModelType.VAE:
|
elif model_type is ModelType.VAE:
|
||||||
config_file = (
|
config_file = (
|
||||||
"stable-diffusion/v1-inference.yaml"
|
# For flux, this is a key in invokeai.backend.flux.util.ae_params
|
||||||
|
# Due to model type and format being the descriminator for model configs this
|
||||||
|
# is used rather than attempting to support flux with separate model types and format
|
||||||
|
# If changed in the future, please fix me
|
||||||
|
"flux"
|
||||||
|
if base_type is BaseModelType.Flux
|
||||||
|
else "stable-diffusion/v1-inference.yaml"
|
||||||
if base_type is BaseModelType.StableDiffusion1
|
if base_type is BaseModelType.StableDiffusion1
|
||||||
else "stable-diffusion/sd_xl_base.yaml"
|
else "stable-diffusion/sd_xl_base.yaml"
|
||||||
if base_type is BaseModelType.StableDiffusionXL
|
if base_type is BaseModelType.StableDiffusionXL
|
||||||
@ -416,11 +442,15 @@ class CheckpointProbeBase(ProbeBase):
|
|||||||
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||||
|
|
||||||
def get_format(self) -> ModelFormat:
|
def get_format(self) -> ModelFormat:
|
||||||
|
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||||
|
if "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict:
|
||||||
|
return ModelFormat.BnbQuantizednf4b
|
||||||
return ModelFormat("checkpoint")
|
return ModelFormat("checkpoint")
|
||||||
|
|
||||||
def get_variant_type(self) -> ModelVariantType:
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
|
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
|
||||||
if model_type != ModelType.Main:
|
base_type = self.get_base_type()
|
||||||
|
if model_type != ModelType.Main or base_type == BaseModelType.Flux:
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||||
@ -440,6 +470,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
|||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||||
|
if "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
|
||||||
|
return BaseModelType.Flux
|
||||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
@ -482,6 +514,7 @@ class VaeCheckpointProbe(CheckpointProbeBase):
|
|||||||
(r"xl", BaseModelType.StableDiffusionXL),
|
(r"xl", BaseModelType.StableDiffusionXL),
|
||||||
(r"sd2", BaseModelType.StableDiffusion2),
|
(r"sd2", BaseModelType.StableDiffusion2),
|
||||||
(r"vae", BaseModelType.StableDiffusion1),
|
(r"vae", BaseModelType.StableDiffusion1),
|
||||||
|
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
|
||||||
]:
|
]:
|
||||||
if re.search(regexp, self.model_path.name, re.IGNORECASE):
|
if re.search(regexp, self.model_path.name, re.IGNORECASE):
|
||||||
return basetype
|
return basetype
|
||||||
@ -713,6 +746,11 @@ class TextualInversionFolderProbe(FolderProbeBase):
|
|||||||
return TextualInversionCheckpointProbe(path).get_base_type()
|
return TextualInversionCheckpointProbe(path).get_base_type()
|
||||||
|
|
||||||
|
|
||||||
|
class T5EncoderFolderProbe(FolderProbeBase):
|
||||||
|
def get_format(self) -> ModelFormat:
|
||||||
|
return ModelFormat.T5Encoder
|
||||||
|
|
||||||
|
|
||||||
class ONNXFolderProbe(PipelineFolderProbe):
|
class ONNXFolderProbe(PipelineFolderProbe):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
# Due to the way the installer is set up, the configuration file for safetensors
|
# Due to the way the installer is set up, the configuration file for safetensors
|
||||||
@ -805,6 +843,11 @@ class CLIPVisionFolderProbe(FolderProbeBase):
|
|||||||
return BaseModelType.Any
|
return BaseModelType.Any
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPEmbedFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
return BaseModelType.Any
|
||||||
|
|
||||||
|
|
||||||
class SpandrelImageToImageFolderProbe(FolderProbeBase):
|
class SpandrelImageToImageFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@ -835,8 +878,10 @@ ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
|||||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.CLIPEmbed, CLIPEmbedFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
|
||||||
|
@ -2,7 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
|
||||||
|
|
||||||
|
|
||||||
class StarterModelWithoutDependencies(BaseModel):
|
class StarterModelWithoutDependencies(BaseModel):
|
||||||
@ -11,6 +11,7 @@ class StarterModelWithoutDependencies(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
base: BaseModelType
|
base: BaseModelType
|
||||||
type: ModelType
|
type: ModelType
|
||||||
|
format: Optional[ModelFormat] = None
|
||||||
is_installed: bool = False
|
is_installed: bool = False
|
||||||
|
|
||||||
|
|
||||||
@ -51,10 +52,76 @@ cyberrealistic_negative = StarterModel(
|
|||||||
type=ModelType.TextualInversion,
|
type=ModelType.TextualInversion,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
t5_base_encoder = StarterModel(
|
||||||
|
name="t5_base_encoder",
|
||||||
|
base=BaseModelType.Any,
|
||||||
|
source="InvokeAI/t5-v1_1-xxl::bfloat16",
|
||||||
|
description="T5-XXL text encoder (used in FLUX pipelines). ~8GB",
|
||||||
|
type=ModelType.T5Encoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
t5_8b_quantized_encoder = StarterModel(
|
||||||
|
name="t5_bnb_int8_quantized_encoder",
|
||||||
|
base=BaseModelType.Any,
|
||||||
|
source="InvokeAI/t5-v1_1-xxl::bnb_llm_int8",
|
||||||
|
description="T5-XXL text encoder with bitsandbytes LLM.int8() quantization (used in FLUX pipelines). ~5GB",
|
||||||
|
type=ModelType.T5Encoder,
|
||||||
|
format=ModelFormat.BnbQuantizedLlmInt8b,
|
||||||
|
)
|
||||||
|
|
||||||
|
clip_l_encoder = StarterModel(
|
||||||
|
name="clip-vit-large-patch14",
|
||||||
|
base=BaseModelType.Any,
|
||||||
|
source="InvokeAI/clip-vit-large-patch14-text-encoder::bfloat16",
|
||||||
|
description="CLIP-L text encoder (used in FLUX pipelines). ~250MB",
|
||||||
|
type=ModelType.CLIPEmbed,
|
||||||
|
)
|
||||||
|
|
||||||
|
flux_vae = StarterModel(
|
||||||
|
name="FLUX.1-schnell_ae",
|
||||||
|
base=BaseModelType.Flux,
|
||||||
|
source="black-forest-labs/FLUX.1-schnell::ae.safetensors",
|
||||||
|
description="FLUX VAE compatible with both schnell and dev variants.",
|
||||||
|
type=ModelType.VAE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# List of starter models, displayed on the frontend.
|
# List of starter models, displayed on the frontend.
|
||||||
# The order/sort of this list is not changed by the frontend - set it how you want it here.
|
# The order/sort of this list is not changed by the frontend - set it how you want it here.
|
||||||
STARTER_MODELS: list[StarterModel] = [
|
STARTER_MODELS: list[StarterModel] = [
|
||||||
# region: Main
|
# region: Main
|
||||||
|
StarterModel(
|
||||||
|
name="FLUX Schnell (Quantized)",
|
||||||
|
base=BaseModelType.Flux,
|
||||||
|
source="InvokeAI/flux_schnell::transformer/bnb_nf4/flux1-schnell-bnb_nf4.safetensors",
|
||||||
|
description="FLUX schnell transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
|
||||||
|
type=ModelType.Main,
|
||||||
|
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||||
|
),
|
||||||
|
StarterModel(
|
||||||
|
name="FLUX Dev (Quantized)",
|
||||||
|
base=BaseModelType.Flux,
|
||||||
|
source="InvokeAI/flux_dev::transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors",
|
||||||
|
description="FLUX dev transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
|
||||||
|
type=ModelType.Main,
|
||||||
|
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||||
|
),
|
||||||
|
StarterModel(
|
||||||
|
name="FLUX Schnell",
|
||||||
|
base=BaseModelType.Flux,
|
||||||
|
source="InvokeAI/flux_schnell::transformer/base/flux1-schnell.safetensors",
|
||||||
|
description="FLUX schnell transformer in bfloat16. Total size with dependencies: ~33GB",
|
||||||
|
type=ModelType.Main,
|
||||||
|
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||||
|
),
|
||||||
|
StarterModel(
|
||||||
|
name="FLUX Dev",
|
||||||
|
base=BaseModelType.Flux,
|
||||||
|
source="InvokeAI/flux_dev::transformer/base/flux1-dev.safetensors",
|
||||||
|
description="FLUX dev transformer in bfloat16. Total size with dependencies: ~33GB",
|
||||||
|
type=ModelType.Main,
|
||||||
|
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||||
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="CyberRealistic v4.1",
|
name="CyberRealistic v4.1",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
@ -125,6 +192,7 @@ STARTER_MODELS: list[StarterModel] = [
|
|||||||
# endregion
|
# endregion
|
||||||
# region VAE
|
# region VAE
|
||||||
sdxl_fp16_vae_fix,
|
sdxl_fp16_vae_fix,
|
||||||
|
flux_vae,
|
||||||
# endregion
|
# endregion
|
||||||
# region LoRA
|
# region LoRA
|
||||||
StarterModel(
|
StarterModel(
|
||||||
@ -450,6 +518,11 @@ STARTER_MODELS: list[StarterModel] = [
|
|||||||
type=ModelType.SpandrelImageToImage,
|
type=ModelType.SpandrelImageToImage,
|
||||||
),
|
),
|
||||||
# endregion
|
# endregion
|
||||||
|
# region TextEncoders
|
||||||
|
t5_base_encoder,
|
||||||
|
t5_8b_quantized_encoder,
|
||||||
|
clip_l_encoder,
|
||||||
|
# endregion
|
||||||
]
|
]
|
||||||
|
|
||||||
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"
|
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"
|
||||||
|
@ -54,6 +54,7 @@ def filter_files(
|
|||||||
"lora_weights.safetensors",
|
"lora_weights.safetensors",
|
||||||
"weights.pb",
|
"weights.pb",
|
||||||
"onnx_data",
|
"onnx_data",
|
||||||
|
"spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`.
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
paths.append(file)
|
paths.append(file)
|
||||||
@ -62,13 +63,13 @@ def filter_files(
|
|||||||
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
||||||
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
||||||
# will adhere to this naming convention, so this is an area to be careful of.
|
# will adhere to this naming convention, so this is an area to be careful of.
|
||||||
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||||
paths.append(file)
|
paths.append(file)
|
||||||
|
|
||||||
# limit search to subfolder if requested
|
# limit search to subfolder if requested
|
||||||
if subfolder:
|
if subfolder:
|
||||||
subfolder = root / subfolder
|
subfolder = root / subfolder
|
||||||
paths = [x for x in paths if x.parent == Path(subfolder)]
|
paths = [x for x in paths if Path(subfolder) in x.parents]
|
||||||
|
|
||||||
# _filter_by_variant uniquifies the paths and returns a set
|
# _filter_by_variant uniquifies the paths and returns a set
|
||||||
return sorted(_filter_by_variant(paths, variant))
|
return sorted(_filter_by_variant(paths, variant))
|
||||||
@ -97,7 +98,9 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
|||||||
if variant == ModelRepoVariant.Flax:
|
if variant == ModelRepoVariant.Flax:
|
||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
elif path.suffix in [".json", ".txt"]:
|
# Note: '.model' was added to support:
|
||||||
|
# https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/768d12a373ed5cc9ef9a9dea7504dc09fcc14842/tokenizer_2/spiece.model
|
||||||
|
elif path.suffix in [".json", ".txt", ".model"]:
|
||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
elif variant in [
|
elif variant in [
|
||||||
@ -140,6 +143,23 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for candidate_list in subfolder_weights.values():
|
for candidate_list in subfolder_weights.values():
|
||||||
|
# Check if at least one of the files has the explicit fp16 variant.
|
||||||
|
at_least_one_fp16 = False
|
||||||
|
for candidate in candidate_list:
|
||||||
|
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
|
||||||
|
at_least_one_fp16 = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not at_least_one_fp16:
|
||||||
|
# If none of the candidates in this candidate_list have the explicit fp16 variant label, then this
|
||||||
|
# candidate_list probably doesn't adhere to the variant naming convention that we expected. In this case,
|
||||||
|
# we'll simply keep all the candidates. An example of a model that hits this case is
|
||||||
|
# `black-forest-labs/FLUX.1-schnell` (as of commit 012d2fd).
|
||||||
|
for candidate in candidate_list:
|
||||||
|
result.add(candidate.path)
|
||||||
|
|
||||||
|
# The candidate_list seems to have the expected variant naming convention. We'll select the highest scoring
|
||||||
|
# candidate.
|
||||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||||
if highest_score_candidate:
|
if highest_score_candidate:
|
||||||
result.add(highest_score_candidate.path)
|
result.add(highest_score_candidate.path)
|
||||||
|
0
invokeai/backend/quantization/__init__.py
Normal file
0
invokeai/backend/quantization/__init__.py
Normal file
125
invokeai/backend/quantization/bnb_llm_int8.py
Normal file
125
invokeai/backend/quantization/bnb_llm_int8.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
import bitsandbytes as bnb
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# This file contains utils for working with models that use bitsandbytes LLM.int8() quantization.
|
||||||
|
# The utils in this file are partially inspired by:
|
||||||
|
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
|
||||||
|
# cleaner by re-implementing bnb.nn.Linear8bitLt with proper use of buffers and less magic. But, for now, we try to
|
||||||
|
# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeInt8Params(bnb.nn.Int8Params):
|
||||||
|
"""We override cuda() to avoid re-quantizing the weights in the following cases:
|
||||||
|
- We loaded quantized weights from a state_dict on the cpu, and then moved the model to the gpu.
|
||||||
|
- We are moving the model back-and-forth between the cpu and gpu.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def cuda(self, device):
|
||||||
|
if self.has_fp16_weights:
|
||||||
|
return super().cuda(device)
|
||||||
|
elif self.CB is not None and self.SCB is not None:
|
||||||
|
self.data = self.data.cuda()
|
||||||
|
self.CB = self.data
|
||||||
|
self.SCB = self.SCB.cuda()
|
||||||
|
else:
|
||||||
|
# we store the 8-bit rows-major weight
|
||||||
|
# we convert this weight to the turning/ampere weight during the first inference pass
|
||||||
|
B = self.data.contiguous().half().cuda(device)
|
||||||
|
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
|
||||||
|
del CBt
|
||||||
|
del SCBt
|
||||||
|
self.data = CB
|
||||||
|
self.CB = CB
|
||||||
|
self.SCB = SCB
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
|
||||||
|
def _load_from_state_dict(
|
||||||
|
self,
|
||||||
|
state_dict: dict[str, torch.Tensor],
|
||||||
|
prefix: str,
|
||||||
|
local_metadata,
|
||||||
|
strict,
|
||||||
|
missing_keys,
|
||||||
|
unexpected_keys,
|
||||||
|
error_msgs,
|
||||||
|
):
|
||||||
|
weight = state_dict.pop(prefix + "weight")
|
||||||
|
bias = state_dict.pop(prefix + "bias", None)
|
||||||
|
|
||||||
|
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
|
||||||
|
scb = state_dict.pop(prefix + "SCB", None)
|
||||||
|
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
|
||||||
|
_weight_format = state_dict.pop(prefix + "weight_format", None)
|
||||||
|
|
||||||
|
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||||
|
# rather than raising an exception to correctly implement this API.
|
||||||
|
assert len(state_dict) == 0
|
||||||
|
|
||||||
|
if scb is not None:
|
||||||
|
# We are loading a pre-quantized state dict.
|
||||||
|
self.weight = InvokeInt8Params(
|
||||||
|
data=weight,
|
||||||
|
requires_grad=self.weight.requires_grad,
|
||||||
|
has_fp16_weights=False,
|
||||||
|
# Note: After quantization, CB is the same as weight.
|
||||||
|
CB=weight,
|
||||||
|
SCB=scb,
|
||||||
|
)
|
||||||
|
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||||
|
else:
|
||||||
|
# We are loading a non-quantized state dict.
|
||||||
|
|
||||||
|
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
|
||||||
|
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
|
||||||
|
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
|
||||||
|
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
|
||||||
|
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
|
||||||
|
self.weight = InvokeInt8Params(
|
||||||
|
data=weight,
|
||||||
|
requires_grad=self.weight.requires_grad,
|
||||||
|
has_fp16_weights=False,
|
||||||
|
CB=None,
|
||||||
|
SCB=None,
|
||||||
|
)
|
||||||
|
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_linear_layers_to_llm_8bit(
|
||||||
|
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
|
||||||
|
) -> None:
|
||||||
|
"""Convert all linear layers in the module to bnb.nn.Linear8bitLt layers."""
|
||||||
|
for name, child in module.named_children():
|
||||||
|
fullname = f"{prefix}.{name}" if prefix else name
|
||||||
|
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||||
|
has_bias = child.bias is not None
|
||||||
|
replacement = InvokeLinear8bitLt(
|
||||||
|
child.in_features,
|
||||||
|
child.out_features,
|
||||||
|
bias=has_bias,
|
||||||
|
has_fp16_weights=False,
|
||||||
|
threshold=outlier_threshold,
|
||||||
|
)
|
||||||
|
replacement.weight.data = child.weight.data
|
||||||
|
if has_bias:
|
||||||
|
replacement.bias.data = child.bias.data
|
||||||
|
replacement.requires_grad_(False)
|
||||||
|
module.__setattr__(name, replacement)
|
||||||
|
else:
|
||||||
|
_convert_linear_layers_to_llm_8bit(
|
||||||
|
child, ignore_modules, outlier_threshold=outlier_threshold, prefix=fullname
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0):
|
||||||
|
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
|
||||||
|
_convert_linear_layers_to_llm_8bit(
|
||||||
|
module=model, ignore_modules=modules_to_not_convert, outlier_threshold=outlier_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
156
invokeai/backend/quantization/bnb_nf4.py
Normal file
156
invokeai/backend/quantization/bnb_nf4.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
import bitsandbytes as bnb
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# This file contains utils for working with models that use bitsandbytes NF4 quantization.
|
||||||
|
# The utils in this file are partially inspired by:
|
||||||
|
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
|
||||||
|
|
||||||
|
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
|
||||||
|
# cleaner by re-implementing bnb.nn.LinearNF4 with proper use of buffers and less magic. But, for now, we try to stick
|
||||||
|
# close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeLinearNF4(bnb.nn.LinearNF4):
|
||||||
|
"""A class that extends `bnb.nn.LinearNF4` to add the following functionality:
|
||||||
|
- Ability to load Linear NF4 layers from a pre-quantized state_dict.
|
||||||
|
- Ability to load Linear NF4 layers from a state_dict when the model is on the "meta" device.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _load_from_state_dict(
|
||||||
|
self,
|
||||||
|
state_dict: dict[str, torch.Tensor],
|
||||||
|
prefix: str,
|
||||||
|
local_metadata,
|
||||||
|
strict,
|
||||||
|
missing_keys,
|
||||||
|
unexpected_keys,
|
||||||
|
error_msgs,
|
||||||
|
):
|
||||||
|
"""This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`:
|
||||||
|
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71
|
||||||
|
"""
|
||||||
|
weight = state_dict.pop(prefix + "weight")
|
||||||
|
bias = state_dict.pop(prefix + "bias", None)
|
||||||
|
# We expect the remaining keys to be quant_state keys.
|
||||||
|
quant_state_sd = state_dict
|
||||||
|
|
||||||
|
# During serialization, the quant_state is stored as subkeys of "weight." (See
|
||||||
|
# `bnb.nn.LinearNF4._save_to_state_dict()`). We validate that they at least have the correct prefix.
|
||||||
|
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
|
||||||
|
# rather than raising an exception to correctly implement this API.
|
||||||
|
assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys())
|
||||||
|
|
||||||
|
if len(quant_state_sd) > 0:
|
||||||
|
# We are loading a pre-quantized state dict.
|
||||||
|
self.weight = bnb.nn.Params4bit.from_prequantized(
|
||||||
|
data=weight, quantized_stats=quant_state_sd, device=weight.device
|
||||||
|
)
|
||||||
|
self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False)
|
||||||
|
else:
|
||||||
|
# We are loading a non-quantized state dict.
|
||||||
|
|
||||||
|
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
|
||||||
|
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
|
||||||
|
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
|
||||||
|
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
|
||||||
|
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
|
||||||
|
self.weight = bnb.nn.Params4bit(
|
||||||
|
data=weight,
|
||||||
|
requires_grad=self.weight.requires_grad,
|
||||||
|
compress_statistics=self.weight.compress_statistics,
|
||||||
|
quant_type=self.weight.quant_type,
|
||||||
|
quant_storage=self.weight.quant_storage,
|
||||||
|
module=self,
|
||||||
|
)
|
||||||
|
self.bias = bias if bias is None else torch.nn.Parameter(bias)
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_param(
|
||||||
|
param: torch.nn.Parameter | bnb.nn.Params4bit,
|
||||||
|
data: torch.Tensor,
|
||||||
|
) -> torch.nn.Parameter:
|
||||||
|
"""A helper function to replace the data of a model parameter with new data in a way that allows replacing params on
|
||||||
|
the "meta" device.
|
||||||
|
|
||||||
|
Supports both `torch.nn.Parameter` and `bnb.nn.Params4bit` parameters.
|
||||||
|
"""
|
||||||
|
if param.device.type == "meta":
|
||||||
|
# Doing `param.data = data` raises a RuntimeError if param.data was on the "meta" device, so we need to
|
||||||
|
# re-create the param instead of overwriting the data.
|
||||||
|
if isinstance(param, bnb.nn.Params4bit):
|
||||||
|
return bnb.nn.Params4bit(
|
||||||
|
data,
|
||||||
|
requires_grad=data.requires_grad,
|
||||||
|
quant_state=param.quant_state,
|
||||||
|
compress_statistics=param.compress_statistics,
|
||||||
|
quant_type=param.quant_type,
|
||||||
|
)
|
||||||
|
return torch.nn.Parameter(data, requires_grad=data.requires_grad)
|
||||||
|
|
||||||
|
param.data = data
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_linear_layers_to_nf4(
|
||||||
|
module: torch.nn.Module,
|
||||||
|
ignore_modules: set[str],
|
||||||
|
compute_dtype: torch.dtype,
|
||||||
|
compress_statistics: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
"""Convert all linear layers in the model to NF4 quantized linear layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: All linear layers in this module will be converted.
|
||||||
|
ignore_modules: A set of module prefixes to ignore when converting linear layers.
|
||||||
|
compute_dtype: The dtype to use for computation in the quantized linear layers.
|
||||||
|
compress_statistics: Whether to enable nested quantization (aka double quantization) where the quantization
|
||||||
|
constants from the first quantization are quantized again.
|
||||||
|
prefix: The prefix of the current module in the model. Used to call this function recursively.
|
||||||
|
"""
|
||||||
|
for name, child in module.named_children():
|
||||||
|
fullname = f"{prefix}.{name}" if prefix else name
|
||||||
|
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
|
||||||
|
has_bias = child.bias is not None
|
||||||
|
replacement = InvokeLinearNF4(
|
||||||
|
child.in_features,
|
||||||
|
child.out_features,
|
||||||
|
bias=has_bias,
|
||||||
|
compute_dtype=compute_dtype,
|
||||||
|
compress_statistics=compress_statistics,
|
||||||
|
)
|
||||||
|
if has_bias:
|
||||||
|
replacement.bias = _replace_param(replacement.bias, child.bias.data)
|
||||||
|
replacement.weight = _replace_param(replacement.weight, child.weight.data)
|
||||||
|
replacement.requires_grad_(False)
|
||||||
|
module.__setattr__(name, replacement)
|
||||||
|
else:
|
||||||
|
_convert_linear_layers_to_nf4(child, ignore_modules, compute_dtype=compute_dtype, prefix=fullname)
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_model_nf4(model: torch.nn.Module, modules_to_not_convert: set[str], compute_dtype: torch.dtype):
|
||||||
|
"""Apply bitsandbytes nf4 quantization to the model.
|
||||||
|
|
||||||
|
You likely want to call this function inside a `accelerate.init_empty_weights()` context.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```
|
||||||
|
# Initialize the model from a config on the meta device.
|
||||||
|
with accelerate.init_empty_weights():
|
||||||
|
model = ModelClass.from_config(...)
|
||||||
|
|
||||||
|
# Add NF4 quantization linear layers to the model - still on the meta device.
|
||||||
|
with accelerate.init_empty_weights():
|
||||||
|
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.float16)
|
||||||
|
|
||||||
|
# Load a state_dict into the model. (Could be either a prequantized or non-quantized state_dict.)
|
||||||
|
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||||
|
|
||||||
|
# Move the model to the "cuda" device. If the model was non-quantized, this is where the weight quantization takes
|
||||||
|
# place.
|
||||||
|
model.to("cuda")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
_convert_linear_layers_to_nf4(module=model, ignore_modules=modules_to_not_convert, compute_dtype=compute_dtype)
|
||||||
|
|
||||||
|
return model
|
@ -0,0 +1,79 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import accelerate
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
|
from invokeai.backend.flux.model import Flux
|
||||||
|
from invokeai.backend.flux.util import params
|
||||||
|
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||||
|
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""A script for quantizing a FLUX transformer model using the bitsandbytes LLM.int8() quantization method.
|
||||||
|
|
||||||
|
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
|
||||||
|
etc.) are hardcoded and would need to be modified for other use cases.
|
||||||
|
"""
|
||||||
|
# Load the FLUX transformer model onto the meta device.
|
||||||
|
model_path = Path(
|
||||||
|
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
with log_time("Intialize FLUX transformer on meta device"):
|
||||||
|
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||||
|
p = params["flux-schnell"]
|
||||||
|
|
||||||
|
# Initialize the model on the "meta" device.
|
||||||
|
with accelerate.init_empty_weights():
|
||||||
|
model = Flux(p)
|
||||||
|
|
||||||
|
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||||
|
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||||
|
modules_to_not_convert: set[str] = set()
|
||||||
|
|
||||||
|
model_int8_path = model_path.parent / "bnb_llm_int8.safetensors"
|
||||||
|
if model_int8_path.exists():
|
||||||
|
# The quantized model already exists, load it and return it.
|
||||||
|
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
|
||||||
|
|
||||||
|
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
|
||||||
|
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||||
|
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
|
||||||
|
|
||||||
|
with log_time("Load state dict into model"):
|
||||||
|
sd = load_file(model_int8_path)
|
||||||
|
model.load_state_dict(sd, strict=True, assign=True)
|
||||||
|
|
||||||
|
with log_time("Move model to cuda"):
|
||||||
|
model = model.to("cuda")
|
||||||
|
|
||||||
|
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# The quantized model does not exist, quantize the model and save it.
|
||||||
|
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
|
||||||
|
|
||||||
|
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||||
|
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
|
||||||
|
|
||||||
|
with log_time("Load state dict into model"):
|
||||||
|
state_dict = load_file(model_path)
|
||||||
|
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||||
|
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||||
|
|
||||||
|
with log_time("Move model to cuda and quantize"):
|
||||||
|
model = model.to("cuda")
|
||||||
|
|
||||||
|
with log_time("Save quantized model"):
|
||||||
|
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_file(model.state_dict(), model_int8_path)
|
||||||
|
|
||||||
|
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
|
||||||
|
|
||||||
|
assert isinstance(model, Flux)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,96 @@
|
|||||||
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import accelerate
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
|
from invokeai.backend.flux.model import Flux
|
||||||
|
from invokeai.backend.flux.util import params
|
||||||
|
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def log_time(name: str):
|
||||||
|
"""Helper context manager to log the time taken by a block of code."""
|
||||||
|
start = time.time()
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
end = time.time()
|
||||||
|
print(f"'{name}' took {end - start:.4f} secs")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""A script for quantizing a FLUX transformer model using the bitsandbytes NF4 quantization method.
|
||||||
|
|
||||||
|
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
|
||||||
|
etc.) are hardcoded and would need to be modified for other use cases.
|
||||||
|
"""
|
||||||
|
model_path = Path(
|
||||||
|
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
# inference_dtype = torch.bfloat16
|
||||||
|
with log_time("Intialize FLUX transformer on meta device"):
|
||||||
|
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
|
||||||
|
p = params["flux-schnell"]
|
||||||
|
|
||||||
|
# Initialize the model on the "meta" device.
|
||||||
|
with accelerate.init_empty_weights():
|
||||||
|
model = Flux(p)
|
||||||
|
|
||||||
|
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||||
|
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||||
|
modules_to_not_convert: set[str] = set()
|
||||||
|
|
||||||
|
model_nf4_path = model_path.parent / "bnb_nf4.safetensors"
|
||||||
|
if model_nf4_path.exists():
|
||||||
|
# The quantized model already exists, load it and return it.
|
||||||
|
print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...")
|
||||||
|
|
||||||
|
# Replace the linear layers with NF4 quantized linear layers (still on the meta device).
|
||||||
|
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
||||||
|
model = quantize_model_nf4(
|
||||||
|
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
with log_time("Load state dict into model"):
|
||||||
|
state_dict = load_file(model_nf4_path)
|
||||||
|
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||||
|
|
||||||
|
with log_time("Move model to cuda"):
|
||||||
|
model = model.to("cuda")
|
||||||
|
|
||||||
|
print(f"Successfully loaded pre-quantized model from '{model_nf4_path}'.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# The quantized model does not exist, quantize the model and save it.
|
||||||
|
print(f"No pre-quantized model found at '{model_nf4_path}'. Quantizing the model...")
|
||||||
|
|
||||||
|
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
|
||||||
|
model = quantize_model_nf4(
|
||||||
|
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
with log_time("Load state dict into model"):
|
||||||
|
state_dict = load_file(model_path)
|
||||||
|
# TODO(ryand): Cast the state_dict to the appropriate dtype?
|
||||||
|
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||||
|
|
||||||
|
with log_time("Move model to cuda and quantize"):
|
||||||
|
model = model.to("cuda")
|
||||||
|
|
||||||
|
with log_time("Save quantized model"):
|
||||||
|
model_nf4_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_file(model.state_dict(), model_nf4_path)
|
||||||
|
|
||||||
|
print(f"Successfully quantized and saved model to '{model_nf4_path}'.")
|
||||||
|
|
||||||
|
assert isinstance(model, Flux)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,92 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import accelerate
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from transformers import AutoConfig, AutoModelForTextEncoding, T5EncoderModel
|
||||||
|
|
||||||
|
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
|
||||||
|
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict_into_t5(model: T5EncoderModel, state_dict: dict):
|
||||||
|
# There is a shared reference to a single weight tensor in the model.
|
||||||
|
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
|
||||||
|
# be present in the state_dict.
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
|
||||||
|
assert len(unexpected_keys) == 0
|
||||||
|
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
|
||||||
|
# Assert that the layers we expect to be shared are actually shared.
|
||||||
|
assert model.encoder.embed_tokens.weight is model.shared.weight
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""A script for quantizing a T5 text encoder model using the bitsandbytes LLM.int8() quantization method.
|
||||||
|
|
||||||
|
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
|
||||||
|
etc.) are hardcoded and would need to be modified for other use cases.
|
||||||
|
"""
|
||||||
|
model_path = Path("/data/misc/text_encoder_2")
|
||||||
|
|
||||||
|
with log_time("Intialize T5 on meta device"):
|
||||||
|
model_config = AutoConfig.from_pretrained(model_path)
|
||||||
|
with accelerate.init_empty_weights():
|
||||||
|
model = AutoModelForTextEncoding.from_config(model_config)
|
||||||
|
|
||||||
|
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
|
||||||
|
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
|
||||||
|
modules_to_not_convert: set[str] = set()
|
||||||
|
|
||||||
|
model_int8_path = model_path / "bnb_llm_int8.safetensors"
|
||||||
|
if model_int8_path.exists():
|
||||||
|
# The quantized model already exists, load it and return it.
|
||||||
|
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
|
||||||
|
|
||||||
|
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
|
||||||
|
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||||
|
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
|
||||||
|
|
||||||
|
with log_time("Load state dict into model"):
|
||||||
|
sd = load_file(model_int8_path)
|
||||||
|
load_state_dict_into_t5(model, sd)
|
||||||
|
|
||||||
|
with log_time("Move model to cuda"):
|
||||||
|
model = model.to("cuda")
|
||||||
|
|
||||||
|
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# The quantized model does not exist, quantize the model and save it.
|
||||||
|
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
|
||||||
|
|
||||||
|
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
|
||||||
|
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
|
||||||
|
|
||||||
|
with log_time("Load state dict into model"):
|
||||||
|
# Load sharded state dict.
|
||||||
|
files = list(model_path.glob("*.safetensors"))
|
||||||
|
state_dict = {}
|
||||||
|
for file in files:
|
||||||
|
sd = load_file(file)
|
||||||
|
state_dict.update(sd)
|
||||||
|
load_state_dict_into_t5(model, state_dict)
|
||||||
|
|
||||||
|
with log_time("Move model to cuda and quantize"):
|
||||||
|
model = model.to("cuda")
|
||||||
|
|
||||||
|
with log_time("Save quantized model"):
|
||||||
|
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
state_dict.pop("encoder.embed_tokens.weight")
|
||||||
|
save_file(state_dict, model_int8_path)
|
||||||
|
# This handling of shared weights could also be achieved with save_model(...), but then we'd lose control
|
||||||
|
# over which keys are kept. And, the corresponding load_model(...) function does not support assign=True.
|
||||||
|
# save_model(model, model_int8_path)
|
||||||
|
|
||||||
|
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
|
||||||
|
|
||||||
|
assert isinstance(model, T5EncoderModel)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -25,11 +25,6 @@ class BasicConditioningInfo:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ConditioningFieldData:
|
|
||||||
conditionings: List[BasicConditioningInfo]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||||
"""SDXL text conditioning information produced by Compel."""
|
"""SDXL text conditioning information produced by Compel."""
|
||||||
@ -43,6 +38,17 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
|||||||
return super().to(device=device, dtype=dtype)
|
return super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FLUXConditioningInfo:
|
||||||
|
clip_embeds: torch.Tensor
|
||||||
|
t5_embeds: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConditioningFieldData:
|
||||||
|
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IPAdapterConditioningInfo:
|
class IPAdapterConditioningInfo:
|
||||||
cond_image_prompt_embeds: torch.Tensor
|
cond_image_prompt_embeds: torch.Tensor
|
||||||
|
@ -12,6 +12,10 @@ module.exports = {
|
|||||||
'i18next/no-literal-string': 'error',
|
'i18next/no-literal-string': 'error',
|
||||||
// https://eslint.org/docs/latest/rules/no-console
|
// https://eslint.org/docs/latest/rules/no-console
|
||||||
'no-console': 'error',
|
'no-console': 'error',
|
||||||
|
// https://eslint.org/docs/latest/rules/no-promise-executor-return
|
||||||
|
'no-promise-executor-return': 'error',
|
||||||
|
// https://eslint.org/docs/latest/rules/require-await
|
||||||
|
'require-await': 'error',
|
||||||
},
|
},
|
||||||
overrides: [
|
overrides: [
|
||||||
/**
|
/**
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { PropsWithChildren, memo, useEffect } from 'react';
|
import { PropsWithChildren, memo, useEffect } from 'react';
|
||||||
import { modelChanged } from '../src/features/parameters/store/generationSlice';
|
import { modelChanged } from '../src/features/controlLayers/store/paramsSlice';
|
||||||
import { useAppDispatch } from '../src/app/store/storeHooks';
|
import { useAppDispatch } from '../src/app/store/storeHooks';
|
||||||
import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
|
import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
|
||||||
/**
|
/**
|
||||||
@ -10,7 +10,9 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
useGlobalModifiersInit();
|
useGlobalModifiersInit();
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
dispatch(modelChanged({ key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' }));
|
dispatch(
|
||||||
|
modelChanged({ model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' } })
|
||||||
|
);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
return props.children;
|
return props.children;
|
||||||
|
@ -9,6 +9,8 @@ const config: KnipConfig = {
|
|||||||
'src/services/api/schema.ts',
|
'src/services/api/schema.ts',
|
||||||
'src/features/nodes/types/v1/**',
|
'src/features/nodes/types/v1/**',
|
||||||
'src/features/nodes/types/v2/**',
|
'src/features/nodes/types/v2/**',
|
||||||
|
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
|
||||||
|
'src/features/controlLayers/konva/util.ts',
|
||||||
],
|
],
|
||||||
ignoreBinaries: ['only-allow'],
|
ignoreBinaries: ['only-allow'],
|
||||||
paths: {
|
paths: {
|
||||||
|
@ -24,7 +24,7 @@
|
|||||||
"build": "pnpm run lint && vite build",
|
"build": "pnpm run lint && vite build",
|
||||||
"typegen": "node scripts/typegen.js",
|
"typegen": "node scripts/typegen.js",
|
||||||
"preview": "vite preview",
|
"preview": "vite preview",
|
||||||
"lint:knip": "knip",
|
"lint:knip": "knip --tags=-knipignore",
|
||||||
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:1 src/main.tsx",
|
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:1 src/main.tsx",
|
||||||
"lint:eslint": "eslint --max-warnings=0 .",
|
"lint:eslint": "eslint --max-warnings=0 .",
|
||||||
"lint:prettier": "prettier --check .",
|
"lint:prettier": "prettier --check .",
|
||||||
@ -52,18 +52,19 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@chakra-ui/react-use-size": "^2.1.0",
|
|
||||||
"@dagrejs/dagre": "^1.1.3",
|
"@dagrejs/dagre": "^1.1.3",
|
||||||
"@dagrejs/graphlib": "^2.2.3",
|
"@dagrejs/graphlib": "^2.2.3",
|
||||||
"@dnd-kit/core": "^6.1.0",
|
"@dnd-kit/core": "^6.1.0",
|
||||||
"@dnd-kit/sortable": "^8.0.0",
|
"@dnd-kit/sortable": "^8.0.0",
|
||||||
"@dnd-kit/utilities": "^3.2.2",
|
"@dnd-kit/utilities": "^3.2.2",
|
||||||
"@fontsource-variable/inter": "^5.0.20",
|
"@fontsource-variable/inter": "^5.0.20",
|
||||||
"@invoke-ai/ui-library": "^0.0.29",
|
"@invoke-ai/ui-library": "^0.0.32",
|
||||||
"@nanostores/react": "^0.7.3",
|
"@nanostores/react": "^0.7.3",
|
||||||
"@reduxjs/toolkit": "2.2.3",
|
"@reduxjs/toolkit": "2.2.3",
|
||||||
"@roarr/browser-log-writer": "^1.3.0",
|
"@roarr/browser-log-writer": "^1.3.0",
|
||||||
|
"async-mutex": "^0.5.0",
|
||||||
"chakra-react-select": "^4.9.1",
|
"chakra-react-select": "^4.9.1",
|
||||||
|
"cmdk": "^1.0.0",
|
||||||
"compare-versions": "^6.1.1",
|
"compare-versions": "^6.1.1",
|
||||||
"dateformat": "^5.0.3",
|
"dateformat": "^5.0.3",
|
||||||
"fracturedjsonjs": "^4.0.2",
|
"fracturedjsonjs": "^4.0.2",
|
||||||
@ -74,6 +75,8 @@
|
|||||||
"jsondiffpatch": "^0.6.0",
|
"jsondiffpatch": "^0.6.0",
|
||||||
"konva": "^9.3.14",
|
"konva": "^9.3.14",
|
||||||
"lodash-es": "^4.17.21",
|
"lodash-es": "^4.17.21",
|
||||||
|
"lru-cache": "^11.0.0",
|
||||||
|
"nanoid": "^5.0.7",
|
||||||
"nanostores": "^0.11.2",
|
"nanostores": "^0.11.2",
|
||||||
"new-github-issue-url": "^1.0.0",
|
"new-github-issue-url": "^1.0.0",
|
||||||
"overlayscrollbars": "^2.10.0",
|
"overlayscrollbars": "^2.10.0",
|
||||||
@ -88,10 +91,8 @@
|
|||||||
"react-hotkeys-hook": "4.5.0",
|
"react-hotkeys-hook": "4.5.0",
|
||||||
"react-i18next": "^14.1.3",
|
"react-i18next": "^14.1.3",
|
||||||
"react-icons": "^5.2.1",
|
"react-icons": "^5.2.1",
|
||||||
"react-konva": "^18.2.10",
|
|
||||||
"react-redux": "9.1.2",
|
"react-redux": "9.1.2",
|
||||||
"react-resizable-panels": "^2.0.23",
|
"react-resizable-panels": "^2.0.23",
|
||||||
"react-select": "5.8.0",
|
|
||||||
"react-use": "^17.5.1",
|
"react-use": "^17.5.1",
|
||||||
"react-virtuoso": "^4.9.0",
|
"react-virtuoso": "^4.9.0",
|
||||||
"reactflow": "^11.11.4",
|
"reactflow": "^11.11.4",
|
||||||
@ -102,9 +103,9 @@
|
|||||||
"roarr": "^7.21.1",
|
"roarr": "^7.21.1",
|
||||||
"serialize-error": "^11.0.3",
|
"serialize-error": "^11.0.3",
|
||||||
"socket.io-client": "^4.7.5",
|
"socket.io-client": "^4.7.5",
|
||||||
|
"stable-hash": "^0.0.4",
|
||||||
"use-debounce": "^10.0.2",
|
"use-debounce": "^10.0.2",
|
||||||
"use-device-pixel-ratio": "^1.1.2",
|
"use-device-pixel-ratio": "^1.1.2",
|
||||||
"use-image": "^1.1.1",
|
|
||||||
"uuid": "^10.0.0",
|
"uuid": "^10.0.0",
|
||||||
"zod": "^3.23.8",
|
"zod": "^3.23.8",
|
||||||
"zod-validation-error": "^3.3.1"
|
"zod-validation-error": "^3.3.1"
|
||||||
|
626
invokeai/frontend/web/pnpm-lock.yaml
generated
626
invokeai/frontend/web/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@ -80,6 +80,7 @@
|
|||||||
"aboutDesc": "Using Invoke for work? Check out:",
|
"aboutDesc": "Using Invoke for work? Check out:",
|
||||||
"aboutHeading": "Own Your Creative Power",
|
"aboutHeading": "Own Your Creative Power",
|
||||||
"accept": "Accept",
|
"accept": "Accept",
|
||||||
|
"apply": "Apply",
|
||||||
"add": "Add",
|
"add": "Add",
|
||||||
"advanced": "Advanced",
|
"advanced": "Advanced",
|
||||||
"ai": "ai",
|
"ai": "ai",
|
||||||
@ -115,6 +116,7 @@
|
|||||||
"githubLabel": "Github",
|
"githubLabel": "Github",
|
||||||
"goTo": "Go to",
|
"goTo": "Go to",
|
||||||
"hotkeysLabel": "Hotkeys",
|
"hotkeysLabel": "Hotkeys",
|
||||||
|
"loadingImage": "Loading Image",
|
||||||
"imageFailedToLoad": "Unable to Load Image",
|
"imageFailedToLoad": "Unable to Load Image",
|
||||||
"img2img": "Image To Image",
|
"img2img": "Image To Image",
|
||||||
"inpaint": "inpaint",
|
"inpaint": "inpaint",
|
||||||
@ -325,6 +327,10 @@
|
|||||||
"canceled": "Canceled",
|
"canceled": "Canceled",
|
||||||
"completedIn": "Completed in",
|
"completedIn": "Completed in",
|
||||||
"batch": "Batch",
|
"batch": "Batch",
|
||||||
|
"origin": "Origin",
|
||||||
|
"originCanvas": "Canvas",
|
||||||
|
"originWorkflows": "Workflows",
|
||||||
|
"originOther": "Other",
|
||||||
"batchFieldValues": "Batch Field Values",
|
"batchFieldValues": "Batch Field Values",
|
||||||
"item": "Item",
|
"item": "Item",
|
||||||
"session": "Session",
|
"session": "Session",
|
||||||
@ -784,6 +790,7 @@
|
|||||||
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
|
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
|
||||||
"source": "Source",
|
"source": "Source",
|
||||||
"starterModels": "Starter Models",
|
"starterModels": "Starter Models",
|
||||||
|
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
|
||||||
"syncModels": "Sync Models",
|
"syncModels": "Sync Models",
|
||||||
"textualInversions": "Textual Inversions",
|
"textualInversions": "Textual Inversions",
|
||||||
"triggerPhrases": "Trigger Phrases",
|
"triggerPhrases": "Trigger Phrases",
|
||||||
@ -1095,7 +1102,6 @@
|
|||||||
"confirmOnDelete": "Confirm On Delete",
|
"confirmOnDelete": "Confirm On Delete",
|
||||||
"developer": "Developer",
|
"developer": "Developer",
|
||||||
"displayInProgress": "Display Progress Images",
|
"displayInProgress": "Display Progress Images",
|
||||||
"enableImageDebugging": "Enable Image Debugging",
|
|
||||||
"enableInformationalPopovers": "Enable Informational Popovers",
|
"enableInformationalPopovers": "Enable Informational Popovers",
|
||||||
"informationalPopoversDisabled": "Informational Popovers Disabled",
|
"informationalPopoversDisabled": "Informational Popovers Disabled",
|
||||||
"informationalPopoversDisabledDesc": "Informational popovers have been disabled. Enable them in Settings.",
|
"informationalPopoversDisabledDesc": "Informational popovers have been disabled. Enable them in Settings.",
|
||||||
@ -1562,7 +1568,7 @@
|
|||||||
"copyToClipboard": "Copy to Clipboard",
|
"copyToClipboard": "Copy to Clipboard",
|
||||||
"cursorPosition": "Cursor Position",
|
"cursorPosition": "Cursor Position",
|
||||||
"darkenOutsideSelection": "Darken Outside Selection",
|
"darkenOutsideSelection": "Darken Outside Selection",
|
||||||
"discardAll": "Discard All",
|
"discardAll": "Discard All & Cancel Pending Generations",
|
||||||
"discardCurrent": "Discard Current",
|
"discardCurrent": "Discard Current",
|
||||||
"downloadAsImage": "Download As Image",
|
"downloadAsImage": "Download As Image",
|
||||||
"enableMask": "Enable Mask",
|
"enableMask": "Enable Mask",
|
||||||
@ -1640,41 +1646,126 @@
|
|||||||
"storeNotInitialized": "Store is not initialized"
|
"storeNotInitialized": "Store is not initialized"
|
||||||
},
|
},
|
||||||
"controlLayers": {
|
"controlLayers": {
|
||||||
"deleteAll": "Delete All",
|
"clearHistory": "Clear History",
|
||||||
|
"generateMode": "Generate",
|
||||||
|
"generateModeDesc": "Create individual images. Generated images are added directly to the gallery.",
|
||||||
|
"composeMode": "Compose",
|
||||||
|
"composeModeDesc": "Compose your work iterative. Generated images are added back to the canvas.",
|
||||||
|
"autoSave": "Auto-save to Gallery",
|
||||||
|
"resetCanvas": "Reset Canvas",
|
||||||
|
"resetAll": "Reset All",
|
||||||
|
"clearCaches": "Clear Caches",
|
||||||
|
"recalculateRects": "Recalculate Rects",
|
||||||
|
"clipToBbox": "Clip Strokes to Bbox",
|
||||||
"addLayer": "Add Layer",
|
"addLayer": "Add Layer",
|
||||||
|
"duplicate": "Duplicate",
|
||||||
"moveToFront": "Move to Front",
|
"moveToFront": "Move to Front",
|
||||||
"moveToBack": "Move to Back",
|
"moveToBack": "Move to Back",
|
||||||
"moveForward": "Move Forward",
|
"moveForward": "Move Forward",
|
||||||
"moveBackward": "Move Backward",
|
"moveBackward": "Move Backward",
|
||||||
"brushSize": "Brush Size",
|
"brushSize": "Brush Size",
|
||||||
|
"width": "Width",
|
||||||
|
"zoom": "Zoom",
|
||||||
|
"resetView": "Reset View",
|
||||||
"controlLayers": "Control Layers",
|
"controlLayers": "Control Layers",
|
||||||
"globalMaskOpacity": "Global Mask Opacity",
|
"globalMaskOpacity": "Global Mask Opacity",
|
||||||
"autoNegative": "Auto Negative",
|
"autoNegative": "Auto Negative",
|
||||||
|
"enableAutoNegative": "Enable Auto Negative",
|
||||||
|
"disableAutoNegative": "Disable Auto Negative",
|
||||||
"deletePrompt": "Delete Prompt",
|
"deletePrompt": "Delete Prompt",
|
||||||
"resetRegion": "Reset Region",
|
"resetRegion": "Reset Region",
|
||||||
"debugLayers": "Debug Layers",
|
"debugLayers": "Debug Layers",
|
||||||
"rectangle": "Rectangle",
|
"rectangle": "Rectangle",
|
||||||
"maskPreviewColor": "Mask Preview Color",
|
"maskFill": "Mask Fill",
|
||||||
"addPositivePrompt": "Add $t(common.positivePrompt)",
|
"addPositivePrompt": "Add $t(common.positivePrompt)",
|
||||||
"addNegativePrompt": "Add $t(common.negativePrompt)",
|
"addNegativePrompt": "Add $t(common.negativePrompt)",
|
||||||
"addIPAdapter": "Add $t(common.ipAdapter)",
|
"addIPAdapter": "Add $t(common.ipAdapter)",
|
||||||
"regionalGuidance": "Regional Guidance",
|
|
||||||
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",
|
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",
|
||||||
|
"raster": "Raster",
|
||||||
|
"rasterLayer_one": "Raster Layer",
|
||||||
|
"controlLayer_one": "Control Layer",
|
||||||
|
"inpaintMask_one": "Inpaint Mask",
|
||||||
|
"regionalGuidance_one": "Regional Guidance",
|
||||||
|
"ipAdapter_one": "IP Adapter",
|
||||||
|
"rasterLayer_other": "Raster Layers",
|
||||||
|
"controlLayer_other": "Control Layers",
|
||||||
|
"inpaintMask_other": "Inpaint Masks",
|
||||||
|
"regionalGuidance_other": "Regional Guidance",
|
||||||
|
"ipAdapter_other": "IP Adapters",
|
||||||
"opacity": "Opacity",
|
"opacity": "Opacity",
|
||||||
|
"regionalGuidance_withCount_hidden": "Regional Guidance ({{count}} hidden)",
|
||||||
|
"controlAdapters_withCount_hidden": "Control Adapters ({{count}} hidden)",
|
||||||
|
"controlLayers_withCount_hidden": "Control Layers ({{count}} hidden)",
|
||||||
|
"rasterLayers_withCount_hidden": "Raster Layers ({{count}} hidden)",
|
||||||
|
"ipAdapters_withCount_hidden": "IP Adapters ({{count}} hidden)",
|
||||||
|
"inpaintMasks_withCount_hidden": "Inpaint Masks ({{count}} hidden)",
|
||||||
|
"regionalGuidance_withCount_visible": "Regional Guidance ({{count}})",
|
||||||
|
"controlAdapters_withCount_visible": "Control Adapters ({{count}})",
|
||||||
|
"controlLayers_withCount_visible": "Control Layers ({{count}})",
|
||||||
|
"rasterLayers_withCount_visible": "Raster Layers ({{count}})",
|
||||||
|
"ipAdapters_withCount_visible": "IP Adapters ({{count}})",
|
||||||
|
"inpaintMasks_withCount_visible": "Inpaint Masks ({{count}})",
|
||||||
"globalControlAdapter": "Global $t(controlnet.controlAdapter_one)",
|
"globalControlAdapter": "Global $t(controlnet.controlAdapter_one)",
|
||||||
"globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
|
"globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
|
||||||
"globalIPAdapter": "Global $t(common.ipAdapter)",
|
"globalIPAdapter": "Global $t(common.ipAdapter)",
|
||||||
"globalIPAdapterLayer": "Global $t(common.ipAdapter) $t(unifiedCanvas.layer)",
|
"globalIPAdapterLayer": "Global $t(common.ipAdapter) $t(unifiedCanvas.layer)",
|
||||||
"globalInitialImage": "Global Initial Image",
|
"globalInitialImage": "Global Initial Image",
|
||||||
"globalInitialImageLayer": "$t(controlLayers.globalInitialImage) $t(unifiedCanvas.layer)",
|
"globalInitialImageLayer": "$t(controlLayers.globalInitialImage) $t(unifiedCanvas.layer)",
|
||||||
|
"layer": "Layer",
|
||||||
"opacityFilter": "Opacity Filter",
|
"opacityFilter": "Opacity Filter",
|
||||||
"clearProcessor": "Clear Processor",
|
"clearProcessor": "Clear Processor",
|
||||||
"resetProcessor": "Reset Processor to Defaults",
|
"resetProcessor": "Reset Processor to Defaults",
|
||||||
"noLayersAdded": "No Layers Added",
|
"noLayersAdded": "No Layers Added",
|
||||||
"layers_one": "Layer",
|
"layers_one": "Layer",
|
||||||
"layers_other": "Layers"
|
"layers_other": "Layers",
|
||||||
|
"objects_zero": "empty",
|
||||||
|
"objects_one": "{{count}} object",
|
||||||
|
"objects_other": "{{count}} objects",
|
||||||
|
"convertToControlLayer": "Convert to Control Layer",
|
||||||
|
"convertToRasterLayer": "Convert to Raster Layer",
|
||||||
|
"transparency": "Transparency",
|
||||||
|
"enableTransparencyEffect": "Enable Transparency Effect",
|
||||||
|
"disableTransparencyEffect": "Disable Transparency Effect",
|
||||||
|
"hidingType": "Hiding {{type}}",
|
||||||
|
"showingType": "Showing {{type}}",
|
||||||
|
"dynamicGrid": "Dynamic Grid",
|
||||||
|
"logDebugInfo": "Log Debug Info",
|
||||||
|
"locked": "Locked",
|
||||||
|
"unlocked": "Unlocked",
|
||||||
|
"deleteSelected": "Delete Selected",
|
||||||
|
"deleteAll": "Delete All",
|
||||||
|
"flipHorizontal": "Flip Horizontal",
|
||||||
|
"flipVertical": "Flip Vertical",
|
||||||
|
"fill": {
|
||||||
|
"fillStyle": "Fill Style",
|
||||||
|
"solid": "Solid",
|
||||||
|
"grid": "Grid",
|
||||||
|
"crosshatch": "Crosshatch",
|
||||||
|
"vertical": "Vertical",
|
||||||
|
"horizontal": "Horizontal",
|
||||||
|
"diagonal": "Diagonal"
|
||||||
|
},
|
||||||
|
"tool": {
|
||||||
|
"brush": "Brush",
|
||||||
|
"eraser": "Eraser",
|
||||||
|
"rectangle": "Rectangle",
|
||||||
|
"bbox": "Bbox",
|
||||||
|
"move": "Move",
|
||||||
|
"view": "View",
|
||||||
|
"transform": "Transform",
|
||||||
|
"colorPicker": "Color Picker"
|
||||||
|
},
|
||||||
|
"filter": {
|
||||||
|
"filter": "Filter",
|
||||||
|
"filters": "Filters",
|
||||||
|
"filterType": "Filter Type",
|
||||||
|
"preview": "Preview",
|
||||||
|
"apply": "Apply",
|
||||||
|
"cancel": "Cancel"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"upscaling": {
|
"upscaling": {
|
||||||
|
"upscale": "Upscale",
|
||||||
"creativity": "Creativity",
|
"creativity": "Creativity",
|
||||||
"exceedsMaxSize": "Upscale settings exceed max size limit",
|
"exceedsMaxSize": "Upscale settings exceed max size limit",
|
||||||
"exceedsMaxSizeDetails": "Max upscale limit is {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixels. Please try a smaller image or decrease your scale selection.",
|
"exceedsMaxSizeDetails": "Max upscale limit is {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixels. Please try a smaller image or decrease your scale selection.",
|
||||||
@ -1723,6 +1814,7 @@
|
|||||||
"positivePrompt": "Positive Prompt",
|
"positivePrompt": "Positive Prompt",
|
||||||
"preview": "Preview",
|
"preview": "Preview",
|
||||||
"private": "Private",
|
"private": "Private",
|
||||||
|
"promptTemplateCleared": "Prompt Template Cleared",
|
||||||
"searchByName": "Search by name",
|
"searchByName": "Search by name",
|
||||||
"shared": "Shared",
|
"shared": "Shared",
|
||||||
"sharedTemplates": "Shared Templates",
|
"sharedTemplates": "Shared Templates",
|
||||||
@ -1758,5 +1850,30 @@
|
|||||||
"upscaling": "Upscaling",
|
"upscaling": "Upscaling",
|
||||||
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
|
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"system": {
|
||||||
|
"enableLogging": "Enable Logging",
|
||||||
|
"logLevel": {
|
||||||
|
"logLevel": "Log Level",
|
||||||
|
"trace": "Trace",
|
||||||
|
"debug": "Debug",
|
||||||
|
"info": "Info",
|
||||||
|
"warn": "Warn",
|
||||||
|
"error": "Error",
|
||||||
|
"fatal": "Fatal"
|
||||||
|
},
|
||||||
|
"logNamespaces": {
|
||||||
|
"logNamespaces": "Log Namespaces",
|
||||||
|
"gallery": "Gallery",
|
||||||
|
"models": "Models",
|
||||||
|
"config": "Config",
|
||||||
|
"canvas": "Canvas",
|
||||||
|
"generation": "Generation",
|
||||||
|
"workflows": "Workflows",
|
||||||
|
"system": "System",
|
||||||
|
"events": "Events",
|
||||||
|
"queue": "Queue",
|
||||||
|
"metadata": "Metadata"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -929,7 +929,7 @@
|
|||||||
"missingInvocationTemplate": "Modello di invocazione mancante",
|
"missingInvocationTemplate": "Modello di invocazione mancante",
|
||||||
"missingFieldTemplate": "Modello di campo mancante",
|
"missingFieldTemplate": "Modello di campo mancante",
|
||||||
"singleFieldType": "{{name}} (Singola)",
|
"singleFieldType": "{{name}} (Singola)",
|
||||||
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino delle impostazioni predefinite",
|
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino ai valori predefiniti",
|
||||||
"boardAccessError": "Impossibile trovare la bacheca {{board_id}}, ripristino ai valori predefiniti",
|
"boardAccessError": "Impossibile trovare la bacheca {{board_id}}, ripristino ai valori predefiniti",
|
||||||
"modelAccessError": "Impossibile trovare il modello {{key}}, ripristino ai valori predefiniti"
|
"modelAccessError": "Impossibile trovare il modello {{key}}, ripristino ai valori predefiniti"
|
||||||
},
|
},
|
||||||
@ -1782,7 +1782,13 @@
|
|||||||
"updatePromptTemplate": "Aggiorna il modello di prompt",
|
"updatePromptTemplate": "Aggiorna il modello di prompt",
|
||||||
"type": "Tipo",
|
"type": "Tipo",
|
||||||
"promptTemplatesDesc2": "Utilizza la stringa segnaposto <Pre>{{placeholder}}</Pre> per specificare dove inserire il tuo prompt nel modello.",
|
"promptTemplatesDesc2": "Utilizza la stringa segnaposto <Pre>{{placeholder}}</Pre> per specificare dove inserire il tuo prompt nel modello.",
|
||||||
"importTemplates": "Importa modelli di prompt",
|
"importTemplates": "Importa modelli di prompt (CSV/JSON)",
|
||||||
"importTemplatesDesc": "Il formato deve essere un CSV con colonne 'name' e 'prompt' o 'positive_prompt' e 'negative_prompt' incluse, oppure un file JSON con chiavi 'name' e 'prompt' o 'positive_prompt' e 'negative_prompt"
|
"exportDownloaded": "Esportazione completata",
|
||||||
|
"exportFailed": "Impossibile generare e scaricare il file CSV",
|
||||||
|
"exportPromptTemplates": "Esporta i miei modelli di prompt (CSV)",
|
||||||
|
"positivePromptColumn": "'prompt' o 'positive_prompt'",
|
||||||
|
"noTemplates": "Nessun modello",
|
||||||
|
"acceptedColumnsKeys": "Colonne/chiavi accettate:",
|
||||||
|
"templateActions": "Azioni modello"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -91,7 +91,8 @@
|
|||||||
"enabled": "Включено",
|
"enabled": "Включено",
|
||||||
"disabled": "Отключено",
|
"disabled": "Отключено",
|
||||||
"comparingDesc": "Сравнение двух изображений",
|
"comparingDesc": "Сравнение двух изображений",
|
||||||
"comparing": "Сравнение"
|
"comparing": "Сравнение",
|
||||||
|
"dontShowMeThese": "Не показывай мне это"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"galleryImageSize": "Размер изображений",
|
"galleryImageSize": "Размер изображений",
|
||||||
@ -153,7 +154,11 @@
|
|||||||
"showArchivedBoards": "Показать архивированные доски",
|
"showArchivedBoards": "Показать архивированные доски",
|
||||||
"searchImages": "Поиск по метаданным",
|
"searchImages": "Поиск по метаданным",
|
||||||
"displayBoardSearch": "Отобразить поиск досок",
|
"displayBoardSearch": "Отобразить поиск досок",
|
||||||
"displaySearch": "Отобразить поиск"
|
"displaySearch": "Отобразить поиск",
|
||||||
|
"exitBoardSearch": "Выйти из поиска досок",
|
||||||
|
"go": "Перейти",
|
||||||
|
"exitSearch": "Выйти из поиска",
|
||||||
|
"jump": "Пыгнуть"
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"keyboardShortcuts": "Горячие клавиши",
|
"keyboardShortcuts": "Горячие клавиши",
|
||||||
@ -376,6 +381,10 @@
|
|||||||
"toggleViewer": {
|
"toggleViewer": {
|
||||||
"title": "Переключить просмотр изображений",
|
"title": "Переключить просмотр изображений",
|
||||||
"desc": "Переключение между средством просмотра изображений и рабочей областью для текущей вкладки."
|
"desc": "Переключение между средством просмотра изображений и рабочей областью для текущей вкладки."
|
||||||
|
},
|
||||||
|
"postProcess": {
|
||||||
|
"desc": "Обработайте текущее изображение с помощью выбранной модели постобработки",
|
||||||
|
"title": "Обработать изображение"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"modelManager": {
|
"modelManager": {
|
||||||
@ -589,7 +598,10 @@
|
|||||||
"infillColorValue": "Цвет заливки",
|
"infillColorValue": "Цвет заливки",
|
||||||
"globalSettings": "Глобальные настройки",
|
"globalSettings": "Глобальные настройки",
|
||||||
"globalNegativePromptPlaceholder": "Глобальный негативный запрос",
|
"globalNegativePromptPlaceholder": "Глобальный негативный запрос",
|
||||||
"globalPositivePromptPlaceholder": "Глобальный запрос"
|
"globalPositivePromptPlaceholder": "Глобальный запрос",
|
||||||
|
"postProcessing": "Постобработка (Shift + U)",
|
||||||
|
"processImage": "Обработка изображения",
|
||||||
|
"sendToUpscale": "Отправить на увеличение"
|
||||||
},
|
},
|
||||||
"settings": {
|
"settings": {
|
||||||
"models": "Модели",
|
"models": "Модели",
|
||||||
@ -623,7 +635,9 @@
|
|||||||
"intermediatesCleared_many": "Очищено {{count}} промежуточных",
|
"intermediatesCleared_many": "Очищено {{count}} промежуточных",
|
||||||
"clearIntermediatesDesc1": "Очистка промежуточных элементов приведет к сбросу состояния Canvas и ControlNet.",
|
"clearIntermediatesDesc1": "Очистка промежуточных элементов приведет к сбросу состояния Canvas и ControlNet.",
|
||||||
"intermediatesClearedFailed": "Проблема очистки промежуточных",
|
"intermediatesClearedFailed": "Проблема очистки промежуточных",
|
||||||
"reloadingIn": "Перезагрузка через"
|
"reloadingIn": "Перезагрузка через",
|
||||||
|
"informationalPopoversDisabled": "Информационные всплывающие окна отключены",
|
||||||
|
"informationalPopoversDisabledDesc": "Информационные всплывающие окна были отключены. Включите их в Настройках."
|
||||||
},
|
},
|
||||||
"toast": {
|
"toast": {
|
||||||
"uploadFailed": "Загрузка не удалась",
|
"uploadFailed": "Загрузка не удалась",
|
||||||
@ -694,7 +708,9 @@
|
|||||||
"sessionRef": "Сессия: {{sessionId}}",
|
"sessionRef": "Сессия: {{sessionId}}",
|
||||||
"outOfMemoryError": "Ошибка нехватки памяти",
|
"outOfMemoryError": "Ошибка нехватки памяти",
|
||||||
"outOfMemoryErrorDesc": "Ваши текущие настройки генерации превышают возможности системы. Пожалуйста, измените настройки и повторите попытку.",
|
"outOfMemoryErrorDesc": "Ваши текущие настройки генерации превышают возможности системы. Пожалуйста, измените настройки и повторите попытку.",
|
||||||
"somethingWentWrong": "Что-то пошло не так"
|
"somethingWentWrong": "Что-то пошло не так",
|
||||||
|
"importFailed": "Импорт неудачен",
|
||||||
|
"importSuccessful": "Импорт успешен"
|
||||||
},
|
},
|
||||||
"tooltip": {
|
"tooltip": {
|
||||||
"feature": {
|
"feature": {
|
||||||
@ -1017,7 +1033,8 @@
|
|||||||
"composition": "Только композиция",
|
"composition": "Только композиция",
|
||||||
"hed": "HED",
|
"hed": "HED",
|
||||||
"beginEndStepPercentShort": "Начало/конец %",
|
"beginEndStepPercentShort": "Начало/конец %",
|
||||||
"setControlImageDimensionsForce": "Скопируйте размер в Ш/В (игнорируйте модель)"
|
"setControlImageDimensionsForce": "Скопируйте размер в Ш/В (игнорируйте модель)",
|
||||||
|
"depthAnythingSmallV2": "Small V2"
|
||||||
},
|
},
|
||||||
"boards": {
|
"boards": {
|
||||||
"autoAddBoard": "Авто добавление Доски",
|
"autoAddBoard": "Авто добавление Доски",
|
||||||
@ -1042,7 +1059,7 @@
|
|||||||
"downloadBoard": "Скачать доску",
|
"downloadBoard": "Скачать доску",
|
||||||
"deleteBoard": "Удалить доску",
|
"deleteBoard": "Удалить доску",
|
||||||
"deleteBoardAndImages": "Удалить доску и изображения",
|
"deleteBoardAndImages": "Удалить доску и изображения",
|
||||||
"deletedBoardsCannotbeRestored": "Удаленные доски не подлежат восстановлению",
|
"deletedBoardsCannotbeRestored": "Удаленные доски не могут быть восстановлены. Выбор «Удалить только доску» переведет изображения в состояние без категории.",
|
||||||
"assetsWithCount_one": "{{count}} ассет",
|
"assetsWithCount_one": "{{count}} ассет",
|
||||||
"assetsWithCount_few": "{{count}} ассета",
|
"assetsWithCount_few": "{{count}} ассета",
|
||||||
"assetsWithCount_many": "{{count}} ассетов",
|
"assetsWithCount_many": "{{count}} ассетов",
|
||||||
@ -1057,7 +1074,11 @@
|
|||||||
"boards": "Доски",
|
"boards": "Доски",
|
||||||
"addPrivateBoard": "Добавить личную доску",
|
"addPrivateBoard": "Добавить личную доску",
|
||||||
"private": "Личные доски",
|
"private": "Личные доски",
|
||||||
"shared": "Общие доски"
|
"shared": "Общие доски",
|
||||||
|
"hideBoards": "Скрыть доски",
|
||||||
|
"viewBoards": "Просмотреть доски",
|
||||||
|
"noBoards": "Нет досок {{boardType}}",
|
||||||
|
"deletedPrivateBoardsCannotbeRestored": "Удаленные доски не могут быть восстановлены. Выбор «Удалить только доску» переведет изображения в приватное состояние без категории для создателя изображения."
|
||||||
},
|
},
|
||||||
"dynamicPrompts": {
|
"dynamicPrompts": {
|
||||||
"seedBehaviour": {
|
"seedBehaviour": {
|
||||||
@ -1417,6 +1438,30 @@
|
|||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Метод, с помощью которого применяется текущий IP-адаптер."
|
"Метод, с помощью которого применяется текущий IP-адаптер."
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
"structure": {
|
||||||
|
"paragraphs": [
|
||||||
|
"Структура контролирует, насколько точно выходное изображение будет соответствовать макету оригинала. Низкая структура допускает значительные изменения, в то время как высокая структура строго сохраняет исходную композицию и макет."
|
||||||
|
],
|
||||||
|
"heading": "Структура"
|
||||||
|
},
|
||||||
|
"scale": {
|
||||||
|
"paragraphs": [
|
||||||
|
"Масштаб управляет размером выходного изображения и основывается на кратном разрешении входного изображения. Например, при увеличении в 2 раза изображения 1024x1024 на выходе получится 2048 x 2048."
|
||||||
|
],
|
||||||
|
"heading": "Масштаб"
|
||||||
|
},
|
||||||
|
"creativity": {
|
||||||
|
"paragraphs": [
|
||||||
|
"Креативность контролирует степень свободы, предоставляемой модели при добавлении деталей. При низкой креативности модель остается близкой к оригинальному изображению, в то время как высокая креативность позволяет вносить больше изменений. При использовании подсказки высокая креативность увеличивает влияние подсказки."
|
||||||
|
],
|
||||||
|
"heading": "Креативность"
|
||||||
|
},
|
||||||
|
"upscaleModel": {
|
||||||
|
"heading": "Модель увеличения",
|
||||||
|
"paragraphs": [
|
||||||
|
"Модель увеличения масштаба масштабирует изображение до выходного размера перед добавлением деталей. Можно использовать любую поддерживаемую модель масштабирования, но некоторые из них специализированы для различных видов изображений, например фотографий или линейных рисунков."
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -1693,7 +1738,78 @@
|
|||||||
"canvasTab": "$t(ui.tabs.canvas) $t(common.tab)",
|
"canvasTab": "$t(ui.tabs.canvas) $t(common.tab)",
|
||||||
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
|
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
|
||||||
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
|
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
|
||||||
"queue": "Очередь"
|
"queue": "Очередь",
|
||||||
}
|
"upscaling": "Увеличение",
|
||||||
|
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"upscaling": {
|
||||||
|
"exceedsMaxSize": "Параметры масштабирования превышают максимальный размер",
|
||||||
|
"exceedsMaxSizeDetails": "Максимальный предел масштабирования составляет {{maxUpscaleDimension}}x{{maxUpscaleDimension}} пикселей. Пожалуйста, попробуйте использовать меньшее изображение или уменьшите масштаб.",
|
||||||
|
"structure": "Структура",
|
||||||
|
"missingTileControlNetModel": "Не установлены подходящие модели ControlNet",
|
||||||
|
"missingUpscaleInitialImage": "Отсутствует увеличиваемое изображение",
|
||||||
|
"missingUpscaleModel": "Отсутствует увеличивающая модель",
|
||||||
|
"creativity": "Креативность",
|
||||||
|
"upscaleModel": "Модель увеличения",
|
||||||
|
"scale": "Масштаб",
|
||||||
|
"mainModelDesc": "Основная модель (архитектура SD1.5 или SDXL)",
|
||||||
|
"upscaleModelDesc": "Модель увеличения (img2img)",
|
||||||
|
"postProcessingModel": "Модель постобработки",
|
||||||
|
"tileControlNetModelDesc": "Модель ControlNet для выбранной архитектуры основной модели",
|
||||||
|
"missingModelsWarning": "Зайдите в <LinkComponent>Менеджер моделей</LinkComponent> чтоб установить необходимые модели:",
|
||||||
|
"postProcessingMissingModelWarning": "Посетите <LinkComponent>Менеджер моделей</LinkComponent>, чтобы установить модель постобработки (img2img)."
|
||||||
|
},
|
||||||
|
"stylePresets": {
|
||||||
|
"noMatchingTemplates": "Нет подходящих шаблонов",
|
||||||
|
"promptTemplatesDesc1": "Шаблоны подсказок добавляют текст к подсказкам, которые вы пишете в окне подсказок.",
|
||||||
|
"sharedTemplates": "Общие шаблоны",
|
||||||
|
"templateDeleted": "Шаблон запроса удален",
|
||||||
|
"toggleViewMode": "Переключить режим просмотра",
|
||||||
|
"type": "Тип",
|
||||||
|
"unableToDeleteTemplate": "Не получилось удалить шаблон запроса",
|
||||||
|
"viewModeTooltip": "Вот как будет выглядеть ваш запрос с выбранным шаблоном. Чтобы его отредактировать, щелкните в любом месте текстового поля.",
|
||||||
|
"viewList": "Просмотреть список шаблонов",
|
||||||
|
"active": "Активно",
|
||||||
|
"choosePromptTemplate": "Выберите шаблон запроса",
|
||||||
|
"defaultTemplates": "Стандартные шаблоны",
|
||||||
|
"deleteImage": "Удалить изображение",
|
||||||
|
"deleteTemplate": "Удалить шаблон",
|
||||||
|
"deleteTemplate2": "Вы уверены, что хотите удалить этот шаблон? Это нельзя отменить.",
|
||||||
|
"editTemplate": "Редактировать шаблон",
|
||||||
|
"exportPromptTemplates": "Экспорт моих шаблонов запроса (CSV)",
|
||||||
|
"exportDownloaded": "Экспорт скачан",
|
||||||
|
"exportFailed": "Невозможно сгенерировать и загрузить CSV",
|
||||||
|
"flatten": "Объединить выбранный шаблон с текущим запросом",
|
||||||
|
"acceptedColumnsKeys": "Принимаемые столбцы/ключи:",
|
||||||
|
"positivePromptColumn": "'prompt' или 'positive_prompt'",
|
||||||
|
"insertPlaceholder": "Вставить заполнитель",
|
||||||
|
"name": "Имя",
|
||||||
|
"negativePrompt": "Негативный запрос",
|
||||||
|
"promptTemplatesDesc3": "Если вы не используете заполнитель, шаблон будет добавлен в конец запроса.",
|
||||||
|
"positivePrompt": "Позитивный запрос",
|
||||||
|
"preview": "Предпросмотр",
|
||||||
|
"private": "Приватный",
|
||||||
|
"templateActions": "Действия с шаблоном",
|
||||||
|
"updatePromptTemplate": "Обновить шаблон запроса",
|
||||||
|
"uploadImage": "Загрузить изображение",
|
||||||
|
"useForTemplate": "Использовать для шаблона запроса",
|
||||||
|
"clearTemplateSelection": "Очистить выбор шаблона",
|
||||||
|
"copyTemplate": "Копировать шаблон",
|
||||||
|
"createPromptTemplate": "Создать шаблон запроса",
|
||||||
|
"importTemplates": "Импортировать шаблоны запроса (CSV/JSON)",
|
||||||
|
"nameColumn": "'name'",
|
||||||
|
"negativePromptColumn": "'negative_prompt'",
|
||||||
|
"myTemplates": "Мои шаблоны",
|
||||||
|
"noTemplates": "Нет шаблонов",
|
||||||
|
"promptTemplatesDesc2": "Используйте строку-заполнитель <Pre>{{placeholder}}</Pre>, чтобы указать место, куда должен быть включен ваш запрос в шаблоне.",
|
||||||
|
"searchByName": "Поиск по имени",
|
||||||
|
"shared": "Общий"
|
||||||
|
},
|
||||||
|
"upsell": {
|
||||||
|
"inviteTeammates": "Пригласите членов команды",
|
||||||
|
"professional": "Профессионал",
|
||||||
|
"professionalUpsell": "Доступно в профессиональной версии Invoke. Нажмите здесь или посетите invoke.com/pricing для получения более подробной информации.",
|
||||||
|
"shareAccess": "Поделиться доступом"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -38,7 +38,7 @@ async function generateTypes(schema) {
|
|||||||
process.stdout.write(`\nOK!\r\n`);
|
process.stdout.write(`\nOK!\r\n`);
|
||||||
}
|
}
|
||||||
|
|
||||||
async function main() {
|
function main() {
|
||||||
const encoding = 'utf-8';
|
const encoding = 'utf-8';
|
||||||
|
|
||||||
if (process.stdin.isTTY) {
|
if (process.stdin.isTTY) {
|
||||||
|
@ -6,6 +6,7 @@ import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/ap
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import type { PartialAppConfig } from 'app/types/invokeai';
|
import type { PartialAppConfig } from 'app/types/invokeai';
|
||||||
import ImageUploadOverlay from 'common/components/ImageUploadOverlay';
|
import ImageUploadOverlay from 'common/components/ImageUploadOverlay';
|
||||||
|
import { useScopeFocusWatcher } from 'common/hooks/interactionScopes';
|
||||||
import { useClearStorage } from 'common/hooks/useClearStorage';
|
import { useClearStorage } from 'common/hooks/useClearStorage';
|
||||||
import { useFullscreenDropzone } from 'common/hooks/useFullscreenDropzone';
|
import { useFullscreenDropzone } from 'common/hooks/useFullscreenDropzone';
|
||||||
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
|
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
|
||||||
@ -13,12 +14,15 @@ import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardMo
|
|||||||
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
|
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
|
||||||
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
|
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
|
||||||
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
||||||
|
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||||
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
|
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
|
||||||
|
import RefreshAfterResetModal from 'features/system/components/SettingsModal/RefreshAfterResetModal';
|
||||||
|
import SettingsModal from 'features/system/components/SettingsModal/SettingsModal';
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
import { selectLanguage } from 'features/system/store/systemSelectors';
|
||||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
import { AppContent } from 'features/ui/components/AppContent';
|
||||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
|
||||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||||
|
import type { TabName } from 'features/ui/store/uiTypes';
|
||||||
import { useGetAndLoadLibraryWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadLibraryWorkflow';
|
import { useGetAndLoadLibraryWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadLibraryWorkflow';
|
||||||
import { AnimatePresence } from 'framer-motion';
|
import { AnimatePresence } from 'framer-motion';
|
||||||
import i18n from 'i18n';
|
import i18n from 'i18n';
|
||||||
@ -39,11 +43,11 @@ interface Props {
|
|||||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||||
};
|
};
|
||||||
selectedWorkflowId?: string;
|
selectedWorkflowId?: string;
|
||||||
destination?: InvokeTabName | undefined;
|
destination?: TabName | undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, destination }: Props) => {
|
const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, destination }: Props) => {
|
||||||
const language = useAppSelector(languageSelector);
|
const language = useAppSelector(selectLanguage);
|
||||||
const logger = useLogger('system');
|
const logger = useLogger('system');
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const clearStorage = useClearStorage();
|
const clearStorage = useClearStorage();
|
||||||
@ -93,6 +97,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
|
|||||||
|
|
||||||
useStarterModelsToast();
|
useStarterModelsToast();
|
||||||
useSyncQueueStatus();
|
useSyncQueueStatus();
|
||||||
|
useScopeFocusWatcher();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
|
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
|
||||||
@ -105,7 +110,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
|
|||||||
{...dropzone.getRootProps()}
|
{...dropzone.getRootProps()}
|
||||||
>
|
>
|
||||||
<input {...dropzone.getInputProps()} />
|
<input {...dropzone.getInputProps()} />
|
||||||
<InvokeTabs />
|
<AppContent />
|
||||||
<AnimatePresence>
|
<AnimatePresence>
|
||||||
{dropzone.isDragActive && isHandlingUpload && (
|
{dropzone.isDragActive && isHandlingUpload && (
|
||||||
<ImageUploadOverlay dropzone={dropzone} setIsHandlingUpload={setIsHandlingUpload} />
|
<ImageUploadOverlay dropzone={dropzone} setIsHandlingUpload={setIsHandlingUpload} />
|
||||||
@ -116,7 +121,10 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
|
|||||||
<ChangeBoardModal />
|
<ChangeBoardModal />
|
||||||
<DynamicPromptsModal />
|
<DynamicPromptsModal />
|
||||||
<StylePresetModal />
|
<StylePresetModal />
|
||||||
|
<ClearQueueConfirmationsAlertDialog />
|
||||||
<PreselectedImage selectedImage={selectedImage} />
|
<PreselectedImage selectedImage={selectedImage} />
|
||||||
|
<SettingsModal />
|
||||||
|
<RefreshAfterResetModal />
|
||||||
</ErrorBoundary>
|
</ErrorBoundary>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import { Button, Flex, Heading, Image, Link, Text } from '@invoke-ai/ui-library';
|
import { Button, Flex, Heading, Image, Link, Text } from '@invoke-ai/ui-library';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import newGithubIssueUrl from 'new-github-issue-url';
|
import newGithubIssueUrl from 'new-github-issue-url';
|
||||||
import InvokeLogoYellow from 'public/assets/images/invoke-symbol-ylw-lrg.svg';
|
import InvokeLogoYellow from 'public/assets/images/invoke-symbol-ylw-lrg.svg';
|
||||||
@ -13,9 +15,11 @@ type Props = {
|
|||||||
resetErrorBoundary: () => void;
|
resetErrorBoundary: () => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const selectIsLocal = createSelector(selectConfigSlice, (config) => config.isLocal);
|
||||||
|
|
||||||
const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
|
const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const isLocal = useAppSelector((s) => s.config.isLocal);
|
const isLocal = useAppSelector(selectIsLocal);
|
||||||
|
|
||||||
const handleCopy = useCallback(() => {
|
const handleCopy = useCallback(() => {
|
||||||
const text = JSON.stringify(serializeError(error), null, 2);
|
const text = JSON.stringify(serializeError(error), null, 2);
|
||||||
|
@ -19,7 +19,7 @@ import type { PartialAppConfig } from 'app/types/invokeai';
|
|||||||
import Loading from 'common/components/Loading/Loading';
|
import Loading from 'common/components/Loading/Loading';
|
||||||
import AppDndContext from 'features/dnd/components/AppDndContext';
|
import AppDndContext from 'features/dnd/components/AppDndContext';
|
||||||
import type { WorkflowCategory } from 'features/nodes/types/workflow';
|
import type { WorkflowCategory } from 'features/nodes/types/workflow';
|
||||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
import type { TabName } from 'features/ui/store/uiTypes';
|
||||||
import type { PropsWithChildren, ReactNode } from 'react';
|
import type { PropsWithChildren, ReactNode } from 'react';
|
||||||
import React, { lazy, memo, useEffect, useMemo } from 'react';
|
import React, { lazy, memo, useEffect, useMemo } from 'react';
|
||||||
import { Provider } from 'react-redux';
|
import { Provider } from 'react-redux';
|
||||||
@ -45,7 +45,7 @@ interface Props extends PropsWithChildren {
|
|||||||
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
|
||||||
};
|
};
|
||||||
selectedWorkflowId?: string;
|
selectedWorkflowId?: string;
|
||||||
destination?: InvokeTabName;
|
destination?: TabName;
|
||||||
customStarUi?: CustomStarUi;
|
customStarUi?: CustomStarUi;
|
||||||
socketOptions?: Partial<ManagerOptions & SocketOptions>;
|
socketOptions?: Partial<ManagerOptions & SocketOptions>;
|
||||||
isDebugging?: boolean;
|
isDebugging?: boolean;
|
||||||
|
@ -2,7 +2,7 @@ import { useStore } from '@nanostores/react';
|
|||||||
import { $authToken } from 'app/store/nanostores/authToken';
|
import { $authToken } from 'app/store/nanostores/authToken';
|
||||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||||
import { $isDebugging } from 'app/store/nanostores/isDebugging';
|
import { $isDebugging } from 'app/store/nanostores/isDebugging';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppStore } from 'app/store/nanostores/store';
|
||||||
import type { MapStore } from 'nanostores';
|
import type { MapStore } from 'nanostores';
|
||||||
import { atom, map } from 'nanostores';
|
import { atom, map } from 'nanostores';
|
||||||
import { useEffect, useMemo } from 'react';
|
import { useEffect, useMemo } from 'react';
|
||||||
@ -18,14 +18,19 @@ declare global {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type AppSocket = Socket<ServerToClientEvents, ClientToServerEvents>;
|
||||||
|
|
||||||
|
export const $socket = atom<AppSocket | null>(null);
|
||||||
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
|
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
|
||||||
|
|
||||||
const $isSocketInitialized = atom<boolean>(false);
|
const $isSocketInitialized = atom<boolean>(false);
|
||||||
|
export const $isConnected = atom<boolean>(false);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initializes the socket.io connection and sets up event listeners.
|
* Initializes the socket.io connection and sets up event listeners.
|
||||||
*/
|
*/
|
||||||
export const useSocketIO = () => {
|
export const useSocketIO = () => {
|
||||||
const dispatch = useAppDispatch();
|
const { dispatch, getState } = useAppStore();
|
||||||
const baseUrl = useStore($baseUrl);
|
const baseUrl = useStore($baseUrl);
|
||||||
const authToken = useStore($authToken);
|
const authToken = useStore($authToken);
|
||||||
const addlSocketOptions = useStore($socketOptions);
|
const addlSocketOptions = useStore($socketOptions);
|
||||||
@ -61,8 +66,9 @@ export const useSocketIO = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const socket: Socket<ServerToClientEvents, ClientToServerEvents> = io(socketUrl, socketOptions);
|
const socket: AppSocket = io(socketUrl, socketOptions);
|
||||||
setEventListeners({ dispatch, socket });
|
$socket.set(socket);
|
||||||
|
setEventListeners({ socket, dispatch, getState, setIsConnected: $isConnected.set });
|
||||||
socket.connect();
|
socket.connect();
|
||||||
|
|
||||||
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
|
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
|
||||||
@ -84,5 +90,5 @@ export const useSocketIO = () => {
|
|||||||
socket.disconnect();
|
socket.disconnect();
|
||||||
$isSocketInitialized.set(false);
|
$isSocketInitialized.set(false);
|
||||||
};
|
};
|
||||||
}, [dispatch, socketOptions, socketUrl]);
|
}, [dispatch, getState, socketOptions, socketUrl]);
|
||||||
};
|
};
|
||||||
|
@ -15,21 +15,21 @@ export const BASE_CONTEXT = {};
|
|||||||
|
|
||||||
export const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
|
export const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
|
||||||
|
|
||||||
export type LoggerNamespace =
|
export const zLogNamespace = z.enum([
|
||||||
| 'images'
|
'canvas',
|
||||||
| 'models'
|
'config',
|
||||||
| 'config'
|
'events',
|
||||||
| 'canvas'
|
'gallery',
|
||||||
| 'generation'
|
'generation',
|
||||||
| 'nodes'
|
'metadata',
|
||||||
| 'system'
|
'models',
|
||||||
| 'socketio'
|
'system',
|
||||||
| 'session'
|
'queue',
|
||||||
| 'queue'
|
'workflows',
|
||||||
| 'dnd'
|
]);
|
||||||
| 'controlLayers';
|
export type LogNamespace = z.infer<typeof zLogNamespace>;
|
||||||
|
|
||||||
export const logger = (namespace: LoggerNamespace) => $logger.get().child({ namespace });
|
export const logger = (namespace: LogNamespace) => $logger.get().child({ namespace });
|
||||||
|
|
||||||
export const zLogLevel = z.enum(['trace', 'debug', 'info', 'warn', 'error', 'fatal']);
|
export const zLogLevel = z.enum(['trace', 'debug', 'info', 'warn', 'error', 'fatal']);
|
||||||
export type LogLevel = z.infer<typeof zLogLevel>;
|
export type LogLevel = z.infer<typeof zLogLevel>;
|
||||||
|
@ -1,29 +1,41 @@
|
|||||||
import { createLogWriter } from '@roarr/browser-log-writer';
|
import { createLogWriter } from '@roarr/browser-log-writer';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import {
|
||||||
|
selectSystemLogIsEnabled,
|
||||||
|
selectSystemLogLevel,
|
||||||
|
selectSystemLogNamespaces,
|
||||||
|
} from 'features/system/store/systemSlice';
|
||||||
import { useEffect, useMemo } from 'react';
|
import { useEffect, useMemo } from 'react';
|
||||||
import { ROARR, Roarr } from 'roarr';
|
import { ROARR, Roarr } from 'roarr';
|
||||||
|
|
||||||
import type { LoggerNamespace } from './logger';
|
import type { LogNamespace } from './logger';
|
||||||
import { $logger, BASE_CONTEXT, LOG_LEVEL_MAP, logger } from './logger';
|
import { $logger, BASE_CONTEXT, LOG_LEVEL_MAP, logger } from './logger';
|
||||||
|
|
||||||
export const useLogger = (namespace: LoggerNamespace) => {
|
export const useLogger = (namespace: LogNamespace) => {
|
||||||
const consoleLogLevel = useAppSelector((s) => s.system.consoleLogLevel);
|
const logLevel = useAppSelector(selectSystemLogLevel);
|
||||||
const shouldLogToConsole = useAppSelector((s) => s.system.shouldLogToConsole);
|
const logNamespaces = useAppSelector(selectSystemLogNamespaces);
|
||||||
|
const logIsEnabled = useAppSelector(selectSystemLogIsEnabled);
|
||||||
|
|
||||||
// The provided Roarr browser log writer uses localStorage to config logging to console
|
// The provided Roarr browser log writer uses localStorage to config logging to console
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (shouldLogToConsole) {
|
if (logIsEnabled) {
|
||||||
// Enable console log output
|
// Enable console log output
|
||||||
localStorage.setItem('ROARR_LOG', 'true');
|
localStorage.setItem('ROARR_LOG', 'true');
|
||||||
|
|
||||||
// Use a filter to show only logs of the given level
|
// Use a filter to show only logs of the given level
|
||||||
localStorage.setItem('ROARR_FILTER', `context.logLevel:>=${LOG_LEVEL_MAP[consoleLogLevel]}`);
|
let filter = `context.logLevel:>=${LOG_LEVEL_MAP[logLevel]}`;
|
||||||
|
if (logNamespaces.length > 0) {
|
||||||
|
filter += ` AND (${logNamespaces.map((ns) => `context.namespace:${ns}`).join(' OR ')})`;
|
||||||
|
} else {
|
||||||
|
filter += ' AND context.namespace:undefined';
|
||||||
|
}
|
||||||
|
localStorage.setItem('ROARR_FILTER', filter);
|
||||||
} else {
|
} else {
|
||||||
// Disable console log output
|
// Disable console log output
|
||||||
localStorage.setItem('ROARR_LOG', 'false');
|
localStorage.setItem('ROARR_LOG', 'false');
|
||||||
}
|
}
|
||||||
ROARR.write = createLogWriter();
|
ROARR.write = createLogWriter();
|
||||||
}, [consoleLogLevel, shouldLogToConsole]);
|
}, [logLevel, logIsEnabled, logNamespaces]);
|
||||||
|
|
||||||
// Update the module-scoped logger context as needed
|
// Update the module-scoped logger context as needed
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
import type { TabName } from 'features/ui/store/uiTypes';
|
||||||
|
|
||||||
export const enqueueRequested = createAction<{
|
export const enqueueRequested = createAction<{
|
||||||
tabName: InvokeTabName;
|
tabName: TabName;
|
||||||
prepend: boolean;
|
prepend: boolean;
|
||||||
}>('app/enqueueRequested');
|
}>('app/enqueueRequested');
|
||||||
|
@ -1,2 +1,3 @@
|
|||||||
export const STORAGE_PREFIX = '@@invokeai-';
|
export const STORAGE_PREFIX = '@@invokeai-';
|
||||||
export const EMPTY_ARRAY = [];
|
export const EMPTY_ARRAY = [];
|
||||||
|
export const EMPTY_OBJECT = {};
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { createDraftSafeSelectorCreator, createSelectorCreator, lruMemoize } from '@reduxjs/toolkit';
|
import { createDraftSafeSelectorCreator, createSelectorCreator, lruMemoize } from '@reduxjs/toolkit';
|
||||||
import type { GetSelectorsOptions } from '@reduxjs/toolkit/dist/entities/state_selectors';
|
import type { GetSelectorsOptions } from '@reduxjs/toolkit/dist/entities/state_selectors';
|
||||||
|
import type { RootState } from 'app/store/store';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -19,3 +20,5 @@ export const getSelectorsOptions: GetSelectorsOptions = {
|
|||||||
argsMemoize: lruMemoize,
|
argsMemoize: lruMemoize,
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const createMemoizedAppSelector = createMemoizedSelector.withTypes<RootState>();
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { parseify } from 'common/util/serialize';
|
|
||||||
import { PersistError, RehydrateError } from 'redux-remember';
|
import { PersistError, RehydrateError } from 'redux-remember';
|
||||||
import { serializeError } from 'serialize-error';
|
import { serializeError } from 'serialize-error';
|
||||||
|
|
||||||
@ -41,6 +40,6 @@ export const errorHandler = (err: PersistError | RehydrateError) => {
|
|||||||
} else if (err instanceof RehydrateError) {
|
} else if (err instanceof RehydrateError) {
|
||||||
log.error({ error: serializeError(err) }, 'Problem rehydrating state');
|
log.error({ error: serializeError(err) }, 'Problem rehydrating state');
|
||||||
} else {
|
} else {
|
||||||
log.error({ error: parseify(err) }, 'Problem in persistence layer');
|
log.error({ error: serializeError(err) }, 'Problem in persistence layer');
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
import type { UnknownAction } from '@reduxjs/toolkit';
|
import type { UnknownAction } from '@reduxjs/toolkit';
|
||||||
import { deepClone } from 'common/util/deepClone';
|
|
||||||
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||||
import type { Graph } from 'services/api/types';
|
import type { Graph } from 'services/api/types';
|
||||||
import { socketGeneratorProgress } from 'services/events/actions';
|
|
||||||
|
|
||||||
export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
|
export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
|
||||||
if (isAnyGraphBuilt(action)) {
|
if (isAnyGraphBuilt(action)) {
|
||||||
@ -24,13 +22,5 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (socketGeneratorProgress.match(action)) {
|
|
||||||
const sanitized = deepClone(action);
|
|
||||||
if (sanitized.payload.data.progress_image) {
|
|
||||||
sanitized.payload.data.progress_image.dataURL = '<Progress image omitted>';
|
|
||||||
}
|
|
||||||
return sanitized;
|
|
||||||
}
|
|
||||||
|
|
||||||
return action;
|
return action;
|
||||||
};
|
};
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import type { TypedStartListening } from '@reduxjs/toolkit';
|
import type { TypedStartListening } from '@reduxjs/toolkit';
|
||||||
import { createListenerMiddleware } from '@reduxjs/toolkit';
|
import { createListenerMiddleware } from '@reduxjs/toolkit';
|
||||||
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
|
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
|
||||||
import { addCommitStagingAreaImageListener } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
|
import { addStagingListeners } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
|
||||||
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
|
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
|
||||||
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
|
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
|
||||||
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
|
||||||
@ -9,17 +9,6 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
|
|||||||
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
|
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
|
||||||
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
|
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
|
||||||
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
||||||
import { addCanvasCopiedToClipboardListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard';
|
|
||||||
import { addCanvasDownloadedAsImageListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage';
|
|
||||||
import { addCanvasImageToControlNetListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet';
|
|
||||||
import { addCanvasMaskSavedToGalleryListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMaskSavedToGallery';
|
|
||||||
import { addCanvasMaskToControlNetListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet';
|
|
||||||
import { addCanvasMergedListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMerged';
|
|
||||||
import { addCanvasSavedToGalleryListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery';
|
|
||||||
import { addControlAdapterPreprocessor } from 'app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor';
|
|
||||||
import { addControlNetAutoProcessListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess';
|
|
||||||
import { addControlNetImageProcessedListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed';
|
|
||||||
import { addEnqueueRequestedCanvasListener } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas';
|
|
||||||
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
|
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
|
||||||
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
|
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
|
||||||
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
|
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
|
||||||
@ -37,16 +26,7 @@ import { addModelSelectedListener } from 'app/store/middleware/listenerMiddlewar
|
|||||||
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
|
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
|
||||||
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
|
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
|
||||||
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
|
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
|
||||||
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
|
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
|
||||||
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
|
|
||||||
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
|
|
||||||
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
|
|
||||||
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
|
|
||||||
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
|
|
||||||
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
|
|
||||||
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
|
|
||||||
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
|
|
||||||
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
|
|
||||||
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
|
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
|
||||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||||
import type { AppDispatch, RootState } from 'app/store/store';
|
import type { AppDispatch, RootState } from 'app/store/store';
|
||||||
@ -83,7 +63,6 @@ addGalleryImageClickedListener(startAppListening);
|
|||||||
addGalleryOffsetChangedListener(startAppListening);
|
addGalleryOffsetChangedListener(startAppListening);
|
||||||
|
|
||||||
// User Invoked
|
// User Invoked
|
||||||
addEnqueueRequestedCanvasListener(startAppListening);
|
|
||||||
addEnqueueRequestedNodes(startAppListening);
|
addEnqueueRequestedNodes(startAppListening);
|
||||||
addEnqueueRequestedLinear(startAppListening);
|
addEnqueueRequestedLinear(startAppListening);
|
||||||
addEnqueueRequestedUpscale(startAppListening);
|
addEnqueueRequestedUpscale(startAppListening);
|
||||||
@ -91,31 +70,22 @@ addAnyEnqueuedListener(startAppListening);
|
|||||||
addBatchEnqueuedListener(startAppListening);
|
addBatchEnqueuedListener(startAppListening);
|
||||||
|
|
||||||
// Canvas actions
|
// Canvas actions
|
||||||
addCanvasSavedToGalleryListener(startAppListening);
|
// addCanvasSavedToGalleryListener(startAppListening);
|
||||||
addCanvasMaskSavedToGalleryListener(startAppListening);
|
// addCanvasMaskSavedToGalleryListener(startAppListening);
|
||||||
addCanvasImageToControlNetListener(startAppListening);
|
// addCanvasImageToControlNetListener(startAppListening);
|
||||||
addCanvasMaskToControlNetListener(startAppListening);
|
// addCanvasMaskToControlNetListener(startAppListening);
|
||||||
addCanvasDownloadedAsImageListener(startAppListening);
|
// addCanvasDownloadedAsImageListener(startAppListening);
|
||||||
addCanvasCopiedToClipboardListener(startAppListening);
|
// addCanvasCopiedToClipboardListener(startAppListening);
|
||||||
addCanvasMergedListener(startAppListening);
|
// addCanvasMergedListener(startAppListening);
|
||||||
addStagingAreaImageSavedListener(startAppListening);
|
// addStagingAreaImageSavedListener(startAppListening);
|
||||||
addCommitStagingAreaImageListener(startAppListening);
|
// addCommitStagingAreaImageListener(startAppListening);
|
||||||
|
addStagingListeners(startAppListening);
|
||||||
|
|
||||||
// Socket.IO
|
// Socket.IO
|
||||||
addGeneratorProgressEventListener(startAppListening);
|
|
||||||
addInvocationCompleteEventListener(startAppListening);
|
|
||||||
addInvocationErrorEventListener(startAppListening);
|
|
||||||
addInvocationStartedEventListener(startAppListening);
|
|
||||||
addSocketConnectedEventListener(startAppListening);
|
addSocketConnectedEventListener(startAppListening);
|
||||||
addSocketDisconnectedEventListener(startAppListening);
|
|
||||||
addModelLoadEventListener(startAppListening);
|
|
||||||
addModelInstallEventListener(startAppListening);
|
|
||||||
addSocketQueueItemStatusChangedEventListener(startAppListening);
|
|
||||||
addBulkDownloadListeners(startAppListening);
|
|
||||||
|
|
||||||
// ControlNet
|
// Gallery bulk download
|
||||||
addControlNetImageProcessedListener(startAppListening);
|
addBulkDownloadListeners(startAppListening);
|
||||||
addControlNetAutoProcessListener(startAppListening);
|
|
||||||
|
|
||||||
// Boards
|
// Boards
|
||||||
addImageAddedToBoardFulfilledListener(startAppListening);
|
addImageAddedToBoardFulfilledListener(startAppListening);
|
||||||
@ -148,4 +118,4 @@ addAdHocPostProcessingRequestedListener(startAppListening);
|
|||||||
addDynamicPromptsListener(startAppListening);
|
addDynamicPromptsListener(startAppListening);
|
||||||
|
|
||||||
addSetDefaultSettingsListener(startAppListening);
|
addSetDefaultSettingsListener(startAppListening);
|
||||||
addControlAdapterPreprocessor(startAppListening);
|
// addControlAdapterPreprocessor(startAppListening);
|
||||||
|
@ -1,21 +1,21 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { parseify } from 'common/util/serialize';
|
import type { SerializableObject } from 'common/types';
|
||||||
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
|
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
import type { BatchConfig, ImageDTO } from 'services/api/types';
|
import type { BatchConfig, ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
|
const log = logger('queue');
|
||||||
|
|
||||||
export const adHocPostProcessingRequested = createAction<{ imageDTO: ImageDTO }>(`upscaling/postProcessingRequested`);
|
export const adHocPostProcessingRequested = createAction<{ imageDTO: ImageDTO }>(`upscaling/postProcessingRequested`);
|
||||||
|
|
||||||
export const addAdHocPostProcessingRequestedListener = (startAppListening: AppStartListening) => {
|
export const addAdHocPostProcessingRequestedListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: adHocPostProcessingRequested,
|
actionCreator: adHocPostProcessingRequested,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
const log = logger('session');
|
|
||||||
|
|
||||||
const { imageDTO } = action.payload;
|
const { imageDTO } = action.payload;
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
@ -39,9 +39,9 @@ export const addAdHocPostProcessingRequestedListener = (startAppListening: AppSt
|
|||||||
|
|
||||||
const enqueueResult = await req.unwrap();
|
const enqueueResult = await req.unwrap();
|
||||||
req.reset();
|
req.reset();
|
||||||
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
|
log.debug({ enqueueResult } as SerializableObject, t('queue.graphQueued'));
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
|
log.error({ enqueueBatchArg } as SerializableObject, t('queue.graphFailedToQueue'));
|
||||||
|
|
||||||
if (error instanceof Object && 'status' in error && error.status === 403) {
|
if (error instanceof Object && 'status' in error && error.status === 403) {
|
||||||
return;
|
return;
|
||||||
|
@ -23,7 +23,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
|||||||
*/
|
*/
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: matchAnyBoardDeleted,
|
matcher: matchAnyBoardDeleted,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const deletedBoardId = action.meta.arg.originalArgs;
|
const deletedBoardId = action.meta.arg.originalArgs;
|
||||||
const { autoAddBoardId, selectedBoardId } = state.gallery;
|
const { autoAddBoardId, selectedBoardId } = state.gallery;
|
||||||
@ -44,7 +44,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
|||||||
// If we archived a board, it may end up hidden. If it's selected or the auto-add board, we should reset those.
|
// If we archived a board, it may end up hidden. If it's selected or the auto-add board, we should reset those.
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: boardsApi.endpoints.updateBoard.matchFulfilled,
|
matcher: boardsApi.endpoints.updateBoard.matchFulfilled,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const { shouldShowArchivedBoards } = state.gallery;
|
const { shouldShowArchivedBoards } = state.gallery;
|
||||||
|
|
||||||
@ -61,7 +61,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
|||||||
// When we hide archived boards, if the selected or the auto-add board is archived, we should reset those.
|
// When we hide archived boards, if the selected or the auto-add board is archived, we should reset those.
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: shouldShowArchivedBoardsChanged,
|
actionCreator: shouldShowArchivedBoardsChanged,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const shouldShowArchivedBoards = action.payload;
|
const shouldShowArchivedBoards = action.payload;
|
||||||
|
|
||||||
// We only need to take action if we have just hidden archived boards.
|
// We only need to take action if we have just hidden archived boards.
|
||||||
@ -100,7 +100,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
|||||||
*/
|
*/
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: boardsApi.endpoints.listAllBoards.matchFulfilled,
|
matcher: boardsApi.endpoints.listAllBoards.matchFulfilled,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const boards = action.payload;
|
const boards = action.payload;
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const { selectedBoardId, autoAddBoardId } = state.gallery;
|
const { selectedBoardId, autoAddBoardId } = state.gallery;
|
||||||
|
@ -1,33 +1,37 @@
|
|||||||
import { isAnyOf } from '@reduxjs/toolkit';
|
|
||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import {
|
import {
|
||||||
canvasBatchIdsReset,
|
sessionStagingAreaImageAccepted,
|
||||||
commitStagingAreaImage,
|
sessionStagingAreaReset,
|
||||||
discardStagedImages,
|
} from 'features/controlLayers/store/canvasSessionSlice';
|
||||||
resetCanvas,
|
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
|
||||||
setInitialCanvasImage,
|
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||||
} from 'features/canvas/store/canvasSlice';
|
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||||
|
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
|
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
const matcher = isAnyOf(commitStagingAreaImage, discardStagedImages, resetCanvas, setInitialCanvasImage);
|
|
||||||
|
|
||||||
export const addCommitStagingAreaImageListener = (startAppListening: AppStartListening) => {
|
|
||||||
startAppListening({
|
|
||||||
matcher,
|
|
||||||
effect: async (_, { dispatch, getState }) => {
|
|
||||||
const log = logger('canvas');
|
const log = logger('canvas');
|
||||||
const state = getState();
|
|
||||||
const { batchIds } = state.canvas;
|
|
||||||
|
|
||||||
|
export const addStagingListeners = (startAppListening: AppStartListening) => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionStagingAreaReset,
|
||||||
|
effect: async (_, { dispatch }) => {
|
||||||
try {
|
try {
|
||||||
const req = dispatch(
|
const req = dispatch(
|
||||||
queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: batchIds }, { fixedCacheKey: 'cancelByBatchIds' })
|
queueApi.endpoints.cancelByBatchOrigin.initiate(
|
||||||
|
{ origin: 'canvas' },
|
||||||
|
{ fixedCacheKey: 'cancelByBatchOrigin' }
|
||||||
|
)
|
||||||
);
|
);
|
||||||
const { canceled } = await req.unwrap();
|
const { canceled } = await req.unwrap();
|
||||||
req.reset();
|
req.reset();
|
||||||
|
|
||||||
|
$lastCanvasProgressEvent.set(null);
|
||||||
|
|
||||||
if (canceled > 0) {
|
if (canceled > 0) {
|
||||||
log.debug(`Canceled ${canceled} canvas batches`);
|
log.debug(`Canceled ${canceled} canvas batches`);
|
||||||
toast({
|
toast({
|
||||||
@ -36,7 +40,6 @@ export const addCommitStagingAreaImageListener = (startAppListening: AppStartLis
|
|||||||
status: 'success',
|
status: 'success',
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
dispatch(canvasBatchIdsReset());
|
|
||||||
} catch {
|
} catch {
|
||||||
log.error('Failed to cancel canvas batches');
|
log.error('Failed to cancel canvas batches');
|
||||||
toast({
|
toast({
|
||||||
@ -47,4 +50,26 @@ export const addCommitStagingAreaImageListener = (startAppListening: AppStartLis
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionStagingAreaImageAccepted,
|
||||||
|
effect: (action, api) => {
|
||||||
|
const { index } = action.payload;
|
||||||
|
const state = api.getState();
|
||||||
|
const stagingAreaImage = state.canvasSession.stagedImages[index];
|
||||||
|
|
||||||
|
assert(stagingAreaImage, 'No staged image found to accept');
|
||||||
|
const { x, y } = selectCanvasSlice(state).bbox.rect;
|
||||||
|
|
||||||
|
const { imageDTO, offsetX, offsetY } = stagingAreaImage;
|
||||||
|
const imageObject = imageDTOToImageObject(imageDTO);
|
||||||
|
const overrides: Partial<CanvasRasterLayerState> = {
|
||||||
|
position: { x: x + offsetX, y: y + offsetY },
|
||||||
|
objects: [imageObject],
|
||||||
|
};
|
||||||
|
|
||||||
|
api.dispatch(rasterLayerAdded({ overrides, isSelected: true }));
|
||||||
|
api.dispatch(sessionStagingAreaReset());
|
||||||
|
},
|
||||||
|
});
|
||||||
};
|
};
|
||||||
|
@ -4,7 +4,7 @@ import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
|
|||||||
export const addAnyEnqueuedListener = (startAppListening: AppStartListening) => {
|
export const addAnyEnqueuedListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
||||||
effect: async (_, { dispatch, getState }) => {
|
effect: (_, { dispatch, getState }) => {
|
||||||
const { data } = selectQueueStatus(getState());
|
const { data } = selectQueueStatus(getState());
|
||||||
|
|
||||||
if (!data || data.processor.is_started) {
|
if (!data || data.processor.is_started) {
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { setInfillMethod } from 'features/parameters/store/generationSlice';
|
import { setInfillMethod } from 'features/controlLayers/store/paramsSlice';
|
||||||
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
|
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
|
||||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||||
|
|
||||||
export const addAppConfigReceivedListener = (startAppListening: AppStartListening) => {
|
export const addAppConfigReceivedListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
|
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: (action, { getState, dispatch }) => {
|
||||||
const { infill_methods = [], nsfw_methods = [], watermarking_methods = [] } = action.payload;
|
const { infill_methods = [], nsfw_methods = [], watermarking_methods = [] } = action.payload;
|
||||||
const infillMethod = getState().generation.infillMethod;
|
const infillMethod = getState().params.infillMethod;
|
||||||
|
|
||||||
if (!infill_methods.includes(infillMethod)) {
|
if (!infill_methods.includes(infillMethod)) {
|
||||||
// if there is no infill method, set it to the first one
|
// if there is no infill method, set it to the first one
|
||||||
|
@ -6,7 +6,7 @@ export const appStarted = createAction('app/appStarted');
|
|||||||
export const addAppStartedListener = (startAppListening: AppStartListening) => {
|
export const addAppStartedListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: appStarted,
|
actionCreator: appStarted,
|
||||||
effect: async (action, { unsubscribe, cancelActiveListeners }) => {
|
effect: (action, { unsubscribe, cancelActiveListeners }) => {
|
||||||
// this should only run once
|
// this should only run once
|
||||||
cancelActiveListeners();
|
cancelActiveListeners();
|
||||||
unsubscribe();
|
unsubscribe();
|
||||||
|
@ -1,27 +1,30 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { parseify } from 'common/util/serialize';
|
import type { SerializableObject } from 'common/types';
|
||||||
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
|
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { truncate, upperFirst } from 'lodash-es';
|
import { truncate, upperFirst } from 'lodash-es';
|
||||||
|
import { serializeError } from 'serialize-error';
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
|
|
||||||
|
const log = logger('queue');
|
||||||
|
|
||||||
export const addBatchEnqueuedListener = (startAppListening: AppStartListening) => {
|
export const addBatchEnqueuedListener = (startAppListening: AppStartListening) => {
|
||||||
// success
|
// success
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
||||||
effect: async (action) => {
|
effect: (action) => {
|
||||||
const response = action.payload;
|
const enqueueResult = action.payload;
|
||||||
const arg = action.meta.arg.originalArgs;
|
const arg = action.meta.arg.originalArgs;
|
||||||
logger('queue').debug({ enqueueResult: parseify(response) }, 'Batch enqueued');
|
log.debug({ enqueueResult } as SerializableObject, 'Batch enqueued');
|
||||||
|
|
||||||
toast({
|
toast({
|
||||||
id: 'QUEUE_BATCH_SUCCEEDED',
|
id: 'QUEUE_BATCH_SUCCEEDED',
|
||||||
title: t('queue.batchQueued'),
|
title: t('queue.batchQueued'),
|
||||||
status: 'success',
|
status: 'success',
|
||||||
description: t('queue.batchQueuedDesc', {
|
description: t('queue.batchQueuedDesc', {
|
||||||
count: response.enqueued,
|
count: enqueueResult.enqueued,
|
||||||
direction: arg.prepend ? t('queue.front') : t('queue.back'),
|
direction: arg.prepend ? t('queue.front') : t('queue.back'),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
@ -31,9 +34,9 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
|
|||||||
// error
|
// error
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: queueApi.endpoints.enqueueBatch.matchRejected,
|
matcher: queueApi.endpoints.enqueueBatch.matchRejected,
|
||||||
effect: async (action) => {
|
effect: (action) => {
|
||||||
const response = action.payload;
|
const response = action.payload;
|
||||||
const arg = action.meta.arg.originalArgs;
|
const batchConfig = action.meta.arg.originalArgs;
|
||||||
|
|
||||||
if (!response) {
|
if (!response) {
|
||||||
toast({
|
toast({
|
||||||
@ -42,7 +45,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
|
|||||||
status: 'error',
|
status: 'error',
|
||||||
description: t('common.unknownError'),
|
description: t('common.unknownError'),
|
||||||
});
|
});
|
||||||
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));
|
log.error({ batchConfig } as SerializableObject, t('queue.batchFailedToQueue'));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,7 +71,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
|
|||||||
description: t('common.unknownError'),
|
description: t('common.unknownError'),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));
|
log.error({ batchConfig, error: serializeError(response) } as SerializableObject, t('queue.batchFailedToQueue'));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,47 +1,31 @@
|
|||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||||
import { controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice';
|
|
||||||
import { allLayersDeleted } from 'features/controlLayers/store/controlLayersSlice';
|
|
||||||
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
|
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
|
||||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
|
||||||
export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppStartListening) => {
|
export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: imagesApi.endpoints.deleteBoardAndImages.matchFulfilled,
|
matcher: imagesApi.endpoints.deleteBoardAndImages.matchFulfilled,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const { deleted_images } = action.payload;
|
const { deleted_images } = action.payload;
|
||||||
|
|
||||||
// Remove all deleted images from the UI
|
// Remove all deleted images from the UI
|
||||||
|
|
||||||
let wasCanvasReset = false;
|
|
||||||
let wasNodeEditorReset = false;
|
let wasNodeEditorReset = false;
|
||||||
let wereControlAdaptersReset = false;
|
|
||||||
let wereControlLayersReset = false;
|
|
||||||
|
|
||||||
const { canvas, nodes, controlAdapters, controlLayers } = getState();
|
const state = getState();
|
||||||
|
const nodes = selectNodesSlice(state);
|
||||||
|
const canvas = selectCanvasSlice(state);
|
||||||
|
|
||||||
deleted_images.forEach((image_name) => {
|
deleted_images.forEach((image_name) => {
|
||||||
const imageUsage = getImageUsage(canvas, nodes.present, controlAdapters, controlLayers.present, image_name);
|
const imageUsage = getImageUsage(nodes, canvas, image_name);
|
||||||
|
|
||||||
if (imageUsage.isCanvasImage && !wasCanvasReset) {
|
|
||||||
dispatch(resetCanvas());
|
|
||||||
wasCanvasReset = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (imageUsage.isNodesImage && !wasNodeEditorReset) {
|
if (imageUsage.isNodesImage && !wasNodeEditorReset) {
|
||||||
dispatch(nodeEditorReset());
|
dispatch(nodeEditorReset());
|
||||||
wasNodeEditorReset = true;
|
wasNodeEditorReset = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (imageUsage.isControlImage && !wereControlAdaptersReset) {
|
|
||||||
dispatch(controlAdaptersReset());
|
|
||||||
wereControlAdaptersReset = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (imageUsage.isControlLayerImage && !wereControlLayersReset) {
|
|
||||||
dispatch(allLayersDeleted());
|
|
||||||
wereControlLayersReset = true;
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -1,21 +1,15 @@
|
|||||||
import { ExternalLink } from '@invoke-ai/ui-library';
|
|
||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import {
|
|
||||||
socketBulkDownloadComplete,
|
|
||||||
socketBulkDownloadError,
|
|
||||||
socketBulkDownloadStarted,
|
|
||||||
} from 'services/events/actions';
|
|
||||||
|
|
||||||
const log = logger('images');
|
const log = logger('gallery');
|
||||||
|
|
||||||
export const addBulkDownloadListeners = (startAppListening: AppStartListening) => {
|
export const addBulkDownloadListeners = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: imagesApi.endpoints.bulkDownloadImages.matchFulfilled,
|
matcher: imagesApi.endpoints.bulkDownloadImages.matchFulfilled,
|
||||||
effect: async (action) => {
|
effect: (action) => {
|
||||||
log.debug(action.payload, 'Bulk download requested');
|
log.debug(action.payload, 'Bulk download requested');
|
||||||
|
|
||||||
// If we have an item name, we are processing the bulk download locally and should use it as the toast id to
|
// If we have an item name, we are processing the bulk download locally and should use it as the toast id to
|
||||||
@ -33,7 +27,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
|
|||||||
|
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: imagesApi.endpoints.bulkDownloadImages.matchRejected,
|
matcher: imagesApi.endpoints.bulkDownloadImages.matchRejected,
|
||||||
effect: async () => {
|
effect: () => {
|
||||||
log.debug('Bulk download request failed');
|
log.debug('Bulk download request failed');
|
||||||
|
|
||||||
// There isn't any toast to update if we get this event.
|
// There isn't any toast to update if we get this event.
|
||||||
@ -44,55 +38,4 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
|
|||||||
});
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: socketBulkDownloadStarted,
|
|
||||||
effect: async (action) => {
|
|
||||||
// This should always happen immediately after the bulk download request, so we don't need to show a toast here.
|
|
||||||
log.debug(action.payload.data, 'Bulk download preparation started');
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: socketBulkDownloadComplete,
|
|
||||||
effect: async (action) => {
|
|
||||||
log.debug(action.payload.data, 'Bulk download preparation completed');
|
|
||||||
|
|
||||||
const { bulk_download_item_name } = action.payload.data;
|
|
||||||
|
|
||||||
// TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first
|
|
||||||
const url = `/api/v1/images/download/${bulk_download_item_name}`;
|
|
||||||
|
|
||||||
toast({
|
|
||||||
id: bulk_download_item_name,
|
|
||||||
title: t('gallery.bulkDownloadReady', 'Download ready'),
|
|
||||||
status: 'success',
|
|
||||||
description: (
|
|
||||||
<ExternalLink
|
|
||||||
label={t('gallery.clickToDownload', 'Click here to download')}
|
|
||||||
href={url}
|
|
||||||
download={bulk_download_item_name}
|
|
||||||
/>
|
|
||||||
),
|
|
||||||
duration: null,
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: socketBulkDownloadError,
|
|
||||||
effect: async (action) => {
|
|
||||||
log.debug(action.payload.data, 'Bulk download preparation failed');
|
|
||||||
|
|
||||||
const { bulk_download_item_name } = action.payload.data;
|
|
||||||
|
|
||||||
toast({
|
|
||||||
id: bulk_download_item_name,
|
|
||||||
title: t('gallery.bulkDownloadFailed'),
|
|
||||||
status: 'error',
|
|
||||||
description: action.payload.data.error,
|
|
||||||
duration: null,
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
|
@ -1,38 +0,0 @@
|
|||||||
import { $logger } from 'app/logging/logger';
|
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
|
||||||
import { canvasCopiedToClipboard } from 'features/canvas/store/actions';
|
|
||||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
|
||||||
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
|
|
||||||
import { toast } from 'features/toast/toast';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
|
|
||||||
export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartListening) => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: canvasCopiedToClipboard,
|
|
||||||
effect: async (action, { getState }) => {
|
|
||||||
const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' });
|
|
||||||
const state = getState();
|
|
||||||
|
|
||||||
try {
|
|
||||||
const blob = getBaseLayerBlob(state);
|
|
||||||
|
|
||||||
copyBlobToClipboard(blob);
|
|
||||||
} catch (err) {
|
|
||||||
moduleLog.error(String(err));
|
|
||||||
toast({
|
|
||||||
id: 'CANVAS_COPY_FAILED',
|
|
||||||
title: t('toast.problemCopyingCanvas'),
|
|
||||||
description: t('toast.problemCopyingCanvasDesc'),
|
|
||||||
status: 'error',
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
toast({
|
|
||||||
id: 'CANVAS_COPY_SUCCEEDED',
|
|
||||||
title: t('toast.canvasCopiedClipboard'),
|
|
||||||
status: 'success',
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -1,34 +0,0 @@
|
|||||||
import { $logger } from 'app/logging/logger';
|
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
|
||||||
import { canvasDownloadedAsImage } from 'features/canvas/store/actions';
|
|
||||||
import { downloadBlob } from 'features/canvas/util/downloadBlob';
|
|
||||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
|
||||||
import { toast } from 'features/toast/toast';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
|
|
||||||
export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartListening) => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: canvasDownloadedAsImage,
|
|
||||||
effect: async (action, { getState }) => {
|
|
||||||
const moduleLog = $logger.get().child({ namespace: 'canvasSavedToGalleryListener' });
|
|
||||||
const state = getState();
|
|
||||||
|
|
||||||
let blob;
|
|
||||||
try {
|
|
||||||
blob = await getBaseLayerBlob(state);
|
|
||||||
} catch (err) {
|
|
||||||
moduleLog.error(String(err));
|
|
||||||
toast({
|
|
||||||
id: 'CANVAS_DOWNLOAD_FAILED',
|
|
||||||
title: t('toast.problemDownloadingCanvas'),
|
|
||||||
description: t('toast.problemDownloadingCanvasDesc'),
|
|
||||||
status: 'error',
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
downloadBlob(blob, 'canvas.png');
|
|
||||||
toast({ id: 'CANVAS_DOWNLOAD_SUCCEEDED', title: t('toast.canvasDownloaded'), status: 'success' });
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -1,60 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
|
||||||
import { canvasImageToControlAdapter } from 'features/canvas/store/actions';
|
|
||||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
|
||||||
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
|
||||||
import { toast } from 'features/toast/toast';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
|
||||||
|
|
||||||
export const addCanvasImageToControlNetListener = (startAppListening: AppStartListening) => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: canvasImageToControlAdapter,
|
|
||||||
effect: async (action, { dispatch, getState }) => {
|
|
||||||
const log = logger('canvas');
|
|
||||||
const state = getState();
|
|
||||||
const { id } = action.payload;
|
|
||||||
|
|
||||||
let blob: Blob;
|
|
||||||
try {
|
|
||||||
blob = await getBaseLayerBlob(state, true);
|
|
||||||
} catch (err) {
|
|
||||||
log.error(String(err));
|
|
||||||
toast({
|
|
||||||
id: 'PROBLEM_SAVING_CANVAS',
|
|
||||||
title: t('toast.problemSavingCanvas'),
|
|
||||||
description: t('toast.problemSavingCanvasDesc'),
|
|
||||||
status: 'error',
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const { autoAddBoardId } = state.gallery;
|
|
||||||
|
|
||||||
const imageDTO = await dispatch(
|
|
||||||
imagesApi.endpoints.uploadImage.initiate({
|
|
||||||
file: new File([blob], 'savedCanvas.png', {
|
|
||||||
type: 'image/png',
|
|
||||||
}),
|
|
||||||
image_category: 'control',
|
|
||||||
is_intermediate: true,
|
|
||||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
|
||||||
crop_visible: false,
|
|
||||||
postUploadAction: {
|
|
||||||
type: 'TOAST',
|
|
||||||
title: t('toast.canvasSentControlnetAssets'),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
).unwrap();
|
|
||||||
|
|
||||||
const { image_name } = imageDTO;
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
controlAdapterImageChanged({
|
|
||||||
id,
|
|
||||||
controlImage: image_name,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -1,60 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
|
||||||
import { canvasMaskSavedToGallery } from 'features/canvas/store/actions';
|
|
||||||
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
|
||||||
import { toast } from 'features/toast/toast';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
|
||||||
|
|
||||||
export const addCanvasMaskSavedToGalleryListener = (startAppListening: AppStartListening) => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: canvasMaskSavedToGallery,
|
|
||||||
effect: async (action, { dispatch, getState }) => {
|
|
||||||
const log = logger('canvas');
|
|
||||||
const state = getState();
|
|
||||||
|
|
||||||
const canvasBlobsAndImageData = await getCanvasData(
|
|
||||||
state.canvas.layerState,
|
|
||||||
state.canvas.boundingBoxCoordinates,
|
|
||||||
state.canvas.boundingBoxDimensions,
|
|
||||||
state.canvas.isMaskEnabled,
|
|
||||||
state.canvas.shouldPreserveMaskedArea
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!canvasBlobsAndImageData) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const { maskBlob } = canvasBlobsAndImageData;
|
|
||||||
|
|
||||||
if (!maskBlob) {
|
|
||||||
log.error('Problem getting mask layer blob');
|
|
||||||
toast({
|
|
||||||
id: 'PROBLEM_SAVING_MASK',
|
|
||||||
title: t('toast.problemSavingMask'),
|
|
||||||
description: t('toast.problemSavingMaskDesc'),
|
|
||||||
status: 'error',
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const { autoAddBoardId } = state.gallery;
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
imagesApi.endpoints.uploadImage.initiate({
|
|
||||||
file: new File([maskBlob], 'canvasMaskImage.png', {
|
|
||||||
type: 'image/png',
|
|
||||||
}),
|
|
||||||
image_category: 'mask',
|
|
||||||
is_intermediate: false,
|
|
||||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
|
||||||
crop_visible: true,
|
|
||||||
postUploadAction: {
|
|
||||||
type: 'TOAST',
|
|
||||||
title: t('toast.maskSavedAssets'),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -1,70 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
|
||||||
import { canvasMaskToControlAdapter } from 'features/canvas/store/actions';
|
|
||||||
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
|
||||||
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
|
||||||
import { toast } from 'features/toast/toast';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
|
||||||
|
|
||||||
export const addCanvasMaskToControlNetListener = (startAppListening: AppStartListening) => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: canvasMaskToControlAdapter,
|
|
||||||
effect: async (action, { dispatch, getState }) => {
|
|
||||||
const log = logger('canvas');
|
|
||||||
const state = getState();
|
|
||||||
const { id } = action.payload;
|
|
||||||
const canvasBlobsAndImageData = await getCanvasData(
|
|
||||||
state.canvas.layerState,
|
|
||||||
state.canvas.boundingBoxCoordinates,
|
|
||||||
state.canvas.boundingBoxDimensions,
|
|
||||||
state.canvas.isMaskEnabled,
|
|
||||||
state.canvas.shouldPreserveMaskedArea
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!canvasBlobsAndImageData) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const { maskBlob } = canvasBlobsAndImageData;
|
|
||||||
|
|
||||||
if (!maskBlob) {
|
|
||||||
log.error('Problem getting mask layer blob');
|
|
||||||
toast({
|
|
||||||
id: 'PROBLEM_IMPORTING_MASK',
|
|
||||||
title: t('toast.problemImportingMask'),
|
|
||||||
description: t('toast.problemImportingMaskDesc'),
|
|
||||||
status: 'error',
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const { autoAddBoardId } = state.gallery;
|
|
||||||
|
|
||||||
const imageDTO = await dispatch(
|
|
||||||
imagesApi.endpoints.uploadImage.initiate({
|
|
||||||
file: new File([maskBlob], 'canvasMaskImage.png', {
|
|
||||||
type: 'image/png',
|
|
||||||
}),
|
|
||||||
image_category: 'mask',
|
|
||||||
is_intermediate: true,
|
|
||||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
|
||||||
crop_visible: false,
|
|
||||||
postUploadAction: {
|
|
||||||
type: 'TOAST',
|
|
||||||
title: t('toast.maskSentControlnetAssets'),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
).unwrap();
|
|
||||||
|
|
||||||
const { image_name } = imageDTO;
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
controlAdapterImageChanged({
|
|
||||||
id,
|
|
||||||
controlImage: image_name,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -1,73 +0,0 @@
|
|||||||
import { $logger } from 'app/logging/logger';
|
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
|
||||||
import { canvasMerged } from 'features/canvas/store/actions';
|
|
||||||
import { $canvasBaseLayer } from 'features/canvas/store/canvasNanostore';
|
|
||||||
import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
|
|
||||||
import { getFullBaseLayerBlob } from 'features/canvas/util/getFullBaseLayerBlob';
|
|
||||||
import { toast } from 'features/toast/toast';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
|
||||||
|
|
||||||
export const addCanvasMergedListener = (startAppListening: AppStartListening) => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: canvasMerged,
|
|
||||||
effect: async (action, { dispatch }) => {
|
|
||||||
const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' });
|
|
||||||
const blob = await getFullBaseLayerBlob();
|
|
||||||
|
|
||||||
if (!blob) {
|
|
||||||
moduleLog.error('Problem getting base layer blob');
|
|
||||||
toast({
|
|
||||||
id: 'PROBLEM_MERGING_CANVAS',
|
|
||||||
title: t('toast.problemMergingCanvas'),
|
|
||||||
description: t('toast.problemMergingCanvasDesc'),
|
|
||||||
status: 'error',
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const canvasBaseLayer = $canvasBaseLayer.get();
|
|
||||||
|
|
||||||
if (!canvasBaseLayer) {
|
|
||||||
moduleLog.error('Problem getting canvas base layer');
|
|
||||||
toast({
|
|
||||||
id: 'PROBLEM_MERGING_CANVAS',
|
|
||||||
title: t('toast.problemMergingCanvas'),
|
|
||||||
description: t('toast.problemMergingCanvasDesc'),
|
|
||||||
status: 'error',
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const baseLayerRect = canvasBaseLayer.getClientRect({
|
|
||||||
relativeTo: canvasBaseLayer.getParent() ?? undefined,
|
|
||||||
});
|
|
||||||
|
|
||||||
const imageDTO = await dispatch(
|
|
||||||
imagesApi.endpoints.uploadImage.initiate({
|
|
||||||
file: new File([blob], 'mergedCanvas.png', {
|
|
||||||
type: 'image/png',
|
|
||||||
}),
|
|
||||||
image_category: 'general',
|
|
||||||
is_intermediate: true,
|
|
||||||
postUploadAction: {
|
|
||||||
type: 'TOAST',
|
|
||||||
title: t('toast.canvasMerged'),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
).unwrap();
|
|
||||||
|
|
||||||
// TODO: I can't figure out how to do the type narrowing in the `take()` so just brute forcing it here
|
|
||||||
const { image_name } = imageDTO;
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
setMergedCanvas({
|
|
||||||
kind: 'image',
|
|
||||||
layer: 'base',
|
|
||||||
imageName: image_name,
|
|
||||||
...baseLayerRect,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -1,53 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
|
||||||
import { parseify } from 'common/util/serialize';
|
|
||||||
import { canvasSavedToGallery } from 'features/canvas/store/actions';
|
|
||||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
|
||||||
import { toast } from 'features/toast/toast';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
|
||||||
|
|
||||||
export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListening) => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: canvasSavedToGallery,
|
|
||||||
effect: async (action, { dispatch, getState }) => {
|
|
||||||
const log = logger('canvas');
|
|
||||||
const state = getState();
|
|
||||||
|
|
||||||
let blob;
|
|
||||||
try {
|
|
||||||
blob = await getBaseLayerBlob(state);
|
|
||||||
} catch (err) {
|
|
||||||
log.error(String(err));
|
|
||||||
toast({
|
|
||||||
id: 'CANVAS_SAVE_FAILED',
|
|
||||||
title: t('toast.problemSavingCanvas'),
|
|
||||||
description: t('toast.problemSavingCanvasDesc'),
|
|
||||||
status: 'error',
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const { autoAddBoardId } = state.gallery;
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
imagesApi.endpoints.uploadImage.initiate({
|
|
||||||
file: new File([blob], 'savedCanvas.png', {
|
|
||||||
type: 'image/png',
|
|
||||||
}),
|
|
||||||
image_category: 'general',
|
|
||||||
is_intermediate: false,
|
|
||||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
|
||||||
crop_visible: true,
|
|
||||||
postUploadAction: {
|
|
||||||
type: 'TOAST',
|
|
||||||
title: t('toast.canvasSavedGallery'),
|
|
||||||
},
|
|
||||||
metadata: {
|
|
||||||
_canvas_objects: parseify(state.canvas.layerState.objects),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -1,194 +0,0 @@
|
|||||||
import { isAnyOf } from '@reduxjs/toolkit';
|
|
||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
|
||||||
import type { AppDispatch } from 'app/store/store';
|
|
||||||
import { parseify } from 'common/util/serialize';
|
|
||||||
import {
|
|
||||||
caLayerImageChanged,
|
|
||||||
caLayerModelChanged,
|
|
||||||
caLayerProcessedImageChanged,
|
|
||||||
caLayerProcessorConfigChanged,
|
|
||||||
caLayerProcessorPendingBatchIdChanged,
|
|
||||||
caLayerRecalled,
|
|
||||||
isControlAdapterLayer,
|
|
||||||
} from 'features/controlLayers/store/controlLayersSlice';
|
|
||||||
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
|
|
||||||
import { toast } from 'features/toast/toast';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
import { isEqual } from 'lodash-es';
|
|
||||||
import { getImageDTO } from 'services/api/endpoints/images';
|
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
|
||||||
import type { BatchConfig } from 'services/api/types';
|
|
||||||
import { socketInvocationComplete } from 'services/events/actions';
|
|
||||||
import { assert } from 'tsafe';
|
|
||||||
|
|
||||||
const matcher = isAnyOf(
|
|
||||||
caLayerImageChanged,
|
|
||||||
caLayerProcessedImageChanged,
|
|
||||||
caLayerProcessorConfigChanged,
|
|
||||||
caLayerModelChanged,
|
|
||||||
caLayerRecalled
|
|
||||||
);
|
|
||||||
|
|
||||||
const DEBOUNCE_MS = 300;
|
|
||||||
const log = logger('session');
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Simple helper to cancel a batch and reset the pending batch ID
|
|
||||||
*/
|
|
||||||
const cancelProcessorBatch = async (dispatch: AppDispatch, layerId: string, batchId: string) => {
|
|
||||||
const req = dispatch(queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: [batchId] }));
|
|
||||||
log.trace({ batchId }, 'Cancelling existing preprocessor batch');
|
|
||||||
try {
|
|
||||||
await req.unwrap();
|
|
||||||
} catch {
|
|
||||||
// no-op
|
|
||||||
} finally {
|
|
||||||
req.reset();
|
|
||||||
// Always reset the pending batch ID - the cancel req could fail if the batch doesn't exist
|
|
||||||
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: null }));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => {
|
|
||||||
startAppListening({
|
|
||||||
matcher,
|
|
||||||
effect: async (action, { dispatch, getState, getOriginalState, cancelActiveListeners, delay, take, signal }) => {
|
|
||||||
const layerId = caLayerRecalled.match(action) ? action.payload.id : action.payload.layerId;
|
|
||||||
const state = getState();
|
|
||||||
const originalState = getOriginalState();
|
|
||||||
|
|
||||||
// Cancel any in-progress instances of this listener
|
|
||||||
cancelActiveListeners();
|
|
||||||
log.trace('Control Layer CA auto-process triggered');
|
|
||||||
|
|
||||||
// Delay before starting actual work
|
|
||||||
await delay(DEBOUNCE_MS);
|
|
||||||
|
|
||||||
const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId);
|
|
||||||
|
|
||||||
if (!layer) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should only process if the processor settings or image have changed
|
|
||||||
const originalLayer = originalState.controlLayers.present.layers
|
|
||||||
.filter(isControlAdapterLayer)
|
|
||||||
.find((l) => l.id === layerId);
|
|
||||||
const originalImage = originalLayer?.controlAdapter.image;
|
|
||||||
const originalConfig = originalLayer?.controlAdapter.processorConfig;
|
|
||||||
|
|
||||||
const image = layer.controlAdapter.image;
|
|
||||||
const processedImage = layer.controlAdapter.processedImage;
|
|
||||||
const config = layer.controlAdapter.processorConfig;
|
|
||||||
|
|
||||||
if (isEqual(config, originalConfig) && isEqual(image, originalImage) && processedImage) {
|
|
||||||
// Neither config nor image have changed, we can bail
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!image || !config) {
|
|
||||||
// - If we have no image, we have nothing to process
|
|
||||||
// - If we have no processor config, we have nothing to process
|
|
||||||
// Clear the processed image and bail
|
|
||||||
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null }));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// At this point, the user has stopped fiddling with the processor settings and there is a processor selected.
|
|
||||||
|
|
||||||
// If there is a pending processor batch, cancel it.
|
|
||||||
if (layer.controlAdapter.processorPendingBatchId) {
|
|
||||||
cancelProcessorBatch(dispatch, layerId, layer.controlAdapter.processorPendingBatchId);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
|
|
||||||
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config as never);
|
|
||||||
const enqueueBatchArg: BatchConfig = {
|
|
||||||
prepend: true,
|
|
||||||
batch: {
|
|
||||||
graph: {
|
|
||||||
nodes: {
|
|
||||||
[processorNode.id]: {
|
|
||||||
...processorNode,
|
|
||||||
// Control images are always intermediate - do not save to gallery
|
|
||||||
is_intermediate: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
edges: [],
|
|
||||||
},
|
|
||||||
runs: 1,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
// Kick off the processor batch
|
|
||||||
const req = dispatch(
|
|
||||||
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
|
|
||||||
fixedCacheKey: 'enqueueBatch',
|
|
||||||
})
|
|
||||||
);
|
|
||||||
|
|
||||||
try {
|
|
||||||
const enqueueResult = await req.unwrap();
|
|
||||||
// TODO(psyche): Update the pydantic models, pretty sure we will _always_ have a batch_id here, but the model says it's optional
|
|
||||||
assert(enqueueResult.batch.batch_id, 'Batch ID not returned from queue');
|
|
||||||
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: enqueueResult.batch.batch_id }));
|
|
||||||
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
|
|
||||||
|
|
||||||
// Wait for the processor node to complete
|
|
||||||
const [invocationCompleteAction] = await take(
|
|
||||||
(action): action is ReturnType<typeof socketInvocationComplete> =>
|
|
||||||
socketInvocationComplete.match(action) &&
|
|
||||||
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
|
|
||||||
action.payload.data.invocation_source_id === processorNode.id
|
|
||||||
);
|
|
||||||
|
|
||||||
// We still have to check the output type
|
|
||||||
assert(
|
|
||||||
invocationCompleteAction.payload.data.result.type === 'image_output',
|
|
||||||
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
|
|
||||||
);
|
|
||||||
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
|
||||||
|
|
||||||
const imageDTO = await getImageDTO(image_name);
|
|
||||||
assert(imageDTO, "Failed to fetch processor output's image DTO");
|
|
||||||
|
|
||||||
// Whew! We made it. Update the layer with the processed image
|
|
||||||
log.debug({ layerId, imageDTO }, 'ControlNet image processed');
|
|
||||||
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO }));
|
|
||||||
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: null }));
|
|
||||||
} catch (error) {
|
|
||||||
if (signal.aborted) {
|
|
||||||
// The listener was canceled - we need to cancel the pending processor batch, if there is one (could have changed by now).
|
|
||||||
const pendingBatchId = getState()
|
|
||||||
.controlLayers.present.layers.filter(isControlAdapterLayer)
|
|
||||||
.find((l) => l.id === layerId)?.controlAdapter.processorPendingBatchId;
|
|
||||||
if (pendingBatchId) {
|
|
||||||
cancelProcessorBatch(dispatch, layerId, pendingBatchId);
|
|
||||||
}
|
|
||||||
log.trace('Control Adapter preprocessor cancelled');
|
|
||||||
} else {
|
|
||||||
// Some other error condition...
|
|
||||||
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
|
|
||||||
|
|
||||||
if (error instanceof Object) {
|
|
||||||
if ('data' in error && 'status' in error) {
|
|
||||||
if (error.status === 403) {
|
|
||||||
dispatch(caLayerImageChanged({ layerId, imageDTO: null }));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
toast({
|
|
||||||
id: 'GRAPH_QUEUE_FAILED',
|
|
||||||
title: t('queue.graphFailedToQueue'),
|
|
||||||
status: 'error',
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
req.reset();
|
|
||||||
}
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -1,85 +0,0 @@
|
|||||||
import type { AnyListenerPredicate } from '@reduxjs/toolkit';
|
|
||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
|
||||||
import type { RootState } from 'app/store/store';
|
|
||||||
import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions';
|
|
||||||
import {
|
|
||||||
controlAdapterAutoConfigToggled,
|
|
||||||
controlAdapterImageChanged,
|
|
||||||
controlAdapterModelChanged,
|
|
||||||
controlAdapterProcessorParamsChanged,
|
|
||||||
controlAdapterProcessortTypeChanged,
|
|
||||||
selectControlAdapterById,
|
|
||||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
|
||||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
|
||||||
|
|
||||||
type AnyControlAdapterParamChangeAction =
|
|
||||||
| ReturnType<typeof controlAdapterProcessorParamsChanged>
|
|
||||||
| ReturnType<typeof controlAdapterModelChanged>
|
|
||||||
| ReturnType<typeof controlAdapterImageChanged>
|
|
||||||
| ReturnType<typeof controlAdapterProcessortTypeChanged>
|
|
||||||
| ReturnType<typeof controlAdapterAutoConfigToggled>;
|
|
||||||
|
|
||||||
const predicate: AnyListenerPredicate<RootState> = (action, state, prevState) => {
|
|
||||||
const isActionMatched =
|
|
||||||
controlAdapterProcessorParamsChanged.match(action) ||
|
|
||||||
controlAdapterModelChanged.match(action) ||
|
|
||||||
controlAdapterImageChanged.match(action) ||
|
|
||||||
controlAdapterProcessortTypeChanged.match(action) ||
|
|
||||||
controlAdapterAutoConfigToggled.match(action);
|
|
||||||
|
|
||||||
if (!isActionMatched) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const { id } = action.payload;
|
|
||||||
const prevCA = selectControlAdapterById(prevState.controlAdapters, id);
|
|
||||||
const ca = selectControlAdapterById(state.controlAdapters, id);
|
|
||||||
if (!prevCA || !isControlNetOrT2IAdapter(prevCA) || !ca || !isControlNetOrT2IAdapter(ca)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (controlAdapterAutoConfigToggled.match(action)) {
|
|
||||||
// do not process if the user just disabled auto-config
|
|
||||||
if (prevCA.shouldAutoConfig === true) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const { controlImage, processorType, shouldAutoConfig } = ca;
|
|
||||||
if (controlAdapterModelChanged.match(action) && !shouldAutoConfig) {
|
|
||||||
// do not process if the action is a model change but the processor settings are dirty
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const isProcessorSelected = processorType !== 'none';
|
|
||||||
|
|
||||||
const hasControlImage = Boolean(controlImage);
|
|
||||||
|
|
||||||
return isProcessorSelected && hasControlImage;
|
|
||||||
};
|
|
||||||
|
|
||||||
const DEBOUNCE_MS = 300;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Listener that automatically processes a ControlNet image when its processor parameters are changed.
|
|
||||||
*
|
|
||||||
* The network request is debounced.
|
|
||||||
*/
|
|
||||||
export const addControlNetAutoProcessListener = (startAppListening: AppStartListening) => {
|
|
||||||
startAppListening({
|
|
||||||
predicate,
|
|
||||||
effect: async (action, { dispatch, cancelActiveListeners, delay }) => {
|
|
||||||
const log = logger('session');
|
|
||||||
const { id } = (action as AnyControlAdapterParamChangeAction).payload;
|
|
||||||
|
|
||||||
// Cancel any in-progress instances of this listener
|
|
||||||
cancelActiveListeners();
|
|
||||||
log.trace('ControlNet auto-process triggered');
|
|
||||||
// Delay before starting actual work
|
|
||||||
await delay(DEBOUNCE_MS);
|
|
||||||
|
|
||||||
dispatch(controlAdapterImageProcessed({ id }));
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -1,118 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
|
||||||
import { parseify } from 'common/util/serialize';
|
|
||||||
import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions';
|
|
||||||
import {
|
|
||||||
controlAdapterImageChanged,
|
|
||||||
controlAdapterProcessedImageChanged,
|
|
||||||
pendingControlImagesCleared,
|
|
||||||
selectControlAdapterById,
|
|
||||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
|
||||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
|
||||||
import { toast } from 'features/toast/toast';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
|
||||||
import type { BatchConfig, ImageDTO } from 'services/api/types';
|
|
||||||
import { socketInvocationComplete } from 'services/events/actions';
|
|
||||||
|
|
||||||
export const addControlNetImageProcessedListener = (startAppListening: AppStartListening) => {
|
|
||||||
startAppListening({
|
|
||||||
actionCreator: controlAdapterImageProcessed,
|
|
||||||
effect: async (action, { dispatch, getState, take }) => {
|
|
||||||
const log = logger('session');
|
|
||||||
const { id } = action.payload;
|
|
||||||
const ca = selectControlAdapterById(getState().controlAdapters, id);
|
|
||||||
|
|
||||||
if (!ca?.controlImage || !isControlNetOrT2IAdapter(ca)) {
|
|
||||||
log.error('Unable to process ControlNet image');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ca.processorType === 'none' || ca.processorNode.type === 'none') {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ControlNet one-off procressing graph is just the processor node, no edges.
|
|
||||||
// Also we need to grab the image.
|
|
||||||
|
|
||||||
const nodeId = ca.processorNode.id;
|
|
||||||
const enqueueBatchArg: BatchConfig = {
|
|
||||||
prepend: true,
|
|
||||||
batch: {
|
|
||||||
graph: {
|
|
||||||
nodes: {
|
|
||||||
[ca.processorNode.id]: {
|
|
||||||
...ca.processorNode,
|
|
||||||
is_intermediate: true,
|
|
||||||
use_cache: false,
|
|
||||||
image: { image_name: ca.controlImage },
|
|
||||||
},
|
|
||||||
},
|
|
||||||
edges: [],
|
|
||||||
},
|
|
||||||
runs: 1,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
try {
|
|
||||||
const req = dispatch(
|
|
||||||
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
|
|
||||||
fixedCacheKey: 'enqueueBatch',
|
|
||||||
})
|
|
||||||
);
|
|
||||||
const enqueueResult = await req.unwrap();
|
|
||||||
req.reset();
|
|
||||||
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
|
|
||||||
|
|
||||||
const [invocationCompleteAction] = await take(
|
|
||||||
(action): action is ReturnType<typeof socketInvocationComplete> =>
|
|
||||||
socketInvocationComplete.match(action) &&
|
|
||||||
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
|
|
||||||
action.payload.data.invocation_source_id === nodeId
|
|
||||||
);
|
|
||||||
|
|
||||||
// We still have to check the output type
|
|
||||||
if (invocationCompleteAction.payload.data.result.type === 'image_output') {
|
|
||||||
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
|
||||||
|
|
||||||
// Wait for the ImageDTO to be received
|
|
||||||
const [{ payload }] = await take(
|
|
||||||
(action) =>
|
|
||||||
imagesApi.endpoints.getImageDTO.matchFulfilled(action) && action.payload.image_name === image_name
|
|
||||||
);
|
|
||||||
|
|
||||||
const processedControlImage = payload as ImageDTO;
|
|
||||||
|
|
||||||
log.debug({ controlNetId: action.payload, processedControlImage }, 'ControlNet image processed');
|
|
||||||
|
|
||||||
// Update the processed image in the store
|
|
||||||
dispatch(
|
|
||||||
controlAdapterProcessedImageChanged({
|
|
||||||
id,
|
|
||||||
processedControlImage: processedControlImage.image_name,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
|
|
||||||
|
|
||||||
if (error instanceof Object) {
|
|
||||||
if ('data' in error && 'status' in error) {
|
|
||||||
if (error.status === 403) {
|
|
||||||
dispatch(pendingControlImagesCleared());
|
|
||||||
dispatch(controlAdapterImageChanged({ id, controlImage: null }));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
toast({
|
|
||||||
id: 'GRAPH_QUEUE_FAILED',
|
|
||||||
title: t('queue.graphFailedToQueue'),
|
|
||||||
status: 'error',
|
|
||||||
});
|
|
||||||
}
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -1,144 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import { enqueueRequested } from 'app/store/actions';
|
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
|
||||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
|
||||||
import { parseify } from 'common/util/serialize';
|
|
||||||
import { canvasBatchIdAdded, stagingAreaInitialized } from 'features/canvas/store/canvasSlice';
|
|
||||||
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
|
||||||
import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
|
||||||
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
|
|
||||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
|
||||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
|
||||||
import { buildCanvasGraph } from 'features/nodes/util/graph/canvas/buildCanvasGraph';
|
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
|
||||||
import type { ImageDTO } from 'services/api/types';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This listener is responsible invoking the canvas. This involves a number of steps:
|
|
||||||
*
|
|
||||||
* 1. Generate image blobs from the canvas layers
|
|
||||||
* 2. Determine the generation mode from the layers (txt2img, img2img, inpaint)
|
|
||||||
* 3. Build the canvas graph
|
|
||||||
* 4. Create the session with the graph
|
|
||||||
* 5. Upload the init image if necessary
|
|
||||||
* 6. Upload the mask image if necessary
|
|
||||||
* 7. Update the init and mask images with the session ID
|
|
||||||
* 8. Initialize the staging area if not yet initialized
|
|
||||||
* 9. Dispatch the sessionReadyToInvoke action to invoke the session
|
|
||||||
*/
|
|
||||||
export const addEnqueueRequestedCanvasListener = (startAppListening: AppStartListening) => {
|
|
||||||
startAppListening({
|
|
||||||
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
|
||||||
enqueueRequested.match(action) && action.payload.tabName === 'canvas',
|
|
||||||
effect: async (action, { getState, dispatch }) => {
|
|
||||||
const log = logger('queue');
|
|
||||||
const { prepend } = action.payload;
|
|
||||||
const state = getState();
|
|
||||||
|
|
||||||
const { layerState, boundingBoxCoordinates, boundingBoxDimensions, isMaskEnabled, shouldPreserveMaskedArea } =
|
|
||||||
state.canvas;
|
|
||||||
|
|
||||||
// Build canvas blobs
|
|
||||||
const canvasBlobsAndImageData = await getCanvasData(
|
|
||||||
layerState,
|
|
||||||
boundingBoxCoordinates,
|
|
||||||
boundingBoxDimensions,
|
|
||||||
isMaskEnabled,
|
|
||||||
shouldPreserveMaskedArea
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!canvasBlobsAndImageData) {
|
|
||||||
log.error('Unable to create canvas data');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const { baseBlob, baseImageData, maskBlob, maskImageData } = canvasBlobsAndImageData;
|
|
||||||
|
|
||||||
// Determine the generation mode
|
|
||||||
const generationMode = getCanvasGenerationMode(baseImageData, maskImageData);
|
|
||||||
|
|
||||||
if (state.system.enableImageDebugging) {
|
|
||||||
const baseDataURL = await blobToDataURL(baseBlob);
|
|
||||||
const maskDataURL = await blobToDataURL(maskBlob);
|
|
||||||
openBase64ImageInTab([
|
|
||||||
{ base64: maskDataURL, caption: 'mask b64' },
|
|
||||||
{ base64: baseDataURL, caption: 'image b64' },
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
|
|
||||||
log.debug(`Generation mode: ${generationMode}`);
|
|
||||||
|
|
||||||
// Temp placeholders for the init and mask images
|
|
||||||
let canvasInitImage: ImageDTO | undefined;
|
|
||||||
let canvasMaskImage: ImageDTO | undefined;
|
|
||||||
|
|
||||||
// For img2img and inpaint/outpaint, we need to upload the init images
|
|
||||||
if (['img2img', 'inpaint', 'outpaint'].includes(generationMode)) {
|
|
||||||
// upload the image, saving the request id
|
|
||||||
canvasInitImage = await dispatch(
|
|
||||||
imagesApi.endpoints.uploadImage.initiate({
|
|
||||||
file: new File([baseBlob], 'canvasInitImage.png', {
|
|
||||||
type: 'image/png',
|
|
||||||
}),
|
|
||||||
image_category: 'general',
|
|
||||||
is_intermediate: true,
|
|
||||||
})
|
|
||||||
).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
// For inpaint/outpaint, we also need to upload the mask layer
|
|
||||||
if (['inpaint', 'outpaint'].includes(generationMode)) {
|
|
||||||
// upload the image, saving the request id
|
|
||||||
canvasMaskImage = await dispatch(
|
|
||||||
imagesApi.endpoints.uploadImage.initiate({
|
|
||||||
file: new File([maskBlob], 'canvasMaskImage.png', {
|
|
||||||
type: 'image/png',
|
|
||||||
}),
|
|
||||||
image_category: 'mask',
|
|
||||||
is_intermediate: true,
|
|
||||||
})
|
|
||||||
).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
const graph = await buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
|
|
||||||
|
|
||||||
log.debug({ graph: parseify(graph) }, `Canvas graph built`);
|
|
||||||
|
|
||||||
// currently this action is just listened to for logging
|
|
||||||
dispatch(canvasGraphBuilt(graph));
|
|
||||||
|
|
||||||
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
|
||||||
|
|
||||||
try {
|
|
||||||
const req = dispatch(
|
|
||||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
|
||||||
fixedCacheKey: 'enqueueBatch',
|
|
||||||
})
|
|
||||||
);
|
|
||||||
|
|
||||||
const enqueueResult = await req.unwrap();
|
|
||||||
req.reset();
|
|
||||||
|
|
||||||
const batchId = enqueueResult.batch.batch_id as string; // we know the is a string, backend provides it
|
|
||||||
|
|
||||||
// Prep the canvas staging area if it is not yet initialized
|
|
||||||
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
|
||||||
dispatch(
|
|
||||||
stagingAreaInitialized({
|
|
||||||
boundingBox: {
|
|
||||||
...state.canvas.boundingBoxCoordinates,
|
|
||||||
...state.canvas.boundingBoxDimensions,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Associate the session with the canvas session ID
|
|
||||||
dispatch(canvasBatchIdAdded(batchId));
|
|
||||||
} catch {
|
|
||||||
// no-op
|
|
||||||
}
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -1,10 +1,21 @@
|
|||||||
|
import { logger } from 'app/logging/logger';
|
||||||
import { enqueueRequested } from 'app/store/actions';
|
import { enqueueRequested } from 'app/store/actions';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
import type { SerializableObject } from 'common/types';
|
||||||
|
import type { Result } from 'common/util/result';
|
||||||
|
import { isErr, withResult, withResultAsync } from 'common/util/result';
|
||||||
|
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||||
|
import { sessionStagingAreaReset, sessionStartedStaging } from 'features/controlLayers/store/canvasSessionSlice';
|
||||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||||
import { buildGenerationTabGraph } from 'features/nodes/util/graph/generation/buildGenerationTabGraph';
|
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||||
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/generation/buildGenerationTabSDXLGraph';
|
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
||||||
|
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||||
|
import { serializeError } from 'serialize-error';
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
|
import type { Invocation } from 'services/api/types';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
|
const log = logger('generation');
|
||||||
|
|
||||||
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
|
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
@ -12,33 +23,77 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
|||||||
enqueueRequested.match(action) && action.payload.tabName === 'generation',
|
enqueueRequested.match(action) && action.payload.tabName === 'generation',
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const { shouldShowProgressInViewer } = state.ui;
|
const model = state.params.model;
|
||||||
const model = state.generation.model;
|
|
||||||
const { prepend } = action.payload;
|
const { prepend } = action.payload;
|
||||||
|
|
||||||
let graph;
|
const manager = $canvasManager.get();
|
||||||
|
assert(manager, 'No model found in state');
|
||||||
|
|
||||||
if (model?.base === 'sdxl') {
|
let didStartStaging = false;
|
||||||
graph = await buildGenerationTabSDXLGraph(state);
|
|
||||||
} else {
|
if (!state.canvasSession.isStaging && state.canvasSession.mode === 'compose') {
|
||||||
graph = await buildGenerationTabGraph(state);
|
dispatch(sessionStartedStaging());
|
||||||
|
didStartStaging = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
const abortStaging = () => {
|
||||||
|
if (didStartStaging && getState().canvasSession.isStaging) {
|
||||||
|
dispatch(sessionStagingAreaReset());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let buildGraphResult: Result<
|
||||||
|
{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel' | 'sdxl_compel_prompt'> },
|
||||||
|
Error
|
||||||
|
>;
|
||||||
|
|
||||||
|
assert(model, 'No model found in state');
|
||||||
|
const base = model.base;
|
||||||
|
|
||||||
|
switch (base) {
|
||||||
|
case 'sdxl':
|
||||||
|
buildGraphResult = await withResultAsync(() => buildSDXLGraph(state, manager));
|
||||||
|
break;
|
||||||
|
case 'sd-1':
|
||||||
|
case `sd-2`:
|
||||||
|
buildGraphResult = await withResultAsync(() => buildSD1Graph(state, manager));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert(false, `No graph builders for base ${base}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isErr(buildGraphResult)) {
|
||||||
|
log.error({ error: serializeError(buildGraphResult.error) }, 'Failed to build graph');
|
||||||
|
abortStaging();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { g, noise, posCond } = buildGraphResult.value;
|
||||||
|
|
||||||
|
const prepareBatchResult = withResult(() => prepareLinearUIBatch(state, g, prepend, noise, posCond));
|
||||||
|
|
||||||
|
if (isErr(prepareBatchResult)) {
|
||||||
|
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||||
|
abortStaging();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const req = dispatch(
|
const req = dispatch(
|
||||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
queueApi.endpoints.enqueueBatch.initiate(prepareBatchResult.value, {
|
||||||
fixedCacheKey: 'enqueueBatch',
|
fixedCacheKey: 'enqueueBatch',
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
try {
|
|
||||||
await req.unwrap();
|
|
||||||
if (shouldShowProgressInViewer) {
|
|
||||||
dispatch(isImageViewerOpenChanged(true));
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
req.reset();
|
req.reset();
|
||||||
|
|
||||||
|
const enqueueResult = await withResultAsync(() => req.unwrap());
|
||||||
|
|
||||||
|
if (isErr(enqueueResult)) {
|
||||||
|
log.error({ error: serializeError(enqueueResult.error) }, 'Failed to enqueue batch');
|
||||||
|
abortStaging();
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.debug({ batchConfig: prepareBatchResult.value } as SerializableObject, 'Enqueued batch');
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { enqueueRequested } from 'app/store/actions';
|
import { enqueueRequested } from 'app/store/actions';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
|
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||||
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
|
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
|
||||||
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
|
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
|
||||||
import { queueApi } from 'services/api/endpoints/queue';
|
import { queueApi } from 'services/api/endpoints/queue';
|
||||||
@ -11,12 +12,12 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
|||||||
enqueueRequested.match(action) && action.payload.tabName === 'workflows',
|
enqueueRequested.match(action) && action.payload.tabName === 'workflows',
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const { nodes, edges } = state.nodes.present;
|
const nodes = selectNodesSlice(state);
|
||||||
const workflow = state.workflow;
|
const workflow = state.workflow;
|
||||||
const graph = buildNodesGraph(state.nodes.present);
|
const graph = buildNodesGraph(nodes);
|
||||||
const builtWorkflow = buildWorkflowWithValidation({
|
const builtWorkflow = buildWorkflowWithValidation({
|
||||||
nodes,
|
nodes: nodes.nodes,
|
||||||
edges,
|
edges: nodes.edges,
|
||||||
workflow,
|
workflow,
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -29,7 +30,8 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
|||||||
batch: {
|
batch: {
|
||||||
graph,
|
graph,
|
||||||
workflow: builtWorkflow,
|
workflow: builtWorkflow,
|
||||||
runs: state.generation.iterations,
|
runs: state.params.iterations,
|
||||||
|
origin: 'workflows',
|
||||||
},
|
},
|
||||||
prepend: action.payload.prepend,
|
prepend: action.payload.prepend,
|
||||||
};
|
};
|
||||||
|
@ -14,9 +14,9 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
|
|||||||
const { shouldShowProgressInViewer } = state.ui;
|
const { shouldShowProgressInViewer } = state.ui;
|
||||||
const { prepend } = action.payload;
|
const { prepend } = action.payload;
|
||||||
|
|
||||||
const graph = await buildMultidiffusionUpscaleGraph(state);
|
const { g, noise, posCond } = await buildMultidiffusionUpscaleGraph(state);
|
||||||
|
|
||||||
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond);
|
||||||
|
|
||||||
const req = dispatch(
|
const req = dispatch(
|
||||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||||
|
@ -27,7 +27,7 @@ export const galleryImageClicked = createAction<{
|
|||||||
export const addGalleryImageClickedListener = (startAppListening: AppStartListening) => {
|
export const addGalleryImageClickedListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: galleryImageClicked,
|
actionCreator: galleryImageClicked,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
|
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const queryArgs = selectListImagesQueryArgs(state);
|
const queryArgs = selectListImagesQueryArgs(state);
|
||||||
|
@ -1,24 +1,27 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
|
import type { SerializableObject } from 'common/types';
|
||||||
import { parseify } from 'common/util/serialize';
|
import { parseify } from 'common/util/serialize';
|
||||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||||
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
|
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
|
||||||
import { size } from 'lodash-es';
|
import { size } from 'lodash-es';
|
||||||
|
import { serializeError } from 'serialize-error';
|
||||||
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
import { appInfoApi } from 'services/api/endpoints/appInfo';
|
||||||
|
|
||||||
|
const log = logger('system');
|
||||||
|
|
||||||
export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening) => {
|
export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: appInfoApi.endpoints.getOpenAPISchema.matchFulfilled,
|
matcher: appInfoApi.endpoints.getOpenAPISchema.matchFulfilled,
|
||||||
effect: (action, { getState }) => {
|
effect: (action, { getState }) => {
|
||||||
const log = logger('system');
|
|
||||||
const schemaJSON = action.payload;
|
const schemaJSON = action.payload;
|
||||||
|
|
||||||
log.debug({ schemaJSON: parseify(schemaJSON) }, 'Received OpenAPI schema');
|
log.debug({ schemaJSON: parseify(schemaJSON) } as SerializableObject, 'Received OpenAPI schema');
|
||||||
const { nodesAllowlist, nodesDenylist } = getState().config;
|
const { nodesAllowlist, nodesDenylist } = getState().config;
|
||||||
|
|
||||||
const nodeTemplates = parseSchema(schemaJSON, nodesAllowlist, nodesDenylist);
|
const nodeTemplates = parseSchema(schemaJSON, nodesAllowlist, nodesDenylist);
|
||||||
|
|
||||||
log.debug({ nodeTemplates: parseify(nodeTemplates) }, `Built ${size(nodeTemplates)} node templates`);
|
log.debug({ nodeTemplates } as SerializableObject, `Built ${size(nodeTemplates)} node templates`);
|
||||||
|
|
||||||
$templates.set(nodeTemplates);
|
$templates.set(nodeTemplates);
|
||||||
},
|
},
|
||||||
@ -30,8 +33,7 @@ export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening
|
|||||||
// If action.meta.condition === true, the request was canceled/skipped because another request was in flight or
|
// If action.meta.condition === true, the request was canceled/skipped because another request was in flight or
|
||||||
// the value was already in the cache. We don't want to log these errors.
|
// the value was already in the cache. We don't want to log these errors.
|
||||||
if (!action.meta.condition) {
|
if (!action.meta.condition) {
|
||||||
const log = logger('system');
|
log.error({ error: serializeError(action.error) }, 'Problem retrieving OpenAPI Schema');
|
||||||
log.error({ error: parseify(action.error) }, 'Problem retrieving OpenAPI Schema');
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -2,15 +2,13 @@ import { logger } from 'app/logging/logger';
|
|||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
|
||||||
|
const log = logger('gallery');
|
||||||
|
|
||||||
export const addImageAddedToBoardFulfilledListener = (startAppListening: AppStartListening) => {
|
export const addImageAddedToBoardFulfilledListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: imagesApi.endpoints.addImageToBoard.matchFulfilled,
|
matcher: imagesApi.endpoints.addImageToBoard.matchFulfilled,
|
||||||
effect: (action) => {
|
effect: (action) => {
|
||||||
const log = logger('images');
|
|
||||||
const { board_id, imageDTO } = action.meta.arg.originalArgs;
|
const { board_id, imageDTO } = action.meta.arg.originalArgs;
|
||||||
|
|
||||||
// TODO: update listImages cache for this board
|
|
||||||
|
|
||||||
log.debug({ board_id, imageDTO }, 'Image added to board');
|
log.debug({ board_id, imageDTO }, 'Image added to board');
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
@ -18,9 +16,7 @@ export const addImageAddedToBoardFulfilledListener = (startAppListening: AppStar
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: imagesApi.endpoints.addImageToBoard.matchRejected,
|
matcher: imagesApi.endpoints.addImageToBoard.matchRejected,
|
||||||
effect: (action) => {
|
effect: (action) => {
|
||||||
const log = logger('images');
|
|
||||||
const { board_id, imageDTO } = action.meta.arg.originalArgs;
|
const { board_id, imageDTO } = action.meta.arg.originalArgs;
|
||||||
|
|
||||||
log.debug({ board_id, imageDTO }, 'Problem adding image to board');
|
log.debug({ board_id, imageDTO }, 'Problem adding image to board');
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -1,20 +1,9 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import type { AppDispatch, RootState } from 'app/store/store';
|
import type { AppDispatch, RootState } from 'app/store/store';
|
||||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
import { entityDeleted, ipaImageChanged } from 'features/controlLayers/store/canvasSlice';
|
||||||
import {
|
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||||
controlAdapterImageChanged,
|
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||||
controlAdapterProcessedImageChanged,
|
|
||||||
selectControlAdapterAll,
|
|
||||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
|
||||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
|
||||||
import {
|
|
||||||
isControlAdapterLayer,
|
|
||||||
isInitialImageLayer,
|
|
||||||
isIPAdapterLayer,
|
|
||||||
isRegionalGuidanceLayer,
|
|
||||||
layerDeleted,
|
|
||||||
} from 'features/controlLayers/store/controlLayersSlice';
|
|
||||||
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
||||||
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
|
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
|
||||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||||
@ -26,6 +15,10 @@ import { forEach, intersectionBy } from 'lodash-es';
|
|||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import type { ImageDTO } from 'services/api/types';
|
import type { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
|
const log = logger('gallery');
|
||||||
|
|
||||||
|
//TODO(psyche): handle image deletion (canvas sessions?)
|
||||||
|
|
||||||
// Some utils to delete images from different parts of the app
|
// Some utils to delete images from different parts of the app
|
||||||
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||||
state.nodes.present.nodes.forEach((node) => {
|
state.nodes.present.nodes.forEach((node) => {
|
||||||
@ -47,52 +40,37 @@ const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: Im
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
const deleteControlAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
// const deleteControlAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||||
forEach(selectControlAdapterAll(state.controlAdapters), (ca) => {
|
// state.canvas.present.controlAdapters.entities.forEach(({ id, imageObject, processedImageObject }) => {
|
||||||
if (
|
// if (
|
||||||
ca.controlImage === imageDTO.image_name ||
|
// imageObject?.image.image_name === imageDTO.image_name ||
|
||||||
(isControlNetOrT2IAdapter(ca) && ca.processedControlImage === imageDTO.image_name)
|
// processedImageObject?.image.image_name === imageDTO.image_name
|
||||||
) {
|
// ) {
|
||||||
dispatch(
|
// dispatch(caImageChanged({ id, imageDTO: null }));
|
||||||
controlAdapterImageChanged({
|
// dispatch(caProcessedImageChanged({ id, imageDTO: null }));
|
||||||
id: ca.id,
|
// }
|
||||||
controlImage: null,
|
// });
|
||||||
})
|
// };
|
||||||
);
|
|
||||||
dispatch(
|
const deleteIPAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||||
controlAdapterProcessedImageChanged({
|
selectCanvasSlice(state).ipAdapters.entities.forEach((entity) => {
|
||||||
id: ca.id,
|
if (entity.ipAdapter.image?.image_name === imageDTO.image_name) {
|
||||||
processedControlImage: null,
|
dispatch(ipaImageChanged({ entityIdentifier: getEntityIdentifier(entity), imageDTO: null }));
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
const deleteControlLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
const deleteLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||||
state.controlLayers.present.layers.forEach((l) => {
|
selectCanvasSlice(state).rasterLayers.entities.forEach(({ id, objects }) => {
|
||||||
if (isRegionalGuidanceLayer(l)) {
|
let shouldDelete = false;
|
||||||
if (l.ipAdapters.some((ipa) => ipa.image?.name === imageDTO.image_name)) {
|
for (const obj of objects) {
|
||||||
dispatch(layerDeleted(l.id));
|
if (obj.type === 'image' && obj.image.image_name === imageDTO.image_name) {
|
||||||
|
shouldDelete = true;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (isControlAdapterLayer(l)) {
|
if (shouldDelete) {
|
||||||
if (
|
dispatch(entityDeleted({ entityIdentifier: { id, type: 'raster_layer' } }));
|
||||||
l.controlAdapter.image?.name === imageDTO.image_name ||
|
|
||||||
l.controlAdapter.processedImage?.name === imageDTO.image_name
|
|
||||||
) {
|
|
||||||
dispatch(layerDeleted(l.id));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (isIPAdapterLayer(l)) {
|
|
||||||
if (l.ipAdapter.image?.name === imageDTO.image_name) {
|
|
||||||
dispatch(layerDeleted(l.id));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (isInitialImageLayer(l)) {
|
|
||||||
if (l.image?.name === imageDTO.image_name) {
|
|
||||||
dispatch(layerDeleted(l.id));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
@ -145,14 +123,10 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
|
||||||
if (imageUsage.isCanvasImage) {
|
|
||||||
dispatch(resetCanvas());
|
|
||||||
}
|
|
||||||
|
|
||||||
deleteControlAdapterImages(state, dispatch, imageDTO);
|
|
||||||
deleteNodesImages(state, dispatch, imageDTO);
|
deleteNodesImages(state, dispatch, imageDTO);
|
||||||
deleteControlLayerImages(state, dispatch, imageDTO);
|
// deleteControlAdapterImages(state, dispatch, imageDTO);
|
||||||
|
deleteIPAdapterImages(state, dispatch, imageDTO);
|
||||||
|
deleteLayerImages(state, dispatch, imageDTO);
|
||||||
} catch {
|
} catch {
|
||||||
// no-op
|
// no-op
|
||||||
} finally {
|
} finally {
|
||||||
@ -189,14 +163,11 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
|
|||||||
|
|
||||||
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
||||||
|
|
||||||
if (imagesUsage.some((i) => i.isCanvasImage)) {
|
|
||||||
dispatch(resetCanvas());
|
|
||||||
}
|
|
||||||
|
|
||||||
imageDTOs.forEach((imageDTO) => {
|
imageDTOs.forEach((imageDTO) => {
|
||||||
deleteControlAdapterImages(state, dispatch, imageDTO);
|
|
||||||
deleteNodesImages(state, dispatch, imageDTO);
|
deleteNodesImages(state, dispatch, imageDTO);
|
||||||
deleteControlLayerImages(state, dispatch, imageDTO);
|
// deleteControlAdapterImages(state, dispatch, imageDTO);
|
||||||
|
deleteIPAdapterImages(state, dispatch, imageDTO);
|
||||||
|
deleteLayerImages(state, dispatch, imageDTO);
|
||||||
});
|
});
|
||||||
} catch {
|
} catch {
|
||||||
// no-op
|
// no-op
|
||||||
@ -220,7 +191,6 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: imagesApi.endpoints.deleteImage.matchFulfilled,
|
matcher: imagesApi.endpoints.deleteImage.matchFulfilled,
|
||||||
effect: (action) => {
|
effect: (action) => {
|
||||||
const log = logger('images');
|
|
||||||
log.debug({ imageDTO: action.meta.arg.originalArgs }, 'Image deleted');
|
log.debug({ imageDTO: action.meta.arg.originalArgs }, 'Image deleted');
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
@ -228,7 +198,6 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: imagesApi.endpoints.deleteImage.matchRejected,
|
matcher: imagesApi.endpoints.deleteImage.matchRejected,
|
||||||
effect: (action) => {
|
effect: (action) => {
|
||||||
const log = logger('images');
|
|
||||||
log.debug({ imageDTO: action.meta.arg.originalArgs }, 'Unable to delete image');
|
log.debug({ imageDTO: action.meta.arg.originalArgs }, 'Unable to delete image');
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -1,28 +1,19 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { parseify } from 'common/util/serialize';
|
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
|
||||||
import {
|
import {
|
||||||
controlAdapterImageChanged,
|
controlLayerAdded,
|
||||||
controlAdapterIsEnabledChanged,
|
ipaImageChanged,
|
||||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
rasterLayerAdded,
|
||||||
import {
|
rgIPAdapterImageChanged,
|
||||||
caLayerImageChanged,
|
} from 'features/controlLayers/store/canvasSlice';
|
||||||
iiLayerImageChanged,
|
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||||
ipaLayerImageChanged,
|
import type { CanvasControlLayerState, CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||||
rgLayerIPAdapterImageChanged,
|
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
|
||||||
} from 'features/controlLayers/store/controlLayersSlice';
|
|
||||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||||
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||||
import {
|
import { imageToCompareChanged, isImageViewerOpenChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||||
imageSelected,
|
|
||||||
imageToCompareChanged,
|
|
||||||
isImageViewerOpenChanged,
|
|
||||||
selectionChanged,
|
|
||||||
} from 'features/gallery/store/gallerySlice';
|
|
||||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
|
||||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
|
||||||
@ -31,11 +22,12 @@ export const dndDropped = createAction<{
|
|||||||
activeData: TypesafeDraggableData;
|
activeData: TypesafeDraggableData;
|
||||||
}>('dnd/dndDropped');
|
}>('dnd/dndDropped');
|
||||||
|
|
||||||
|
const log = logger('system');
|
||||||
|
|
||||||
export const addImageDroppedListener = (startAppListening: AppStartListening) => {
|
export const addImageDroppedListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: dndDropped,
|
actionCreator: dndDropped,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const log = logger('dnd');
|
|
||||||
const { activeData, overData } = action.payload;
|
const { activeData, overData } = action.payload;
|
||||||
if (!isValidDrop(overData, activeData)) {
|
if (!isValidDrop(overData, activeData)) {
|
||||||
return;
|
return;
|
||||||
@ -46,80 +38,22 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
|||||||
} else if (activeData.payloadType === 'GALLERY_SELECTION') {
|
} else if (activeData.payloadType === 'GALLERY_SELECTION') {
|
||||||
log.debug({ activeData, overData }, `Images (${getState().gallery.selection.length}) dropped`);
|
log.debug({ activeData, overData }, `Images (${getState().gallery.selection.length}) dropped`);
|
||||||
} else if (activeData.payloadType === 'NODE_FIELD') {
|
} else if (activeData.payloadType === 'NODE_FIELD') {
|
||||||
log.debug({ activeData: parseify(activeData), overData: parseify(overData) }, 'Node field dropped');
|
log.debug({ activeData, overData }, 'Node field dropped');
|
||||||
} else {
|
} else {
|
||||||
log.debug({ activeData, overData }, `Unknown payload dropped`);
|
log.debug({ activeData, overData }, `Unknown payload dropped`);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Image dropped on current image
|
|
||||||
*/
|
|
||||||
if (
|
|
||||||
overData.actionType === 'SET_CURRENT_IMAGE' &&
|
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
|
||||||
activeData.payload.imageDTO
|
|
||||||
) {
|
|
||||||
dispatch(imageSelected(activeData.payload.imageDTO));
|
|
||||||
dispatch(isImageViewerOpenChanged(true));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Image dropped on ControlNet
|
|
||||||
*/
|
|
||||||
if (
|
|
||||||
overData.actionType === 'SET_CONTROL_ADAPTER_IMAGE' &&
|
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
|
||||||
activeData.payload.imageDTO
|
|
||||||
) {
|
|
||||||
const { id } = overData.context;
|
|
||||||
dispatch(
|
|
||||||
controlAdapterImageChanged({
|
|
||||||
id,
|
|
||||||
controlImage: activeData.payload.imageDTO.image_name,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
dispatch(
|
|
||||||
controlAdapterIsEnabledChanged({
|
|
||||||
id,
|
|
||||||
isEnabled: true,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Image dropped on Control Adapter Layer
|
|
||||||
*/
|
|
||||||
if (
|
|
||||||
overData.actionType === 'SET_CA_LAYER_IMAGE' &&
|
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
|
||||||
activeData.payload.imageDTO
|
|
||||||
) {
|
|
||||||
const { layerId } = overData.context;
|
|
||||||
dispatch(
|
|
||||||
caLayerImageChanged({
|
|
||||||
layerId,
|
|
||||||
imageDTO: activeData.payload.imageDTO,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Image dropped on IP Adapter Layer
|
* Image dropped on IP Adapter Layer
|
||||||
*/
|
*/
|
||||||
if (
|
if (
|
||||||
overData.actionType === 'SET_IPA_LAYER_IMAGE' &&
|
overData.actionType === 'SET_IPA_IMAGE' &&
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
activeData.payload.imageDTO
|
activeData.payload.imageDTO
|
||||||
) {
|
) {
|
||||||
const { layerId } = overData.context;
|
const { id } = overData.context;
|
||||||
dispatch(
|
dispatch(
|
||||||
ipaLayerImageChanged({
|
ipaImageChanged({ entityIdentifier: { id, type: 'ip_adapter' }, imageDTO: activeData.payload.imageDTO })
|
||||||
layerId,
|
|
||||||
imageDTO: activeData.payload.imageDTO,
|
|
||||||
})
|
|
||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -128,14 +62,14 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
|||||||
* Image dropped on RG Layer IP Adapter
|
* Image dropped on RG Layer IP Adapter
|
||||||
*/
|
*/
|
||||||
if (
|
if (
|
||||||
overData.actionType === 'SET_RG_LAYER_IP_ADAPTER_IMAGE' &&
|
overData.actionType === 'SET_RG_IP_ADAPTER_IMAGE' &&
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
activeData.payload.imageDTO
|
activeData.payload.imageDTO
|
||||||
) {
|
) {
|
||||||
const { layerId, ipAdapterId } = overData.context;
|
const { id, ipAdapterId } = overData.context;
|
||||||
dispatch(
|
dispatch(
|
||||||
rgLayerIPAdapterImageChanged({
|
rgIPAdapterImageChanged({
|
||||||
layerId,
|
entityIdentifier: { id, type: 'regional_guidance' },
|
||||||
ipAdapterId,
|
ipAdapterId,
|
||||||
imageDTO: activeData.payload.imageDTO,
|
imageDTO: activeData.payload.imageDTO,
|
||||||
})
|
})
|
||||||
@ -144,32 +78,38 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Image dropped on II Layer Image
|
* Image dropped on Raster layer
|
||||||
*/
|
*/
|
||||||
if (
|
if (
|
||||||
overData.actionType === 'SET_II_LAYER_IMAGE' &&
|
overData.actionType === 'ADD_RASTER_LAYER_FROM_IMAGE' &&
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
activeData.payload.imageDTO
|
activeData.payload.imageDTO
|
||||||
) {
|
) {
|
||||||
const { layerId } = overData.context;
|
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
|
||||||
dispatch(
|
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
|
||||||
iiLayerImageChanged({
|
const overrides: Partial<CanvasRasterLayerState> = {
|
||||||
layerId,
|
objects: [imageObject],
|
||||||
imageDTO: activeData.payload.imageDTO,
|
position: { x, y },
|
||||||
})
|
};
|
||||||
);
|
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Image dropped on Canvas
|
* Image dropped on Raster layer
|
||||||
*/
|
*/
|
||||||
if (
|
if (
|
||||||
overData.actionType === 'SET_CANVAS_INITIAL_IMAGE' &&
|
overData.actionType === 'ADD_CONTROL_LAYER_FROM_IMAGE' &&
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
activeData.payload.imageDTO
|
activeData.payload.imageDTO
|
||||||
) {
|
) {
|
||||||
dispatch(setInitialCanvasImage(activeData.payload.imageDTO, selectOptimalDimension(getState())));
|
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
|
||||||
|
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
|
||||||
|
const overrides: Partial<CanvasControlLayerState> = {
|
||||||
|
objects: [imageObject],
|
||||||
|
position: { x, y },
|
||||||
|
};
|
||||||
|
dispatch(controlLayerAdded({ overrides, isSelected: true }));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,13 +2,13 @@ import { logger } from 'app/logging/logger';
|
|||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
|
||||||
|
const log = logger('gallery');
|
||||||
|
|
||||||
export const addImageRemovedFromBoardFulfilledListener = (startAppListening: AppStartListening) => {
|
export const addImageRemovedFromBoardFulfilledListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: imagesApi.endpoints.removeImageFromBoard.matchFulfilled,
|
matcher: imagesApi.endpoints.removeImageFromBoard.matchFulfilled,
|
||||||
effect: (action) => {
|
effect: (action) => {
|
||||||
const log = logger('images');
|
|
||||||
const imageDTO = action.meta.arg.originalArgs;
|
const imageDTO = action.meta.arg.originalArgs;
|
||||||
|
|
||||||
log.debug({ imageDTO }, 'Image removed from board');
|
log.debug({ imageDTO }, 'Image removed from board');
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
@ -16,9 +16,7 @@ export const addImageRemovedFromBoardFulfilledListener = (startAppListening: App
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: imagesApi.endpoints.removeImageFromBoard.matchRejected,
|
matcher: imagesApi.endpoints.removeImageFromBoard.matchRejected,
|
||||||
effect: (action) => {
|
effect: (action) => {
|
||||||
const log = logger('images');
|
|
||||||
const imageDTO = action.meta.arg.originalArgs;
|
const imageDTO = action.meta.arg.originalArgs;
|
||||||
|
|
||||||
log.debug({ imageDTO }, 'Problem removing image from board');
|
log.debug({ imageDTO }, 'Problem removing image from board');
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -6,16 +6,17 @@ import { imagesToDeleteSelected, isModalOpenChanged } from 'features/deleteImage
|
|||||||
export const addImageToDeleteSelectedListener = (startAppListening: AppStartListening) => {
|
export const addImageToDeleteSelectedListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: imagesToDeleteSelected,
|
actionCreator: imagesToDeleteSelected,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const imageDTOs = action.payload;
|
const imageDTOs = action.payload;
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const { shouldConfirmOnDelete } = state.system;
|
const { shouldConfirmOnDelete } = state.system;
|
||||||
const imagesUsage = selectImageUsage(getState());
|
const imagesUsage = selectImageUsage(getState());
|
||||||
|
|
||||||
const isImageInUse =
|
const isImageInUse =
|
||||||
imagesUsage.some((i) => i.isCanvasImage) ||
|
imagesUsage.some((i) => i.isLayerImage) ||
|
||||||
imagesUsage.some((i) => i.isControlImage) ||
|
imagesUsage.some((i) => i.isControlAdapterImage) ||
|
||||||
imagesUsage.some((i) => i.isNodesImage);
|
imagesUsage.some((i) => i.isIPAdapterImage) ||
|
||||||
|
imagesUsage.some((i) => i.isLayerImage);
|
||||||
|
|
||||||
if (shouldConfirmOnDelete || isImageInUse) {
|
if (shouldConfirmOnDelete || isImageInUse) {
|
||||||
dispatch(isModalOpenChanged(true));
|
dispatch(isModalOpenChanged(true));
|
||||||
|
@ -1,19 +1,8 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
import { ipaImageChanged, rgIPAdapterImageChanged } from 'features/controlLayers/store/canvasSlice';
|
||||||
import {
|
|
||||||
controlAdapterImageChanged,
|
|
||||||
controlAdapterIsEnabledChanged,
|
|
||||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
|
||||||
import {
|
|
||||||
caLayerImageChanged,
|
|
||||||
iiLayerImageChanged,
|
|
||||||
ipaLayerImageChanged,
|
|
||||||
rgLayerIPAdapterImageChanged,
|
|
||||||
} from 'features/controlLayers/store/controlLayersSlice';
|
|
||||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
|
||||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
@ -21,11 +10,12 @@ import { omit } from 'lodash-es';
|
|||||||
import { boardsApi } from 'services/api/endpoints/boards';
|
import { boardsApi } from 'services/api/endpoints/boards';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
|
||||||
|
const log = logger('gallery');
|
||||||
|
|
||||||
export const addImageUploadedFulfilledListener = (startAppListening: AppStartListening) => {
|
export const addImageUploadedFulfilledListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: imagesApi.endpoints.uploadImage.matchFulfilled,
|
matcher: imagesApi.endpoints.uploadImage.matchFulfilled,
|
||||||
effect: (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const log = logger('images');
|
|
||||||
const imageDTO = action.payload;
|
const imageDTO = action.payload;
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const { autoAddBoardId } = state.gallery;
|
const { autoAddBoardId } = state.gallery;
|
||||||
@ -81,15 +71,6 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (postUploadAction?.type === 'SET_CANVAS_INITIAL_IMAGE') {
|
|
||||||
dispatch(setInitialCanvasImage(imageDTO, selectOptimalDimension(state)));
|
|
||||||
toast({
|
|
||||||
...DEFAULT_UPLOADED_TOAST,
|
|
||||||
description: t('toast.setAsCanvasInitialImage'),
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (postUploadAction?.type === 'SET_UPSCALE_INITIAL_IMAGE') {
|
if (postUploadAction?.type === 'SET_UPSCALE_INITIAL_IMAGE') {
|
||||||
dispatch(upscaleInitialImageChanged(imageDTO));
|
dispatch(upscaleInitialImageChanged(imageDTO));
|
||||||
toast({
|
toast({
|
||||||
@ -99,70 +80,33 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (postUploadAction?.type === 'SET_CONTROL_ADAPTER_IMAGE') {
|
// if (postUploadAction?.type === 'SET_CA_IMAGE') {
|
||||||
|
// const { id } = postUploadAction;
|
||||||
|
// dispatch(caImageChanged({ id, imageDTO }));
|
||||||
|
// toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
|
||||||
|
if (postUploadAction?.type === 'SET_IPA_IMAGE') {
|
||||||
const { id } = postUploadAction;
|
const { id } = postUploadAction;
|
||||||
dispatch(
|
dispatch(ipaImageChanged({ entityIdentifier: { id, type: 'ip_adapter' }, imageDTO }));
|
||||||
controlAdapterIsEnabledChanged({
|
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
|
||||||
id,
|
|
||||||
isEnabled: true,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
dispatch(
|
|
||||||
controlAdapterImageChanged({
|
|
||||||
id,
|
|
||||||
controlImage: imageDTO.image_name,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
toast({
|
|
||||||
...DEFAULT_UPLOADED_TOAST,
|
|
||||||
description: t('toast.setControlImage'),
|
|
||||||
});
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (postUploadAction?.type === 'SET_CA_LAYER_IMAGE') {
|
if (postUploadAction?.type === 'SET_RG_IP_ADAPTER_IMAGE') {
|
||||||
const { layerId } = postUploadAction;
|
const { id, ipAdapterId } = postUploadAction;
|
||||||
dispatch(caLayerImageChanged({ layerId, imageDTO }));
|
dispatch(
|
||||||
toast({
|
rgIPAdapterImageChanged({ entityIdentifier: { id, type: 'regional_guidance' }, ipAdapterId, imageDTO })
|
||||||
...DEFAULT_UPLOADED_TOAST,
|
);
|
||||||
description: t('toast.setControlImage'),
|
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
|
||||||
});
|
return;
|
||||||
}
|
|
||||||
|
|
||||||
if (postUploadAction?.type === 'SET_IPA_LAYER_IMAGE') {
|
|
||||||
const { layerId } = postUploadAction;
|
|
||||||
dispatch(ipaLayerImageChanged({ layerId, imageDTO }));
|
|
||||||
toast({
|
|
||||||
...DEFAULT_UPLOADED_TOAST,
|
|
||||||
description: t('toast.setControlImage'),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (postUploadAction?.type === 'SET_RG_LAYER_IP_ADAPTER_IMAGE') {
|
|
||||||
const { layerId, ipAdapterId } = postUploadAction;
|
|
||||||
dispatch(rgLayerIPAdapterImageChanged({ layerId, ipAdapterId, imageDTO }));
|
|
||||||
toast({
|
|
||||||
...DEFAULT_UPLOADED_TOAST,
|
|
||||||
description: t('toast.setControlImage'),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (postUploadAction?.type === 'SET_II_LAYER_IMAGE') {
|
|
||||||
const { layerId } = postUploadAction;
|
|
||||||
dispatch(iiLayerImageChanged({ layerId, imageDTO }));
|
|
||||||
toast({
|
|
||||||
...DEFAULT_UPLOADED_TOAST,
|
|
||||||
description: t('toast.setControlImage'),
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (postUploadAction?.type === 'SET_NODES_IMAGE') {
|
if (postUploadAction?.type === 'SET_NODES_IMAGE') {
|
||||||
const { nodeId, fieldName } = postUploadAction;
|
const { nodeId, fieldName } = postUploadAction;
|
||||||
dispatch(fieldImageValueChanged({ nodeId, fieldName, value: imageDTO }));
|
dispatch(fieldImageValueChanged({ nodeId, fieldName, value: imageDTO }));
|
||||||
toast({
|
toast({ ...DEFAULT_UPLOADED_TOAST, description: `${t('toast.setNodeField')} ${fieldName}` });
|
||||||
...DEFAULT_UPLOADED_TOAST,
|
|
||||||
description: `${t('toast.setNodeField')} ${fieldName}`,
|
|
||||||
});
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -171,7 +115,6 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: imagesApi.endpoints.uploadImage.matchRejected,
|
matcher: imagesApi.endpoints.uploadImage.matchRejected,
|
||||||
effect: (action) => {
|
effect: (action) => {
|
||||||
const log = logger('images');
|
|
||||||
const sanitizedData = {
|
const sanitizedData = {
|
||||||
arg: {
|
arg: {
|
||||||
...omit(action.meta.arg.originalArgs, ['file', 'postUploadAction']),
|
...omit(action.meta.arg.originalArgs, ['file', 'postUploadAction']),
|
||||||
|
@ -6,7 +6,7 @@ import type { ImageDTO } from 'services/api/types';
|
|||||||
export const addImagesStarredListener = (startAppListening: AppStartListening) => {
|
export const addImagesStarredListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: imagesApi.endpoints.starImages.matchFulfilled,
|
matcher: imagesApi.endpoints.starImages.matchFulfilled,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const { updated_image_names: starredImages } = action.payload;
|
const { updated_image_names: starredImages } = action.payload;
|
||||||
|
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
@ -6,7 +6,7 @@ import type { ImageDTO } from 'services/api/types';
|
|||||||
export const addImagesUnstarredListener = (startAppListening: AppStartListening) => {
|
export const addImagesUnstarredListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: imagesApi.endpoints.unstarImages.matchFulfilled,
|
matcher: imagesApi.endpoints.unstarImages.matchFulfilled,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const { updated_image_names: unstarredImages } = action.payload;
|
const { updated_image_names: unstarredImages } = action.payload;
|
||||||
|
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
@ -1,23 +1,18 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import {
|
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
|
||||||
controlAdapterIsEnabledChanged,
|
import { modelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
|
||||||
selectControlAdapterAll,
|
|
||||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
|
||||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
|
||||||
import { modelSelected } from 'features/parameters/store/actions';
|
import { modelSelected } from 'features/parameters/store/actions';
|
||||||
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
|
|
||||||
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
|
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { forEach } from 'lodash-es';
|
|
||||||
|
const log = logger('models');
|
||||||
|
|
||||||
export const addModelSelectedListener = (startAppListening: AppStartListening) => {
|
export const addModelSelectedListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: modelSelected,
|
actionCreator: modelSelected,
|
||||||
effect: (action, { getState, dispatch }) => {
|
effect: (action, { getState, dispatch }) => {
|
||||||
const log = logger('models');
|
|
||||||
|
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const result = zParameterModel.safeParse(action.payload);
|
const result = zParameterModel.safeParse(action.payload);
|
||||||
|
|
||||||
@ -29,34 +24,36 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
|||||||
const newModel = result.data;
|
const newModel = result.data;
|
||||||
|
|
||||||
const newBaseModel = newModel.base;
|
const newBaseModel = newModel.base;
|
||||||
const didBaseModelChange = state.generation.model?.base !== newBaseModel;
|
const didBaseModelChange = state.params.model?.base !== newBaseModel;
|
||||||
|
|
||||||
if (didBaseModelChange) {
|
if (didBaseModelChange) {
|
||||||
// we may need to reset some incompatible submodels
|
// we may need to reset some incompatible submodels
|
||||||
let modelsCleared = 0;
|
let modelsCleared = 0;
|
||||||
|
|
||||||
// handle incompatible loras
|
// handle incompatible loras
|
||||||
forEach(state.lora.loras, (lora, id) => {
|
state.loras.loras.forEach((lora) => {
|
||||||
if (lora.model.base !== newBaseModel) {
|
if (lora.model.base !== newBaseModel) {
|
||||||
dispatch(loraRemoved(id));
|
dispatch(loraDeleted({ id: lora.id }));
|
||||||
modelsCleared += 1;
|
modelsCleared += 1;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// handle incompatible vae
|
// handle incompatible vae
|
||||||
const { vae } = state.generation;
|
const { vae } = state.params;
|
||||||
if (vae && vae.base !== newBaseModel) {
|
if (vae && vae.base !== newBaseModel) {
|
||||||
dispatch(vaeSelected(null));
|
dispatch(vaeSelected(null));
|
||||||
modelsCleared += 1;
|
modelsCleared += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle incompatible controlnets
|
// handle incompatible controlnets
|
||||||
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
|
// state.canvas.present.controlAdapters.entities.forEach((ca) => {
|
||||||
if (ca.model?.base !== newBaseModel) {
|
// if (ca.model?.base !== newBaseModel) {
|
||||||
dispatch(controlAdapterIsEnabledChanged({ id: ca.id, isEnabled: false }));
|
// modelsCleared += 1;
|
||||||
modelsCleared += 1;
|
// if (ca.isEnabled) {
|
||||||
}
|
// dispatch(entityIsEnabledToggled({ entityIdentifier: { id: ca.id, type: 'control_adapter' } }));
|
||||||
});
|
// }
|
||||||
|
// }
|
||||||
|
// });
|
||||||
|
|
||||||
if (modelsCleared > 0) {
|
if (modelsCleared > 0) {
|
||||||
toast({
|
toast({
|
||||||
@ -70,7 +67,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(modelChanged(newModel, state.generation.model));
|
dispatch(modelChanged({ model: newModel, previousModel: state.params.model }));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,36 +1,42 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import type { AppDispatch, RootState } from 'app/store/store';
|
import type { AppDispatch, RootState } from 'app/store/store';
|
||||||
import type { JSONObject } from 'common/types';
|
import type { SerializableObject } from 'common/types';
|
||||||
import {
|
import {
|
||||||
controlAdapterModelCleared,
|
bboxHeightChanged,
|
||||||
selectControlAdapterAll,
|
bboxWidthChanged,
|
||||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
controlLayerModelChanged,
|
||||||
import { heightChanged, widthChanged } from 'features/controlLayers/store/controlLayersSlice';
|
ipaModelChanged,
|
||||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
rgIPAdapterModelChanged,
|
||||||
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
|
} from 'features/controlLayers/store/canvasSlice';
|
||||||
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
|
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
|
||||||
|
import { modelChanged, refinerModelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
|
||||||
|
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||||
|
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||||
|
import { calculateNewSize } from 'features/parameters/components/DocumentSize/calculateNewSize';
|
||||||
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
|
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
|
||||||
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||||
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
|
||||||
import { forEach } from 'lodash-es';
|
|
||||||
import type { Logger } from 'roarr';
|
import type { Logger } from 'roarr';
|
||||||
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
|
isControlNetOrT2IAdapterModelConfig,
|
||||||
|
isIPAdapterModelConfig,
|
||||||
|
isLoRAModelConfig,
|
||||||
isNonRefinerMainModelConfig,
|
isNonRefinerMainModelConfig,
|
||||||
isRefinerMainModelModelConfig,
|
isRefinerMainModelModelConfig,
|
||||||
isSpandrelImageToImageModelConfig,
|
isSpandrelImageToImageModelConfig,
|
||||||
isVAEModelConfig,
|
isVAEModelConfig,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
|
|
||||||
|
const log = logger('models');
|
||||||
|
|
||||||
export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
|
export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
predicate: modelsApi.endpoints.getModelConfigs.matchFulfilled,
|
predicate: modelsApi.endpoints.getModelConfigs.matchFulfilled,
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: (action, { getState, dispatch }) => {
|
||||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||||
const log = logger('models');
|
|
||||||
log.info({ models: action.payload.entities }, `Models loaded (${action.payload.ids.length})`);
|
log.info({ models: action.payload.entities }, `Models loaded (${action.payload.ids.length})`);
|
||||||
|
|
||||||
const state = getState();
|
const state = getState();
|
||||||
@ -43,6 +49,7 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
|||||||
handleLoRAModels(models, state, dispatch, log);
|
handleLoRAModels(models, state, dispatch, log);
|
||||||
handleControlAdapterModels(models, state, dispatch, log);
|
handleControlAdapterModels(models, state, dispatch, log);
|
||||||
handleSpandrelImageToImageModels(models, state, dispatch, log);
|
handleSpandrelImageToImageModels(models, state, dispatch, log);
|
||||||
|
handleIPAdapterModels(models, state, dispatch, log);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
@ -51,15 +58,15 @@ type ModelHandler = (
|
|||||||
models: AnyModelConfig[],
|
models: AnyModelConfig[],
|
||||||
state: RootState,
|
state: RootState,
|
||||||
dispatch: AppDispatch,
|
dispatch: AppDispatch,
|
||||||
log: Logger<JSONObject>
|
log: Logger<SerializableObject>
|
||||||
) => undefined;
|
) => undefined;
|
||||||
|
|
||||||
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
|
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
|
||||||
const currentModel = state.generation.model;
|
const currentModel = state.params.model;
|
||||||
const mainModels = models.filter(isNonRefinerMainModelConfig);
|
const mainModels = models.filter(isNonRefinerMainModelConfig);
|
||||||
if (mainModels.length === 0) {
|
if (mainModels.length === 0) {
|
||||||
// No models loaded at all
|
// No models loaded at all
|
||||||
dispatch(modelChanged(null));
|
dispatch(modelChanged({ model: null }));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,25 +81,16 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
|
|||||||
if (defaultModelInList) {
|
if (defaultModelInList) {
|
||||||
const result = zParameterModel.safeParse(defaultModelInList);
|
const result = zParameterModel.safeParse(defaultModelInList);
|
||||||
if (result.success) {
|
if (result.success) {
|
||||||
dispatch(modelChanged(defaultModelInList, currentModel));
|
dispatch(modelChanged({ model: defaultModelInList, previousModel: currentModel }));
|
||||||
|
const { bbox } = selectCanvasSlice(state);
|
||||||
const optimalDimension = getOptimalDimension(defaultModelInList);
|
const optimalDimension = getOptimalDimension(defaultModelInList);
|
||||||
if (
|
if (getIsSizeOptimal(bbox.rect.width, bbox.rect.height, optimalDimension)) {
|
||||||
getIsSizeOptimal(
|
|
||||||
state.controlLayers.present.size.width,
|
|
||||||
state.controlLayers.present.size.height,
|
|
||||||
optimalDimension
|
|
||||||
)
|
|
||||||
) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const { width, height } = calculateNewSize(
|
const { width, height } = calculateNewSize(bbox.aspectRatio.value, optimalDimension * optimalDimension);
|
||||||
state.controlLayers.present.size.aspectRatio.value,
|
|
||||||
optimalDimension * optimalDimension
|
|
||||||
);
|
|
||||||
|
|
||||||
dispatch(widthChanged({ width }));
|
dispatch(bboxWidthChanged({ width }));
|
||||||
dispatch(heightChanged({ height }));
|
dispatch(bboxHeightChanged({ height }));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -104,11 +102,11 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(modelChanged(result.data, currentModel));
|
dispatch(modelChanged({ model: result.data, previousModel: currentModel }));
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
|
const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||||
const currentRefinerModel = state.sdxl.refinerModel;
|
const currentRefinerModel = state.params.refinerModel;
|
||||||
const refinerModels = models.filter(isRefinerMainModelModelConfig);
|
const refinerModels = models.filter(isRefinerMainModelModelConfig);
|
||||||
if (models.length === 0) {
|
if (models.length === 0) {
|
||||||
// No models loaded at all
|
// No models loaded at all
|
||||||
@ -127,7 +125,7 @@ const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
|
const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
|
||||||
const currentVae = state.generation.vae;
|
const currentVae = state.params.vae;
|
||||||
|
|
||||||
if (currentVae === null) {
|
if (currentVae === null) {
|
||||||
// null is a valid VAE! it means "use the default with the main model"
|
// null is a valid VAE! it means "use the default with the main model"
|
||||||
@ -160,28 +158,47 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
|
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||||
const loras = state.lora.loras;
|
const loraModels = models.filter(isLoRAModelConfig);
|
||||||
|
state.loras.loras.forEach((lora) => {
|
||||||
forEach(loras, (lora, id) => {
|
const isLoRAAvailable = loraModels.some((m) => m.key === lora.model.key);
|
||||||
const isLoRAAvailable = models.some((m) => m.key === lora.model.key);
|
|
||||||
|
|
||||||
if (isLoRAAvailable) {
|
if (isLoRAAvailable) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
dispatch(loraDeleted({ id: lora.id }));
|
||||||
dispatch(loraRemoved(id));
|
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
|
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||||
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
|
const caModels = models.filter(isControlNetOrT2IAdapterModelConfig);
|
||||||
const isModelAvailable = models.some((m) => m.key === ca.model?.key);
|
selectCanvasSlice(state).controlLayers.entities.forEach((entity) => {
|
||||||
|
const isModelAvailable = caModels.some((m) => m.key === entity.controlAdapter.model?.key);
|
||||||
if (isModelAvailable) {
|
if (isModelAvailable) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
dispatch(controlLayerModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
const handleIPAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||||
|
const ipaModels = models.filter(isIPAdapterModelConfig);
|
||||||
|
selectCanvasSlice(state).ipAdapters.entities.forEach((entity) => {
|
||||||
|
const isModelAvailable = ipaModels.some((m) => m.key === entity.ipAdapter.model?.key);
|
||||||
|
if (isModelAvailable) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(ipaModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
|
||||||
|
});
|
||||||
|
|
||||||
|
selectCanvasSlice(state).regions.entities.forEach((entity) => {
|
||||||
|
entity.ipAdapters.forEach(({ id: ipAdapterId, model }) => {
|
||||||
|
const isModelAvailable = ipaModels.some((m) => m.key === model?.key);
|
||||||
|
if (isModelAvailable) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(
|
||||||
|
rgIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), ipAdapterId, modelConfig: null })
|
||||||
|
);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { isAnyOf } from '@reduxjs/toolkit';
|
import { isAnyOf } from '@reduxjs/toolkit';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { positivePromptChanged } from 'features/controlLayers/store/controlLayersSlice';
|
import { positivePromptChanged } from 'features/controlLayers/store/paramsSlice';
|
||||||
import {
|
import {
|
||||||
combinatorialToggled,
|
combinatorialToggled,
|
||||||
isErrorChanged,
|
isErrorChanged,
|
||||||
@ -13,8 +13,9 @@ import {
|
|||||||
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
|
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
|
||||||
import { getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
import { getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||||
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||||
|
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
|
||||||
import { utilitiesApi } from 'services/api/endpoints/utilities';
|
import { utilitiesApi } from 'services/api/endpoints/utilities';
|
||||||
import { socketConnected } from 'services/events/actions';
|
import { socketConnected } from 'services/events/setEventListeners';
|
||||||
|
|
||||||
const matcher = isAnyOf(
|
const matcher = isAnyOf(
|
||||||
positivePromptChanged,
|
positivePromptChanged,
|
||||||
@ -22,7 +23,8 @@ const matcher = isAnyOf(
|
|||||||
maxPromptsChanged,
|
maxPromptsChanged,
|
||||||
maxPromptsReset,
|
maxPromptsReset,
|
||||||
socketConnected,
|
socketConnected,
|
||||||
activeStylePresetIdChanged
|
activeStylePresetIdChanged,
|
||||||
|
stylePresetsApi.endpoints.listStylePresets.matchFulfilled
|
||||||
);
|
);
|
||||||
|
|
||||||
export const addDynamicPromptsListener = (startAppListening: AppStartListening) => {
|
export const addDynamicPromptsListener = (startAppListening: AppStartListening) => {
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user