Compare commits

..

187 Commits

Author SHA1 Message Date
3207822738 Update invokeai_version.py 2024-04-25 12:31:59 +10:00
8d86fabf4b chore(ui): lint 2024-04-24 20:09:52 +10:00
af3e910ad3 fix(ui): fix layer arrangement 2024-04-24 20:09:52 +10:00
af25d00964 tidy(ui): use const for brush spacing 2024-04-24 20:09:52 +10:00
d4a30d08ef feat(ui): create new line when mouse held down, leaves canvas and comes back over 2024-04-24 20:09:52 +10:00
bd8a33e824 tidy(ui): clean up renderer functions
- Split logic to create layers/objects from the updating logic
- Organize and comment functions
2024-04-24 20:09:52 +10:00
b425646b7b chore(ui): lint 2024-04-24 20:09:52 +10:00
293e11cfa6 feat(ui): hide add prompt buttons when user has a prompt 2024-04-24 20:09:52 +10:00
c73aabdfbf feat(ui): regional control defaults to having a positive prompt 2024-04-24 20:09:52 +10:00
ca989c54b0 fix(ui): restore OG aspect ratio preview for non-t2i tabs 2024-04-24 20:09:52 +10:00
260e24733f fix: update SDXL IP Adpater starter model to be ViT-H 2024-04-24 00:08:21 -04:00
bb6e3e726d fix: update ip adapter starter models path (#6262)
## Summary

<!--A description of the changes in this PR. Include the kind of change
(fix, feature, docs, etc), the "why" and the "how". Screenshots or
videos are useful for frontend changes.-->

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how we can test the changes in this PR.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-04-24 08:58:15 +05:30
6b394554e2 fix: update ip adapter starter models path 2024-04-24 08:48:25 +05:30
ae1955a1a8 feat(ui): update canvas graphs to provide unet 2024-04-23 07:32:53 -04:00
1bef13db37 feat(nodes): restore unet check on CreateGradientMaskInvocation
Special handling for inpainting models
2024-04-23 07:32:53 -04:00
a461537087 chore: ruff 2024-04-23 07:32:53 -04:00
99e28da19b feat(ui): add variant to model edit
Also simplify the layouting for all model view/edit components.
2024-04-23 07:32:53 -04:00
42a159beaa chore(ui): typegen 2024-04-23 07:32:53 -04:00
0aa5aadfe8 fix(mm): move variant to MainConfigBase
shoulda been here all along
2024-04-23 07:32:53 -04:00
2537d260e3 tests: add test for probing diffusers model variant type 2024-04-23 07:32:53 -04:00
bbf919a933 chore: frontend check error 2024-04-23 07:32:53 -04:00
01897ec576 remove extra inputs 2024-04-23 07:32:53 -04:00
bc12d6654e chore: comments and ruff 2024-04-23 07:32:53 -04:00
6d7c8d5f57 remove unet test 2024-04-23 07:32:53 -04:00
38604aa408 update canvas graphs 2024-04-23 07:32:53 -04:00
781de914f4 fix threshhold 2024-04-23 07:32:53 -04:00
c094bad233 add unet check in gradient mask node 2024-04-23 07:32:53 -04:00
0063014f2b gradient mask node test for inpaint 2024-04-23 07:32:53 -04:00
d7b5ad02e8 tests: add object serializer test for dangling folders
- Ensure they are deleted on init if ephemeral
- Ensure they are _not_ deleted on init if _not_ ephemeral
2024-04-23 17:12:14 +10:00
2cee436ecf tidy(app): remove unused class 2024-04-23 17:12:14 +10:00
e6386d969f fix(app): only clear tempdirs if ephemeral and before creating tempdir
Also, this needs to happen in init, else it deletes the temp dir created in init
2024-04-23 17:12:14 +10:00
4b2b983646 tidy(api): reverted unnecessary changes in dependencies.py 2024-04-23 17:12:14 +10:00
53808149fb moved cleanup routine into object_serializer_disk.py 2024-04-23 17:12:14 +10:00
21ba55d0a6 add an initialization function that removes dangling tmpdirs from outputs/tensors 2024-04-23 17:12:14 +10:00
28c28b2fc0 fix: 🐛 handle trigger phrase form submits 2024-04-23 16:42:40 +10:00
8b9c4c62a6 chore: v4.2.0a2 2024-04-23 13:08:26 +10:00
cf637ecaa6 fix(ui): disabled ip adapters applied to regional control 2024-04-23 13:08:26 +10:00
fca718bdd3 tidy(ui): remove extraneous cursor sync 2024-04-23 12:11:47 +10:00
5196a2efec fix(ui): minor canvas overflow 2024-04-23 12:11:47 +10:00
385e93443a feat(ui): rp hotkeys
- Shift+C: Reset selected layer mask (same as canvas)
- Shift+D: Delete selected layer (cannot be Del, that deletes an image in gallery)
- Shift+A: Add layer (cannot be Ctrl+Shift+N, that opens a new window)
- Ctrl/Cmd+Wheel: Brush size (same as canvas)
2024-04-23 12:11:47 +10:00
604217313a chore(ui): lint 2024-04-23 12:11:47 +10:00
229423b370 tidy(ui): memo aspectratiopreview 2024-04-23 12:11:47 +10:00
75a548e3eb perf(ui): debounce render wait = 300ms 2024-04-23 12:11:47 +10:00
24dbb65ebb perf(ui): add brush spacing
Only add point to line if the next point is 10 or more px from the last point
2024-04-23 12:11:47 +10:00
c915220965 feat(ui): aspect ratio preview is regional prompts canvas 2024-04-23 12:11:47 +10:00
bb37e25ed0 feat(ui): rp ui layout 2024-04-23 12:11:47 +10:00
dda1111f20 Make it alpha 2024-04-22 10:54:21 -04:00
9d71b91b7f chore: v4.2.0b1 2024-04-22 10:54:21 -04:00
714126b832 build(ui): temp disable circular dependency check
I'll need to think about how to fix this properly. For now, disable the check as the UI can still build fine.
2024-04-22 23:09:39 +10:00
a10c66797d chore(ui): lint 2024-04-22 23:09:39 +10:00
6dcaf75b5f feat(ui): regional prompts spray n pray
Trying a lot of different things as I iterated, so this is smooshed into one big commit... too hard to split it now.

- Iterated on IP adapter handling and UI. Unfortunately there is an bug related to undo/redo. The IP adapter state is split across the `controlAdapters` slice and the `regionalPrompts` slice, but only the `regionalPrompts` slice supports undo/redo. If you delete the IP adapter and then undo/redo to a history state where it existed, you'll get an error. The fix is likely to merge the slices... Maybe there's a workaround.
- Iterated on UI. I think the layers are OK now.
- Removed ability to disable RP globally for now. It's enabled if you have enabled RP layers.
- Many minor tweaks and fixes.
2024-04-22 23:09:39 +10:00
018845cda0 tidy(ui): regional prompts kind -> type 2024-04-22 23:09:39 +10:00
8c0a061fa0 fix(ui): hotkeys dependency array 2024-04-20 11:32:08 -04:00
4895875ded feat(ui): rects on regional prompt UI 2024-04-20 11:32:08 -04:00
cfddbda578 tidy(ui): clean up action names 2024-04-20 11:32:08 -04:00
58d3a9e7d4 refactor(ui): revise regional prompts state to support prompt-less mask layers
This structure is more adaptable to future features like IP-Adapter-only regions, controlnet layers, image masks, etc.
2024-04-20 11:32:08 -04:00
a00e703144 feat(nodes): image mask to tensor invocation
Thanks @JPPhoto!
2024-04-20 11:32:08 -04:00
e4024bdeb9 fix(ui): floor all pixel coords
This prevents rendering objects with sub-pixel positioning, which looks soft
2024-04-20 11:32:08 -04:00
944690ac8e feat(ui): remove drag distance on layers 2024-04-20 11:32:08 -04:00
a7d69aa0a9 fix(ui): brush preview cursor jank 2024-04-20 11:32:08 -04:00
15018fdbc0 fix(ui): brush preview not visible after hotkey 2024-04-20 11:32:08 -04:00
31ace9aff8 feat(ui): tool hotkeys for rp 2024-04-20 11:32:08 -04:00
3f4ea30113 fix(ui): fix missing bbox when a layer is empty 2024-04-20 11:32:08 -04:00
7edcadb371 fix(ui): bbox rendered slightly too small 2024-04-20 11:32:08 -04:00
d582203c62 chore(ui): lint 2024-04-20 14:54:49 +10:00
148a6c08ca fix(ui): fix bbox caching 2024-04-20 14:54:49 +10:00
1e904d281a feat(ui): hook up sd1.5 t2i graph to regional prompts 2024-04-20 14:54:49 +10:00
03d9a75720 feat(ui): better rp colors 2024-04-20 14:54:49 +10:00
5edce0a4de perf(ui): caching efficiency 2024-04-20 14:54:49 +10:00
604bf4e9ec perf(ui): use efficient group caching instead of a compositing rect
Seems to be the same speed and it's less complex.
2024-04-20 14:54:49 +10:00
39d036bb37 feat(ui): update move tool to show all bboxes, mouseover bbox strokes 2024-04-20 14:54:49 +10:00
8a69fbd336 perf(ui): more bbox optimizations
- Keep track of whether the bbox needs to be recalculated (e.g. had lines/points added)
- Keep track of whether the bbox has eraser strokes - if yes, we need to do the full pixel-perfect bbox calculation, otherwise we can use the faster getClientRect
- Use comparison rather than Math.min/max in bbox calculation (slightly faster)
- Return `null` if no pixel data at all in bbox
2024-04-20 14:54:49 +10:00
a71ed10b71 perf(ui): more efficient bbox method with smaller minimum offscreen canvas size 2024-04-20 14:54:49 +10:00
9d3978edcf fix(ui): give min dimensions to rp storybook 2024-04-20 14:54:49 +10:00
18e1d74917 fix(ui): group layer color change history 2024-04-20 14:54:49 +10:00
9276ecfd02 feat(ui): rp ui styling/layout 2024-04-19 09:32:56 -04:00
ea527f5fe1 feat(nodes): add beta classification to mask tensor nodes 2024-04-19 09:32:56 -04:00
d43f9732ab feat(ui): rp ui styling 2024-04-19 09:32:56 -04:00
c613839740 feat(ui): use translations for rp features 2024-04-19 09:32:56 -04:00
bb371cfeca feat(ui): minor styling rp 2024-04-19 09:32:56 -04:00
6a5510146c feat(ui): add default rp brush size 2024-04-19 09:32:56 -04:00
9667f77c41 feat(ui): rp editor styling 2024-04-19 09:32:56 -04:00
e93e0612af tidy(ui): selectedLayer -> selectedLayerId 2024-04-19 09:32:56 -04:00
9528287d56 feat(ui): move ephemeral tool state out of redux 2024-04-19 09:32:56 -04:00
14c722c265 tidy(ui): remove unused conditional 2024-04-19 09:32:56 -04:00
4b2cd2da9f feat(ui): remove special handling of main prompt
Until we have a good handle on what works best, leaving this to the user
2024-04-19 09:32:56 -04:00
3c5b728bee feat(ui): add enabled state for RP 2024-04-19 09:32:56 -04:00
9b5c47748d tidy(ui): isRegionalPromptLayer -> isRPLayer 2024-04-19 09:32:56 -04:00
eb781272f7 tidy(ui): organize rp layer components 2024-04-19 09:32:56 -04:00
642a0de3dd feat(ui): organize layer naming
prep for non-rp layer types
2024-04-19 09:32:56 -04:00
f3b4cecf2e feat(ui): invert tensor mask instead of loading mask image and converting to tensor second time
minor efficiency improvement
2024-04-19 09:32:56 -04:00
499e7a7b74 chore(ui): typegen 2024-04-19 09:32:56 -04:00
aace364677 feat(nodes): add InvertTensorMaskInvocation 2024-04-19 09:32:56 -04:00
c195094e91 fix(ui): do not open panels when collapsed and window resize 2024-04-19 09:32:56 -04:00
e6c57edf87 tidy(ui): shuffle graph builder logic 2024-04-19 09:32:56 -04:00
c217e052a8 tidy(ui): remove unused action 2024-04-19 09:32:56 -04:00
964e2236b9 feat(ui): do not add promptless conditioning nodes 2024-04-19 09:32:56 -04:00
a6e64423d9 feat(ui): per-layer autonegative 2024-04-19 09:32:56 -04:00
d3aa97ab99 feat(ui): add copy graph button to queue item detail view 2024-04-19 09:32:56 -04:00
0d8edd67ab fix(ui): group lines together in undo history 2024-04-19 09:32:56 -04:00
d9dd00ea20 feat(ui): undo/redo in regional prompts
using the `redux-undo` library
2024-04-19 09:32:56 -04:00
170763899a tidy(ui): tidy regional prompts graph helper, add comments 2024-04-19 09:32:56 -04:00
9e1a4a4a07 feat(ui): regional prompts default layer opacity 2024-04-19 09:32:56 -04:00
dcb4a40741 fix(ui): regional prompts brush preview wonkiness 2024-04-19 09:32:56 -04:00
f8bf985256 perf(ui): do not recreate map callback on every render 2024-04-19 09:32:56 -04:00
81f29b9624 tidy(ui): remove errant console.log 2024-04-19 09:32:56 -04:00
f2dde9a035 feat(ui): cleared selected layer styling 2024-04-19 09:32:56 -04:00
85f4a066fb feat(ui): use logger for stage renderer 2024-04-19 09:32:56 -04:00
b9e6b7ba48 feat(ui): restore layer arrange functionality 2024-04-19 09:32:56 -04:00
085f7bdbee feat(ui): add invert negative mode
Adds an additional negative conditioning using the inverted mask of the positive conditioning and the positive prompt. May be useful for mutually exclusive regions.
2024-04-19 09:32:56 -04:00
e4fcb6627a feat(ui): style regional prompt boxes 2024-04-19 09:32:56 -04:00
47aa6357d1 tidy(ui): organize regional prompts files 2024-04-19 09:32:56 -04:00
b81030fe27 tidy(ui): remove unused exports 2024-04-19 09:32:56 -04:00
a1a9f0da73 tidy(ui): remove more unused files 2024-04-19 09:32:56 -04:00
8f4f3b773c tidy(ui): remove unused files, code 2024-04-19 09:32:56 -04:00
00737efc31 tidy(ui): tidy naming of regional prompt utils 2024-04-19 09:32:56 -04:00
5924dc6ff6 feat(ui): transparency on regional prompts canvas 2024-04-19 09:32:56 -04:00
246fabf2a0 feat(ui): scaling regional prompt canvas 2024-04-19 09:32:56 -04:00
30e3e12513 feat(ui): layouting regional prompts 2024-04-19 09:32:56 -04:00
a5bfe2dccb feat(ui): support negative regional prompt 2024-04-19 09:32:56 -04:00
aa6bfc8645 fix(ui): wip misc regional prompting ui 2024-04-19 09:32:56 -04:00
20ccdb6c8f fix(ui): remove extra type in nodestate 2024-04-19 09:32:56 -04:00
8caa7bc2b1 feat(ui): abstract out bbox renderer 2024-04-19 09:32:56 -04:00
ede8826757 feat(ui): remove dep on stage in mouse handlers 2024-04-19 09:32:56 -04:00
ff7aa2558a feat(ui): display prompt when debugging regions 2024-04-19 09:32:56 -04:00
c9bf00b80b feat(ui): restore invoke button (wip) 2024-04-19 09:32:56 -04:00
1f8f429d55 feat(ui): abstract layer renderer 2024-04-19 09:32:56 -04:00
d34e431002 feat(ui): abstract brush preview logic 2024-04-19 09:32:56 -04:00
cdb481e836 feat(ui): use konva generics for types in selector functions 2024-04-19 09:32:56 -04:00
525e6d697c feat(ui): re-implement with imperative konva api (wip) 2024-04-19 09:32:56 -04:00
bbbb5479e8 feat(ui): re-implement with imperative konva api (wip) 2024-04-19 09:32:56 -04:00
ae7797f662 feat(ui): re-implement with imperative konva api (wip) 2024-04-19 09:32:56 -04:00
05deeb68fa feat(ui): draft of graph helper for regional prompts 2024-04-19 09:32:56 -04:00
602a59066e fix(nodes): handle invert in alpha_mask_to_tensor 2024-04-19 09:32:56 -04:00
d1db6198b5 perf(ui): memoize & otherwise optimize regional prompts ui 2024-04-19 09:32:56 -04:00
944fa1a847 chore(ui): lint 2024-04-19 09:32:56 -04:00
52e7daffe7 feat(ui): selected layer styling 2024-04-19 09:32:56 -04:00
cf4c1750cb fix(ui): caching broke layer rendering 2024-04-19 09:32:56 -04:00
de7ecc8e3e feat(ui): tweak bbox styling 2024-04-19 09:32:56 -04:00
6c0481ef51 fix(ui): do not reset layer position when toggling visibility 2024-04-19 09:32:56 -04:00
b9d0da44eb feat(ui): wip layer transparency 2024-04-19 09:32:56 -04:00
0a42d7d510 docs(ui): update docstrings for helper function 2024-04-19 09:32:56 -04:00
c1aae0815d feat(ui): regional prompting layout, styling 2024-04-19 09:32:56 -04:00
e7523bd1d9 fix(ui): fix layer debug 2024-04-19 09:32:56 -04:00
8911017bd1 feat(ui): selectable & draggable layers 2024-04-19 09:32:56 -04:00
fc26f3e430 feat(nodes): add alpha mask to tensor invocation 2024-04-19 09:32:56 -04:00
c89a24d1ea feat(ui): add util to get blobs from layers 2024-04-19 09:32:56 -04:00
52ba4966c9 feat(ui): wip regional prompting UI
- Add eraser tool, applies per layer
2024-04-19 09:32:56 -04:00
822dfa77fc feat(ui): wip regional prompting UI
- Arrange layers
- Layer visibility
- Layered brush preview
- Cleanup
2024-04-19 09:32:56 -04:00
83d359b681 feat(ui): wip regional prompting UI 2024-04-19 09:32:56 -04:00
f87eee810b feat(ui): rough out regional prompts components 2024-04-19 09:32:56 -04:00
1d1e4d02dc feat(ui): rough out regional prompts store 2024-04-19 09:32:56 -04:00
2b9f06dc4c Re-enable app shutdown actions (#6244)
* closes #6242

* only override sigINT during slow model scanning

* fix ruff formatting

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
2024-04-19 06:45:42 -04:00
a35386f24c fix: IP Adapter Method having incorrect informational popover 2024-04-18 13:37:55 -04:00
ac1071a5e5 chore: v4.1.0 2024-04-18 07:19:22 +10:00
5295a398f3 translationBot(ui): update translation (Italian)
Currently translated at 98.4% (1122 of 1140 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-04-17 08:41:57 +10:00
0c7283c82d translationBot(ui): update translation (Turkish)
Currently translated at 50.8% (580 of 1140 strings)

translationBot(ui): update translation (Korean)

Currently translated at 43.3% (494 of 1140 strings)

translationBot(ui): update translation (Chinese (Simplified))

Currently translated at 80.9% (923 of 1140 strings)

translationBot(ui): update translation (Russian)

Currently translated at 98.8% (1127 of 1140 strings)

translationBot(ui): update translation (Dutch)

Currently translated at 63.7% (727 of 1140 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 50.4% (575 of 1140 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.3% (1121 of 1140 strings)

translationBot(ui): update translation (Spanish)

Currently translated at 27.8% (317 of 1140 strings)

translationBot(ui): update translation (German)

Currently translated at 72.2% (824 of 1140 strings)

Co-authored-by: Anonymous <noreply@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/es/
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ja/
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ko/
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/nl/
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/tr/
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/zh_Hans/
Translation: InvokeAI/Web UI
2024-04-17 08:41:57 +10:00
73ad173c74 update labels for Style Only and CompositionOnly to be designated as beta 2024-04-17 08:29:10 +10:00
c828a4e59f Add IP Adapter Style & Composition Modes (#6213)
## Summary

Until now IP Adapter had complete control on the contents of the output.
With this PR, users are now able to select "Style Only" or "Composition
Only" to draw just the style or layout of the reference image.

Based off: https://arxiv.org/abs/2404.02733

### New IP Method Option

- `Full` - Both style and layout of the refence image are used.
- `Style Only` - Only the style of the image is used
- `Composition Only` - Only the composition of the image is used.


![opera_0BkqZTwObO](https://github.com/invoke-ai/InvokeAI/assets/54517381/1b2fbbba-44c9-4c25-87cb-3711a17d13e3)

### Example Result


![demo](https://github.com/invoke-ai/InvokeAI/assets/54517381/703f3de5-e685-4691-acda-9338a4c10796)

### Notes

- Supports both SDXL and SD1.5

### Testing

- Just check and test if it works as expected with all IP Adapter models
- both SDXL and SD1.5

## Merge Plan

Good to merge once tested for all edge cases.
2024-04-16 14:23:36 -04:00
6bab040d24 Merge branch 'main' into ip-adapter-style-comp 2024-04-16 21:14:06 +05:30
f46bbaf8c4 fix: make ip-adapter weights not be optional 2024-04-16 21:12:45 +05:30
fce6b3e44c maybe solve race issue 2024-04-16 13:09:26 +10:00
d27907cc6d fix: entire reshaping block needs to be skipped 2024-04-16 04:29:53 +05:30
7ee3fef2db cleanup: better var names for the ip adapter weight collection block 2024-04-16 04:23:50 +05:30
b39ce642b6 cleanup: raise ValueErrors when target_blocks dont match base model 2024-04-16 04:12:30 +05:30
a148c4322c fix: IP Adapter weights being incorrectly applied
They were being overwritten rather than being appended
2024-04-16 04:10:41 +05:30
f6b7bc5d98 fix: Dynamically adapt height of control adapter opts 2024-04-16 01:18:43 +05:30
5f6c6abf9c chore: change IPAdapterAttentionWeights to a dataclass 2024-04-15 23:38:55 +05:30
cd76a31a8f fix: IP Adapter method not being recalled 2024-04-15 22:29:32 +05:30
e93f4d632d [util] Add generic torch device class (#6174)
* introduce new abstraction layer for GPU devices

* add unit test for device abstraction

* fix ruff

* convert TorchDeviceSelect into a stateless class

* move logic to select context-specific execution device into context API

* add mock hardware environments to pytest

* remove dangling mocker fixture

* fix unit test for running on non-CUDA systems

* remove unimplemented get_execution_device() call

* remove autocast precision

* Multiple changes:

1. Remove TorchDeviceSelect.get_execution_device(), as well as calls to
   context.models.get_execution_device().
2. Rename TorchDeviceSelect to TorchDevice
3. Added back the legacy public API defined in `invocation_api`, including
   choose_precision().
4. Added a config file migration script to accommodate removal of precision=autocast.

* add deprecation warnings to choose_torch_device() and choose_precision()

* fix test crash

* remove app_config argument from choose_torch_device() and choose_torch_dtype()

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
2024-04-15 13:12:49 +00:00
5a8489bbfc perf(ui): memoize infill components 2024-04-15 22:50:54 +10:00
a24c9d0f7a perf(ui): optimize useFeatureStatus 2024-04-15 22:50:54 +10:00
7a92afc117 perf(ui): fix rerenders in nodes
Unmemoized selector tanking perf
2024-04-15 22:50:54 +10:00
b508945b11 feat(ui): edge labels
Add setting to render labels with format `Source Node label -> Target Node label` on edges.
2024-04-15 22:48:46 +10:00
8426f1e7b2 fix(experimental): Possible fix for conflict with regional embed length mismatch
Pushing this so people can test it out and see if this needs to be handled in a different way.
2024-04-14 12:19:19 +05:30
9cb0f63c44 refactor: fix a bunch of type issues in custom_attention 2024-04-13 14:17:25 +05:30
2d5786d3bb fix: Incorrect composition blocks for SD1.5 2024-04-13 13:52:10 +05:30
27466ffa1a chore: update the ip adapter node version 2024-04-13 13:39:08 +05:30
f50b156511 chore: do not include custom nodes in schema 2024-04-13 12:43:49 +05:30
9fc73743b2 feat: support SD1.5 2024-04-13 12:30:39 +05:30
d4393e4170 chore: linter fixes 2024-04-13 12:14:45 +05:30
145a0b029e Merge branch 'ip-adapter-style-comp' of https://github.com/blessedcoolant/InvokeAI into ip-adapter-style-comp 2024-04-13 12:13:06 +05:30
f2506cc769 chore: ruff fixes
Revert "chore: ruff fixes"

This reverts commit af36fe8c1e.

Revert "chore: ruff fixes"

This reverts commit af36fe8c1e.
2024-04-13 12:12:33 +05:30
7a67fd6a06 Revert "chore: ruff fixes"
This reverts commit af36fe8c1e.
2024-04-13 12:10:20 +05:30
af36fe8c1e chore: ruff fixes 2024-04-13 12:08:52 +05:30
e9f16ac8c7 feat: add UI for IP Adapter Method 2024-04-13 12:06:59 +05:30
6ea183f0d4 wip: Initial Implementation IP Adapter Style & Comp Modes 2024-04-13 11:09:45 +05:30
157 changed files with 5304 additions and 585 deletions

View File

@ -28,7 +28,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import get_torch_device_name
from invokeai.backend.util.devices import TorchDevice
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
@ -63,7 +63,7 @@ logger = InvokeAILogger.get_logger(config=app_config)
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
torch_device_name = get_torch_device_name()
torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")

View File

@ -24,7 +24,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningFieldData,
SDXLConditioningInfo,
)
from invokeai.backend.util.devices import torch_dtype
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import CLIPField
@ -99,7 +99,7 @@ class CompelInvocation(BaseInvocation):
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
)
@ -193,7 +193,7 @@ class SDXLPromptInvocationBase:
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,

View File

@ -4,20 +4,8 @@ from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
OutputField,
TensorField,
UIType,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
@ -36,6 +24,7 @@ class IPAdapterField(BaseModel):
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the IP-Adapter.")
target_blocks: List[str] = Field(default=[], description="The IP Adapter blocks to apply")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
@ -69,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput):
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.3.0")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""
@ -90,6 +79,9 @@ class IPAdapterInvocation(BaseInvocation):
weight: Union[float, List[float]] = InputField(
default=1, description="The weight given to the IP-Adapter", title="Weight"
)
method: Literal["full", "style", "composition"] = InputField(
default="full", description="The method to apply the IP-Adapter"
)
begin_step_percent: float = InputField(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
@ -124,12 +116,32 @@ class IPAdapterInvocation(BaseInvocation):
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
if self.method == "style":
if ip_adapter_info.base == "sd-1":
target_blocks = ["up_blocks.1"]
elif ip_adapter_info.base == "sdxl":
target_blocks = ["up_blocks.0.attentions.1"]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "composition":
if ip_adapter_info.base == "sd-1":
target_blocks = ["down_blocks.2", "mid_block"]
elif ip_adapter_info.base == "sdxl":
target_blocks = ["down_blocks.2.attentions.1"]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "full":
target_blocks = ["block"]
else:
raise ValueError(f"Unexpected IP-Adapter method: '{self.method}'.")
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
weight=self.weight,
target_blocks=target_blocks,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
mask=self.mask,

View File

@ -51,6 +51,7 @@ from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType, LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
@ -72,15 +73,12 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
image_resized_to_grid_as_tensor,
)
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device
from ...backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .controlnet_image_processors import ControlField
from .model import ModelIdentifierField, UNetField, VAEField
if choose_torch_device() == torch.device("mps"):
from torch import mps
DEFAULT_PRECISION = choose_precision(choose_torch_device())
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
@invocation_output("scheduler_output")
@ -188,7 +186,7 @@ class GradientMaskOutput(BaseInvocationOutput):
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
version="1.0.0",
version="1.1.0",
)
class CreateGradientMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
@ -201,6 +199,32 @@ class CreateGradientMaskInvocation(BaseInvocation):
minimum_denoise: float = InputField(
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
)
image: Optional[ImageField] = InputField(
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] Image",
ui_order=6,
)
unet: Optional[UNetField] = InputField(
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
default=None,
input=Input.Connection,
title="[OPTIONAL] UNet",
ui_order=5,
)
vae: Optional[VAEField] = InputField(
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] VAE",
input=Input.Connection,
ui_order=7,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
fp32: bool = InputField(
default=DEFAULT_PRECISION == "float32",
description=FieldDescriptions.fp32,
ui_order=9,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
@ -236,8 +260,27 @@ class CreateGradientMaskInvocation(BaseInvocation):
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
expanded_image_dto = context.images.save(expanded_mask_image)
masked_latents_name = None
if self.unet is not None and self.vae is not None and self.image is not None:
# all three fields must be present at the same time
main_model_config = context.models.get_config(self.unet.unet.key)
assert isinstance(main_model_config, MainConfigBase)
if main_model_config.variant is ModelVariantType.Inpaint:
mask = blur_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
masked_latents = ImageToLatentsInvocation.vae_encode(
vae_info, self.fp32, self.tiled, masked_image.clone()
)
masked_latents_name = context.tensors.save(tensor=masked_latents)
return GradientMaskOutput(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
)
@ -682,6 +725,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
IPAdapterData(
ip_adapter_model=ip_adapter_model,
weight=single_ip_adapter.weight,
target_blocks=single_ip_adapter.target_blocks,
begin_step_percent=single_ip_adapter.begin_step_percent,
end_step_percent=single_ip_adapter.end_step_percent,
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
@ -959,9 +1003,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
@ -1028,9 +1070,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.disable_tiling()
# clear memory as vae decode can request a lot
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
with torch.inference_mode():
# copied from diffusers pipeline
@ -1042,9 +1082,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
image_dto = context.images.save(image=image)
@ -1083,9 +1121,7 @@ class ResizeLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
# TODO:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
@ -1096,9 +1132,8 @@ class ResizeLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@ -1125,8 +1160,7 @@ class ScaleLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
# TODO:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
# resizing
resized_latents = torch.nn.functional.interpolate(
@ -1138,9 +1172,7 @@ class ScaleLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@ -1272,8 +1304,7 @@ class BlendLatentsInvocation(BaseInvocation):
if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.")
# TODO:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
def slerp(
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
@ -1326,9 +1357,8 @@ class BlendLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents)

View File

@ -1,7 +1,8 @@
import numpy as np
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext, invocation
from invokeai.app.invocations.fields import InputField, TensorField, WithMetadata
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
from invokeai.app.invocations.primitives import MaskOutput
@ -34,3 +35,86 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
width=self.width,
height=self.height,
)
@invocation(
"alpha_mask_to_tensor",
title="Alpha Mask to Tensor",
tags=["conditioning"],
category="conditioning",
version="1.0.0",
classification=Classification.Beta,
)
class AlphaMaskToTensorInvocation(BaseInvocation):
"""Convert a mask image to a tensor. Opaque regions are 1 and transparent regions are 0."""
image: ImageField = InputField(description="The mask image to convert.")
invert: bool = InputField(default=False, description="Whether to invert the mask.")
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.images.get_pil(self.image.image_name)
mask = torch.zeros((1, image.height, image.width), dtype=torch.bool)
if self.invert:
mask[0] = torch.tensor(np.array(image)[:, :, 3] == 0, dtype=torch.bool)
else:
mask[0] = torch.tensor(np.array(image)[:, :, 3] > 0, dtype=torch.bool)
return MaskOutput(
mask=TensorField(tensor_name=context.tensors.save(mask)),
height=mask.shape[1],
width=mask.shape[2],
)
@invocation(
"invert_tensor_mask",
title="Invert Tensor Mask",
tags=["conditioning"],
category="conditioning",
version="1.0.0",
classification=Classification.Beta,
)
class InvertTensorMaskInvocation(BaseInvocation):
"""Inverts a tensor mask."""
mask: TensorField = InputField(description="The tensor mask to convert.")
def invoke(self, context: InvocationContext) -> MaskOutput:
mask = context.tensors.load(self.mask.tensor_name)
inverted = ~mask
return MaskOutput(
mask=TensorField(tensor_name=context.tensors.save(inverted)),
height=inverted.shape[1],
width=inverted.shape[2],
)
@invocation(
"image_mask_to_tensor",
title="Image Mask to Tensor",
tags=["conditioning"],
category="conditioning",
version="1.0.0",
)
class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
"""Convert a mask image to a tensor. Converts the image to grayscale and uses thresholding at the specified value."""
image: ImageField = InputField(description="The mask image to convert.")
cutoff: int = InputField(ge=0, le=255, description="Cutoff (<)", default=128)
invert: bool = InputField(default=False, description="Whether to invert the mask.")
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.images.get_pil(self.image.image_name, mode="L")
mask = torch.zeros((1, image.height, image.width), dtype=torch.bool)
if self.invert:
mask[0] = torch.tensor(np.array(image)[:, :] >= self.cutoff, dtype=torch.bool)
else:
mask[0] = torch.tensor(np.array(image)[:, :] < self.cutoff, dtype=torch.bool)
return MaskOutput(
mask=TensorField(tensor_name=context.tensors.save(mask)),
height=mask.shape[1],
width=mask.shape[2],
)

View File

@ -36,6 +36,7 @@ class IPAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
method: Literal["full", "style", "composition"] = Field(description="Method to apply IP Weights with")
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")

View File

@ -9,7 +9,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField, Laten
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.util.devices import TorchDevice
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -46,7 +46,7 @@ def get_noise(
height // downsampling_factor,
width // downsampling_factor,
],
dtype=torch_dtype(device),
dtype=TorchDevice.choose_torch_dtype(device=device),
device=noise_device_type,
generator=generator,
).to("cpu")
@ -111,14 +111,14 @@ class NoiseInvocation(BaseInvocation):
@field_validator("seed", mode="before")
def modulo_seed(cls, v):
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
"""Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
return v % (SEED_MAX + 1)
def invoke(self, context: InvocationContext) -> NoiseOutput:
noise = get_noise(
width=self.width,
height=self.height,
device=choose_torch_device(),
device=TorchDevice.choose_torch_device(),
seed=self.seed,
use_cpu=self.use_cpu,
)

View File

@ -4,7 +4,6 @@ from typing import Literal
import cv2
import numpy as np
import torch
from PIL import Image
from pydantic import ConfigDict
@ -14,7 +13,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
@ -35,9 +34,6 @@ ESRGAN_MODEL_URLS: dict[str, str] = {
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
}
if choose_torch_device() == torch.device("mps"):
from torch import mps
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
@ -120,9 +116,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
upscaled_image = upscaler.upscale(cv2_image)
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
image_dto = context.images.save(image=pil_image)

View File

@ -27,12 +27,12 @@ DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.0"
CONFIG_SCHEMA_VERSION = "4.0.1"
def get_default_ram_cache_size() -> float:
@ -105,7 +105,7 @@ class InvokeAIAppConfig(BaseSettings):
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
@ -370,6 +370,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
@ -392,6 +395,28 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
return config
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.0 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
if k == "schema_version":
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
@ -418,17 +443,21 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
migrated_config.write_file(config_path)
return migrated_config
else:
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
if loaded_config_dict["schema_version"] == "4.0.0":
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
@lru_cache(maxsize=1)

View File

@ -270,7 +270,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
job.dest.parent.mkdir(parents=True, exist_ok=True)
job.download_path = job.dest
assert job.download_path is not None
assert job.download_path
# Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af
# for code that instead resumes an interrupted download.
@ -280,9 +280,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
# append ".downloading" to the path
in_progress_path = self._in_progress_path(job.download_path)
# catch rare race condition that is appearing in unit tests.
assert in_progress_path.parent.exists(), f"Directory doesn't exist! in_progress_path={in_progress_path}; parent={in_progress_path.parent}"
# signal caller that the download is starting. At this point, key fields such as
# download_path and total_bytes will be populated. We call it here because the might
# discover that the local file is already complete and generate a COMPLETED status.

View File

@ -3,7 +3,6 @@
import locale
import os
import re
import signal
import threading
import time
from hashlib import sha256
@ -13,6 +12,7 @@ from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Union
import torch
import yaml
from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl
@ -42,7 +42,8 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.devices import choose_precision, choose_torch_device
from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice
from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP,
@ -111,17 +112,6 @@ class ModelInstallService(ModelInstallServiceBase):
def start(self, invoker: Optional[Invoker] = None) -> None:
"""Start the installer thread."""
# Yes, this is weird. When the installer thread is running, the
# thread masks the ^C signal. When we receive a
# sigINT, we stop the thread, reset sigINT, and send a new
# sigINT to the parent process.
def sigint_handler(signum, frame):
self.stop()
signal.signal(signal.SIGINT, signal.SIG_DFL)
signal.raise_signal(signal.SIGINT)
signal.signal(signal.SIGINT, sigint_handler)
with self._lock:
if self._running:
raise Exception("Attempt to start the installer service twice")
@ -131,7 +121,8 @@ class ModelInstallService(ModelInstallServiceBase):
# In normal use, we do not want to scan the models directory - it should never have orphaned models.
# We should only do the scan when the flag is set (which should only be set when testing).
if self.app_config.scan_models_on_startup:
self._register_orphaned_models()
with catch_sigint():
self._register_orphaned_models()
# Check all models' paths and confirm they exist. A model could be missing if it was installed on a volume
# that isn't currently mounted. In this case, we don't want to delete the model from the database, but we do
@ -634,11 +625,10 @@ class ModelInstallService(ModelInstallServiceBase):
self._next_job_id += 1
return id
@staticmethod
def _guess_variant() -> Optional[ModelRepoVariant]:
def _guess_variant(self) -> Optional[ModelRepoVariant]:
"""Guess the best HuggingFace variant type to download."""
precision = choose_precision(choose_torch_device())
return ModelRepoVariant.FP16 if precision == "float16" else None
precision = TorchDevice.choose_torch_dtype()
return ModelRepoVariant.FP16 if precision == torch.float16 else None
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
return ModelInstallJob(
@ -754,6 +744,8 @@ class ModelInstallService(ModelInstallServiceBase):
self._download_cache[download_job.source] = install_job # matches a download job to an install job
install_job.download_parts.add(download_job)
# only start the jobs once install_job.download_parts is fully populated
for download_job in install_job.download_parts:
self._download_queue.submit_download_job(
download_job,
on_start=self._download_started_callback,
@ -762,6 +754,7 @@ class ModelInstallService(ModelInstallServiceBase):
on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback,
)
return install_job
def _stat_size(self, path: Path) -> int:

View File

@ -1,12 +1,14 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
from typing import Optional
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig
@ -67,7 +69,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_device: torch.device = choose_torch_device(),
execution_device: Optional[torch.device] = None,
) -> Self:
"""
Construct the model manager service instance.
@ -82,7 +84,7 @@ class ModelManagerService(ModelManagerServiceBase):
max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload,
logger=logger,
execution_device=execution_device,
execution_device=execution_device or TorchDevice.choose_torch_device(),
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService(

View File

@ -1,6 +1,6 @@
import shutil
import tempfile
import typing
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional, TypeVar
@ -17,12 +17,6 @@ if TYPE_CHECKING:
T = TypeVar("T")
@dataclass
class DeleteAllResult:
deleted_count: int
freed_space_bytes: float
class ObjectSerializerDisk(ObjectSerializerBase[T]):
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
@ -35,6 +29,12 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
self._ephemeral = ephemeral
self._base_output_dir = output_dir
self._base_output_dir.mkdir(parents=True, exist_ok=True)
if self._ephemeral:
# Remove dangling tempdirs that might have been left over from an earlier unplanned shutdown.
for temp_dir in filter(Path.is_dir, self._base_output_dir.glob("tmp*")):
shutil.rmtree(temp_dir)
# Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows
self._tempdir = (
tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None

View File

@ -13,7 +13,7 @@ from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
config = get_config()
@ -56,7 +56,7 @@ class DepthAnythingDetector:
def __init__(self) -> None:
self.model = None
self.model_size: Union[Literal["large", "base", "small"], None] = None
self.device = choose_torch_device()
self.device = TorchDevice.choose_torch_device()
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
@ -81,7 +81,7 @@ class DepthAnythingDetector:
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
self.model.eval()
self.model.to(choose_torch_device())
self.model.to(self.device)
return self.model
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
@ -94,7 +94,7 @@ class DepthAnythingDetector:
image_height, image_width = np_image.shape[:2]
np_image = transform({"image": np_image})["image"]
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device())
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
with torch.no_grad():
depth = self.model(tensor_image)

View File

@ -7,7 +7,7 @@ import onnxruntime as ort
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from .onnxdet import inference_detector
from .onnxpose import inference_pose
@ -28,9 +28,9 @@ config = get_config()
class Wholebody:
def __init__(self):
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)

View File

@ -8,7 +8,7 @@ from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
def norm_img(np_img):
@ -29,7 +29,7 @@ def load_jit_model(url_or_path, device):
class LaMA:
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
model_location = get_config().models_path / "core/misc/lama/lama.pt"
if not model_location.exists():

View File

@ -11,7 +11,7 @@ from cv2.typing import MatLike
from tqdm import tqdm
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
"""
Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
@ -65,7 +65,7 @@ class RealESRGAN:
self.pre_pad = pre_pad
self.mod_scale: Optional[int] = None
self.half = half
self.device = choose_torch_device()
self.device = TorchDevice.choose_torch_device()
loadnet = torch.load(model_path, map_location=torch.device("cpu"))

View File

@ -13,7 +13,7 @@ from transformers import AutoFeatureExtractor
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.silence_warnings import SilenceWarnings
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
@ -51,7 +51,7 @@ class SafetyChecker:
cls._load_safety_checker()
if cls.safety_checker is None or cls.feature_extractor is None:
return False
device = choose_torch_device()
device = TorchDevice.choose_torch_device()
features = cls.feature_extractor([image], return_tensors="pt")
features.to(device)
cls.safety_checker.to(device)

View File

@ -301,12 +301,12 @@ class MainConfigBase(ModelConfigBase):
default_settings: Optional[MainModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
variant: ModelVariantType = ModelVariantType.Normal
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
"""Model config for main checkpoint models."""
variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False

View File

@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoad
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from invokeai.backend.util.devices import TorchDevice
# TO DO: The loader is not thread safe!
@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase):
self._logger = logger
self._ram_cache = ram_cache
self._convert_cache = convert_cache
self._torch_dtype = torch_dtype(choose_torch_device())
self._torch_dtype = TorchDevice.choose_torch_dtype()
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""

View File

@ -30,15 +30,12 @@ import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
from .model_locker import ModelLocker
if choose_torch_device() == torch.device("mps"):
from torch import mps
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
@ -244,9 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
)
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device.
@ -416,10 +411,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self.stats.cleared = models_cleared
gc.collect()
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:

View File

@ -17,7 +17,7 @@ from diffusers.utils import logging as dlogging
from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from invokeai.backend.util.devices import TorchDevice
from . import (
AnyModelConfig,
@ -43,6 +43,7 @@ class ModelMerger(object):
Initialize a ModelMerger object with the model installer.
"""
self._installer = installer
self._dtype = TorchDevice.choose_torch_dtype()
def merge_diffusion_models(
self,
@ -68,7 +69,7 @@ class ModelMerger(object):
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
dtype = torch.float16 if variant == "fp16" else self._dtype
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
@ -151,7 +152,7 @@ class ModelMerger(object):
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
dtype = torch.float16 if variant == "fp16" else self._dtype
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
# register model and get its unique key

View File

@ -155,7 +155,7 @@ STARTER_MODELS: list[StarterModel] = [
StarterModel(
name="IP Adapter",
base=BaseModelType.StableDiffusion1,
source="InvokeAI/ip_adapter_sd15",
source="https://huggingface.co/InvokeAI/ip_adapter_sd15/resolve/main/ip-adapter_sd15.safetensors",
description="IP-Adapter for SD 1.5 models",
type=ModelType.IPAdapter,
dependencies=[ip_adapter_sd_image_encoder],
@ -163,7 +163,7 @@ STARTER_MODELS: list[StarterModel] = [
StarterModel(
name="IP Adapter Plus",
base=BaseModelType.StableDiffusion1,
source="InvokeAI/ip_adapter_plus_sd15",
source="https://huggingface.co/InvokeAI/ip_adapter_plus_sd15/resolve/main/ip-adapter-plus_sd15.safetensors",
description="Refined IP-Adapter for SD 1.5 models",
type=ModelType.IPAdapter,
dependencies=[ip_adapter_sd_image_encoder],
@ -171,7 +171,7 @@ STARTER_MODELS: list[StarterModel] = [
StarterModel(
name="IP Adapter Plus Face",
base=BaseModelType.StableDiffusion1,
source="InvokeAI/ip_adapter_plus_face_sd15",
source="https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15/resolve/main/ip-adapter-plus-face_sd15.safetensors",
description="Refined IP-Adapter for SD 1.5 models, adapted for faces",
type=ModelType.IPAdapter,
dependencies=[ip_adapter_sd_image_encoder],
@ -179,7 +179,7 @@ STARTER_MODELS: list[StarterModel] = [
StarterModel(
name="IP Adapter SDXL",
base=BaseModelType.StableDiffusionXL,
source="InvokeAI/ip_adapter_sdxl",
source="https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h/resolve/main/ip-adapter_sdxl_vit-h.safetensors",
description="IP-Adapter for SDXL models",
type=ModelType.IPAdapter,
dependencies=[ip_adapter_sdxl_image_encoder],

View File

@ -21,14 +21,11 @@ from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
IPAdapterData,
TextConditioningData,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import normalize_device
from invokeai.backend.util.devices import TorchDevice
@dataclass
@ -258,7 +255,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
mem_free = psutil.virtual_memory().free
elif self.unet.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device))
else:
raise ValueError(f"unrecognized device {self.unet.device}")
# input tensor of [1, 4, h/8, w/8]
@ -394,8 +391,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
unet_attention_patcher = None
self.use_ip_adapter = use_ip_adapter
attn_ctx = nullcontext()
if use_ip_adapter or use_regional_prompting:
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
ip_adapters: Optional[List[UNetIPAdapterData]] = (
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
if use_ip_adapter
else None
)
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)

View File

@ -53,6 +53,7 @@ class IPAdapterData:
ip_adapter_model: IPAdapter
ip_adapter_conditioning: IPAdapterConditioningInfo
mask: torch.Tensor
target_blocks: List[str]
# Either a single weight applied to all steps, or a list of weights for each step.
weight: Union[float, List[float]] = 1.0

View File

@ -1,4 +1,5 @@
from typing import Optional
from dataclasses import dataclass
from typing import List, Optional, cast
import torch
import torch.nn.functional as F
@ -9,6 +10,12 @@ from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import Regiona
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
@dataclass
class IPAdapterAttentionWeights:
ip_adapter_weights: IPAttentionProcessorWeights
skip: bool
class CustomAttnProcessor2_0(AttnProcessor2_0):
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
This implementation is based on
@ -20,7 +27,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
def __init__(
self,
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
):
"""Initialize a CustomAttnProcessor2_0.
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
@ -30,23 +37,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
for the i'th IP-Adapter.
"""
super().__init__()
self._ip_adapter_weights = ip_adapter_weights
def _is_ip_adapter_enabled(self) -> bool:
return self._ip_adapter_weights is not None
self._ip_adapter_attention_weights = ip_adapter_attention_weights
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
# For regional prompting:
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
# For Regional Prompting:
regional_prompt_data: Optional[RegionalPromptData] = None,
percent_through: Optional[torch.FloatTensor] = None,
percent_through: Optional[torch.Tensor] = None,
# For IP-Adapter:
regional_ip_data: Optional[RegionalIPData] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
"""Apply attention.
Args:
@ -130,17 +136,19 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Apply IP-Adapter conditioning.
if is_cross_attention:
if self._is_ip_adapter_enabled():
if self._ip_adapter_attention_weights:
assert regional_ip_data is not None
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
assert (
len(regional_ip_data.image_prompt_embeds)
== len(self._ip_adapter_weights)
== len(self._ip_adapter_attention_weights)
== len(regional_ip_data.scales)
== ip_masks.shape[1]
)
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
ipa_weights = self._ip_adapter_weights[ipa_index]
ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
ipa_scale = regional_ip_data.scales[ipa_index]
ip_mask = ip_masks[0, ipa_index, ...]
@ -153,29 +161,33 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
if not self._ip_adapter_attention_weights[ipa_index].skip:
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
# Expected ip_key and ip_value shape:
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# Expected ip_key and ip_value shape:
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
ip_hidden_states = ip_hidden_states.to(query.dtype)
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
else:
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
assert regional_ip_data is None
@ -188,11 +200,15 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End of unmodified block from AttnProcessor2_0
return hidden_states
# casting torch.Tensor to torch.FloatTensor to avoid type issues
return cast(torch.FloatTensor, hidden_states)

View File

@ -1,17 +1,25 @@
from contextlib import contextmanager
from typing import Optional
from typing import List, Optional, TypedDict
from diffusers.models import UNet2DConditionModel
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
CustomAttnProcessor2_0,
IPAdapterAttentionWeights,
)
class UNetIPAdapterData(TypedDict):
ip_adapter: IPAdapter
target_blocks: List[str]
class UNetAttentionPatcher:
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
self._ip_adapters = ip_adapters
def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]):
self._ip_adapters = ip_adapter_data
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
@ -26,9 +34,22 @@ class UNetAttentionPatcher:
attn_procs[name] = CustomAttnProcessor2_0()
else:
# Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = CustomAttnProcessor2_0(
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
)
ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
for ip_adapter in self._ip_adapters:
ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
skip = True
for block in ip_adapter["target_blocks"]:
if block in name:
skip = False
break
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
ip_adapter_weights=ip_adapter_weights, skip=skip
)
ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)
attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)
return attn_procs
@contextmanager

View File

@ -2,7 +2,6 @@
Initialization file for invokeai.backend.util
"""
from .devices import choose_precision, choose_torch_device
from .logging import InvokeAILogger
from .util import GIG, Chdir, directory_size
@ -11,6 +10,4 @@ __all__ = [
"directory_size",
"Chdir",
"InvokeAILogger",
"choose_precision",
"choose_torch_device",
]

View File

@ -0,0 +1,29 @@
"""
This module defines a context manager `catch_sigint()` which temporarily replaces
the sigINT handler defined by the ASGI in order to allow the user to ^C the application
and shut it down immediately. This was implemented in order to allow the user to interrupt
slow model hashing during startup.
Use like this:
from invokeai.backend.util.catch_sigint import catch_sigint
with catch_sigint():
run_some_hard_to_interrupt_process()
"""
import signal
from contextlib import contextmanager
from typing import Generator
def sigint_handler(signum, frame): # type: ignore
signal.signal(signal.SIGINT, signal.SIG_DFL)
signal.raise_signal(signal.SIGINT)
@contextmanager
def catch_sigint() -> Generator[None, None, None]:
original_handler = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGINT, sigint_handler)
yield
signal.signal(signal.SIGINT, original_handler)

View File

@ -1,89 +1,110 @@
from __future__ import annotations
from contextlib import nullcontext
from typing import Literal, Optional, Union
from typing import Dict, Literal, Optional, Union
import torch
from torch import autocast
from deprecated import deprecated
from invokeai.app.services.config.config_default import PRECISION, get_config
from invokeai.app.services.config.config_default import get_config
# legacy APIs
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps")
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
def choose_precision(device: torch.device) -> TorchPrecisionNames:
"""Return the string representation of the recommended torch device."""
torch_dtype = TorchDevice.choose_torch_dtype(device)
return PRECISION_TO_NAME[torch_dtype]
@deprecated("Use TorchDevice.choose_torch_device() instead.") # type: ignore
def choose_torch_device() -> torch.device:
"""Convenience routine for guessing which GPU device to run model on"""
config = get_config()
if config.device == "auto":
if torch.cuda.is_available():
return torch.device("cuda")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
"""Return the torch.device to use for accelerated inference."""
return TorchDevice.choose_torch_device()
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
def torch_dtype(device: torch.device) -> torch.dtype:
"""Return the torch precision for the recommended torch device."""
return TorchDevice.choose_torch_dtype(device)
NAME_TO_PRECISION: Dict[TorchPrecisionNames, torch.dtype] = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NAME_TO_PRECISION.items()}
class TorchDevice:
"""Abstraction layer for torch devices."""
@classmethod
def choose_torch_device(cls) -> torch.device:
"""Return the torch.device to use for accelerated inference."""
app_config = get_config()
if app_config.device != "auto":
device = torch.device(app_config.device)
elif torch.cuda.is_available():
device = CUDA_DEVICE
elif torch.backends.mps.is_available():
device = MPS_DEVICE
else:
return CPU_DEVICE
else:
return torch.device(config.device)
device = CPU_DEVICE
return cls.normalize(device)
@classmethod
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
"""Return the precision to use for accelerated inference."""
device = device or cls.choose_torch_device()
config = get_config()
if device.type == "cuda" and torch.cuda.is_available():
device_name = torch.cuda.get_device_name(device)
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
# These GPUs have limited support for float16
return cls._to_dtype("float32")
elif config.precision == "auto":
# Default to float16 for CUDA devices
return cls._to_dtype("float16")
else:
# Use the user-defined precision
return cls._to_dtype(config.precision)
def get_torch_device_name() -> str:
device = choose_torch_device()
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
elif device.type == "mps" and torch.backends.mps.is_available():
if config.precision == "auto":
# Default to float16 for MPS devices
return cls._to_dtype("float16")
else:
# Use the user-defined precision
return cls._to_dtype(config.precision)
# CPU / safe fallback
return cls._to_dtype("float32")
@classmethod
def get_torch_device_name(cls) -> str:
"""Return the device name for the current torch device."""
device = cls.choose_torch_device()
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
def choose_precision(device: torch.device) -> Literal["float32", "float16", "bfloat16"]:
"""Return an appropriate precision for the given torch device."""
app_config = get_config()
if device.type == "cuda":
device_name = torch.cuda.get_device_name(device)
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
# These GPUs have limited support for float16
return "float32"
elif app_config.precision == "auto" or app_config.precision == "autocast":
# Default to float16 for CUDA devices
return "float16"
else:
# Use the user-defined precision
return app_config.precision
elif device.type == "mps":
if app_config.precision == "auto" or app_config.precision == "autocast":
# Default to float16 for MPS devices
return "float16"
else:
# Use the user-defined precision
return app_config.precision
# CPU / safe fallback
return "float32"
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
device = device or choose_torch_device()
precision = choose_precision(device)
if precision == "float16":
return torch.float16
if precision == "bfloat16":
return torch.bfloat16
else:
# "auto", "autocast", "float32"
return torch.float32
def choose_autocast(precision: PRECISION):
"""Returns an autocast context or nullcontext for the given precision string"""
# float16 currently requires autocast to avoid errors like:
# 'expected scalar type Half but found Float'
if precision == "autocast" or precision == "float16":
return autocast
return nullcontext
def normalize_device(device: Union[str, torch.device]) -> torch.device:
"""Ensure device has a device index defined, if appropriate."""
device = torch.device(device)
if device.index is None:
# cuda might be the only torch backend that currently uses the device index?
# I don't see anything like `current_device` for cpu or mps.
if device.type == "cuda":
@classmethod
def normalize(cls, device: Union[str, torch.device]) -> torch.device:
"""Add the device index to CUDA devices."""
device = torch.device(device)
if device.index is None and device.type == "cuda" and torch.cuda.is_available():
device = torch.device(device.type, torch.cuda.current_device())
return device
return device
@classmethod
def empty_cache(cls) -> None:
"""Clear the GPU device cache."""
if torch.backends.mps.is_available():
torch.mps.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
@classmethod
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
return NAME_TO_PRECISION[precision_name]

View File

@ -11,6 +11,7 @@ import { createStore } from '../src/app/store/store';
// @ts-ignore
import translationEN from '../public/locales/en.json';
import { ReduxInit } from './ReduxInit';
import { $store } from 'app/store/nanostores/store';
i18n.use(initReactI18next).init({
lng: 'en',
@ -25,6 +26,7 @@ i18n.use(initReactI18next).init({
});
const store = createStore(undefined, false);
$store.set(store);
$baseUrl.set('http://localhost:9090');
const preview: Preview = {

View File

@ -25,7 +25,7 @@
"typegen": "node scripts/typegen.js",
"preview": "vite preview",
"lint:knip": "knip",
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:1 src/main.tsx",
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:0 src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",
"lint:prettier": "prettier --check .",
"lint:tsc": "tsc --noEmit",
@ -95,6 +95,7 @@
"reactflow": "^11.10.4",
"redux-dynamic-middlewares": "^2.2.0",
"redux-remember": "^5.1.0",
"redux-undo": "^1.1.0",
"rfdc": "^1.3.1",
"roarr": "^7.21.1",
"serialize-error": "^11.0.3",

View File

@ -140,6 +140,9 @@ dependencies:
redux-remember:
specifier: ^5.1.0
version: 5.1.0(redux@5.0.1)
redux-undo:
specifier: ^1.1.0
version: 1.1.0
rfdc:
specifier: ^1.3.1
version: 1.3.1
@ -11962,6 +11965,10 @@ packages:
redux: 5.0.1
dev: false
/redux-undo@1.1.0:
resolution: {integrity: sha512-zzLFh2qeF0MTIlzDhDLm9NtkfBqCllQJ3OCuIl5RKlG/ayHw6GUdIFdMhzMS9NnrnWdBX5u//ExMOHpfudGGOg==}
dev: false
/redux@5.0.1:
resolution: {integrity: sha512-M9/ELqF6fy8FwmkpnF0S3YKOqMyoWJ4+CS5Efg2ct3oY9daQvd/Pc71FpGZsVsbl3Cpb+IIcjBDUnnyBdQbq4w==}
dev: false

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 KiB

View File

@ -85,7 +85,8 @@
"loadMore": "Mehr laden",
"noImagesInGallery": "Keine Bilder in der Galerie",
"loading": "Lade",
"deleteImage": "Lösche Bild",
"deleteImage_one": "Lösche Bild",
"deleteImage_other": "",
"copy": "Kopieren",
"download": "Runterladen",
"setCurrentImage": "Setze aktuelle Bild",

View File

@ -69,6 +69,7 @@
"auto": "Auto",
"back": "Back",
"batch": "Batch Manager",
"beta": "Beta",
"cancel": "Cancel",
"copy": "Copy",
"copyError": "$t(gallery.copy) Error",
@ -83,6 +84,8 @@
"direction": "Direction",
"ipAdapter": "IP Adapter",
"t2iAdapter": "T2I Adapter",
"positivePrompt": "Positive Prompt",
"negativePrompt": "Negative Prompt",
"discordLabel": "Discord",
"dontAskMeAgain": "Don't ask me again",
"error": "Error",
@ -135,7 +138,9 @@
"red": "Red",
"green": "Green",
"blue": "Blue",
"alpha": "Alpha"
"alpha": "Alpha",
"selected": "Selected",
"viewer": "Viewer"
},
"controlnet": {
"controlAdapter_one": "Control Adapter",
@ -213,6 +218,10 @@
"resize": "Resize",
"resizeSimple": "Resize (Simple)",
"resizeMode": "Resize Mode",
"ipAdapterMethod": "Method",
"full": "Full",
"style": "Style Only",
"composition": "Composition Only",
"safe": "Safe",
"saveControlImage": "Save Control Image",
"scribble": "scribble",
@ -770,6 +779,8 @@
"float": "Float",
"fullyContainNodes": "Fully Contain Nodes to Select",
"fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected",
"showEdgeLabels": "Show Edge Labels",
"showEdgeLabelsHelp": "Show labels on edges, indicating the connected nodes",
"hideLegendNodes": "Hide Field Type Legend",
"hideMinimapnodes": "Hide MiniMap",
"inputMayOnlyHaveOneConnection": "Input may only have one connection",
@ -886,6 +897,7 @@
"denoisingStrength": "Denoising Strength",
"downloadImage": "Download Image",
"general": "General",
"globalSettings": "Global Settings",
"height": "Height",
"imageFit": "Fit Initial Image To Output Size",
"images": "Images",
@ -1176,6 +1188,10 @@
"heading": "Resize Mode",
"paragraphs": ["Method to fit Control Adapter's input image size to the output generation size."]
},
"ipAdapterMethod": {
"heading": "Method",
"paragraphs": ["Method by which to apply the current IP Adapter."]
},
"controlNetWeight": {
"heading": "Weight",
"paragraphs": [
@ -1494,5 +1510,27 @@
},
"app": {
"storeNotInitialized": "Store is not initialized"
},
"regionalPrompts": {
"deleteAll": "Delete All",
"addLayer": "Add Layer",
"moveToFront": "Move to Front",
"moveToBack": "Move to Back",
"moveForward": "Move Forward",
"moveBackward": "Move Backward",
"brushSize": "Brush Size",
"regionalControl": "Regional Control (ALPHA)",
"enableRegionalPrompts": "Enable $t(regionalPrompts.regionalPrompts)",
"globalMaskOpacity": "Global Mask Opacity",
"autoNegative": "Auto Negative",
"toggleVisibility": "Toggle Layer Visibility",
"deletePrompt": "Delete Prompt",
"resetRegion": "Reset Region",
"debugLayers": "Debug Layers",
"rectangle": "Rectangle",
"maskPreviewColor": "Mask Preview Color",
"addPositivePrompt": "Add $t(common.positivePrompt)",
"addNegativePrompt": "Add $t(common.negativePrompt)",
"addIPAdapter": "Add $t(common.ipAdapter)"
}
}

View File

@ -33,7 +33,9 @@
"autoSwitchNewImages": "Auto seleccionar Imágenes nuevas",
"loadMore": "Cargar más",
"noImagesInGallery": "No hay imágenes para mostrar",
"deleteImage": "Eliminar Imagen",
"deleteImage_one": "Eliminar Imagen",
"deleteImage_many": "",
"deleteImage_other": "",
"deleteImageBin": "Las imágenes eliminadas se enviarán a la papelera de tu sistema operativo.",
"deleteImagePermanent": "Las imágenes eliminadas no se pueden restaurar.",
"assets": "Activos",

View File

@ -82,7 +82,9 @@
"autoSwitchNewImages": "Passaggio automatico a nuove immagini",
"loadMore": "Carica altro",
"noImagesInGallery": "Nessuna immagine da visualizzare",
"deleteImage": "Elimina l'immagine",
"deleteImage_one": "Elimina l'immagine",
"deleteImage_many": "Elimina {{count}} immagini",
"deleteImage_other": "Elimina {{count}} immagini",
"deleteImagePermanent": "Le immagini eliminate non possono essere ripristinate.",
"deleteImageBin": "Le immagini eliminate verranno spostate nel cestino del tuo sistema operativo.",
"assets": "Risorse",

View File

@ -90,7 +90,7 @@
"problemDeletingImages": "画像の削除中に問題が発生",
"drop": "ドロップ",
"dropOrUpload": "$t(gallery.drop) またはアップロード",
"deleteImage": "画像を削除",
"deleteImage_other": "画像を削除",
"deleteImageBin": "削除された画像はOSのゴミ箱に送られます。",
"deleteImagePermanent": "削除された画像は復元できません。",
"download": "ダウンロード",

View File

@ -82,7 +82,7 @@
"drop": "드랍",
"problemDeletingImages": "이미지 삭제 중 발생한 문제",
"downloadSelection": "선택 항목 다운로드",
"deleteImage": "이미지 삭제",
"deleteImage_other": "이미지 삭제",
"currentlyInUse": "이 이미지는 현재 다음 기능에서 사용되고 있습니다:",
"dropOrUpload": "$t(gallery.drop) 또는 업로드",
"copy": "복사",

View File

@ -42,7 +42,8 @@
"autoSwitchNewImages": "Wissel autom. naar nieuwe afbeeldingen",
"loadMore": "Laad meer",
"noImagesInGallery": "Geen afbeeldingen om te tonen",
"deleteImage": "Verwijder afbeelding",
"deleteImage_one": "Verwijder afbeelding",
"deleteImage_other": "",
"deleteImageBin": "Verwijderde afbeeldingen worden naar de prullenbak van je besturingssysteem gestuurd.",
"deleteImagePermanent": "Verwijderde afbeeldingen kunnen niet worden hersteld.",
"assets": "Eigen onderdelen",

View File

@ -86,7 +86,9 @@
"noImagesInGallery": "Изображений нет",
"deleteImagePermanent": "Удаленные изображения невозможно восстановить.",
"deleteImageBin": "Удаленные изображения будут отправлены в корзину вашей операционной системы.",
"deleteImage": "Удалить изображение",
"deleteImage_one": "Удалить изображение",
"deleteImage_few": "",
"deleteImage_many": "",
"assets": "Ресурсы",
"autoAssignBoardOnClick": "Авто-назначение доски по клику",
"deleteSelection": "Удалить выделенное",

View File

@ -298,7 +298,8 @@
"noImagesInGallery": "Gösterilecek Görsel Yok",
"autoSwitchNewImages": "Yeni Görseli Biter Bitmez Gör",
"currentlyInUse": "Bu görsel şurada kullanımda:",
"deleteImage": "Görseli Sil",
"deleteImage_one": "Görseli Sil",
"deleteImage_other": "",
"loadMore": "Daha Getir",
"setCurrentImage": "Çalışma Görseli Yap",
"unableToLoad": "Galeri Yüklenemedi",

View File

@ -78,7 +78,7 @@
"autoSwitchNewImages": "自动切换到新图像",
"loadMore": "加载更多",
"noImagesInGallery": "无图像可用于显示",
"deleteImage": "删除图片",
"deleteImage_other": "删除图片",
"deleteImageBin": "被删除的图片会发送到你操作系统的回收站。",
"deleteImagePermanent": "删除的图片无法被恢复。",
"assets": "素材",

View File

@ -27,7 +27,8 @@ export type LoggerNamespace =
| 'socketio'
| 'session'
| 'queue'
| 'dnd';
| 'dnd'
| 'regionalPrompts';
export const logger = (namespace: LoggerNamespace) => $logger.get().child({ namespace });

View File

@ -21,6 +21,11 @@ import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workf
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
import { queueSlice } from 'features/queue/store/queueSlice';
import {
regionalPromptsPersistConfig,
regionalPromptsSlice,
regionalPromptsUndoableConfig,
} from 'features/regionalPrompts/store/regionalPromptsSlice';
import { sdxlPersistConfig, sdxlSlice } from 'features/sdxl/store/sdxlSlice';
import { configSlice } from 'features/system/store/configSlice';
import { systemPersistConfig, systemSlice } from 'features/system/store/systemSlice';
@ -30,6 +35,7 @@ import { defaultsDeep, keys, omit, pick } from 'lodash-es';
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import type { SerializeFunction, UnserializeFunction } from 'redux-remember';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
import undoable from 'redux-undo';
import { serializeError } from 'serialize-error';
import { api } from 'services/api';
import { authToastMiddleware } from 'services/api/authToastMiddleware';
@ -59,6 +65,7 @@ const allReducers = {
[queueSlice.name]: queueSlice.reducer,
[workflowSlice.name]: workflowSlice.reducer,
[hrfSlice.name]: hrfSlice.reducer,
[regionalPromptsSlice.name]: undoable(regionalPromptsSlice.reducer, regionalPromptsUndoableConfig),
[api.reducerPath]: api.reducer,
};
@ -103,6 +110,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[loraPersistConfig.name]: loraPersistConfig,
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
[hrfPersistConfig.name]: hrfPersistConfig,
[regionalPromptsPersistConfig.name]: regionalPromptsPersistConfig,
};
const unserialize: UnserializeFunction = (data, key) => {
@ -114,6 +122,7 @@ const unserialize: UnserializeFunction = (data, key) => {
try {
const { initialState, migrate } = persistConfig;
const parsed = JSON.parse(data);
// strip out old keys
const stripped = pick(parsed, keys(initialState));
// run (additive) migrations
@ -141,7 +150,9 @@ const serialize: SerializeFunction = (data, key) => {
if (!persistConfig) {
throw new Error(`No persist config for slice "${key}"`);
}
const result = omit(data, persistConfig.persistDenylist);
// Heuristic to determine if the slice is undoable - could just hardcode it in the persistConfig
const isUndoable = 'present' in data && 'past' in data && 'future' in data && '_latestUnfiltered' in data;
const result = omit(isUndoable ? data.present : data, persistConfig.persistDenylist);
return JSON.stringify(result);
};

View File

@ -26,7 +26,7 @@ const sx: ChakraProps['sx'] = {
const colorPickerStyles: CSSProperties = { width: '100%' };
const numberInputWidth: ChakraProps['w'] = '4.2rem';
const numberInputWidth: ChakraProps['w'] = '3.5rem';
const IAIColorPicker = (props: IAIColorPickerProps) => {
const { color, onChange, withNumberInput, ...rest } = props;
@ -41,7 +41,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
{withNumberInput && (
<Flex gap={5}>
<FormControl gap={0}>
<FormLabel>{t('common.red')}</FormLabel>
<FormLabel>{t('common.red')[0]}</FormLabel>
<CompositeNumberInput
value={color.r}
onChange={handleChangeR}
@ -53,7 +53,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
/>
</FormControl>
<FormControl gap={0}>
<FormLabel>{t('common.green')}</FormLabel>
<FormLabel>{t('common.green')[0]}</FormLabel>
<CompositeNumberInput
value={color.g}
onChange={handleChangeG}
@ -65,7 +65,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
/>
</FormControl>
<FormControl gap={0}>
<FormLabel>{t('common.blue')}</FormLabel>
<FormLabel>{t('common.blue')[0]}</FormLabel>
<CompositeNumberInput
value={color.b}
onChange={handleChangeB}
@ -77,7 +77,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
/>
</FormControl>
<FormControl gap={0}>
<FormLabel>{t('common.alpha')}</FormLabel>
<FormLabel>{t('common.alpha')[0]}</FormLabel>
<CompositeNumberInput
value={color.a}
onChange={handleChangeA}

View File

@ -24,6 +24,7 @@ export type Feature =
| 'dynamicPromptsSeedBehaviour'
| 'imageFit'
| 'infillMethod'
| 'ipAdapterMethod'
| 'lora'
| 'loraWeight'
| 'noiseUseCPU'

View File

@ -0,0 +1,84 @@
import type { ChakraProps } from '@invoke-ai/ui-library';
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { CSSProperties } from 'react';
import { memo, useCallback } from 'react';
import { RgbColorPicker as ColorfulRgbColorPicker } from 'react-colorful';
import type { ColorPickerBaseProps, RgbColor } from 'react-colorful/dist/types';
import { useTranslation } from 'react-i18next';
type RgbColorPickerProps = ColorPickerBaseProps<RgbColor> & {
withNumberInput?: boolean;
};
const colorPickerPointerStyles: NonNullable<ChakraProps['sx']> = {
width: 6,
height: 6,
borderColor: 'base.100',
};
const sx: ChakraProps['sx'] = {
'.react-colorful__hue-pointer': colorPickerPointerStyles,
'.react-colorful__saturation-pointer': colorPickerPointerStyles,
'.react-colorful__alpha-pointer': colorPickerPointerStyles,
gap: 5,
flexDir: 'column',
};
const colorPickerStyles: CSSProperties = { width: '100%' };
const numberInputWidth: ChakraProps['w'] = '3.5rem';
const RgbColorPicker = (props: RgbColorPickerProps) => {
const { color, onChange, withNumberInput, ...rest } = props;
const { t } = useTranslation();
const handleChangeR = useCallback((r: number) => onChange({ ...color, r }), [color, onChange]);
const handleChangeG = useCallback((g: number) => onChange({ ...color, g }), [color, onChange]);
const handleChangeB = useCallback((b: number) => onChange({ ...color, b }), [color, onChange]);
return (
<Flex sx={sx}>
<ColorfulRgbColorPicker color={color} onChange={onChange} style={colorPickerStyles} {...rest} />
{withNumberInput && (
<Flex gap={5}>
<FormControl gap={0}>
<FormLabel>{t('common.red')[0]}</FormLabel>
<CompositeNumberInput
value={color.r}
onChange={handleChangeR}
min={0}
max={255}
step={1}
w={numberInputWidth}
defaultValue={90}
/>
</FormControl>
<FormControl gap={0}>
<FormLabel>{t('common.green')[0]}</FormLabel>
<CompositeNumberInput
value={color.g}
onChange={handleChangeG}
min={0}
max={255}
step={1}
w={numberInputWidth}
defaultValue={90}
/>
</FormControl>
<FormControl gap={0}>
<FormLabel>{t('common.blue')[0]}</FormLabel>
<CompositeNumberInput
value={color.b}
onChange={handleChangeB}
min={0}
max={255}
step={1}
w={numberInputWidth}
defaultValue={255}
/>
</FormControl>
</Flex>
)}
</Flex>
);
};
export default memo(RgbColorPicker);

View File

@ -9,7 +9,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
export const useGlobalHotkeys = () => {
const dispatch = useAppDispatch();
const isModelManagerEnabled = useFeatureStatus('modelManager').isFeatureEnabled;
const isModelManagerEnabled = useFeatureStatus('modelManager');
const { queueBack, isDisabled: isDisabledQueueBack, isLoading: isLoadingQueueBack } = useQueueBack();
useHotkeys(

View File

@ -0,0 +1,85 @@
import { moveBackward, moveForward, moveToBack, moveToFront } from 'common/util/arrayUtils';
import { describe, expect, it } from 'vitest';
describe('Array Manipulation Functions', () => {
const originalArray = ['a', 'b', 'c', 'd'];
describe('moveForwardOne', () => {
it('should move an item forward by one position', () => {
const array = [...originalArray];
const result = moveForward(array, (item) => item === 'b');
expect(result).toEqual(['a', 'c', 'b', 'd']);
});
it('should do nothing if the item is at the end', () => {
const array = [...originalArray];
const result = moveForward(array, (item) => item === 'd');
expect(result).toEqual(['a', 'b', 'c', 'd']);
});
it("should leave the array unchanged if the item isn't in the array", () => {
const array = [...originalArray];
const result = moveForward(array, (item) => item === 'z');
expect(result).toEqual(originalArray);
});
});
describe('moveToFront', () => {
it('should move an item to the front', () => {
const array = [...originalArray];
const result = moveToFront(array, (item) => item === 'c');
expect(result).toEqual(['c', 'a', 'b', 'd']);
});
it('should do nothing if the item is already at the front', () => {
const array = [...originalArray];
const result = moveToFront(array, (item) => item === 'a');
expect(result).toEqual(['a', 'b', 'c', 'd']);
});
it("should leave the array unchanged if the item isn't in the array", () => {
const array = [...originalArray];
const result = moveToFront(array, (item) => item === 'z');
expect(result).toEqual(originalArray);
});
});
describe('moveBackwardsOne', () => {
it('should move an item backward by one position', () => {
const array = [...originalArray];
const result = moveBackward(array, (item) => item === 'c');
expect(result).toEqual(['a', 'c', 'b', 'd']);
});
it('should do nothing if the item is at the beginning', () => {
const array = [...originalArray];
const result = moveBackward(array, (item) => item === 'a');
expect(result).toEqual(['a', 'b', 'c', 'd']);
});
it("should leave the array unchanged if the item isn't in the array", () => {
const array = [...originalArray];
const result = moveBackward(array, (item) => item === 'z');
expect(result).toEqual(originalArray);
});
});
describe('moveToBack', () => {
it('should move an item to the back', () => {
const array = [...originalArray];
const result = moveToBack(array, (item) => item === 'b');
expect(result).toEqual(['a', 'c', 'd', 'b']);
});
it('should do nothing if the item is already at the back', () => {
const array = [...originalArray];
const result = moveToBack(array, (item) => item === 'd');
expect(result).toEqual(['a', 'b', 'c', 'd']);
});
it("should leave the array unchanged if the item isn't in the array", () => {
const array = [...originalArray];
const result = moveToBack(array, (item) => item === 'z');
expect(result).toEqual(originalArray);
});
});
});

View File

@ -0,0 +1,37 @@
export const moveForward = <T>(array: T[], callback: (item: T) => boolean): T[] => {
const index = array.findIndex(callback);
if (index >= 0 && index < array.length - 1) {
//@ts-expect-error - These indicies are safe per the previous check
[array[index], array[index + 1]] = [array[index + 1], array[index]];
}
return array;
};
export const moveToFront = <T>(array: T[], callback: (item: T) => boolean): T[] => {
const index = array.findIndex(callback);
if (index > 0) {
const [item] = array.splice(index, 1);
//@ts-expect-error - These indicies are safe per the previous check
array.unshift(item);
}
return array;
};
export const moveBackward = <T>(array: T[], callback: (item: T) => boolean): T[] => {
const index = array.findIndex(callback);
if (index > 0) {
//@ts-expect-error - These indicies are safe per the previous check
[array[index], array[index - 1]] = [array[index - 1], array[index]];
}
return array;
};
export const moveToBack = <T>(array: T[], callback: (item: T) => boolean): T[] => {
const index = array.findIndex(callback);
if (index >= 0 && index < array.length - 1) {
const [item] = array.splice(index, 1);
//@ts-expect-error - These indicies are safe per the previous check
array.push(item);
}
return array;
};

View File

@ -10,6 +10,18 @@ import { clamp } from 'lodash-es';
import type { MutableRefObject } from 'react';
import { useCallback } from 'react';
export const calculateNewBrushSize = (brushSize: number, delta: number) => {
// This equation was derived by fitting a curve to the desired brush sizes and deltas
// see https://github.com/invoke-ai/InvokeAI/pull/5542#issuecomment-1915847565
const targetDelta = Math.sign(delta) * 0.7363 * Math.pow(1.0394, brushSize);
// This needs to be clamped to prevent the delta from getting too large
const finalDelta = clamp(targetDelta, -20, 20);
// The new brush size is also clamped to prevent it from getting too large or small
const newBrushSize = clamp(brushSize + finalDelta, 1, 500);
return newBrushSize;
};
const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
const dispatch = useAppDispatch();
const stageScale = useAppSelector((s) => s.canvas.stageScale);
@ -36,15 +48,7 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
}
if ($ctrl.get() || $meta.get()) {
// This equation was derived by fitting a curve to the desired brush sizes and deltas
// see https://github.com/invoke-ai/InvokeAI/pull/5542#issuecomment-1915847565
const targetDelta = Math.sign(delta) * 0.7363 * Math.pow(1.0394, brushSize);
// This needs to be clamped to prevent the delta from getting too large
const finalDelta = clamp(targetDelta, -20, 20);
// The new brush size is also clamped to prevent it from getting too large or small
const newBrushSize = clamp(brushSize + finalDelta, 1, 500);
dispatch(setBrushSize(newBrushSize));
dispatch(setBrushSize(calculateNewBrushSize(brushSize, delta)));
} else {
const cursorPos = stageRef.current.getPointerPosition();
let delta = e.evt.deltaY;

View File

@ -7,3 +7,22 @@ export const blobToDataURL = (blob: Blob): Promise<string> => {
reader.readAsDataURL(blob);
});
};
export function imageDataToDataURL(imageData: ImageData): string {
const { width, height } = imageData;
// Create a canvas to transfer the ImageData to
const canvas = document.createElement('canvas');
canvas.width = width;
canvas.height = height;
// Draw the ImageData onto the canvas
const ctx = canvas.getContext('2d');
if (!ctx) {
throw new Error('Unable to get canvas context');
}
ctx.putImageData(imageData, 0, 0);
// Convert the canvas to a data URL (base64)
return canvas.toDataURL();
}

View File

@ -1,6 +1,11 @@
import type { RgbaColor } from 'react-colorful';
import type { RgbaColor, RgbColor } from 'react-colorful';
export const rgbaColorToString = (color: RgbaColor): string => {
const { r, g, b, a } = color;
return `rgba(${r}, ${g}, ${b}, ${a})`;
};
export const rgbColorToString = (color: RgbColor): string => {
const { r, g, b } = color;
return `rgba(${r}, ${g}, ${b})`;
};

View File

@ -21,6 +21,7 @@ import ControlAdapterShouldAutoConfig from './ControlAdapterShouldAutoConfig';
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
import { ParamControlAdapterBeginEnd } from './parameters/ParamControlAdapterBeginEnd';
import ParamControlAdapterControlMode from './parameters/ParamControlAdapterControlMode';
import ParamControlAdapterIPMethod from './parameters/ParamControlAdapterIPMethod';
import ParamControlAdapterProcessorSelect from './parameters/ParamControlAdapterProcessorSelect';
import ParamControlAdapterResizeMode from './parameters/ParamControlAdapterResizeMode';
import ParamControlAdapterWeight from './parameters/ParamControlAdapterWeight';
@ -111,7 +112,8 @@ const ControlAdapterConfig = (props: { id: string; number: number }) => {
<Flex w="full" flexDir="column" gap={4}>
<Flex gap={8} w="full" alignItems="center">
<Flex flexDir="column" gap={2} h={32} w="full">
<Flex flexDir="column" gap={4} h={controlAdapterType === 'ip_adapter' ? 40 : 32} w="full">
<ParamControlAdapterIPMethod id={id} />
<ParamControlAdapterWeight id={id} />
<ParamControlAdapterBeginEnd id={id} />
</Flex>

View File

@ -0,0 +1,63 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useControlAdapterIPMethod } from 'features/controlAdapters/hooks/useControlAdapterIPMethod';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { controlAdapterIPMethodChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { IPMethod } from 'features/controlAdapters/store/types';
import { isIPMethod } from 'features/controlAdapters/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
type Props = {
id: string;
};
const ParamControlAdapterIPMethod = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const method = useControlAdapterIPMethod(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const options: { label: string; value: IPMethod }[] = useMemo(
() => [
{ label: t('controlnet.full'), value: 'full' },
{ label: `${t('controlnet.style')} (${t('common.beta')})`, value: 'style' },
{ label: `${t('controlnet.composition')} (${t('common.beta')})`, value: 'composition' },
],
[t]
);
const handleIPMethodChanged = useCallback<ComboboxOnChange>(
(v) => {
if (!isIPMethod(v?.value)) {
return;
}
dispatch(
controlAdapterIPMethodChanged({
id,
method: v.value,
})
);
},
[id, dispatch]
);
const value = useMemo(() => options.find((o) => o.value === method), [options, method]);
if (!method) {
return null;
}
return (
<FormControl>
<InformationalPopover feature="ipAdapterMethod">
<FormLabel>{t('controlnet.ipAdapterMethod')}</FormLabel>
</InformationalPopover>
<Combobox value={value} options={options} isDisabled={!isEnabled} onChange={handleIPMethodChanged} />
</FormControl>
);
};
export default memo(ParamControlAdapterIPMethod);

View File

@ -0,0 +1,24 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import {
selectControlAdapterById,
selectControlAdaptersSlice,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { useMemo } from 'react';
export const useControlAdapterIPMethod = (id: string) => {
const selector = useMemo(
() =>
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
const cn = selectControlAdapterById(controlAdapters, id);
if (cn && cn?.type === 'ip_adapter') {
return cn.method;
}
}),
[id]
);
const method = useAppSelector(selector);
return method;
};

View File

@ -6,6 +6,7 @@ import { deepClone } from 'common/util/deepClone';
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { maskLayerIPAdapterAdded } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { merge, uniq } from 'lodash-es';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { socketInvocationError } from 'services/events/actions';
@ -21,6 +22,7 @@ import type {
ControlAdapterType,
ControlMode,
ControlNetConfig,
IPMethod,
RequiredControlAdapterProcessorNode,
ResizeMode,
T2IAdapterConfig,
@ -245,6 +247,10 @@ export const controlAdaptersSlice = createSlice({
}
caAdapter.updateOne(state, { id, changes: { controlMode } });
},
controlAdapterIPMethodChanged: (state, action: PayloadAction<{ id: string; method: IPMethod }>) => {
const { id, method } = action.payload;
caAdapter.updateOne(state, { id, changes: { method } });
},
controlAdapterCLIPVisionModelChanged: (
state,
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
@ -377,6 +383,10 @@ export const controlAdaptersSlice = createSlice({
builder.addCase(socketInvocationError, (state) => {
state.pendingControlImages = [];
});
builder.addCase(maskLayerIPAdapterAdded, (state, action) => {
caAdapter.addOne(state, buildControlAdapter(action.meta.uuid, 'ip_adapter'));
});
},
});
@ -390,6 +400,7 @@ export const {
controlAdapterIsEnabledChanged,
controlAdapterModelChanged,
controlAdapterCLIPVisionModelChanged,
controlAdapterIPMethodChanged,
controlAdapterWeightChanged,
controlAdapterBeginStepPctChanged,
controlAdapterEndStepPctChanged,

View File

@ -210,6 +210,10 @@ const zResizeMode = z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_r
export type ResizeMode = z.infer<typeof zResizeMode>;
export const isResizeMode = (v: unknown): v is ResizeMode => zResizeMode.safeParse(v).success;
const zIPMethod = z.enum(['full', 'style', 'composition']);
export type IPMethod = z.infer<typeof zIPMethod>;
export const isIPMethod = (v: unknown): v is IPMethod => zIPMethod.safeParse(v).success;
export type ControlNetConfig = {
type: 'controlnet';
id: string;
@ -253,6 +257,7 @@ export type IPAdapterConfig = {
model: ParameterIPAdapterModel | null;
clipVisionModel: CLIPVisionModel;
weight: number;
method: IPMethod;
beginStepPct: number;
endStepPct: number;
};

View File

@ -46,6 +46,7 @@ export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
isEnabled: true,
controlImage: null,
model: null,
method: 'full',
clipVisionModel: 'ViT-H',
weight: 1,
beginStepPct: 0,

View File

@ -32,7 +32,7 @@ const BoardContextMenu = ({ board, board_id, setBoardToDelete, children }: Props
const isSelectedForAutoAdd = useAppSelector(selectIsSelectedForAutoAdd);
const boardName = useBoardName(board_id);
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled;
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
const [bulkDownload] = useBulkDownloadImagesMutation();

View File

@ -54,7 +54,7 @@ const CurrentImageButtons = () => {
const selection = useAppSelector((s) => s.gallery.selection);
const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons);
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
const isUpscalingEnabled = useFeatureStatus('upscaling');
const isQueueMutationInProgress = useIsQueueMutationInProgress();
const toaster = useAppToaster();
const { t } = useTranslation();

View File

@ -20,7 +20,7 @@ const MultipleSelectionMenuItems = () => {
const selection = useAppSelector((s) => s.gallery.selection);
const customStarUi = useStore($customStarUI);
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled;
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
const [starImages] = useStarImagesMutation();
const [unstarImages] = useUnstarImagesMutation();

View File

@ -45,7 +45,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const toaster = useAppToaster();
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const isCanvasEnabled = useFeatureStatus('unifiedCanvas');
const customStarUi = useStore($customStarUI);
const { downloadImage } = useDownloadImage();

View File

@ -1,4 +1,4 @@
import { Box, Flex, IconButton, Tooltip } from '@invoke-ai/ui-library';
import { Box, Flex, IconButton, Tooltip, useShiftModifier } from '@invoke-ai/ui-library';
import { getOverlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
import { isString } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
@ -9,18 +9,19 @@ import { PiCopyBold, PiDownloadSimpleBold } from 'react-icons/pi';
type Props = {
label: string;
data: object | string;
data: unknown;
fileName?: string;
withDownload?: boolean;
withCopy?: boolean;
extraCopyActions?: { label: string; getData: (data: unknown) => unknown }[];
};
const overlayscrollbarsOptions = getOverlayScrollbarsParams('scroll', 'scroll').options;
const DataViewer = (props: Props) => {
const { label, data, fileName, withDownload = true, withCopy = true } = props;
const { label, data, fileName, withDownload = true, withCopy = true, extraCopyActions } = props;
const dataString = useMemo(() => (isString(data) ? data : JSON.stringify(data, null, 2)), [data]);
const shift = useShiftModifier();
const handleCopy = useCallback(() => {
navigator.clipboard.writeText(dataString);
}, [dataString]);
@ -67,6 +68,10 @@ const DataViewer = (props: Props) => {
/>
</Tooltip>
)}
{shift &&
extraCopyActions?.map(({ label, getData }) => (
<ExtraCopyAction label={label} getData={getData} data={data} key={label} />
))}
</Flex>
</Flex>
);
@ -78,3 +83,27 @@ const overlayScrollbarsStyles: CSSProperties = {
height: '100%',
width: '100%',
};
type ExtraCopyActionProps = {
label: string;
data: unknown;
getData: (data: unknown) => unknown;
};
const ExtraCopyAction = ({ label, data, getData }: ExtraCopyActionProps) => {
const { t } = useTranslation();
const handleCopy = useCallback(() => {
navigator.clipboard.writeText(JSON.stringify(getData(data), null, 2));
}, [data, getData]);
return (
<Tooltip label={`${t('gallery.copy')} ${label} JSON`}>
<IconButton
aria-label={`${t('gallery.copy')} ${label} JSON`}
icon={<PiCopyBold size={16} />}
variant="ghost"
opacity={0.7}
onClick={handleCopy}
/>
</Tooltip>
);
};

View File

@ -18,7 +18,7 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
[imageDTO?.image_name]
);
const isSelected = useAppSelector(selectIsSelected);
const isMultiSelectEnabled = useFeatureStatus('multiselect').isFeatureEnabled;
const isMultiSelectEnabled = useFeatureStatus('multiselect');
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {

View File

@ -8,7 +8,7 @@ import ParamHrfStrength from './ParamHrfStrength';
import ParamHrfToggle from './ParamHrfToggle';
export const HrfSettings = memo(() => {
const isHRFFeatureEnabled = useFeatureStatus('hrf').isFeatureEnabled;
const isHRFFeatureEnabled = useFeatureStatus('hrf');
const hrfEnabled = useAppSelector((s) => s.hrf.hrfEnabled);
if (!isHRFFeatureEnabled) {

View File

@ -386,6 +386,10 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'weight'));
const method = zIPAdapterField.shape.method
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'method'));
const begin_step_percent = zIPAdapterField.shape.begin_step_percent
.nullish()
.catch(null)
@ -403,6 +407,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
clipVisionModel: 'ViT-H',
controlImage: image?.image_name ?? null,
weight: weight ?? initialIPAdapter.weight,
method: method ?? initialIPAdapter.method,
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
};

View File

@ -10,7 +10,7 @@ const TOAST_ID = 'starterModels';
export const useStarterModelsToast = () => {
const { t } = useTranslation();
const isEnabled = useFeatureStatus('starterModels').isFeatureEnabled;
const isEnabled = useFeatureStatus('starterModels');
const [didToast, setDidToast] = useState(false);
const [mainModels, { data }] = useMainModels();
const toast = useToast();

View File

@ -1,4 +1,4 @@
import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings';
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor';
@ -92,13 +92,9 @@ export const ControlNetOrT2IAdapterDefaultSettings = () => {
</Button>
</Flex>
<Flex flexDir="column" gap={8}>
<Flex gap={8}>
<Flex gap={4} w="full">
<DefaultPreprocessor control={control} name="preprocessor" />
</Flex>
</Flex>
</Flex>
<SimpleGrid columns={2} gap={8}>
<DefaultPreprocessor control={control} name="preprocessor" />
</SimpleGrid>
</>
);
};

View File

@ -1,4 +1,4 @@
import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings';
import { DefaultHeight } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultHeight';
@ -122,40 +122,16 @@ export const MainModelDefaultSettings = () => {
</Button>
</Flex>
<Flex flexDir="column" gap={8}>
<Flex gap={8}>
<Flex gap={4} w="full">
<DefaultVae control={control} name="vae" />
</Flex>
<Flex gap={4} w="full">
<DefaultVaePrecision control={control} name="vaePrecision" />
</Flex>
</Flex>
<Flex gap={8}>
<Flex gap={4} w="full">
<DefaultScheduler control={control} name="scheduler" />
</Flex>
<Flex gap={4} w="full">
<DefaultSteps control={control} name="steps" />
</Flex>
</Flex>
<Flex gap={8}>
<Flex gap={4} w="full">
<DefaultCfgScale control={control} name="cfgScale" />
</Flex>
<Flex gap={4} w="full">
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
</Flex>
</Flex>
<Flex gap={8}>
<Flex gap={4} w="full">
<DefaultWidth control={control} optimalDimension={optimalDimension} />
</Flex>
<Flex gap={4} w="full">
<DefaultHeight control={control} optimalDimension={optimalDimension} />
</Flex>
</Flex>
</Flex>
<SimpleGrid columns={2} gap={8}>
<DefaultVae control={control} name="vae" />
<DefaultVaePrecision control={control} name="vaePrecision" />
<DefaultScheduler control={control} name="scheduler" />
<DefaultSteps control={control} name="steps" />
<DefaultCfgScale control={control} name="cfgScale" />
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
<DefaultWidth control={control} optimalDimension={optimalDimension} />
<DefaultHeight control={control} optimalDimension={optimalDimension} />
</SimpleGrid>
</>
);
};

View File

@ -6,6 +6,7 @@ import {
FormLabel,
Heading,
Input,
SimpleGrid,
Text,
Textarea,
} from '@invoke-ai/ui-library';
@ -66,25 +67,21 @@ export const ModelEdit = ({ form }: Props) => {
<Heading as="h3" fontSize="md" mt="4">
{t('modelManager.modelSettings')}
</Heading>
<Flex gap={4}>
<SimpleGrid columns={2} gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={form.control} />
</FormControl>
</Flex>
{data.type === 'main' && data.format === 'checkpoint' && (
<>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect control={form.control} />
</FormControl>
{data.type === 'main' && data.format === 'checkpoint' && (
<>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
<Input {...form.register('config_path', stringFieldOptions)} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect control={form.control} />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect control={form.control} />
@ -93,9 +90,9 @@ export const ModelEdit = ({ form }: Props) => {
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<Checkbox {...form.register('upcast_attention')} />
</FormControl>
</Flex>
</>
)}
</>
)}
</SimpleGrid>
</Flex>
</form>
</Flex>

View File

@ -1,4 +1,4 @@
import { Box, Flex, Text } from '@invoke-ai/ui-library';
import { Box, Flex, SimpleGrid, Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { ControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
@ -24,57 +24,32 @@ export const ModelView = () => {
return (
<Flex flexDir="column" h="full" gap={4}>
<Box layerStyle="second" borderRadius="base" p={4}>
<Flex flexDir="column" gap={4}>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.baseModel')} value={data.base} />
<ModelAttrView label={t('modelManager.modelType')} value={data.type} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('common.format')} value={data.format} />
<ModelAttrView label={t('modelManager.path')} value={data.path} />
</Flex>
<SimpleGrid columns={2} gap={4}>
<ModelAttrView label={t('modelManager.baseModel')} value={data.base} />
<ModelAttrView label={t('modelManager.modelType')} value={data.type} />
<ModelAttrView label={t('common.format')} value={data.format} />
<ModelAttrView label={t('modelManager.path')} value={data.path} />
{data.type === 'main' && <ModelAttrView label={t('modelManager.variant')} value={data.variant} />}
{data.type === 'main' && data.format === 'diffusers' && data.repo_variant && (
<Flex gap={2}>
<ModelAttrView label={t('modelManager.repoVariant')} value={data.repo_variant} />
</Flex>
<ModelAttrView label={t('modelManager.repoVariant')} value={data.repo_variant} />
)}
{data.type === 'main' && data.format === 'checkpoint' && (
<>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.pathToConfig')} value={data.config_path} />
<ModelAttrView label={t('modelManager.variant')} value={data.variant} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.predictionType')} value={data.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${data.upcast_attention}`} />
</Flex>
<ModelAttrView label={t('modelManager.pathToConfig')} value={data.config_path} />
<ModelAttrView label={t('modelManager.predictionType')} value={data.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${data.upcast_attention}`} />
</>
)}
{data.type === 'ip_adapter' && data.format === 'invokeai' && (
<Flex gap={2}>
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
</Flex>
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
)}
</Flex>
</SimpleGrid>
</Box>
<Box layerStyle="second" borderRadius="base" p={4}>
{data.type === 'main' && data.base !== 'sdxl-refiner' && <MainModelDefaultSettings />}
{(data.type === 'controlnet' || data.type === 't2i_adapter') && <ControlNetOrT2IAdapterDefaultSettings />}
{(data.type === 'main' || data.type === 'lora') && <TriggerPhrases />}
</Box>
{data.type === 'main' && data.base !== 'sdxl-refiner' && (
<Box layerStyle="second" borderRadius="base" p={4}>
<MainModelDefaultSettings />
</Box>
)}
{(data.type === 'controlnet' || data.type === 't2i_adapter') && (
<Box layerStyle="second" borderRadius="base" p={4}>
<ControlNetOrT2IAdapterDefaultSettings />
</Box>
)}
{(data.type === 'main' || data.type === 'lora') && (
<Box layerStyle="second" borderRadius="base" p={4}>
<TriggerPhrases />
</Box>
)}
</Flex>
);
};

View File

@ -77,9 +77,17 @@ export const TriggerPhrases = () => {
[updateModel, selectedModelKey, triggerPhrases]
);
const onTriggerPhraseAddFormSubmit = useCallback(
(e: React.FormEvent<HTMLFormElement>) => {
e.preventDefault();
addTriggerPhrase();
},
[addTriggerPhrase]
);
return (
<Flex flexDir="column" w="full" gap="5">
<form>
<form onSubmit={onTriggerPhraseAddFormSubmit}>
<FormControl w="full" isInvalid={Boolean(errors.length)} orientation="vertical">
<FormLabel>{t('modelManager.triggerPhrases')}</FormLabel>
<Flex flexDir="column" w="full">

View File

@ -1,8 +1,9 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import type { EdgeProps } from 'reactflow';
import { BaseEdge, getBezierPath } from 'reactflow';
import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
import { makeEdgeSelector } from './util/makeEdgeSelector';
@ -25,9 +26,10 @@ const InvocationDefaultEdge = ({
[source, sourceHandleId, target, targetHandleId, selected]
);
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.shouldShowEdgeLabels);
const [edgePath] = getBezierPath({
const [edgePath, labelX, labelY] = getBezierPath({
sourceX,
sourceY,
sourcePosition,
@ -47,7 +49,33 @@ const InvocationDefaultEdge = ({
[isSelected, shouldAnimate, stroke]
);
return <BaseEdge path={edgePath} markerEnd={markerEnd} style={edgeStyles} />;
return (
<>
<BaseEdge path={edgePath} markerEnd={markerEnd} style={edgeStyles} />
{label && shouldShowEdgeLabels && (
<EdgeLabelRenderer>
<Flex
className="nodrag nopan"
pointerEvents="all"
position="absolute"
transform={`translate(-50%, -50%) translate(${labelX}px,${labelY}px)`}
bg="base.800"
borderRadius="base"
borderWidth={1}
borderColor={isSelected ? 'undefined' : 'transparent'}
opacity={isSelected ? 1 : 0.5}
py={1}
px={3}
shadow="md"
>
<Text size="sm" fontWeight="semibold" color={isSelected ? 'base.100' : 'base.300'}>
{label}
</Text>
</Flex>
</EdgeLabelRenderer>
)}
</>
);
};
export default memo(InvocationDefaultEdge);

View File

@ -1,7 +1,7 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldOutputTemplate } from 'features/nodes/store/selectors';
import { selectFieldOutputTemplate, selectNodeTemplate } from 'features/nodes/store/selectors';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getFieldColor } from './getEdgeColor';
@ -10,6 +10,7 @@ const defaultReturnValue = {
isSelected: false,
shouldAnimate: false,
stroke: colorTokenToCssVar('base.500'),
label: '',
};
export const makeEdgeSelector = (
@ -19,25 +20,34 @@ export const makeEdgeSelector = (
targetHandleId: string | null | undefined,
selected?: boolean
) =>
createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => {
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
createMemoizedSelector(
selectNodesSlice,
(nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
if (!sourceNode || !sourceHandleId) {
return defaultReturnValue;
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) {
return defaultReturnValue;
}
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
const sourceNodeTemplate = selectNodeTemplate(nodes, sourceNode.id);
const targetNodeTemplate = selectNodeTemplate(nodes, targetNode.id);
const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
return {
isSelected,
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
stroke,
label,
};
}
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
return {
isSelected,
shouldAnimate: nodes.shouldAnimateEdges && isSelected,
stroke,
};
});
);

View File

@ -16,7 +16,7 @@ const props: ChakraProps = { w: 'unset' };
const InvocationNodeFooter = ({ nodeId }: Props) => {
const hasImageOutput = useHasImageOutput(nodeId);
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
const isCacheEnabled = useFeatureStatus('invocationCache');
return (
<Flex
className={DRAG_HANDLE_CLASSNAME}

View File

@ -24,6 +24,7 @@ import {
selectNodesSlice,
shouldAnimateEdgesChanged,
shouldColorEdgesChanged,
shouldShowEdgeLabelsChanged,
shouldSnapToGridChanged,
shouldValidateGraphChanged,
} from 'features/nodes/store/nodesSlice';
@ -35,12 +36,20 @@ import { SelectionMode } from 'reactflow';
const formLabelProps: FormLabelProps = { flexGrow: 1 };
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionMode } = nodes;
const {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
shouldShowEdgeLabels,
selectionMode,
} = nodes;
return {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
shouldShowEdgeLabels,
selectionModeIsChecked: selectionMode === SelectionMode.Full,
};
});
@ -52,8 +61,14 @@ type Props = {
const WorkflowEditorSettings = ({ children }: Props) => {
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionModeIsChecked } =
useAppSelector(selector);
const {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
shouldShowEdgeLabels,
selectionModeIsChecked,
} = useAppSelector(selector);
const handleChangeShouldValidate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
@ -90,6 +105,13 @@ const WorkflowEditorSettings = ({ children }: Props) => {
[dispatch]
);
const handleChangeShouldShowEdgeLabels = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldShowEdgeLabelsChanged(e.target.checked));
},
[dispatch]
);
const { t } = useTranslation();
return (
@ -137,6 +159,14 @@ const WorkflowEditorSettings = ({ children }: Props) => {
<FormHelperText>{t('nodes.fullyContainNodesHelp')}</FormHelperText>
</FormControl>
<Divider />
<FormControl>
<Flex w="full">
<FormLabel>{t('nodes.showEdgeLabels')}</FormLabel>
<Switch isChecked={shouldShowEdgeLabels} onChange={handleChangeShouldShowEdgeLabels} />
</Flex>
<FormHelperText>{t('nodes.showEdgeLabelsHelp')}</FormHelperText>
</FormControl>
<Divider />
<Heading size="sm" pt={4}>
{t('common.advanced')}
</Heading>

View File

@ -1,5 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
@ -10,7 +10,7 @@ import { useMemo } from 'react';
export const useOutputFieldNames = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) => {
createMemoizedSelector(selectNodesSlice, (nodes) => {
const template = selectNodeTemplate(nodes, nodeId);
if (!template) {
return EMPTY_ARRAY;

View File

@ -5,8 +5,7 @@ import { useHasImageOutput } from './useHasImageOutput';
export const useWithFooter = (nodeId: string) => {
const hasImageOutput = useHasImageOutput(nodeId);
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
const isCacheEnabled = useFeatureStatus('invocationCache');
const withFooter = useMemo(() => hasImageOutput || isCacheEnabled, [hasImageOutput, isCacheEnabled]);
return withFooter;
};

View File

@ -103,6 +103,7 @@ const initialNodesState: NodesState = {
shouldAnimateEdges: true,
shouldSnapToGrid: false,
shouldColorEdges: true,
shouldShowEdgeLabels: false,
isAddNodePopoverOpen: false,
nodeOpacity: 1,
selectedNodes: [],
@ -549,6 +550,9 @@ export const nodesSlice = createSlice({
shouldAnimateEdgesChanged: (state, action: PayloadAction<boolean>) => {
state.shouldAnimateEdges = action.payload;
},
shouldShowEdgeLabelsChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowEdgeLabels = action.payload;
},
shouldSnapToGridChanged: (state, action: PayloadAction<boolean>) => {
state.shouldSnapToGrid = action.payload;
},
@ -831,6 +835,7 @@ export const {
viewportChanged,
edgeAdded,
nodeTemplatesBuilt,
shouldShowEdgeLabelsChanged,
} = nodesSlice.actions;
// This is used for tracking `state.workflow.isTouched`

View File

@ -23,6 +23,7 @@ export type NodesState = {
nodeOpacity: number;
shouldSnapToGrid: boolean;
shouldColorEdges: boolean;
shouldShowEdgeLabels: boolean;
selectedNodes: string[];
selectedEdges: string[];
nodeExecutionStates: Record<string, NodeExecutionState>;

View File

@ -109,6 +109,7 @@ export const zIPAdapterField = z.object({
image: zImageField,
ip_adapter_model: zModelIdentifierField,
weight: z.number(),
method: z.enum(['full', 'style', 'composition']),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),
});

View File

@ -19,12 +19,14 @@ export const addIPAdapterToLinearGraph = async (
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(({ model, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
const doesBaseMatch = model?.base === state.generation.model?.base;
const hasControlImage = controlImage;
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
});
const validIPAdapters = selectValidIPAdapters(state.controlAdapters)
.filter(({ model, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
const doesBaseMatch = model?.base === state.generation.model?.base;
const hasControlImage = controlImage;
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
})
.filter((ca) => !state.regionalPrompts.present.layers.some((l) => l.ipAdapterIds.includes(ca.id)));
if (validIPAdapters.length) {
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
@ -48,7 +50,7 @@ export const addIPAdapterToLinearGraph = async (
if (!ipAdapter.model) {
return;
}
const { id, weight, model, clipVisionModel, beginStepPct, endStepPct, controlImage } = ipAdapter;
const { id, weight, model, clipVisionModel, method, beginStepPct, endStepPct, controlImage } = ipAdapter;
assert(controlImage, 'IP Adapter image is required');
@ -57,6 +59,7 @@ export const addIPAdapterToLinearGraph = async (
type: 'ip_adapter',
is_intermediate: true,
weight: weight,
method: method,
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
begin_step_percent: beginStepPct,
@ -84,7 +87,7 @@ export const addIPAdapterToLinearGraph = async (
};
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, weight } = ipAdapter;
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, method, weight } = ipAdapter;
assert(model, 'IP Adapter model is required');
@ -102,6 +105,7 @@ const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadat
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
weight,
method,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image,

View File

@ -0,0 +1,346 @@
import { getStore } from 'app/store/nanostores/store';
import type { RootState } from 'app/store/store';
import { selectAllIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import {
IP_ADAPTER_COLLECT,
NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT,
POSITIVE_CONDITIONING,
POSITIVE_CONDITIONING_COLLECT,
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
PROMPT_REGION_NEGATIVE_COND_PREFIX,
PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX,
PROMPT_REGION_POSITIVE_COND_PREFIX,
} from 'features/nodes/util/graph/constants';
import { isVectorMaskLayer } from 'features/regionalPrompts/store/regionalPromptsSlice';
import { getRegionalPromptLayerBlobs } from 'features/regionalPrompts/util/getLayerBlobs';
import { size } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import type { CollectInvocation, Edge, IPAdapterInvocation, NonNullableGraph, S } from 'services/api/types';
import { assert } from 'tsafe';
export const addRegionalPromptsToGraph = async (state: RootState, graph: NonNullableGraph, denoiseNodeId: string) => {
if (!state.regionalPrompts.present.isEnabled) {
return;
}
const { dispatch } = getStore();
const isSDXL = state.generation.model?.base === 'sdxl';
const layers = state.regionalPrompts.present.layers
// Only support vector mask layers now
// TODO: Image masks
.filter(isVectorMaskLayer)
// Only visible layers are rendered on the canvas
.filter((l) => l.isVisible)
// Only layers with prompts get added to the graph
.filter((l) => {
const hasTextPrompt = Boolean(l.positivePrompt || l.negativePrompt);
const hasIPAdapter = l.ipAdapterIds.length !== 0;
return hasTextPrompt || hasIPAdapter;
});
const regionalIPAdapters = selectAllIPAdapters(state.controlAdapters).filter(
({ id, model, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
const doesBaseMatch = model?.base === state.generation.model?.base;
const hasControlImage = controlImage;
const isRegional = layers.some((l) => l.ipAdapterIds.includes(id));
return isEnabled && hasModel && doesBaseMatch && hasControlImage && isRegional;
}
);
const layerIds = layers.map((l) => l.id);
const blobs = await getRegionalPromptLayerBlobs(layerIds);
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
// the existing conditioning nodes.
// With regional prompts we have multiple conditioning nodes which much be routed into collectors. Set those up
const posCondCollectNode: CollectInvocation = {
id: POSITIVE_CONDITIONING_COLLECT,
type: 'collect',
};
graph.nodes[POSITIVE_CONDITIONING_COLLECT] = posCondCollectNode;
const negCondCollectNode: CollectInvocation = {
id: NEGATIVE_CONDITIONING_COLLECT,
type: 'collect',
};
graph.nodes[NEGATIVE_CONDITIONING_COLLECT] = negCondCollectNode;
// Re-route the denoise node's OG conditioning inputs to the collect nodes
const newEdges: Edge[] = [];
for (const edge of graph.edges) {
if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'positive_conditioning') {
newEdges.push({
source: edge.source,
destination: {
node_id: POSITIVE_CONDITIONING_COLLECT,
field: 'item',
},
});
} else if (edge.destination.node_id === denoiseNodeId && edge.destination.field === 'negative_conditioning') {
newEdges.push({
source: edge.source,
destination: {
node_id: NEGATIVE_CONDITIONING_COLLECT,
field: 'item',
},
});
} else {
newEdges.push(edge);
}
}
graph.edges = newEdges;
// Connect collectors to the denoise nodes - must happen _after_ rerouting else you get cycles
graph.edges.push({
source: {
node_id: POSITIVE_CONDITIONING_COLLECT,
field: 'collection',
},
destination: {
node_id: denoiseNodeId,
field: 'positive_conditioning',
},
});
graph.edges.push({
source: {
node_id: NEGATIVE_CONDITIONING_COLLECT,
field: 'collection',
},
destination: {
node_id: denoiseNodeId,
field: 'negative_conditioning',
},
});
if (!graph.nodes[IP_ADAPTER_COLLECT] && regionalIPAdapters.length > 0) {
const ipAdapterCollectNode: CollectInvocation = {
id: IP_ADAPTER_COLLECT,
type: 'collect',
is_intermediate: true,
};
graph.nodes[IP_ADAPTER_COLLECT] = ipAdapterCollectNode;
graph.edges.push({
source: { node_id: IP_ADAPTER_COLLECT, field: 'collection' },
destination: {
node_id: denoiseNodeId,
field: 'ip_adapter',
},
});
}
// Upload the blobs to the backend, add each to graph
// TODO: Store the uploaded image names in redux to reuse them, so long as the layer hasn't otherwise changed. This
// would be a great perf win - not only would we skip re-uploading the same image, but we'd be able to use the node
// cache (currently, when we re-use the same mask data, since it is a different image, the node cache is not used).
for (const layer of layers) {
const blob = blobs[layer.id];
assert(blob, `Blob for layer ${layer.id} not found`);
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();
// TODO: This will raise on network error
const { image_name } = await req.unwrap();
// The main mask-to-tensor node
const maskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
id: `${PROMPT_REGION_MASK_TO_TENSOR_PREFIX}_${layer.id}`,
type: 'alpha_mask_to_tensor',
image: {
image_name,
},
};
graph.nodes[maskToTensorNode.id] = maskToTensorNode;
if (layer.positivePrompt) {
// The main positive conditioning node
const regionalPositiveCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
style: layer.positivePrompt, // TODO: Should we put the positive prompt in both fields?
}
: {
type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
};
graph.nodes[regionalPositiveCondNode.id] = regionalPositiveCondNode;
// Connect the mask to the conditioning
graph.edges.push({
source: { node_id: maskToTensorNode.id, field: 'mask' },
destination: { node_id: regionalPositiveCondNode.id, field: 'mask' },
});
// Connect the conditioning to the collector
graph.edges.push({
source: { node_id: regionalPositiveCondNode.id, field: 'conditioning' },
destination: { node_id: posCondCollectNode.id, field: 'item' },
});
// Copy the connections to the "global" positive conditioning node to the regional cond
for (const edge of graph.edges) {
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalPositiveCondNode.id, field: edge.destination.field },
});
}
}
}
if (layer.negativePrompt) {
// The main negative conditioning node
const regionalNegativeCondNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.negativePrompt,
style: layer.negativePrompt,
}
: {
type: 'compel',
id: `${PROMPT_REGION_NEGATIVE_COND_PREFIX}_${layer.id}`,
prompt: layer.negativePrompt,
};
graph.nodes[regionalNegativeCondNode.id] = regionalNegativeCondNode;
// Connect the mask to the conditioning
graph.edges.push({
source: { node_id: maskToTensorNode.id, field: 'mask' },
destination: { node_id: regionalNegativeCondNode.id, field: 'mask' },
});
// Connect the conditioning to the collector
graph.edges.push({
source: { node_id: regionalNegativeCondNode.id, field: 'conditioning' },
destination: { node_id: negCondCollectNode.id, field: 'item' },
});
// Copy the connections to the "global" negative conditioning node to the regional cond
for (const edge of graph.edges) {
if (edge.destination.node_id === NEGATIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalNegativeCondNode.id, field: edge.destination.field },
});
}
}
}
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
if (layer.autoNegative === 'invert' && layer.positivePrompt) {
// We re-use the mask image, but invert it when converting to tensor
const invertTensorMaskNode: S['InvertTensorMaskInvocation'] = {
id: `${PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX}_${layer.id}`,
type: 'invert_tensor_mask',
};
graph.nodes[invertTensorMaskNode.id] = invertTensorMaskNode;
// Connect the OG mask image to the inverted mask-to-tensor node
graph.edges.push({
source: {
node_id: maskToTensorNode.id,
field: 'mask',
},
destination: {
node_id: invertTensorMaskNode.id,
field: 'mask',
},
});
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the
// positive prompt
const regionalPositiveCondInvertedNode: S['SDXLCompelPromptInvocation'] | S['CompelInvocation'] = isSDXL
? {
type: 'sdxl_compel_prompt',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
style: layer.positivePrompt,
}
: {
type: 'compel',
id: `${PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX}_${layer.id}`,
prompt: layer.positivePrompt,
};
graph.nodes[regionalPositiveCondInvertedNode.id] = regionalPositiveCondInvertedNode;
// Connect the inverted mask to the conditioning
graph.edges.push({
source: { node_id: invertTensorMaskNode.id, field: 'mask' },
destination: { node_id: regionalPositiveCondInvertedNode.id, field: 'mask' },
});
// Connect the conditioning to the negative collector
graph.edges.push({
source: { node_id: regionalPositiveCondInvertedNode.id, field: 'conditioning' },
destination: { node_id: negCondCollectNode.id, field: 'item' },
});
// Copy the connections to the "global" positive conditioning node to our regional node
for (const edge of graph.edges) {
if (edge.destination.node_id === POSITIVE_CONDITIONING && edge.destination.field !== 'prompt') {
graph.edges.push({
source: edge.source,
destination: { node_id: regionalPositiveCondInvertedNode.id, field: edge.destination.field },
});
}
}
}
for (const ipAdapterId of layer.ipAdapterIds) {
const ipAdapter = selectAllIPAdapters(state.controlAdapters)
.filter(({ id, model, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
const doesBaseMatch = model?.base === state.generation.model?.base;
const hasControlImage = controlImage;
const isRegional = layers.some((l) => l.ipAdapterIds.includes(id));
return isEnabled && hasModel && doesBaseMatch && hasControlImage && isRegional;
})
.find((ca) => ca.id === ipAdapterId);
if (!ipAdapter?.model) {
return;
}
const { id, weight, model, clipVisionModel, method, beginStepPct, endStepPct, controlImage } = ipAdapter;
assert(controlImage, 'IP Adapter image is required');
const ipAdapterNode: IPAdapterInvocation = {
id: `ip_adapter_${id}`,
type: 'ip_adapter',
is_intermediate: true,
weight: weight,
method: method,
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image: {
image_name: controlImage,
},
};
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
// Connect the mask to the conditioning
graph.edges.push({
source: { node_id: maskToTensorNode.id, field: 'mask' },
destination: { node_id: ipAdapterNode.id, field: 'mask' },
});
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
destination: {
node_id: IP_ADAPTER_COLLECT,
field: 'item',
},
});
}
}
};

View File

@ -9,6 +9,7 @@ import {
CANVAS_TEXT_TO_IMAGE_GRAPH,
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
INPAINT_CREATE_MASK,
INPAINT_IMAGE,
LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER,
@ -145,6 +146,16 @@ export const addVAEToGraph = async (
field: 'vae',
},
},
{
source: {
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? modelLoaderNodeId : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: INPAINT_CREATE_MASK,
field: 'vae',
},
},
{
source: {

View File

@ -133,6 +133,8 @@ export const buildCanvasInpaintGraph = async (
coherence_mode: canvasCoherenceMode,
minimum_denoise: canvasCoherenceMinDenoise,
edge_radius: canvasCoherenceEdgeSize,
tiled: false,
fp32: fp32,
},
[DENOISE_LATENTS]: {
type: 'denoise_latents',
@ -182,6 +184,16 @@ export const buildCanvasInpaintGraph = async (
field: 'clip',
},
},
{
source: {
node_id: modelLoaderNodeId,
field: 'unet',
},
destination: {
node_id: INPAINT_CREATE_MASK,
field: 'unet',
},
},
// Connect CLIP Skip to Conditioning
{
source: {
@ -331,6 +343,16 @@ export const buildCanvasInpaintGraph = async (
field: 'mask',
},
},
{
source: {
node_id: INPAINT_IMAGE_RESIZE_UP,
field: 'image',
},
destination: {
node_id: INPAINT_CREATE_MASK,
field: 'image',
},
},
// Resize Down
{
source: {

View File

@ -157,6 +157,8 @@ export const buildCanvasOutpaintGraph = async (
coherence_mode: canvasCoherenceMode,
edge_radius: canvasCoherenceEdgeSize,
minimum_denoise: canvasCoherenceMinDenoise,
tiled: false,
fp32: fp32,
},
[DENOISE_LATENTS]: {
type: 'denoise_latents',
@ -207,6 +209,16 @@ export const buildCanvasOutpaintGraph = async (
field: 'clip',
},
},
{
source: {
node_id: modelLoaderNodeId,
field: 'unet',
},
destination: {
node_id: INPAINT_CREATE_MASK,
field: 'unet',
},
},
// Connect CLIP Skip to Conditioning
{
source: {
@ -453,6 +465,16 @@ export const buildCanvasOutpaintGraph = async (
field: 'image',
},
},
{
source: {
node_id: INPAINT_IMAGE_RESIZE_UP,
field: 'image',
},
destination: {
node_id: INPAINT_CREATE_MASK,
field: 'image',
},
},
// Resize Results Down
{
source: {

View File

@ -135,6 +135,8 @@ export const buildCanvasSDXLInpaintGraph = async (
coherence_mode: canvasCoherenceMode,
minimum_denoise: refinerModel ? Math.max(0.2, canvasCoherenceMinDenoise) : canvasCoherenceMinDenoise,
edge_radius: canvasCoherenceEdgeSize,
tiled: false,
fp32: fp32,
},
[SDXL_DENOISE_LATENTS]: {
type: 'denoise_latents',
@ -214,6 +216,16 @@ export const buildCanvasSDXLInpaintGraph = async (
field: 'clip2',
},
},
{
source: {
node_id: modelLoaderNodeId,
field: 'unet',
},
destination: {
node_id: INPAINT_CREATE_MASK,
field: 'unet',
},
},
// Connect Everything To Inpaint Node
{
source: {
@ -342,6 +354,16 @@ export const buildCanvasSDXLInpaintGraph = async (
field: 'mask',
},
},
{
source: {
node_id: INPAINT_IMAGE_RESIZE_UP,
field: 'image',
},
destination: {
node_id: INPAINT_CREATE_MASK,
field: 'image',
},
},
// Resize Down
{
source: {

View File

@ -157,6 +157,8 @@ export const buildCanvasSDXLOutpaintGraph = async (
coherence_mode: canvasCoherenceMode,
edge_radius: canvasCoherenceEdgeSize,
minimum_denoise: refinerModel ? Math.max(0.2, canvasCoherenceMinDenoise) : canvasCoherenceMinDenoise,
tiled: false,
fp32: fp32,
},
[SDXL_DENOISE_LATENTS]: {
type: 'denoise_latents',
@ -237,6 +239,16 @@ export const buildCanvasSDXLOutpaintGraph = async (
field: 'clip2',
},
},
{
source: {
node_id: modelLoaderNodeId,
field: 'unet',
},
destination: {
node_id: INPAINT_CREATE_MASK,
field: 'unet',
},
},
// Connect Infill Result To Inpaint Image
{
source: {
@ -451,6 +463,16 @@ export const buildCanvasSDXLOutpaintGraph = async (
field: 'image',
},
},
{
source: {
node_id: INPAINT_IMAGE_RESIZE_UP,
field: 'image',
},
destination: {
node_id: INPAINT_CREATE_MASK,
field: 'image',
},
},
// Take combined mask and resize
{
source: {

View File

@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addRegionalPromptsToGraph } from 'features/nodes/util/graph/addRegionalPromptsToGraph';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
@ -273,6 +274,8 @@ export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addRegionalPromptsToGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addRegionalPromptsToGraph } from 'features/nodes/util/graph/addRegionalPromptsToGraph';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
@ -255,6 +256,8 @@ export const buildLinearTextToImageGraph = async (state: RootState): Promise<Non
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
await addRegionalPromptsToGraph(state, graph, DENOISE_LATENTS);
// High resolution fix.
if (state.hrf.hrfEnabled) {
addHrfToGraph(state, graph);

View File

@ -46,6 +46,13 @@ export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents';
export const SDXL_REFINER_INPAINT_CREATE_MASK = 'refiner_inpaint_create_mask';
export const SEAMLESS = 'seamless';
export const SDXL_REFINER_SEAMLESS = 'refiner_seamless';
export const PROMPT_REGION_MASK_TO_TENSOR_PREFIX = 'prompt_region_mask_to_tensor';
export const PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX = 'prompt_region_invert_tensor_mask';
export const PROMPT_REGION_POSITIVE_COND_PREFIX = 'prompt_region_positive_cond';
export const PROMPT_REGION_NEGATIVE_COND_PREFIX = 'prompt_region_negative_cond';
export const PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX = 'prompt_region_positive_cond_inverted';
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
// friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';

View File

@ -1,24 +1,18 @@
import { Box, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIColorPicker from 'common/components/IAIColorPicker';
import { selectGenerationSlice, setInfillColorValue } from 'features/parameters/store/generationSlice';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback } from 'react';
import type { RgbaColor } from 'react-colorful';
import { useTranslation } from 'react-i18next';
const selectInfillColor = createMemoizedSelector(selectGenerationSlice, (generation) => generation.infillColorValue);
const ParamInfillColorOptions = () => {
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(selectGenerationSlice, (generation) => ({
infillColor: generation.infillColorValue,
})),
[]
);
const { infillColor } = useAppSelector(selector);
const infillColor = useAppSelector(selectInfillColor);
const infillMethod = useAppSelector((s) => s.generation.infillMethod);

View File

@ -1,35 +1,23 @@
import { Box, CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIColorPicker from 'common/components/IAIColorPicker';
import {
selectGenerationSlice,
setInfillMosaicMaxColor,
setInfillMosaicMinColor,
setInfillMosaicTileHeight,
setInfillMosaicTileWidth,
} from 'features/parameters/store/generationSlice';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback } from 'react';
import type { RgbaColor } from 'react-colorful';
import { useTranslation } from 'react-i18next';
const ParamInfillMosaicTileSize = () => {
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(selectGenerationSlice, (generation) => ({
infillMosaicTileWidth: generation.infillMosaicTileWidth,
infillMosaicTileHeight: generation.infillMosaicTileHeight,
infillMosaicMinColor: generation.infillMosaicMinColor,
infillMosaicMaxColor: generation.infillMosaicMaxColor,
})),
[]
);
const { infillMosaicTileWidth, infillMosaicTileHeight, infillMosaicMinColor, infillMosaicMaxColor } =
useAppSelector(selector);
const infillMosaicTileWidth = useAppSelector((s) => s.generation.infillMosaicTileWidth);
const infillMosaicTileHeight = useAppSelector((s) => s.generation.infillMosaicTileHeight);
const infillMosaicMinColor = useAppSelector((s) => s.generation.infillMosaicMinColor);
const infillMosaicMaxColor = useAppSelector((s) => s.generation.infillMosaicMaxColor);
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
const { t } = useTranslation();

View File

@ -0,0 +1,13 @@
import { Flex } from '@invoke-ai/ui-library';
import { StageComponent } from 'features/regionalPrompts/components/StageComponent';
import { memo } from 'react';
export const AspectRatioCanvasPreview = memo(() => {
return (
<Flex w="full" h="full" alignItems="center" justifyContent="center" position="relative">
<StageComponent asPreview />
</Flex>
);
});
AspectRatioCanvasPreview.displayName = 'AspectRatioCanvasPreview';

View File

@ -2,7 +2,7 @@ import { useSize } from '@chakra-ui/react-use-size';
import { Flex, Icon } from '@invoke-ai/ui-library';
import { useImageSizeContext } from 'features/parameters/components/ImageSize/ImageSizeContext';
import { AnimatePresence, motion } from 'framer-motion';
import { useMemo, useRef } from 'react';
import { memo, useMemo, useRef } from 'react';
import { PiFrameCorners } from 'react-icons/pi';
import {
@ -15,7 +15,7 @@ import {
MOTION_ICON_INITIAL,
} from './constants';
export const AspectRatioPreview = () => {
export const AspectRatioIconPreview = memo(() => {
const ctx = useImageSizeContext();
const containerRef = useRef<HTMLDivElement>(null);
const containerSize = useSize(containerRef);
@ -70,4 +70,6 @@ export const AspectRatioPreview = () => {
</Flex>
</Flex>
);
};
});
AspectRatioIconPreview.displayName = 'AspectRatioIconPreview';

View File

@ -1,6 +1,5 @@
import type { FormLabelProps } from '@invoke-ai/ui-library';
import { Flex, FormControlGroup } from '@invoke-ai/ui-library';
import { AspectRatioPreview } from 'features/parameters/components/ImageSize/AspectRatioPreview';
import { AspectRatioSelect } from 'features/parameters/components/ImageSize/AspectRatioSelect';
import type { ImageSizeContextInnerValue } from 'features/parameters/components/ImageSize/ImageSizeContext';
import { ImageSizeContext } from 'features/parameters/components/ImageSize/ImageSizeContext';
@ -13,10 +12,11 @@ import { memo } from 'react';
type ImageSizeProps = ImageSizeContextInnerValue & {
widthComponent: ReactNode;
heightComponent: ReactNode;
previewComponent: ReactNode;
};
export const ImageSize = memo((props: ImageSizeProps) => {
const { widthComponent, heightComponent, ...ctx } = props;
const { widthComponent, heightComponent, previewComponent, ...ctx } = props;
return (
<ImageSizeContext.Provider value={ctx}>
<Flex gap={4} alignItems="center">
@ -33,7 +33,7 @@ export const ImageSize = memo((props: ImageSizeProps) => {
</FormControlGroup>
</Flex>
<Flex w="108px" h="108px" flexShrink={0} flexGrow={0}>
<AspectRatioPreview />
{previewComponent}
</Flex>
</Flex>
</ImageSizeContext.Provider>

View File

@ -1,7 +1,6 @@
import type { ComboboxOption } from '@invoke-ai/ui-library';
import type { AspectRatioID, AspectRatioState } from './types';
// When the aspect ratio is between these two values, we show the icon (experimentally determined)
export const ICON_LOW_CUTOFF = 0.23;
export const ICON_HIGH_CUTOFF = 1 / ICON_LOW_CUTOFF;
@ -25,7 +24,6 @@ export const ICON_CONTAINER_STYLES = {
alignItems: 'center',
justifyContent: 'center',
};
export const ASPECT_RATIO_OPTIONS: ComboboxOption[] = [
{ label: 'Free' as const, value: 'Free' },
{ label: '16:9' as const, value: '16:9' },

Some files were not shown because too many files have changed in this diff Show More