Compare commits

...

682 Commits

Author SHA1 Message Date
5db7b48cd8 chore(ui): lint 2024-08-28 21:38:57 +10:00
ea014c66ac chore: release v4.2.9.dev6 2024-08-28 21:33:12 +10:00
25918c28aa feat(ui): migrate add node popover to cmdk
Put this together as a way to figure out the library before moving on to the full app cmdk. Works great.
2024-08-28 21:32:10 +10:00
0c60469401 fix(ui): schema parsing now that node_pack is guaranteed to be present 2024-08-28 21:31:15 +10:00
f1aa50f447 chore(ui): typegen 2024-08-28 21:06:45 +10:00
a413b261f0 fix(app): node_pack not added to openapi schema correctly 2024-08-28 21:06:33 +10:00
4a1a6639f6 fix(ui): unnecessary z-index on invoke button 2024-08-28 16:51:44 +10:00
201c370ca1 feat(ui): split settings modal 2024-08-28 16:51:38 +10:00
d070c7c726 perf(ui): disable useInert on modals
This hook forcibly updates _all_ portals with `data-hidden=true` when the modal opens - then reverts it when the modal closes. It's intended to help screen readers. Unfortunately, this absolutely tanks performance because we have many portals. React needs to do alot of layout calculations (not re-renders).

IMO this behaviour is a bug in chakra. The modals which generated the portals are hidden by default, so this data attr should really be set by default. Dunno why it isn't.
2024-08-28 16:49:19 +10:00
e38e20a992 feat(ui): fix queue item count badge positioning
Previously this badge, floating over the queue menu button next to the invoke button, was rendered within the existing layout. When I initially positioned it, the app layout interfered - it would extend into an area reserved for a flex gap, which cut off the badge.

As a (bad) workaround, I had shifted the whole app down a few pixels to make room for it. What I should have done is what I've done in this commit - render the badge in a portal to take it out of the layout so we don't need that extra vertical padding.

Sleekified some styling a bit too.
2024-08-28 16:37:50 +10:00
39a94ec70e fix(ui): transparency effect not updating 2024-08-28 16:08:12 +10:00
c7bfae2d1e feat(ui): tidy canvas toolbar buttons 2024-08-28 16:04:37 +10:00
e7944c427d feat(ui): revised viewer toggle @joshistoast 2024-08-28 16:04:25 +10:00
48ed4e120d fix(ui): opacity reset value incorrect 2024-08-28 16:03:48 +10:00
a5b038a1b1 revert(ui): roll back flip, doesn't work with rotate yet 2024-08-28 12:03:38 +10:00
dc752c98b0 fix(ui): disable opacity slider fully when no valid entity selected 2024-08-28 09:00:55 +10:00
85a47cc6fe fix(ui): layer preview image sometimes not rendering
The canvas size was dynamic based on the container div's size. When the div was hidden (e.g. when selecting another tab), the container's effective size is 0. This resulted in the preview image canvas being drawn at a scale of 0.

Fixed by using an absolute size for the canvas container.
2024-08-28 08:59:38 +10:00
6450f42cfa feat(ui): tweak regional prompt box styles 2024-08-28 08:44:04 +10:00
3876f71ff4 feat(ui): tweak enabled/locked toggle styles 2024-08-28 08:25:37 +10:00
cf819e8eab feat(ui): tweak filter styling 2024-08-28 08:07:36 +10:00
2217fb8485 feat(ui): add flip & reset to transform 2024-08-28 07:57:15 +10:00
43652e830a tidy(ui): use helper to sync scaled bbox size on model change 2024-08-28 07:11:05 +10:00
a3417bf81d fix(ui): randomize seed toggle linked to prompt concat 2024-08-28 07:05:14 +10:00
06637161e3 chore: release v4.2.9.dev5 2024-08-27 19:57:34 +10:00
c4f4b16a36 chore(ui): lint 2024-08-27 19:57:02 +10:00
3001718f9f feat(ui): generalize mask fill, add to action bar 2024-08-27 19:56:57 +10:00
ac16fa65a3 feat(ui): implement interaction locking on layers 2024-08-27 19:27:38 +10:00
bc0b5335ff feat(ui): iterate on layer actions
- Add lock toggle
- Tweak lock and enabled styles
- Update entity list action bar w/ delete & delete all
- Move add layer menu to action bar
- Adjust opacity slider style
2024-08-27 19:08:57 +10:00
e91c7c5a30 feat(ui): collapsible entity groups 2024-08-27 17:59:21 +10:00
74791cc490 tidy(ui): rename some classes to be consistent 2024-08-27 17:27:38 +10:00
68409c6a0f feat(ui): tuned canvas undo/redo
- Throttle pushing to history for actions of the same type, starting with 1000ms throttle.
- History has a limit of 64 items, same as workflow editor
- Add clear history button
- Fix an issue where entity transformers would reset the entity state when the entity is fully transparent, resetting the redo stack. This could happen when you undo to the starting state of a layer
2024-08-27 17:10:01 +10:00
85613b220c tidy(ui): move all undoable reducers back to canvas slice 2024-08-27 16:25:42 +10:00
80085ad854 fix(ui): dnd image count 2024-08-27 16:16:11 +10:00
b6bfa65104 fix(ui): canvas entity opacity scale 2024-08-27 16:16:02 +10:00
0bfa033089 perf(ui): optimize all selectors 2
Mostly selector optimization. Still a few places to tidy up but I'll get to that later.
2024-08-27 16:04:05 +10:00
6f0974b5bc perf(ui): optimize all selectors 1
I learned that the inline selector syntax recreates the selector function on every render:

```ts
const val = useAppSelector((s) => s.slice.val)
```

Not good! Better is to create a selector outside the function and use it. Doing that for all selectors now, most of the way through now. Feels snappier.
2024-08-27 13:19:14 +10:00
c3b53fc4f6 feat(ui): rough out undo/redo on canvas 2024-08-27 12:05:45 +10:00
8f59a32d81 chore: release v4.2.9.dev4
Canvas dev build.
2024-08-27 11:12:28 +10:00
3b4f20f433 fix(ui): handle error from internal konva method
We are dipping into konva's private API for preview images and it appears to be unsafe (got an error once). Wrapped in a try/catch.
2024-08-27 11:02:18 +10:00
73da6e9628 feat(ui): split out loras state from canvas rendering state 2024-08-27 11:02:18 +10:00
2fc482141d feat(ui): split out session state from canvas rendering state 2024-08-27 11:02:18 +10:00
a41ec5f3fc feat(ui): split out settings state from canvas rendering state 2024-08-27 11:02:18 +10:00
c906225d03 feat(ui): split out tool state from canvas rendering state 2024-08-27 11:02:18 +10:00
2985ea3716 feat(ui): split out params/compositing state from canvas rendering state
First step to restoring undo/redo - the undoable state must be in its own slice. So params and settings must be isolated.
2024-08-27 11:02:18 +10:00
4f151c6c6f feat(ui): add CanvasModuleBase class to standardize canvas APIs
I did this ages ago but undid it for some reason, not sure why. Caught a few issues related to subscriptions.
2024-08-27 11:02:04 +10:00
c4f5252c1a feat(ui): move selected tool and tool buffer out of redux
This ephemeral state can live in the canvas classes.
2024-08-27 11:02:04 +10:00
77c13f2cf3 feat(ui): move ephemeral state into canvas classes
Things like `$lastCursorPos` are now created within the canvas drawing classes. Consumers in react access them via `useCanvasManager`.

For example:
```tsx
const canvasManager = useCanvasManager();
const lastCursorPos = useStore(canvasManager.stateApi.$lastCursorPos);
```
2024-08-27 11:02:04 +10:00
3270d36fca feat(ui): normalize all actions to accept an entityIdentifier
Previously, canvas actions specific to an entity type only needed the id of that entity type. This allowed you to pass in the id of an entity of the wrong type.

All actions for a specific entity now take a full entity identifier, and the entity identifier type can be narrowed.

`selectEntity` and `selectEntityOrThrow` now need a full entity identifier, and narrow their return values to a specific entity type _if_ the entity identifier is narrowed.

The types for canvas entities are updated with optional type parameters for this purpose.

All reducers, actions and components have been updated.
2024-08-27 11:02:04 +10:00
b6b30ff01f feat(ui): move events into modules who care about them 2024-08-27 11:02:04 +10:00
aa9bfdff35 fix(ui): color picker resets brush opacity 2024-08-27 11:02:04 +10:00
80308cc3b8 fix(ui): scaled bbox loses sync 2024-08-27 11:02:04 +10:00
f6db73bf1f feat(ui): add context menu to entity list 2024-08-27 11:02:04 +10:00
ef9f61a39f chore(ui): bump @invoke-ai/ui-library 2024-08-27 11:02:04 +10:00
a1a0881133 fix(ui): missing vae precision in graph builders 2024-08-27 11:02:04 +10:00
9956919ab6 chore: release v4.2.9.dev3
Instead of using dates, just going to increment.
2024-08-27 11:02:04 +10:00
abc07f57d6 feat(ui): use new Result utils for enqueueing 2024-08-27 11:02:04 +10:00
1a1cae79f1 fix(ui): graph building issue w/ controlnet 2024-08-27 11:02:04 +10:00
bcfafe7b06 feat(ui): add Result type & helpers
Wrappers to capture errors and turn into results:
- `withResult` wraps a sync function
- `withResultAsync` wraps an async function

Comments, tests.
2024-08-27 11:02:04 +10:00
34e8ced592 chore: release v4.2.9.dev20240824 2024-08-27 11:02:04 +10:00
1fdada65b6 fix(ui): lint & fix issues with adding regional ip adapters 2024-08-27 11:02:04 +10:00
433f3e1971 feat(ui): add knipignore tag
I'm not ready to delete some things but still want to build the app.
2024-08-27 11:02:04 +10:00
a60e23f825 feat(ui): duplicate entity 2024-08-27 11:02:04 +10:00
f69de3148e feat(ui): autocomplete on getPrefixeId 2024-08-27 11:02:04 +10:00
cbcd36ef54 feat(ui): paste canvas gens back on source in generate mode 2024-08-27 11:02:04 +10:00
aa76134340 chore(ui): typegen 2024-08-27 11:02:04 +10:00
55758acae8 feat(nodes): CanvasV2MaskAndCropInvocation can paste generated image back on source
This is needed for `Generate` mode.
2024-08-27 11:02:04 +10:00
196e43b5e5 fix(ui): extraneous entity preview updates 2024-08-27 11:02:04 +10:00
38b9828441 fix(ui): newly-added entities are selected 2024-08-27 11:02:04 +10:00
0048a7077e feat(ui): add crosshair to color picker 2024-08-27 11:02:04 +10:00
527a39a3ad fix(ui): color picker ignores alpha 2024-08-27 11:02:04 +10:00
30ce4c55c7 fix(ui): calculate renderable entities correctly in tool module 2024-08-27 11:02:04 +10:00
ca082d4288 feat(ui): better color picker 2024-08-27 11:02:04 +10:00
5e59a4f43a feat(ui): colored mask preview image 2024-08-27 11:02:04 +10:00
9f86605049 fix(ui): new rectangles don't trigger rerender 2024-08-27 11:02:04 +10:00
79058a7894 chore: bump version v4.2.9.dev20240823 2024-08-27 11:02:04 +10:00
bb3ad8c2f1 feat(ui): disable most interaction while filtering 2024-08-27 11:02:04 +10:00
799688514b fix(ui): filter preview offset 2024-08-27 11:02:04 +10:00
b7344b0df2 feat(ui): tweak layout of staging area toolbar 2024-08-27 11:02:04 +10:00
7e382c5f3f chore(ui): typegen 2024-08-27 11:02:04 +10:00
9cf357e184 tidy(app): clean up app changes for canvas v2 2024-08-27 11:01:53 +10:00
95b6c773d4 feat(ui): use singleton for clear q confirm dialog 2024-08-27 11:01:53 +10:00
89d8c5ba00 fix(ui): rip out broken recall logic, NO TS ERRORS 2024-08-27 11:01:53 +10:00
59580cf6ed chore(ui): lint 2024-08-27 11:01:53 +10:00
2b0c084f5b fix(ui): staging area interaction scopes 2024-08-27 11:01:53 +10:00
4d896073ff fix(ui): staging area actions 2024-08-27 11:01:53 +10:00
9f69503a80 tidy(ui): more cleanup 2024-08-27 11:01:53 +10:00
0311e852a0 fix(ui): upscale tab graph 2024-08-27 11:01:53 +10:00
7003a3d546 fix(ui): sdxl graph builder 2024-08-27 11:01:53 +10:00
dc73072e27 fix(ui): select next entity in the list when deleting 2024-08-27 11:01:53 +10:00
e549c44ad7 feat(ui): fix delete layer hotkey 2024-08-27 11:01:52 +10:00
45a4231cbe tidy(ui): "eye dropper" -> "color picker" 2024-08-27 11:01:52 +10:00
81f046ebac tidy(ui): regional guidance buttons 2024-08-27 11:01:52 +10:00
6ef6c593c4 feat(ui): update entity list menu 2024-08-27 11:01:52 +10:00
5b53eefef7 feat(ui): add log debug button 2024-08-27 11:01:52 +10:00
9a9919c0af chore(ui): lint 2024-08-27 11:01:52 +10:00
10661b33d4 chore(ui): prettier 2024-08-27 11:01:52 +10:00
52193d604d chore(ui): eslint 2024-08-27 11:01:52 +10:00
2568441e6a tidy(ui): remove unused stuff 4 2024-08-27 11:01:52 +10:00
1a14860b3b tidy(ui): remove unused stuff 3 2024-08-27 11:01:52 +10:00
9ff7647ec5 tidy(ui): remove unused pkg @chakra-ui/react-use-size 2024-08-27 11:01:52 +10:00
b49106e8fe feat(ui): revise graph building for control layers, fix issues w/ invocation complete events 2024-08-27 11:01:52 +10:00
906d0902a3 feat(ui): use unique id for metadata in Graph class 2024-08-27 11:01:52 +10:00
fbde6f5a7f tidy(ui): remove unused stuff 2 2024-08-27 11:01:52 +10:00
b388268987 tidy(ui): remove unused stuff 2024-08-27 11:01:52 +10:00
3b4164bd62 tidy(ui): reduce use of parseify util 2024-08-27 11:01:52 +10:00
b7fc6fe573 feat(ui): refine canvas entity list items & menus 2024-08-27 11:01:52 +10:00
2954a19d27 feat(ui): canvas layer preview, revised reactivity for adapters 2024-08-27 11:01:52 +10:00
aa45ce7fbd feat(ui): add SyncableMap
Can be used with useSyncExternal store to make a `Map` reactive.
2024-08-27 11:01:52 +10:00
77e5078e4a tidy(ui): removed unused transform methods from canvasmanager 2024-08-27 11:01:52 +10:00
603cc7bf2e feat(ui): transform tool ux 2024-08-27 11:01:52 +10:00
cd517a102d feat(ui): rough out canvas mode 2024-08-27 11:01:52 +10:00
9a442918b5 feat(ui): add canvas autosave checkbox 2024-08-27 11:01:52 +10:00
f9c03d85a5 fix(ui): memory leak when getting image DTO
must unsubscribe!
2024-08-27 11:01:52 +10:00
10d07c71c4 feat(ui): rework settings menu 2024-08-27 11:01:52 +10:00
cd05a78219 feat(ui): no entities fallback buttons 2024-08-27 11:01:52 +10:00
f8ee572abc perf(ui): optimize gallery image delete button rendering 2024-08-27 11:01:52 +10:00
d918654509 feat(ui): remove "solid" background option 2024-08-27 11:01:52 +10:00
582e30c542 tidy(ui): organise files and classes 2024-08-27 11:01:52 +10:00
34a6555301 tidy(ui): abstract compositing logic to module 2024-08-27 11:01:52 +10:00
fff860090b fix(ui): fix canvas cache property access 2024-08-27 11:01:52 +10:00
f4971197c1 tidy(ui): clean up CanvasFilter class 2024-08-27 11:01:52 +10:00
621d5e0462 tidy(ui): clean up a few bits and bobs 2024-08-27 11:01:52 +10:00
0b68a69a6c tidy(ui): abstract canvas rendering logic to module 2024-08-27 11:01:52 +10:00
9a599ce595 tidy(ui): abstract caching logic to module 2024-08-27 11:01:52 +10:00
1467ba276f tidy(ui): abstract worker logic to module 2024-08-27 11:01:52 +10:00
708facf707 tidy(ui): abstract stage logic into module 2024-08-27 11:01:52 +10:00
9c6c6adb1f feat(ui): add entity group hiding 2024-08-27 11:01:52 +10:00
c335b8581c feat(ui): move all caching out of redux
While we lose the benefit of the caches persisting across reloads, this is a much simpler way to handle things. If we need a persistent cache, we can explore it in the future.
2024-08-27 11:01:52 +10:00
f1348e45bd feat(ui): revised rasterization caching
- use `stable-hash` to generate stable, non-crypto hashes for cache entries, instead of using deep object comparisons
- use an object to store image name caches
2024-08-27 11:01:52 +10:00
ce6cf9b079 feat(ui): revise filter implementation 2024-08-27 11:01:52 +10:00
13ec80736a fix(ui): add button to delete inpaint mask 2024-08-27 11:01:52 +10:00
c9690a4b21 feat(ui): add contexts/hooks to access entity adapters directly 2024-08-27 11:01:52 +10:00
489e875a6e feat(ui): add CanvasManagerProviderGate
This context waits to render its children its until the canvas manager is available. Then its children have access to the manager directly via hook.
2024-08-27 11:01:52 +10:00
8651396048 feat(ui) do not set $canvasManager until ready 2024-08-27 11:01:52 +10:00
2bab5a6179 fix(ui): inpaint mask naming 2024-08-27 11:01:52 +10:00
006f06b615 feat(ui): efficient canvas compositing
Also solves issue of exporting layers at different opacities than what is visible
2024-08-27 11:01:52 +10:00
d603923d1b feat(ui): allow multiple inpaint masks
This is easier than making it a nullable singleton
2024-08-27 11:01:52 +10:00
86878e855b fix(ui): missing rasterization cache invalidations 2024-08-27 11:01:52 +10:00
35de60a8fa feat(ui): iterate on filter UI, flow 2024-08-27 11:01:52 +10:00
2c444a1941 fix(ui): rehydration data loss 2024-08-27 11:01:52 +10:00
3dfef01889 feat(ui): sort log namespaces 2024-08-27 11:01:52 +10:00
a845a2daa5 fix(ui): do not merge arrays by index during rehydration 2024-08-27 11:01:52 +10:00
df41f4fbce fix(ui): clone parsed data during state rehydration
Without this, the objects and arrays in `parsed` could be mutated, and the log statment would show the mutated data.
2024-08-27 11:01:52 +10:00
76482da6f5 fix(ui): fix logger filter
was accidetnally replacing the filter instead of appending to it.
2024-08-27 11:01:52 +10:00
8205abbbbf fix(ui): race condition queue status
Sequence of events causing the race condition:
- Enqueue batch
- Invalidate `SessionQueueStatus` tag
- Request updated queue status via HTTP - batch still processing at this point
- Batch completes
- Event emitted saying so
- Optimistically update the queue status cache, it is correct
- HTTP request makes it back and overwrites the optimistic update, indicating the batch is still in progress

FIxed by not invalidating the cache.
2024-08-27 11:01:52 +10:00
926873de26 fix(ui): handle opacity for masks 2024-08-27 11:01:52 +10:00
00cb1903ba feat(ui): default background to checkerboard 2024-08-27 11:01:52 +10:00
58ba38b9c7 feat(ui): clean up logging namespaces, allow skipping namespaces 2024-08-27 11:01:52 +10:00
2f6a5617f9 chore(ui): bump ui library 2024-08-27 11:01:52 +10:00
e0d84743be fix(ui): do not allow drawing if layer disabled 2024-08-27 11:01:52 +10:00
ee7c62acc4 fix(ui): stale state causing race conditions & extraneous renders 2024-08-27 11:01:52 +10:00
daf3e58bd9 fix(ui): do not clear buffer when rendering "real" objects 2024-08-27 11:01:52 +10:00
c5b9209057 tidy(ui): remove "filter" from CanvasImageState 2024-08-27 11:01:52 +10:00
2a4d6d98e2 feat(ui): better editable title 2024-08-27 11:01:52 +10:00
cfdf59d906 fix(ui): stroke eraserline 2024-08-27 11:01:52 +10:00
f91ce1a47c feat(ui): restore transparency effect for control layers 2024-08-27 11:01:52 +10:00
4af2888168 feat(ui): use text cursor for entity title 2024-08-27 11:01:52 +10:00
8471c6fe86 tidy(ui): remove extraneous logging in CanvasStateApi 2024-08-27 11:01:52 +10:00
fe65a5a2db feat(ui): better buffer commit logic 2024-08-27 11:01:52 +10:00
0df26e967c feat(ui): render buffer separately from "real" objects 2024-08-27 11:01:52 +10:00
d4822b305e fix(ui): pixelRect should always be integer 2024-08-27 11:01:52 +10:00
8df5447563 fix(ui): only update stage attrs when stage itself is dragged 2024-08-27 11:01:52 +10:00
7b5a43df9b feat(ui): add line simplification
This fixes some awkward issues where line segments stack up.
2024-08-27 11:01:52 +10:00
61ef630175 fix(ui): various things listening when they need not listen 2024-08-27 11:01:52 +10:00
4eda2ef555 feat(ui): layer opacity via caching 2024-08-27 11:01:52 +10:00
57f4489520 feat(ui): reset view fits all visible objects 2024-08-27 11:01:52 +10:00
fb6cf9e3da fix(ui): rerenders when changing canvas scale 2024-08-27 11:01:52 +10:00
f776326cff fix(ui): do not render rasterized layer unless renderObjects=true 2024-08-27 11:01:52 +10:00
5be32d5733 feat(ui): revise app layout strategy, add interaction scopes for hotkeys 2024-08-27 11:01:52 +10:00
1f73435241 feat(ui): tweak mask patterns 2024-08-27 11:01:52 +10:00
3251a00631 fix(ui): dynamic prompts recalcs when presets are loaded 2024-08-27 11:01:52 +10:00
49c4ad1dd7 fix(ui): use style preset prompts correctly 2024-08-27 11:01:52 +10:00
5857e95c4a fix(ui): discard selected staging image not all other images 2024-08-27 11:01:52 +10:00
85be2532c6 fix(ui): respect image size in staging preview 2024-08-27 11:01:52 +10:00
8b81a00def tidy(ui): cleanup after events change 2024-08-27 11:01:52 +10:00
8544595c27 feat(ui): move socket event handling out of redux
Download events and invocation status events (including progress images) are very frequent. There's no real need for these to pass through redux. Handling them outside redux is a significant performance win - far fewer store subscription calls, far fewer trips through middleware.

All event handling is moved outside middleware. Cleanup of unused actions and listeners to follow.
2024-08-27 11:01:52 +10:00
a6a5d1470c fix(ui): rebase conflicts 2024-08-27 11:01:39 +10:00
febcc12ec9 fix(ui): update compositing rect when fill changes 2024-08-27 11:01:39 +10:00
ab64078b76 feat(ui): add canvas background style 2024-08-27 11:01:39 +10:00
0ff3459b07 feat(ui): mask layers choose own opacity 2024-08-27 11:01:39 +10:00
2abd7c9bfe feat(ui): mask fill patterns 2024-08-27 11:01:39 +10:00
8e5330bdc9 build(ui): add vite types to tsconfig 2024-08-27 11:01:39 +10:00
1ecec4ea3a fix(ui): do not smooth pixel data when using eyeDropper 2024-08-27 11:01:39 +10:00
700dbe69f3 tidy(ui): tool components & translations 2024-08-27 11:01:39 +10:00
af7ba3b7e4 feat(ui): rough out eyedropper tool
It's a bit slow bc we are converting the stage to canvas on every mouse move. Also need to improve the visual but it works.
2024-08-27 11:01:39 +10:00
1f7144d62e fix(ui): ip adapters work 2024-08-27 11:01:39 +10:00
4e389e415b feat(ui): rename layers 2024-08-27 11:01:39 +10:00
2db29fb6ab feat(ui): revise entity menus 2024-08-27 11:01:39 +10:00
79653fcff5 feat(ui): split control layers from raster layers for UI and internal state, same rendering as raster layers 2024-08-27 11:01:39 +10:00
9e39180fbc feat(ui): implement cache for image rasterization, rip out some old controladapters code 2024-08-27 11:01:39 +10:00
dc0f832d8f feat(ui, app): use layer as control (wip) 2024-08-27 11:01:39 +10:00
0dcd6aa5d9 feat(ui): add contextmenu for canvas entities 2024-08-27 11:01:39 +10:00
9f4a8f11f8 feat(ui): more better logging & naming 2024-08-27 11:01:39 +10:00
6f9579d6ec feat(ui): better logging w/ path 2024-08-27 11:01:39 +10:00
2b2aabb234 feat(ui): always show marks on canvas scale slider 2024-08-27 11:01:39 +10:00
a79b9633ab fix(ui): do not import button from chakra 2024-08-27 11:01:39 +10:00
7b628c908b fix(ui): scaled bbox preview 2024-08-27 11:01:39 +10:00
181703a709 feat(ui): tidy up atoms 2024-08-27 11:01:39 +10:00
c439e3c204 feat(ui): convert all my pubsubs to atoms
its the same but better
2024-08-27 11:01:39 +10:00
11e81eb456 feat(ui): add trnalsation 2024-08-27 11:01:39 +10:00
3e24bf640e fix(ui): give up on thumbnail loading, causes flash during transformer 2024-08-27 11:01:39 +10:00
a17664fb75 fix(ui): depth anything v2 2024-08-27 11:01:39 +10:00
5aed23dc91 tidy(ui): remove unused code, comments 2024-08-27 11:01:39 +10:00
7f389716d0 fix(ui): staging area works 2024-08-27 11:01:39 +10:00
eca13b674a feat(nodes): temp disable canvas output crop 2024-08-27 11:01:39 +10:00
7c982a1bdf fix(ui): max scale 1 when reset view 2024-08-27 11:01:39 +10:00
bdfe6870fd feat(ui): better scale changer component, reset view functionality 2024-08-27 11:01:39 +10:00
7eebbc0dd9 fix(ui): img2img 2024-08-27 11:01:39 +10:00
6ae46d7c8b feat(ui): add manual scale controls 2024-08-27 11:01:39 +10:00
8aae30566e fix(ui): do not await clearBuffer 2024-08-27 11:01:39 +10:00
4b7c3e221c feat(ui): dnd image into layer 2024-08-27 11:01:39 +10:00
4015795b7f fix(ui): do not await commitBuffer 2024-08-27 11:01:39 +10:00
132dd61d8d fix(ui): properly destroy entities in manager cleanup 2024-08-27 11:01:39 +10:00
6f2b548dd1 tidy(ui): clearer component names for regional guidance 2024-08-27 11:01:39 +10:00
49dd316f17 tidy(ui): clearer component names for ip adapter 2024-08-27 11:01:39 +10:00
e0a8bb149d tidy(ui): clearer component names for inpaint mask 2024-08-27 11:01:39 +10:00
020b6db34b tidy(ui): clearer component names for control adapters 2024-08-27 11:01:39 +10:00
f6d2f0bf8c feat(ui): simplify canvas list item headers 2024-08-27 11:01:39 +10:00
1280cce803 fix(ui): ip adapter list item 2024-08-27 11:01:39 +10:00
82463d12e2 tidy(ui): clean up unused logic 2024-08-27 11:01:39 +10:00
ba6c1b84e4 feat(ui): clean up state, add mutex for image loading, add thumbnail loading 2024-08-27 11:01:39 +10:00
92b6d3198a chore(ui): add async-mutex dep 2024-08-27 11:01:39 +10:00
693ae1af50 feat(ui): txt2img, img2img, inpaint & outpaint working 2024-08-27 11:01:39 +10:00
3873a3096c feat(ui): no padding on transformer outlines 2024-08-27 11:01:39 +10:00
4a9f6ab5ef feat(ui): restore object count to layer titles 2024-08-27 11:01:39 +10:00
2fc29c5125 tidy(ui): "useIsEntitySelected" -> "useEntityIsSelected" 2024-08-27 11:01:39 +10:00
5fbc876cfd tidy(ui): move transformer statics into class 2024-08-27 11:01:39 +10:00
89b0673ac9 tidy(ui): massive cleanup
- create a context for entity identifiers, massively simplifying UI for each entity int he list
- consolidate common redux actions
- remove now-unused code
2024-08-27 11:01:39 +10:00
3179a16189 perf(ui): do not add duplicate points to lines 2024-08-27 11:01:39 +10:00
3ce216b391 feat(ui): up line tension to 0.3 2024-08-27 11:01:39 +10:00
8a3a94e21a perf(ui): disable stroke, perfect draw on compositing rect 2024-08-27 11:01:39 +10:00
f61af188f9 tidy(ui): remove unused code, initial image 2024-08-27 11:01:39 +10:00
73fc52bfed tidy(ui): remove unused state & actions 2024-08-27 11:01:39 +10:00
4eeff4eef8 feat(ui): region mask rendering 2024-08-27 11:01:39 +10:00
cff2c43030 feat(ui): esc cancels drawing buffer
maybe this is not wanted? we'll see
2024-08-27 11:01:39 +10:00
504b1f2425 fix(ui): render transformer over objects, fix issue w/ inpaint rect color 2024-08-27 11:01:38 +10:00
eff5b56990 fix(ui): brush preview fill for inpaint/region 2024-08-27 11:01:38 +10:00
0e673f1a18 fix(ui): no objects rendered until vis toggled 2024-08-27 11:01:38 +10:00
4fea22aea4 feat(ui): inpaint mask transform 2024-08-27 11:01:38 +10:00
c2478c9ac3 fix(ui): layer accidental early set isFirstRender=false 2024-08-27 11:01:38 +10:00
ed8243825e fix(ui): inpaint mask rendering 2024-08-27 11:01:38 +10:00
c7ac2b5278 feat(ui): wip inpaint mask uses new API 2024-08-27 11:01:38 +10:00
a783003556 feat(ui): move updatePosition to transformer 2024-08-27 11:01:38 +10:00
44ba1c6113 feat(ui): move resetScale to transformer 2024-08-27 11:01:38 +10:00
a02d67fcc6 tidy(ui): more imperative naming 2024-08-27 11:01:38 +10:00
14c8f7c4f5 tidy(ui): use imperative names for setters in stateapi 2024-08-27 11:01:38 +10:00
a487ecb50f fix(ui): commit drawing buffer on tool change, fixing bbox not calculating 2024-08-27 11:01:38 +10:00
fc55862823 fix(ui): sync transformer when requesting bbox calc 2024-08-27 11:01:38 +10:00
e5400601d6 tidy(ui): rename union CanvasEntity -> CanvasEntityState 2024-08-27 11:01:38 +10:00
735f9f1483 fix(ui): request rect calc immediately on transform, hiding rect 2024-08-27 11:01:38 +10:00
d139db0a0f feat(ui): move bbox calculation to transformer 2024-08-27 11:01:38 +10:00
a35bb450b1 feat(ui): use set for transformer subscriptions 2024-08-27 11:01:38 +10:00
26c01dfa48 tidy(ui): clean up worker tasks when complete 2024-08-27 11:01:38 +10:00
2b1839374a tidy(ui): remove unused code in CanvasTool 2024-08-27 11:01:38 +10:00
59f5f18e1d feat(ui): use pubsub for isTransforming on manager 2024-08-27 11:01:38 +10:00
e41fcb081c docs(ui): update transformer docstrings 2024-08-27 11:01:38 +10:00
3032042b35 feat(ui): revised event pubsub, transformer logic split out 2024-08-27 11:01:38 +10:00
544db61044 feat(ui): add simple pubsub 2024-08-27 11:01:38 +10:00
67a3aa6dff feat(ui): document & clean up object renderer 2024-08-27 11:01:38 +10:00
34f4468b20 feat(ui): split out object renderer 2024-08-27 11:01:38 +10:00
d99ae58001 fix(ui): unable to hold shit while transforming to retain ratio 2024-08-27 11:01:38 +10:00
d60ec53762 tidy(ui): rename canvas stuff 2024-08-27 11:01:38 +10:00
0ece9361d5 tidy(ui): consolidate getLoggingContext builders 2024-08-27 11:01:38 +10:00
69987a2f00 fix(ui): align all tools to 1px grid
- Offset brush tool by 0.5px when width is odd, ensuring each stroke edge is exactly on a pixel boundary
- Round the rect tool also
2024-08-27 11:01:38 +10:00
7346bfccb9 feat(ui): disable image smoothing on layers 2024-08-27 11:01:38 +10:00
b705083ce2 fix(ui): round position when rasterizing layer 2024-08-27 11:01:38 +10:00
4b3c82df6f feat(ui): continue modularizing transform 2024-08-27 11:01:38 +10:00
46290205d5 feat(ui): fix a few things that didn't unsubscribe correctly, add helper to manage subscriptions 2024-08-27 11:01:38 +10:00
7b15585b80 feat(ui): merge bbox outline into transformer 2024-08-27 11:01:38 +10:00
cbfeeb9079 fix(ui): update parent's pos not transformers 2024-08-27 11:01:38 +10:00
fdac20b43e feat(ui): merge interaction rect into transformer class 2024-08-27 11:01:38 +10:00
ae2312104e feat(ui): prepare staging area 2024-08-27 11:01:38 +10:00
74b6674af6 feat(ui): typing for logging context 2024-08-27 11:01:38 +10:00
81adce3238 feat(ui): remove inheritance of CanvasObject
JS is terrible
2024-08-27 11:01:38 +10:00
48b7f460a8 feat(ui): split & document transformer logic, iterate on class structures 2024-08-27 11:01:38 +10:00
38a8232341 feat(ui): rotation snap to nearest 45deg when holding shift 2024-08-27 11:01:38 +10:00
39a51c4f4c feat(ui): expose subscribe method for nanostores 2024-08-27 11:01:38 +10:00
220bfeb37d tidy(ui): remove layer scaling reducers 2024-08-27 11:01:38 +10:00
e766279950 fix(ui): pixel-perfect transforms 2024-08-27 11:01:38 +10:00
db82406525 fix(ui): layer visibility toggle 2024-08-27 11:01:38 +10:00
2d44f332d9 fix(nodes): fix canvas mask erode
it wasn't eroding enough and caused incorrect transparency in result images
2024-08-27 11:01:38 +10:00
ea1526689a fix(ui): do not reset layer on first render 2024-08-27 11:01:38 +10:00
200338ed72 feat(ui): revised logging and naming setup, fix staging area 2024-08-27 11:01:38 +10:00
88003a61bd feat(ui): add repr methods to layer and object classes 2024-08-27 11:01:38 +10:00
b82089b30b feat(ui): use nanoid(10) instead of uuidv4 for canvas
Shorter ids makes it much more readable
2024-08-27 11:01:38 +10:00
efd780d395 build(ui): add nanoid as explicit dep 2024-08-27 11:01:38 +10:00
3f90f783de fix(ui): move CanvasImage's konva image to correct object 2024-08-27 11:01:38 +10:00
4f5b755117 fix(ui): prevent flash when applying transform 2024-08-27 11:01:38 +10:00
a455f12581 build(ui): add eslint rules for async stuff 2024-08-27 11:01:38 +10:00
de516db383 feat(ui): trying to fix flicker after transform 2024-08-27 11:01:38 +10:00
6781575293 feat(ui): transform cleanup 2024-08-27 11:01:38 +10:00
65b0e40fc8 feat(ui): fix transform when rotated 2024-08-27 11:01:38 +10:00
19e78a07b7 fix(ui): use pixel bbox when image is in layer 2024-08-27 11:01:38 +10:00
e3b60dda07 fix(ui): transforming when axes flipped 2024-08-27 11:01:38 +10:00
a9696d3193 feat(ui): hallelujah (???) 2024-08-27 11:01:38 +10:00
336d72873f feat(ui): add debug button 2024-08-27 11:01:38 +10:00
111a380bce fix(ui): transformer padding 2024-08-27 11:01:38 +10:00
012a8351af feat(ui): wip transform mode 2 2024-08-27 11:01:38 +10:00
deeb80ea9b feat(ui): wip transform mode 2024-08-27 11:01:38 +10:00
1e97a917d6 feat(ui): wip transform mode 2024-08-27 11:01:38 +10:00
a15ba925db fix(ui): dnd to canvas broke 2024-08-27 11:01:38 +10:00
6e5a968aad fix(ui): conflicts after rebasing 2024-08-27 11:01:38 +10:00
915edaa02f fix(ui): imageDropped listener 2024-08-27 11:01:38 +10:00
f6b0fa7c18 wip 2024-08-27 11:01:38 +10:00
3f1fba0f35 fix(ui): transform tool seems to be working 2024-08-27 11:01:38 +10:00
d9fa85a4c6 fix(ui): move tool fixes, add transform tool 2024-08-27 11:01:38 +10:00
022bb8649c feat(ui): move tool now only moves 2024-08-27 11:01:38 +10:00
673bc33a01 feat(ui): layer bbox calc in worker 2024-08-27 11:01:38 +10:00
579a64928d feat(ui): tweaked entity & group selection styles 2024-08-27 11:01:38 +10:00
5316df7d7d feat(ui): canvas entity list headers 2024-08-27 11:01:38 +10:00
62fe61dc30 tidy(ui): CanvasRegion 2024-08-27 11:01:38 +10:00
ea819a4a2f tidy(ui): CanvasRect 2024-08-27 11:01:38 +10:00
868a25dae2 tidy(ui): CanvasLayer 2024-08-27 11:01:38 +10:00
3a25d00cb6 tidy(ui): CanvasInpaintMask 2024-08-27 11:01:38 +10:00
6301b74d87 tidy(ui): CanvasInitialImage 2024-08-27 11:01:38 +10:00
b43b90c299 tidy(ui): CanvasImage 2024-08-27 11:01:38 +10:00
0a43444ab3 tidy(ui): CanvasEraserLine 2024-08-27 11:01:38 +10:00
75ea4d2155 tidy(ui): CanvasControlAdapter 2024-08-27 11:01:38 +10:00
9d281388e0 tidy(ui): CanvasBrushLine 2024-08-27 11:01:38 +10:00
178e1cc50b tidy(ui): CanvasBbox 2024-08-27 11:01:38 +10:00
6eafe53b2d tidy(ui): CanvasBackground 2024-08-27 11:01:38 +10:00
7514b2b7d4 tidy(ui): update canvas classes, organise location of konva nodes 2024-08-27 11:01:38 +10:00
51dee2dba2 feat(ui): add names to all konva objects
Makes troubleshooting much simpler
2024-08-27 11:01:38 +10:00
a81c3d841c fix(ui): do not await creating new canvas image
If you await this, it causes a race condition where multiple images are created.
2024-08-27 11:01:38 +10:00
154487f966 feat(ui): use position and dimensions instead of separate x,y,width,height attrs 2024-08-27 11:01:38 +10:00
588daafcf5 fix(ui): remove weird rtkq hook wrapper
I do not understand why I did that initially but it doesn't work with TS.
2024-08-27 11:01:38 +10:00
ab00097aed feat(ui): rename types size and position to dimensions and coordinate 2024-08-27 11:01:38 +10:00
5aefae71a2 tidy(ui): hide layer settings by default 2024-08-27 11:01:38 +10:00
0edd598970 fix(ui): layer rendering when starting as disabled 2024-08-27 11:01:38 +10:00
0bb485031f feat(invocation): reduce canvas v2 mask & crop mask dilation 2024-08-27 11:01:38 +10:00
01ffd86367 feat(ui): de-jank staging area and progress images 2024-08-27 11:01:38 +10:00
e2e02f31b6 feat(ui): update staging handling to work w/ cropped mask 2024-08-27 11:01:38 +10:00
0008617348 chore(ui): typegen 2024-08-27 11:01:38 +10:00
f8f21c0edd feat(app): update CanvasV2MaskAndCropInvocation 2024-08-27 11:01:38 +10:00
010916158b feat(ui): use new canvas output node 2024-08-27 11:01:38 +10:00
df0ba004ca chore(ui): typegen 2024-08-27 11:01:38 +10:00
038b29e15b feat(app): add CanvasV2MaskAndCropInvocation & CanvasV2MaskAndCropOutput
This handles some masking and cropping that the canvas needs.
2024-08-27 11:01:21 +10:00
763ab73923 fix(ui): restore nodes output tracking 2024-08-27 11:01:21 +10:00
9a060b4437 feat(ui): rip out document size
barely knew ye
2024-08-27 11:01:21 +10:00
2307a30892 feat(ui): convert initial image to layer when starting canvas session 2024-08-27 11:01:21 +10:00
2006f84f6e fix(ui): fix layer transparency calculation 2024-08-27 11:01:21 +10:00
23b15fef6a fix(ui): reset initial image when resetting canvas 2024-08-27 11:01:21 +10:00
41e72f929d fix(ui): reset node executions states when loading workflow 2024-08-27 11:01:21 +10:00
29fc49bb3b fix(ui): entity display list 2024-08-27 11:01:21 +10:00
6f65b6a40f feat(ui): img2img working 2024-08-27 11:01:21 +10:00
1e9b22e3a4 feat(ui): rough out img2img on canvas 2024-08-27 11:01:21 +10:00
3dc2c723c3 UNDO ME WIP 2024-08-27 11:01:21 +10:00
6f007cbd48 feat(ui): log invocation source id on socket event 2024-08-27 11:01:21 +10:00
9a402dd10e feat(ui): restore document size overlay renderer 2024-08-27 11:01:21 +10:00
4249d0e13b feat(ui): make documnet size a rect 2024-08-27 11:01:21 +10:00
f6050bad67 refactor(ui): remove modular imagesize components
This is no longer necessary with canvas v2 and added a ton of extraneous redux actions when changing the image size. Also renamed to document size
2024-08-27 11:01:21 +10:00
d3a4b7b51b feat(ui): initialState is for generation mode 2024-08-27 11:01:21 +10:00
3f5f9ac764 feat(ui): split out canvas entity list component 2024-08-27 11:01:21 +10:00
e7c1299a7f feat(ui): hide bbox button when no canvas session active 2024-08-27 11:01:21 +10:00
56a3918a1e tidy(ui): remove unused naming objects/utils
The canvas manager means we don't need to worry about konva node names as we never directly select konva nodes.
2024-08-27 11:01:21 +10:00
dabf7718cf feat(ui): split up tool chooser buttons
Prep for distinct toolbars for generation vs canvas modes
2024-08-27 11:01:21 +10:00
ef22c29288 feat(ui): add useAssertSingleton util hook
This simple hook asserts that it is only ever called once. Particularly useful for things like hotkeys hooks.
2024-08-27 11:01:21 +10:00
74f06074f7 feat(ui): "stagingArea" -> "session" 2024-08-27 11:01:21 +10:00
b97bf52faa feat(ui): add reset button to canvas 2024-08-27 11:01:21 +10:00
0a77f5cec8 feat(ui): add snapToRect util 2024-08-27 11:01:21 +10:00
f8e92f7b73 fix(ui): fiddle with control adapter filters
some jank still
2024-08-27 11:01:21 +10:00
530c6e3a59 feat(ui): temp disable doc size overlay 2024-08-27 11:01:21 +10:00
11b95cfaf4 feat(ui): no animation on layer selection
Felt sluggish
2024-08-27 11:01:21 +10:00
707c005a26 feat(ui): use canvas as source for control images (wip) 2024-08-27 11:01:21 +10:00
19fa8e7e33 fix(ui): control adapter translate & scale 2024-08-27 11:01:21 +10:00
b252ded366 tidy(ui): removed unused state related to non-buffered drawing 2024-08-27 11:01:21 +10:00
84305d4e73 feat(ui): control adapter image rendering 2024-08-27 11:01:21 +10:00
1150b41e14 fix(ui): do not floor bbox calc, it cuts off the last pixels 2024-08-27 11:01:21 +10:00
88d8ccb34b feat(ui): fix issue where creating line needs 2 points 2024-08-27 11:01:21 +10:00
4111b3f1aa fix(ui): edge cases when holding shift and drawing lines 2024-08-27 11:01:21 +10:00
36862be2aa fix(ui): set buffered rect color to full alpha 2024-08-27 11:01:21 +10:00
425665e0d9 fix(ui): handle mouseup correctly 2024-08-27 11:01:21 +10:00
9abd604f69 feat(ui): buffered rect drawing 2024-08-27 11:01:21 +10:00
59bdc288b5 fix(ui): buffered drawing edge cases 2024-08-27 11:01:21 +10:00
eb37d2958e perf(ui): do not use stage.find 2024-08-27 11:01:21 +10:00
2a9738a341 perf(ui): object groups do not listen 2024-08-27 11:01:21 +10:00
6aac1cf33a perf(ui): buffered drawing (wip) 2024-08-27 11:01:21 +10:00
9ca4d072ab tidy(ui): organise files 2024-08-27 11:01:21 +10:00
7aaf14c26b tidy(ui): organise files 2024-08-27 11:01:21 +10:00
cf598ca175 tidy(ui): organise files 2024-08-27 11:01:21 +10:00
a722790afc fix(ui): background rendering 2024-08-27 11:01:21 +10:00
320151a040 pkg(ui): remove unused deps react-konva & use-image 2024-08-27 11:01:21 +10:00
c090f511c3 feat(ui): organize konva state and files 2024-08-27 11:01:21 +10:00
86dd1475b3 fix(ui): merge conflicts in image deletion listener 2024-08-27 11:01:21 +10:00
0b71ac258c fix(ui): region rendering 2024-08-27 11:01:21 +10:00
54e1eae509 fix(ui): inpaint mask rendering 2024-08-27 11:01:21 +10:00
bf57b2dc77 fix(ui): staging area rendering 2024-08-27 11:01:21 +10:00
de3c27b44f fix(ui): stale selected entity 2024-08-27 11:01:21 +10:00
05717fea93 fix(ui): staging area image offset 2024-08-27 11:01:21 +10:00
191584d229 feat(ui): tweak layer ui component 2024-08-27 11:01:21 +10:00
6069169e6b fix(ui): resetting layer resets position 2024-08-27 11:01:21 +10:00
07438587f3 feat(ui): updated layer list component styling 2024-08-27 11:01:21 +10:00
913e36d6fd feat(ui): transformable layers 2024-08-27 11:01:21 +10:00
139004b976 feat(ui): move tool icon is pointer like in other apps 2024-08-27 11:01:21 +10:00
ef4269d585 feat(ui): do not floor cursor position 2024-08-27 11:01:21 +10:00
954cb129a4 feat(ui): disable gallery hotkeys while staging 2024-08-27 11:01:21 +10:00
02c4b28de5 feat(ui): revised canvas progress & staging image handling 2024-08-27 11:01:21 +10:00
febea88b58 feat(ui): show queue item origin in queue list 2024-08-27 11:01:21 +10:00
40ccfac514 chore(ui): typegen 2024-08-27 11:01:21 +10:00
831fb814cc feat(app): add origin to session queue
The origin is an optional field indicating the queue item's origin. For example, "canvas" when the queue item originated from the canvas or "workflows" when the queue item originated from the workflows tab. If omitted, we assume the queue item originated from the API directly.

- Add migration to add the nullable column to the `session_queue` table.
- Update relevant event payloads with the new field.
- Add `cancel_by_origin` method to `session_queue` service and corresponding route. This is required for the canvas to bail out early when staging images.
- Add `origin` to both `SessionQueueItem` and `Batch` - it needs to be provided initially via the batch and then passed onto the queue item.
-
2024-08-27 11:01:21 +10:00
bf166fdd61 fix(ui): denoise start on outpainting 2024-08-27 11:01:21 +10:00
384bde3539 feat(ui): add redux events for queue cleared & batch enqueued socket events 2024-08-27 11:01:21 +10:00
6f1d238d0a feat(ui): canvas staging area works 2024-08-27 11:01:21 +10:00
ac524153a7 feat(ui): switch to view tool when staging 2024-08-27 11:01:21 +10:00
2cad2b15cf tidy(ui): disable preview images on every enqueue 2024-08-27 11:01:21 +10:00
fd63e202fe feat(ui): rough out save staging image 2024-08-27 11:01:21 +10:00
a0250e47e3 feat(ui): staging area image visibility toggle 2024-08-27 11:01:21 +10:00
7d8ece45bb fix(ui): batch building after removing canvas files 2024-08-27 11:01:21 +10:00
ffb8f053da feat(ui): make Graph class's getMetadataNode public 2024-08-27 11:01:21 +10:00
fb46f457f9 tidy(ui): remove old canvas graphs 2024-08-27 11:01:21 +10:00
6d4f4152a7 fix(ui): do not select already-selected entity 2024-08-27 11:01:20 +10:00
d3d0ac7327 tidy(ui): naming things 2024-08-27 11:01:20 +10:00
f57df46995 tidy(ui): file organisation 2024-08-27 11:01:20 +10:00
a747171745 fix(ui): reset cursor pos when fitting document 2024-08-27 11:01:20 +10:00
e9ae9e80d4 feat(ui): staging area works more better 2024-08-27 11:01:20 +10:00
3b9a59b98d feat(ui): staging area barely works 2024-08-27 11:01:20 +10:00
8a381e7f74 feat(ui): consolidate konva API 2024-08-27 11:01:20 +10:00
61513fc800 feat(ui): consolidate konva API 2024-08-27 11:01:20 +10:00
1d26c49e92 feat(ui): staging area (rendering wip) 2024-08-27 11:01:20 +10:00
77be9836d2 tidy(ui): type "Dimensions" -> "Size" 2024-08-27 11:01:20 +10:00
4427960acb feat(ui): add updateNode to Graph 2024-08-27 11:01:20 +10:00
84aa4fb7bc feat(ui): sdxl graphs 2024-08-27 11:01:20 +10:00
01df96cbe0 feat(ui): sd1 outpaint graph 2024-08-27 11:01:20 +10:00
1ac0634f57 tests(ui): add missing tests for Graph class 2024-08-27 11:01:20 +10:00
8e7d3634b1 feat(ui): add Graph.getid() util 2024-08-27 11:01:20 +10:00
fadafe5c77 feat(ui): outpaint graph, organize builder a bit 2024-08-27 11:01:20 +10:00
b2ea1f6690 feat(ui): inpaint sd1 graph 2024-08-27 11:01:20 +10:00
0a03c1f882 feat(ui): temp disable image caching while testing 2024-08-27 11:01:20 +10:00
ce8b490ed8 feat(ui): txt2img & img2img graphs 2024-08-27 11:01:20 +10:00
86eccba80d feat(ui): minor change to canvas bbox state type 2024-08-27 11:01:20 +10:00
97453e7c6c feat(ui): simplified konva node to blob/imagedata utils 2024-08-27 11:01:20 +10:00
9e1084b701 feat(ui): node manager getter/setter 2024-08-27 11:01:20 +10:00
41aec81f3f feat(ui): generation mode calculation, fudged graphs 2024-08-27 11:01:20 +10:00
a5741a0551 feat(ui): add utils for getting images from canvas 2024-08-27 11:01:20 +10:00
72f73c231a feat(ui): even more simplified API - lean on the konva node manager to abstract imperative state API & rendering 2024-08-27 11:01:20 +10:00
b8fcaa274e feat(ui): revised docstrings for renderers & simplified api 2024-08-27 11:01:20 +10:00
a33bbf48bb feat(ui): inpaint mask UI components 2024-08-27 11:01:20 +10:00
98d9490fb9 feat(ui): inpaint mask rendering (wip) 2024-08-27 11:01:20 +10:00
51c643c4f8 fix(ui): models loaded handler 2024-08-27 11:01:20 +10:00
2d7370ca6c feat(ui): internal state for inpaint mask 2024-08-27 11:01:20 +10:00
0d68141387 refactor(ui): divvy up canvas state a bit 2024-08-27 11:01:20 +10:00
e88a8c6639 feat(ui): get region and base layer canvas to blob logic working 2024-08-27 11:01:20 +10:00
2d04bb286e refactor(ui): node manager handles more tedious annoying stuff 2024-08-27 11:01:20 +10:00
6d9ba24c32 feat(ui): use node manager for addRegions 2024-08-27 11:01:20 +10:00
7645a1c86e feat(ui): persist bbox 2024-08-27 11:01:20 +10:00
dc284b9a48 fix(ui): fix generation graphs 2024-08-27 11:01:20 +10:00
8a38332d44 feat(ui): add toggle for clipToBbox 2024-08-27 11:01:20 +10:00
2407d7d148 feat(ui): rename konva node manager 2024-08-27 11:01:20 +10:00
23275cf607 refactor(ui): create classes to abstract mgmt of konva nodes 2024-08-27 11:01:20 +10:00
8a0e02d335 tidy(ui): organise renderers 2024-08-27 11:01:20 +10:00
913873623c refactor(ui): create entity to konva node map abstraction (wip)
Instead of chaining konva `find` and `findOne` methods, all konva nodes are added to a mapping object. Finding and manipulating them is much simpler.

Done for regions and layers, wip for control adapters.
2024-08-27 11:01:20 +10:00
5d81e7dd4d perf(ui): fix lag w/ region rendering
Needed to memoize these selectors
2024-08-27 11:01:20 +10:00
418650fdf3 feat(ui): move canvas fill color picker to right 2024-08-27 11:01:20 +10:00
7c24e56d9f refactor(ui): remove unused ellipse & polygon objects 2024-08-27 11:01:20 +10:00
f6faed46c3 fix(ui): incorrect rect/brush/eraser positions 2024-08-27 11:01:20 +10:00
1d393eecf1 refactor(ui): enable global debugging flag 2024-08-27 11:01:20 +10:00
03a72240c0 refactor(ui): disable the preview renderer for now 2024-08-27 11:01:20 +10:00
acf62450fb tweak(ui): canvas editor layout 2024-08-27 11:01:20 +10:00
d0269310cf perf(ui): memoize layeractionsmenu valid actions 2024-08-27 11:01:20 +10:00
43618a74e7 refactor(ui): decouple konva renderer from react
Subscribe to redux store directly, skipping all the react overhead.

With react in dev mode, a typical frame while using the brush tool on almost-empty canvas is reduced from ~7.5ms to ~3.5ms. All things considered, this still feels slow, but it's a massive improvement.
2024-08-27 11:01:20 +10:00
56fed637ec feat(ui): clip lines to bbox 2024-08-27 11:01:20 +10:00
944ae4a604 fix(ui): document fit positioning 2024-08-27 11:01:20 +10:00
f3da609102 feat(ui): document bounds overlay 2024-08-27 11:01:20 +10:00
11c8a8cf72 tidy(ui): background layer 2024-08-27 11:01:20 +10:00
656978676f refactor(ui): use "entity" instead of "data" for canvas 2024-08-27 11:01:20 +10:00
d216e6519f feat(ui): brush size border radius = 1 2024-08-27 11:01:20 +10:00
0a2bcae0e3 fix(ui): canvas HUD doesn't interrupt tool 2024-08-27 11:01:20 +10:00
31cf244420 refactor(ui): split up canvas entity renderers, temp disable preview 2024-08-27 11:01:20 +10:00
d761effac1 fix(ui): delete all layers button 2024-08-27 11:01:20 +10:00
5efaf2c661 fix(ui): ignore keyboard shortcuts in input/textarea elements 2024-08-27 11:01:20 +10:00
b79161fbec fix(ui): canvas entity ids getting clobbered 2024-08-27 11:01:20 +10:00
89a2f2134b fix(ui): move lora followup fixes 2024-08-27 11:01:20 +10:00
343c3b19b1 chore(ui): lint 2024-08-27 11:01:20 +10:00
e6ec646b2c refactor(ui): move loras to canvas slice 2024-08-27 11:01:20 +10:00
b557fe4e40 fix(ui): layer is selected when added 2024-08-27 11:01:20 +10:00
45908dfbd2 feat(ui): r to center & fit stage on document 2024-08-27 11:01:20 +10:00
1cf0673a22 feat(ui): better HUD 2024-08-27 11:01:20 +10:00
189847f8a5 fix(ui): always use current brush width when making straight lines 2024-08-27 11:01:20 +10:00
b27929f702 feat(ui): hold shift w/ brush to draw straight line 2024-08-27 11:01:20 +10:00
b1a6a9835d fix(ui): update bg on canvas resize 2024-08-27 11:01:20 +10:00
5564a16d4b refactor(ui): better hud 2024-08-27 11:01:20 +10:00
b2aa447d50 refactor(ui): scaled tool preview border 2024-08-27 11:01:20 +10:00
8666f9a848 refactor(ui): port remaining canvasV1 rendering logic to V2, remove old code 2024-08-27 11:01:20 +10:00
e7dc7c4917 refactor(ui): fix more types 2024-08-27 11:01:20 +10:00
513d0f1e5c refactor(ui): metadata recall (wip)
just enough let the app run
2024-08-27 11:01:20 +10:00
c6ce7618cf refactor(ui): undo/redo button temp fix 2024-08-27 11:01:20 +10:00
9b78b8dc91 refactor(ui): fix renderer stuff 2024-08-27 11:01:20 +10:00
3bae233f40 refactor(ui): fix misc types 2024-08-27 11:01:20 +10:00
0a305c4291 refactor(ui): fix gallery stuff 2024-08-27 11:01:20 +10:00
54fe4ddf3e refactor(ui): fix delete image stuff 2024-08-27 11:01:20 +10:00
aa877b981c refactor(ui): fix useIsReadyToEnqueue for new adapterType field 2024-08-27 11:01:20 +10:00
b8ec4348a5 refactor(ui): update generation tab graphs 2024-08-27 11:01:20 +10:00
0e3e27668c refactor(ui): add adapterType to ControlAdapterData 2024-08-27 11:01:20 +10:00
e8c8025119 refactor(ui): update components & logic to use new unified slice (again) 2024-08-27 11:01:20 +10:00
6a72eda5d2 refactor(ui): update components & logic to use new unified slice 2024-08-27 11:01:20 +10:00
35c8c54466 refactor(ui): merge compositing, params into canvasV2 slice 2024-08-27 10:59:24 +10:00
2eda45ca5b refactor(ui): add scaled bbox state 2024-08-27 10:59:24 +10:00
ecd6e7960c refactor(ui): update dnd/image upload 2024-08-27 10:59:24 +10:00
0d3a61cbdb refactor(ui): update size/prompts state 2024-08-27 10:59:24 +10:00
6737c275d7 refactor(ui): rip out old control adapter implementation 2024-08-27 10:59:24 +10:00
201a3e5838 refactor(ui): canvas v2 (wip)
fix entity count select
2024-08-27 10:59:24 +10:00
85cb239219 refactor(ui): canvas v2 (wip)
delete unused file
2024-08-27 10:59:24 +10:00
002f45e383 refactor(ui): canvas v2 (wip)
merge all canvas state reducers into one big slice (but with the logic split across files so it's not hell)
2024-08-27 10:59:24 +10:00
470e5ba290 refactor(ui): canvas v2 (wip)
Fix a few more components
2024-08-27 10:59:24 +10:00
b6722b3a10 refactor(ui): canvas v2 (wip)
missed a spot
2024-08-27 10:59:24 +10:00
6a624916ca refactor(ui): canvas v2 (wip)
Redo all UI components for different canvas entity types
2024-08-27 10:59:24 +10:00
65653e0932 refactor(ui): canvas v2 (wip) 2024-08-27 10:59:24 +10:00
a7e59d6697 refactor(ui): canvas v2 (wip) 2024-08-27 10:59:24 +10:00
f1356105c1 refactor(ui): canvas v2 (wip) 2024-08-27 10:59:24 +10:00
8acc6379fb refactor(ui): canvas v2 (wip) 2024-08-27 10:59:24 +10:00
d13014c5d9 feat(ui): bbox tool 2024-08-27 10:59:24 +10:00
db11d8ba90 fix(ui): rect tool preview 2024-08-27 10:59:24 +10:00
46a4ee2360 fix(ui): multiple stages 2024-08-27 10:59:24 +10:00
a19c053b88 feat(ui): decouple konva logic from nanostores 2024-08-27 10:59:24 +10:00
5b47a32d31 feat(ui): store all stage attrs in nanostores 2024-08-27 10:59:24 +10:00
32ae9efb6a feat(ui): round stage scale 2024-08-27 10:59:24 +10:00
b81bee551a chore(ui): bump konva 2024-08-27 10:59:24 +10:00
8c9dffd082 feat(ui): generation bbox transformation working
whew
2024-08-27 10:59:24 +10:00
cdfae643e4 feat(ui): wip generation bbox 2024-08-27 10:59:24 +10:00
cb93108206 feat(ui): wip generation bbox 2024-08-27 10:59:24 +10:00
c11343dc1c feat(ui): CL zoom and pan, some rendering optimizations 2024-08-27 10:59:24 +10:00
3733c6f89d Revert "feat(ui): add x,y,scaleX,scaleY,rotation to objects"
This reverts commit 53318b396c967c72326a7e4dea09667b2ab20bdd.
2024-08-27 10:59:24 +10:00
3c23a0eac0 feat(ui): layers manage their own bbox 2024-08-27 10:59:24 +10:00
7fd69ab1f1 docs(ui): konva image object docstrings 2024-08-27 10:59:24 +10:00
3f511774de feat(ui): add x,y,scaleX,scaleY,rotation to objects 2024-08-27 10:59:24 +10:00
3600100879 fix(ui): show color picker when using rect tool 2024-08-27 10:59:24 +10:00
8fae372103 feat(ui): image loading fallback for raster layers 2024-08-27 10:59:24 +10:00
2fce1fe04c feat(ui): bbox calc for raster layers 2024-08-27 10:59:24 +10:00
22f3e975c7 feat(ui): do not fill brush preview when drawing 2024-08-27 10:59:24 +10:00
c6ced5a210 fix(ui): brush spacing handling 2024-08-27 10:59:24 +10:00
5d66e85205 fix(ui): jank when starting a shape when not already focused on stage 2024-08-27 10:59:24 +10:00
649f163bf7 feat(ui): wip raster layers
I meant to split this up into smaller commits and undo some of it, but I committed afterwards and it's tedious to undo.
2024-08-27 10:59:24 +10:00
4c821bd930 feat(ui): support image objects on raster layers
Just the UI and internal state, not rendering yet.
2024-08-27 10:59:24 +10:00
6359de4e25 tidy(ui): clean up event handlers
Separate logic for each tool in preparation for ellipse and polygon tools.
2024-08-27 10:59:24 +10:00
2c610c8cd4 feat(ui): raster layer reset, object group util 2024-08-27 10:59:24 +10:00
7efe8a249b feat(ui): rect shape preview now has fill 2024-08-27 10:59:24 +10:00
bd421d184e feat(ui): cancel shape drawing on esc 2024-08-27 10:59:24 +10:00
78f5844ba0 feat(ui): temp disable history on CL 2024-08-27 10:59:23 +10:00
f6bb4d5051 feat(ui): raster layer logic
- Deduplicate shared logic
- Split up giant renderers file into separate cohesive files
- Tons of cleanup
- Progress on raster layer functionality
2024-08-27 10:59:23 +10:00
56642c3e87 feat(ui): add raster layer rendering and interaction (WIP) 2024-08-27 10:59:23 +10:00
9fd8678d3d feat(ui): scaffold out raster layers
Raster layers may have images, lines and shapes. These will replace initial image layers and provide sketching functionality like we have on canvas.
2024-08-27 10:59:23 +10:00
8b89518fd6 refactor(ui): revise types for line and rect objects
- Create separate object types for brush and eraser lines, instead of a single type that has a `tool` field.
- Create new object type for rect shapes.
- Add logic to schemas to migrate old object types to new.
- Update renderers & reducers.
2024-08-27 10:59:23 +10:00
50085b40bb Update starter model size estimates. 2024-08-26 20:17:50 -04:00
cff382715a default workflow: add steps to exposed fields, add more notes 2024-08-26 20:17:50 -04:00
54d54d1bf2 Run ruff 2024-08-26 20:17:50 -04:00
e84ea68282 remove prompt 2024-08-26 20:17:50 -04:00
160dd36782 update default workflow for flux 2024-08-26 20:17:50 -04:00
65bb46bcca Rename params for flux and flux vae, add comments explaining use of the config_path in model config 2024-08-26 20:17:50 -04:00
2d185fb766 Run ruff 2024-08-26 20:17:50 -04:00
2ba9b02932 Fix type error in tsc 2024-08-26 20:17:50 -04:00
849da67cc7 Remove no longer used code in the flux denoise function 2024-08-26 20:17:50 -04:00
3ea6c9666e Remove in progress images until we're able to make the valuable 2024-08-26 20:17:50 -04:00
cf633e4ef2 Only install starter models if not already installed 2024-08-26 20:17:50 -04:00
bbf934d980 Remove outdated TODO. 2024-08-26 20:17:50 -04:00
620f733110 ruff format 2024-08-26 20:17:50 -04:00
67928609a3 Downgrade accelerate and huggingface-hub deps to original versions. 2024-08-26 20:17:50 -04:00
5f15afb7db Remove flux repo dependency 2024-08-26 20:17:50 -04:00
635d2f480d ruff 2024-08-26 20:17:50 -04:00
70c278c810 Remove dependency on flux config files 2024-08-26 20:17:50 -04:00
56b9906e2e Setup scaffolding for in progress images and add ability to cancel the flux node 2024-08-26 20:17:50 -04:00
a808ce81fd Replace swish() with torch.nn.functional.silu(h). They are functionally equivalent, but in my test VAE deconding was ~8% faster after the change. 2024-08-26 20:17:50 -04:00
83f82c5ddf Switch the CLIP-L start model to use our hosted version - which is much smaller. 2024-08-26 20:17:50 -04:00
101de8c25d Update t5 encoder formats to accurately reflect the quantization strategy and data type 2024-08-26 20:17:50 -04:00
3339a4baf0 Downgrade revert torch version after removing optimum-qanto, and other minor version-related fixes. 2024-08-26 20:17:50 -04:00
dff4a88baa Move quantization scripts to a scripts/ subdir. 2024-08-26 20:17:50 -04:00
a21f6c4964 Update docs for T5 quantization script. 2024-08-26 20:17:50 -04:00
97562504b7 Remove all references to optimum-quanto and downgrade diffusers. 2024-08-26 20:17:50 -04:00
75d8ac378c Update the T5 8-bit quantized starter model to use the BnB LLM.int8() variant. 2024-08-26 20:17:50 -04:00
b9dd354e2b Fixes to the T5XXL quantization script. 2024-08-26 20:17:50 -04:00
33c2fbd201 Add script for quantizing a T5 model. 2024-08-26 20:17:50 -04:00
5063be92bf Switch flux to using its own conditioning field 2024-08-26 20:17:50 -04:00
1047584b3e Only import bnb quantize file if bitsandbytes is installed 2024-08-26 20:17:50 -04:00
6764dcfdaa Load and unload clip/t5 encoders and run inference separately in text encoding 2024-08-26 20:17:50 -04:00
012864ceb1 Update macos test vm to macOS-14 2024-08-26 20:17:50 -04:00
a0bf20bcee Run FLUX VAE decoding in the user's preferred dtype rather than float32. Tested, and seems to work well at float16. 2024-08-26 20:17:50 -04:00
14ab339b33 Move prepare_latent_image_patches(...) to sampling.py with all of the related FLUX inference code. 2024-08-26 20:17:50 -04:00
25c91efbb6 Rename field positive_prompt -> prompt. 2024-08-26 20:17:50 -04:00
1c1f2c6664 Add comment about incorrect T5 Tokenizer size calculation. 2024-08-26 20:17:50 -04:00
d7c22b3bf7 Tidy is_schnell detection logic. 2024-08-26 20:17:50 -04:00
185f2a395f Make FLUX get_noise(...) consistent across devices/dtypes. 2024-08-26 20:17:50 -04:00
0c5649491e Mark FLUX nodes as prototypes. 2024-08-26 20:17:50 -04:00
94aba5892a Attribute black-forest-labs/flux for much of the flux code 2024-08-26 20:17:50 -04:00
ef093dde29 Don't install bitsandbytes on macOS 2024-08-26 20:17:50 -04:00
34451e5f27 added FLUX dev to starter models 2024-08-26 20:17:50 -04:00
1f9bdd1a9a Undo changes to the v2 dir of frontend types 2024-08-26 20:17:50 -04:00
c27d59baf7 Run ruff 2024-08-26 20:17:50 -04:00
f130ddec7c Remove automatic install of models during flux model loader, remove no longer used import function on context 2024-08-26 20:17:50 -04:00
a0a259eef1 Fix max_seq_len field description. 2024-08-26 20:17:50 -04:00
b66f19d4d1 Add docs to the quantization scripts. 2024-08-26 20:17:50 -04:00
4105a78b83 Update load_flux_model_bnb_llm_int8.py to work with a single-file FLUX transformer checkpoint. 2024-08-26 20:17:50 -04:00
19a68afb3a Fix bug in InvokeInt8Params that was causing it to use double the necessary VRAM. 2024-08-26 20:17:50 -04:00
fd68a2475b add better workflow name 2024-08-26 20:17:50 -04:00
28ff7ba830 add better workflow description 2024-08-26 20:17:50 -04:00
5d0b248fdb fix(worker) fix T5 type 2024-08-26 20:17:50 -04:00
01a4e0f6ef update default workflow 2024-08-26 20:17:50 -04:00
91e0731506 fix schema 2024-08-26 20:17:50 -04:00
d1f904d41f tsc and lint fix 2024-08-26 20:17:50 -04:00
269388c9f4 feat(ui): create new field for t5 encoder models in nodes 2024-08-26 20:17:50 -04:00
b8486379ce fix(ui): pass base/type when installing models, add flux formats to MM badges 2024-08-26 20:17:50 -04:00
400eb94d3b fix(ui): only exclude flux main models from linear UI dropdown, not model manager list 2024-08-26 20:17:50 -04:00
e210c96485 add FLUX schnell starter models and submodels as dependenices or adhoc download options 2024-08-26 20:17:50 -04:00
5f567f41f4 add case for clip embed models in probe 2024-08-26 20:17:50 -04:00
5fed573a29 update flux_model_loader node to take a T5 encoder from node field instead of hardcoded list, assume all models have been downloaded 2024-08-26 20:17:50 -04:00
cfac7c8189 Move requantize.py to the quatnization/ dir. 2024-08-26 20:17:50 -04:00
1787de6836 Add docs to the requantize(...) function explaining why it was copied from optimum-quanto. 2024-08-26 20:17:50 -04:00
ac96f187bd Remove duplicate log_time(...) function. 2024-08-26 20:17:50 -04:00
72398350b4 More flux loader cleanup 2024-08-26 20:17:50 -04:00
df9445c351 Various styling and exception type updates 2024-08-26 20:17:50 -04:00
87b7a2e39b Switch inheritance class of flux model loaders 2024-08-26 20:17:50 -04:00
f7e46622a1 Update doc string for import_local_model and remove access_token since it's only usable for local file paths 2024-08-26 20:17:50 -04:00
71f18353a9 Address minor review comments. 2024-08-26 20:17:50 -04:00
4228de707b Rename t5Encoder -> t5_encoder. 2024-08-26 20:17:50 -04:00
b6a05629ef add default workflow for flux t2i 2024-08-26 20:17:50 -04:00
fbaa820643 exclude flux models from main model dropdown 2024-08-26 20:17:50 -04:00
db2a2d5e38 Some cleanup of the tags and description of flux nodes 2024-08-26 20:17:50 -04:00
8ba6e6b1f8 Add t5 encoders and clip embeds to the model manager 2024-08-26 20:17:50 -04:00
57168d719b Fix styling/lint 2024-08-26 20:17:50 -04:00
dee6d2c98e Fix support for 8b quantized t5 encoders, update exception messages in flux loaders 2024-08-26 20:17:50 -04:00
e49105ece5 Add tqdm progress bar to FLUX denoising. 2024-08-26 20:17:50 -04:00
0c5e11f521 Fix FLUX output image clamping. And a few other minor fixes to make inference work with the full bfloat16 FLUX transformer model. 2024-08-26 20:17:50 -04:00
a63f842a13 Select dev/schnell based on state dict, use correct max seq len based on dev/schnell, and shift in inference, separate vae flux params into separate config 2024-08-26 20:17:50 -04:00
4bd7fda694 Install sub directories with folders correctly, ensure consistent dtype of tensors in flux pipeline and vae 2024-08-26 20:17:50 -04:00
81f0886d6f Working inference node with quantized bnb nf4 checkpoint 2024-08-26 20:17:50 -04:00
2eb87f3306 Remove unused param on _run_vae_decoding in flux text to image 2024-08-26 20:17:50 -04:00
723f3ab0a9 Add nf4 bnb quantized format 2024-08-26 20:17:50 -04:00
1bd90e0fd4 Run ruff, setup initial text to image node 2024-08-26 20:17:50 -04:00
436f18ff55 Add backend functions and classes for Flux implementation, Update the way flux encoders/tokenizers are loaded for prompt encoding, Update way flux vae is loaded 2024-08-26 20:17:50 -04:00
cde9696214 Some UI cleanup, regenerate schema 2024-08-26 20:17:50 -04:00
2d9042fb93 Run Ruff 2024-08-26 20:17:50 -04:00
9ed53af520 Run Ruff 2024-08-26 20:17:50 -04:00
56fda669fd Manage quantization of models within the loader 2024-08-26 20:17:50 -04:00
1d8545a76c Remove changes to v1 workflow 2024-08-26 20:17:50 -04:00
5f59a828f9 Setup flux model loading in the UI 2024-08-26 20:17:50 -04:00
1fa6bddc89 WIP on moving from diffusers to FLUX 2024-08-26 20:17:50 -04:00
d3a5ca5247 More improvements for LLM.int8() - not fully tested. 2024-08-26 20:17:50 -04:00
f01f56a98e LLM.int8() quantization is working, but still some rough edges to solve. 2024-08-26 20:17:50 -04:00
99b0f79784 Clean up NF4 implementation. 2024-08-26 20:17:50 -04:00
e1eb104345 NF4 inference working 2024-08-26 20:17:50 -04:00
5c2f95ef50 NF4 loading working... I think. 2024-08-26 20:17:50 -04:00
b63df9bab9 wip 2024-08-26 20:17:50 -04:00
a52c899c6d Split a FluxTextEncoderInvocation out from the FluxTextToImageInvocation. This has the advantage that we benfit from automatic caching when the prompt isn't changed. 2024-08-26 20:17:50 -04:00
eeabb7ebe5 Make quantized loading fast for both T5XXL and FLUX transformer. 2024-08-26 20:17:50 -04:00
8b1cef978c Make quantized loading fast. 2024-08-26 20:17:50 -04:00
152da482cd WIP - experimentation 2024-08-26 20:17:50 -04:00
3cf0365a35 Make float16 inference work with FLUX on 24GB GPU. 2024-08-26 20:17:50 -04:00
5870742bb9 Add support for 8-bit quantizatino of the FLUX T5XXL text encoder. 2024-08-26 20:17:50 -04:00
01d8c62c57 Make 8-bit quantization save/reload work for the FLUX transformer. Reload is still very slow with the current optimum.quanto implementation. 2024-08-26 20:17:50 -04:00
55a242b2d6 Minor improvements to FLUX workflow. 2024-08-26 20:17:50 -04:00
45263b339f Got FLUX schnell working with 8-bit quantization. Still lots of rough edges to clean up. 2024-08-26 20:17:50 -04:00
3319491861 Use the FluxPipeline.encode_prompt() api rather than trying to run the two text encoders separately. 2024-08-26 20:17:50 -04:00
e687afac90 Add sentencepiece dependency for the T5 tokenizer. 2024-08-26 20:17:50 -04:00
b39031ea53 First draft of FluxTextToImageInvocation. 2024-08-26 20:17:50 -04:00
0b77511271 Update HF download logic to work for black-forest-labs/FLUX.1-schnell. 2024-08-26 20:17:50 -04:00
c99cd989c1 Update imports for compatibility with bumped diffusers version. 2024-08-26 20:17:50 -04:00
317fdadb21 Bump diffusers version to include FLUX support. 2024-08-26 20:17:50 -04:00
4e294f9e3e disable export button if no non-default presets 2024-08-26 09:23:15 -04:00
526e0f30a0 Added support for bounding boxes in the Invocation API
Adding built-in bounding boxes as a core type would help developers of nodes that include bounding box support.
2024-08-26 08:03:30 +10:00
231e5ec94a chore: bump version v4.2.8post1 2024-08-23 06:55:30 +10:00
e5bb6f9693 lint fix 2024-08-23 06:46:19 +10:00
da7dee44c6 fix(ui): use empty string fallback if unable to parse prompts when creating style preset from existing image 2024-08-23 06:46:19 +10:00
83144f4fe3 fix(docs): follow-up docker readme fixes 2024-08-22 11:19:07 -04:00
c451f52ea3 chore(ui): lint 2024-08-22 21:00:09 +10:00
8a2c78f2e1 fix(ui): dynamic prompts not recalculating when deleting or updating a style preset
The root cause was the active style preset not being reset when it was deleted, or no longer present in the list of style presets.

- Add extra reducer to `stylePresetSlice` to reset the active preset if it is deleted or otherwise unavailable
- Update the dynamic prompts listener to trigger on delete/update/list of style presets
2024-08-22 21:00:09 +10:00
bcc78bde9b chore: bump version to v4.2.8 2024-08-22 21:00:09 +10:00
054bb6fe0a translationBot(ui): update translation (Russian)
Currently translated at 100.0% (1367 of 1367 strings)

Co-authored-by: Васянатор <ilabulanov339@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/
Translation: InvokeAI/Web UI
2024-08-22 13:09:56 +10:00
4f4aa6d92e translationBot(ui): update translation (Italian)
Currently translated at 98.4% (1346 of 1367 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (1346 of 1367 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-08-22 13:09:56 +10:00
eac51ac6f5 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2024-08-22 13:09:56 +10:00
9f349a7c0a fix(ui): do not constrain width of hide/show boards button
lets translations display fully
2024-08-22 11:36:07 +10:00
918afa5b15 fix(ui): show more of current board name 2024-08-22 11:36:07 +10:00
eb1113f95c feat(ui): add translation string for "Upscale" 2024-08-22 11:36:07 +10:00
4f4ba7b462 tidy(ui): clean up ActiveStylePreset markup 2024-08-21 09:06:41 +10:00
2298be0e6b fix(ui): error handling if unable to convert image URL to blob 2024-08-21 09:06:41 +10:00
63494dfca7 remove extra slash in exports path 2024-08-21 09:06:41 +10:00
36a1d39454 fix(ui): handle badge styling when template name is long 2024-08-21 09:06:41 +10:00
a6f6d5c400 fix(ui): add loading state to button when creating or updating a style preset 2024-08-21 09:06:41 +10:00
e85f221aca fix(ui): clear prompt template when prompts are recalled 2024-08-21 09:04:35 +10:00
d4797e37dc fix(ui): properly unwrap delete style preset API request so that error is caught 2024-08-19 16:12:39 -04:00
3e7923d072 fix(api): allow updating of type for style preset 2024-08-19 16:12:39 -04:00
a85d69ce3d tidy(ui): getViewModeChunks.tsx -> .ts 2024-08-19 08:25:39 +10:00
96db006c99 fix(ui): edge case with getViewModeChunks 2024-08-19 08:25:39 +10:00
8ca57d03d8 tests(ui): add tests for getViewModeChunks 2024-08-19 08:25:39 +10:00
6c404ce5f8 fix(ui): prompt template preset preview out of order 2024-08-19 08:25:39 +10:00
897 changed files with 27147 additions and 30123 deletions

View File

@ -60,7 +60,7 @@ jobs:
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- platform: macos-default
os: macOS-12
os: macOS-14
github-env: $GITHUB_ENV
- platform: windows-cpu
os: windows-2022

View File

@ -1,20 +1,22 @@
# Invoke in Docker
- Ensure that Docker can use the GPU on your system
- This documentation assumes Linux, but should work similarly under Windows with WSL2
First things first:
- Ensure that Docker can use your [NVIDIA][nvidia docker docs] or [AMD][amd docker docs] GPU.
- This document assumes a Linux system, but should work similarly under Windows with WSL2.
- We don't recommend running Invoke in Docker on macOS at this time. It works, but very slowly.
## Quickstart :lightning:
## Quickstart
No `docker compose`, no persistence, just a simple one-liner using the official images:
No `docker compose`, no persistence, single command, using the official images:
**CUDA:**
**CUDA (NVIDIA GPU):**
```bash
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
```
**ROCm:**
**ROCm (AMD GPU):**
```bash
docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invoke-ai/invokeai:main-rocm
@ -22,12 +24,20 @@ docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invok
Open `http://localhost:9090` in your browser once the container finishes booting, install some models, and generate away!
> [!TIP]
> To persist your data (including downloaded models) outside of the container, add a `--volume/-v` flag to the above command, e.g.: `docker run --volume /some/local/path:/invokeai <...the rest of the command>`
### Data persistence
To persist your generated images and downloaded models outside of the container, add a `--volume/-v` flag to the above command, e.g.:
```bash
docker run --volume /some/local/path:/invokeai {...etc...}
```
`/some/local/path/invokeai` will contain all your data.
It can *usually* be reused between different installs of Invoke. Tread with caution and read the release notes!
## Customize the container
We ship the `run.sh` script, which is a convenient wrapper around `docker compose` for cases where custom image build args are needed. Alternatively, the familiar `docker compose` commands work just as well.
The included `run.sh` script is a convenience wrapper around `docker compose`. It can be helpful for passing additional build arguments to `docker compose`. Alternatively, the familiar `docker compose` commands work just as well.
```bash
cd docker
@ -38,11 +48,14 @@ cp .env.sample .env
It will take a few minutes to build the image the first time. Once the application starts up, open `http://localhost:9090` in your browser to invoke!
>[!TIP]
>When using the `run.sh` script, the container will continue running after Ctrl+C. To shut it down, use the `docker compose down` command.
## Docker setup in detail
#### Linux
1. Ensure builkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
1. Ensure buildkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
2. Install the `docker compose` plugin using your package manager, or follow a [tutorial](https://docs.docker.com/compose/install/linux/#install-using-the-repository).
- The deprecated `docker-compose` (hyphenated) CLI probably won't work. Update to a recent version.
3. Ensure docker daemon is able to access the GPU.
@ -98,25 +111,7 @@ GPU_DRIVER=cuda
Any environment variables supported by InvokeAI can be set here. See the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
## Even More Customizing!
---
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.
### Reconfigure the runtime directory
Can be used to download additional models from the supported model list
In conjunction with `INVOKEAI_ROOT` can be also used to initialize a runtime directory
```yaml
command:
- invokeai-configure
- --yes
```
Or install models:
```yaml
command:
- invokeai-model-install
```
[nvidia docker docs]: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
[amd docker docs]: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html

View File

@ -11,6 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByOriginResult,
ClearResult,
EnqueueBatchResult,
PruneResult,
@ -105,6 +106,19 @@ async def cancel_by_batch_ids(
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
@session_queue_router.put(
"/{queue_id}/cancel_by_origin",
operation_id="cancel_by_origin",
responses={200: {"model": CancelByBatchIDsResult}},
)
async def cancel_by_origin(
queue_id: str = Path(description="The queue id to perform this operation on"),
origin: str = Query(description="The origin to cancel all queue items for"),
) -> CancelByOriginResult:
"""Immediately cancels all queue items with the given origin"""
return ApiDependencies.invoker.services.session_queue.cancel_by_origin(queue_id=queue_id, origin=origin)
@session_queue_router.put(
"/{queue_id}/clear",
operation_id="clear",

View File

@ -26,13 +26,10 @@ from invokeai.app.services.style_preset_records.style_preset_records_common impo
)
class StylePresetUpdateFormData(BaseModel):
class StylePresetFormData(BaseModel):
name: str = Field(description="Preset name")
positive_prompt: str = Field(description="Positive prompt")
negative_prompt: str = Field(description="Negative prompt")
class StylePresetCreateFormData(StylePresetUpdateFormData):
type: PresetType = Field(description="Preset type")
@ -95,9 +92,10 @@ async def update_style_preset(
try:
parsed_data = json.loads(data)
validated_data = StylePresetUpdateFormData(**parsed_data)
validated_data = StylePresetFormData(**parsed_data)
name = validated_data.name
type = validated_data.type
positive_prompt = validated_data.positive_prompt
negative_prompt = validated_data.negative_prompt
@ -105,7 +103,7 @@ async def update_style_preset(
raise HTTPException(status_code=400, detail="Invalid preset data")
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
changes = StylePresetChanges(name=name, preset_data=preset_data)
changes = StylePresetChanges(name=name, preset_data=preset_data, type=type)
style_preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
style_preset = ApiDependencies.invoker.services.style_preset_records.update(
@ -145,7 +143,7 @@ async def create_style_preset(
try:
parsed_data = json.loads(data)
validated_data = StylePresetCreateFormData(**parsed_data)
validated_data = StylePresetFormData(**parsed_data)
name = validated_data.name
type = validated_data.type

View File

@ -20,7 +20,6 @@ from typing import (
Type,
TypeVar,
Union,
cast,
)
import semver
@ -80,7 +79,7 @@ class UIConfigBase(BaseModel):
version: str = Field(
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
)
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
node_pack: str = Field(description="The node pack that this node belongs to, will be 'invokeai' for built-in nodes")
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
model_config = ConfigDict(
@ -230,18 +229,16 @@ class BaseInvocation(ABC, BaseModel):
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
if uiconfig is not None:
if uiconfig.title is not None:
schema["title"] = uiconfig.title
if uiconfig.tags is not None:
schema["tags"] = uiconfig.tags
if uiconfig.category is not None:
schema["category"] = uiconfig.category
if uiconfig.node_pack is not None:
schema["node_pack"] = uiconfig.node_pack
schema["classification"] = uiconfig.classification
schema["version"] = uiconfig.version
if title := model_class.UIConfig.title:
schema["title"] = title
if tags := model_class.UIConfig.tags:
schema["tags"] = tags
if category := model_class.UIConfig.category:
schema["category"] = category
if node_pack := model_class.UIConfig.node_pack:
schema["node_pack"] = node_pack
schema["classification"] = model_class.UIConfig.classification
schema["version"] = model_class.UIConfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["class"] = "invocation"
@ -312,7 +309,7 @@ class BaseInvocation(ABC, BaseModel):
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
)
UIConfig: ClassVar[Type[UIConfigBase]]
UIConfig: ClassVar[UIConfigBase]
model_config = ConfigDict(
protected_namespaces=(),
@ -441,30 +438,25 @@ def invocation(
validate_fields(cls.model_fields, invocation_type)
# Add OpenAPI schema extras
uiconfig_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name:
cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {})
cls.UIConfig.title = title
cls.UIConfig.tags = tags
cls.UIConfig.category = category
cls.UIConfig.classification = classification
# Grab the node pack's name from the module name, if it's a custom node
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"
if is_custom_node:
cls.UIConfig.node_pack = cls.__module__.split(".")[0]
else:
cls.UIConfig.node_pack = None
uiconfig: dict[str, Any] = {}
uiconfig["title"] = title
uiconfig["tags"] = tags
uiconfig["category"] = category
uiconfig["classification"] = classification
# The node pack is the module name - will be "invokeai" for built-in nodes
uiconfig["node_pack"] = cls.__module__.split(".")[0]
if version is not None:
try:
semver.Version.parse(version)
except ValueError as e:
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
cls.UIConfig.version = version
uiconfig["version"] = version
else:
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
cls.UIConfig.version = "1.0.0"
uiconfig["version"] = "1.0.0"
cls.UIConfig = UIConfigBase(**uiconfig)
if use_cache is not None:
cls.model_fields["use_cache"].default = use_cache

View File

@ -40,6 +40,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
# region Model Field Types
MainModel = "MainModelField"
FluxMainModel = "FluxMainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
@ -48,6 +49,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField"
T5EncoderModel = "T5EncoderModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
# endregion
@ -125,13 +127,16 @@ class FieldDescriptions:
negative_cond = "Negative conditioning tensor"
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
t5_encoder = "T5 tokenizer and text encoder"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
@ -231,6 +236,12 @@ class ColorField(BaseModel):
return (self.r, self.g, self.b, self.a)
class FluxConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

View File

@ -0,0 +1,86 @@
from typing import Literal
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
@invocation(
"flux_text_encoder",
title="FLUX Text Encoding",
tags=["prompt", "conditioning", "flux"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxTextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a flux image."""
clip: CLIPField = InputField(
title="CLIP",
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5_encoder: T5EncoderField = InputField(
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
t5_max_seq_len: Literal[256, 512] = InputField(
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
)
prompt: str = InputField(description="Text prompt to encode.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
t5_embeddings, clip_embeddings = self._encode_prompt(context)
conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
)
conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput.build(conditioning_name)
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
# Load CLIP.
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
# Load T5.
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, T5Tokenizer)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
prompt_embeds = t5_encoder(prompt)
with (
clip_text_encoder_info as clip_text_encoder,
clip_tokenizer_info as clip_tokenizer,
):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
pooled_prompt_embeds = clip_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return prompt_embeds, pooled_prompt_embeds

View File

@ -0,0 +1,172 @@
import torch
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux_text_to_image",
title="FLUX Text to Image",
tags=["image", "flux"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Text-to-image generation using a FLUX model."""
transformer: TransformerField = InputField(
description=FieldDescriptions.flux_model,
input=Input.Connection,
title="Transformer",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
positive_text_conditioning: FluxConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50."
)
guidance: float = InputField(
default=4.0,
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
image = self._run_vae_decoding(context, latents)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)
def _run_diffusion(
self,
context: InvocationContext,
clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor,
):
transformer_info = context.models.load(self.transformer.transformer)
inference_dtype = torch.bfloat16
# Prepare input noise.
x = get_noise(
num_samples=1,
height=self.height,
width=self.width,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
seed=self.seed,
)
img, img_ids = prepare_latent_img_patches(x)
is_schnell = "schnell" in transformer_info.config.config_path
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],
shift=not is_schnell,
)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty.
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
with transformer_info as transformer:
assert isinstance(transformer, Flux)
def step_callback() -> None:
if context.util.is_canceled():
raise CanceledException
# TODO: Make this look like the image before re-enabling
# latent_image = unpack(img.float(), self.height, self.width)
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
# # Create a new tensor of the required shape [255, 255, 3]
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
# # Convert to a NumPy array and then to a PIL Image
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
# (width, height) = image.size
# width *= 8
# height *= 8
# dataURL = image_to_dataURL(image, image_format="JPEG")
# # TODO: move this whole function to invocation context to properly reference these variables
# context._services.events.emit_invocation_denoise_progress(
# context._data.queue_item,
# context._data.invocation,
# state,
# ProgressImage(dataURL=dataURL, width=width, height=height),
# )
x = denoise(
model=transformer,
img=img,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,
vec=clip_embeddings,
timesteps=timesteps,
step_callback=step_callback,
guidance=self.guidance,
)
x = unpack(x.float(), self.height, self.width)
return x
def _run_vae_decoding(
self,
context: InvocationContext,
latents: torch.Tensor,
) -> Image.Image:
vae_info = context.models.load(self.vae.vae)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
img = vae.decode(latents)
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
return img_pil

View File

@ -6,13 +6,19 @@ import cv2
import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import (
ColorField,
FieldDescriptions,
ImageField,
InputField,
OutputField,
WithBoard,
WithMetadata,
)
@ -1007,3 +1013,62 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
return ImageOutput.build(image_dto)
@invocation_output("canvas_v2_mask_and_crop_output")
class CanvasV2MaskAndCropOutput(ImageOutput):
offset_x: int = OutputField(description="The x offset of the image, after cropping")
offset_y: int = OutputField(description="The y offset of the image, after cropping")
@invocation(
"canvas_v2_mask_and_crop",
title="Canvas V2 Mask and Crop",
tags=["image", "mask", "id"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Handles Canvas V2 image output masking and cropping"""
source_image: ImageField | None = InputField(
default=None,
description="The source image onto which the masked generated image is pasted. If omitted, the masked generated image is returned with transparency.",
)
generated_image: ImageField = InputField(description="The image to apply the mask to")
mask: ImageField = InputField(description="The mask to apply")
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")
def _prepare_mask(self, mask: Image.Image) -> Image.Image:
mask_array = numpy.array(mask)
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
dilated_mask = Image.fromarray(dilated_mask_array)
if self.mask_blur > 0:
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
return ImageOps.invert(mask.convert("L"))
def invoke(self, context: InvocationContext) -> CanvasV2MaskAndCropOutput:
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
if self.source_image:
generated_image = context.images.get_pil(self.generated_image.image_name)
source_image = context.images.get_pil(self.source_image.image_name)
source_image.paste(generated_image, (0, 0), mask)
image_dto = context.images.save(image=source_image)
else:
generated_image = context.images.get_pil(self.generated_image.image_name)
generated_image.putalpha(mask)
image_dto = context.images.save(image=generated_image)
# bbox = image.getbbox()
# image = image.crop(bbox)
return CanvasV2MaskAndCropOutput(
image=ImageField(image_name=image_dto.image_name),
offset_x=0,
offset_y=0,
width=image_dto.width,
height=image_dto.height,
)

View File

@ -1,5 +1,5 @@
import copy
from typing import List, Optional
from typing import List, Literal, Optional
from pydantic import BaseModel, Field
@ -13,7 +13,14 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
CheckpointConfigBase,
ModelType,
SubModelType,
)
class ModelIdentifierField(BaseModel):
@ -60,6 +67,15 @@ class CLIPField(BaseModel):
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
class T5EncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@ -122,6 +138,112 @@ class ModelIdentifierInvocation(BaseInvocation):
return ModelIdentifierOutput(model=self.model)
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
title="Max Seq Length",
)
@invocation(
"flux_model_loader",
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.3",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)
t5_encoder: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder,
ui_type=UIType.T5EncoderModel,
input=Input.Direct,
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
model_key = self.model.key
if not context.models.exists(model_key):
raise ValueError(f"Unknown model: {model_key}")
transformer = self._get_model(context, SubModelType.Transformer)
tokenizer = self._get_model(context, SubModelType.Tokenizer)
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
vae = self._get_model(context, SubModelType.VAE)
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
match submodel:
case SubModelType.Transformer:
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
case SubModelType.VAE:
return self._pull_model_from_mm(
context,
SubModelType.VAE,
"FLUX.1-schnell_ae",
ModelType.VAE,
BaseModelType.Flux,
)
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
return self._pull_model_from_mm(
context,
submodel,
"clip-vit-large-patch14",
ModelType.CLIPEmbed,
BaseModelType.Any,
)
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
return self._pull_model_from_mm(
context,
submodel,
self.t5_encoder.name,
ModelType.T5Encoder,
BaseModelType.Any,
)
case _:
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
def _pull_model_from_mm(
self,
context: InvocationContext,
submodel: SubModelType,
name: str,
type: ModelType,
base: BaseModelType,
):
if models := context.models.search_by_attrs(name=name, base=base, type=type):
if len(models) != 1:
raise Exception(f"Multiple models detected for selected model with name {name}")
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
else:
raise ValueError(f"Please install the {base}:{type} model named {name} via starter models")
@invocation(
"main_model_loader",
title="Main Model",

View File

@ -12,6 +12,7 @@ from invokeai.app.invocations.fields import (
ConditioningField,
DenoiseMaskField,
FieldDescriptions,
FluxConditioningField,
ImageField,
Input,
InputField,
@ -414,6 +415,17 @@ class MaskOutput(BaseInvocationOutput):
height: int = OutputField(description="The height of the mask in pixels.")
@invocation_output("flux_conditioning_output")
class FluxConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""
conditioning: FluxConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "FluxConditioningOutput":
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""

View File

@ -88,6 +88,7 @@ class QueueItemEventBase(QueueEventBase):
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
origin: str | None = Field(default=None, description="The origin of the batch")
class InvocationEventBase(QueueItemEventBase):
@ -95,8 +96,6 @@ class InvocationEventBase(QueueItemEventBase):
session_id: str = Field(description="The ID of the session (aka graph execution state)")
queue_id: str = Field(description="The ID of the queue")
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation: AnyInvocation = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
@ -114,6 +113,7 @@ class InvocationStartedEvent(InvocationEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@ -147,6 +147,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@ -184,6 +185,7 @@ class InvocationCompleteEvent(InvocationEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@ -216,6 +218,7 @@ class InvocationErrorEvent(InvocationEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
@ -253,6 +256,7 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
status=queue_item.status,
error_type=queue_item.error_type,
@ -279,12 +283,14 @@ class BatchEnqueuedEvent(QueueEventBase):
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
)
priority: int = Field(description="The priority of the batch")
origin: str | None = Field(default=None, description="The origin of the batch")
@classmethod
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
return cls(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
origin=enqueue_result.batch.origin,
enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested,
priority=enqueue_result.priority,

View File

@ -783,8 +783,9 @@ class ModelInstallService(ModelInstallServiceBase):
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
if subfolder:
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
path_to_add = Path(f"{top}_{subfolder}")
path_to_remove = top / subfolder # sdxl-turbo/vae/
subfolder_rename = subfolder.name.replace("/", "_").replace("\\", "_")
path_to_add = Path(f"{top}_{subfolder_rename}")
else:
path_to_remove = Path(".")
path_to_add = Path(".")

View File

@ -77,6 +77,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
type: Optional[ModelType] = Field(description="Type of model", default=None)
key: Optional[str] = Field(description="Database ID for this model", default=None)
hash: Optional[str] = Field(description="hash of model file", default=None)
format: Optional[str] = Field(description="format of model file", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None

View File

@ -6,6 +6,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByOriginResult,
CancelByQueueIDResult,
ClearResult,
EnqueueBatchResult,
@ -95,6 +96,11 @@ class SessionQueueBase(ABC):
"""Cancels all queue items with matching batch IDs"""
pass
@abstractmethod
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
"""Cancels all queue items with the given batch origin"""
pass
@abstractmethod
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
"""Cancels all queue items with matching queue ID"""

View File

@ -77,6 +77,7 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]]
class Batch(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
origin: str | None = Field(default=None, description="The origin of this batch.")
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
graph: Graph = Field(description="The graph to initialize the session with")
workflow: Optional[WorkflowWithoutID] = Field(
@ -195,6 +196,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
priority: int = Field(default=0, description="The priority of this queue item")
batch_id: str = Field(description="The ID of the batch associated with this queue item")
origin: str | None = Field(default=None, description="The origin of this queue item. ")
session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
)
@ -294,6 +296,7 @@ class SessionQueueStatus(BaseModel):
class BatchStatus(BaseModel):
queue_id: str = Field(..., description="The ID of the queue")
batch_id: str = Field(..., description="The ID of the batch")
origin: str | None = Field(..., description="The origin of the batch")
pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'")
@ -328,6 +331,12 @@ class CancelByBatchIDsResult(BaseModel):
canceled: int = Field(..., description="Number of queue items canceled")
class CancelByOriginResult(BaseModel):
"""Result of canceling by list of batch ids"""
canceled: int = Field(..., description="Number of queue items canceled")
class CancelByQueueIDResult(CancelByBatchIDsResult):
"""Result of canceling by queue id"""
@ -433,6 +442,7 @@ class SessionQueueValueToInsert(NamedTuple):
field_values: Optional[str] # field_values json
priority: int # priority
workflow: Optional[str] # workflow json
origin: str | None
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
@ -453,6 +463,7 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
priority, # priority
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
batch.origin, # origin
)
)
return values_to_insert

View File

@ -10,6 +10,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByOriginResult,
CancelByQueueIDResult,
ClearResult,
EnqueueBatchResult,
@ -127,8 +128,8 @@ class SqliteSessionQueue(SessionQueueBase):
self.__cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow)
VALUES (?, ?, ?, ?, ?, ?, ?)
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
@ -417,11 +418,7 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
current_queue_item, batch_status, queue_status
)
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self.__conn.rollback()
raise
@ -429,6 +426,46 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release()
return CancelByBatchIDsResult(canceled=count)
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
try:
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
where = """--sql
WHERE
queue_id == ?
AND origin == ?
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
"""
params = (queue_id, origin)
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
params,
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
{where};
""",
params,
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.origin == origin:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelByOriginResult(canceled=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try:
current_queue_item = self.get_current(queue_id)
@ -541,7 +578,8 @@ class SqliteSessionQueue(SessionQueueBase):
started_at,
session_id,
batch_id,
queue_id
queue_id,
origin
FROM session_queue
WHERE queue_id = ?
"""
@ -621,7 +659,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*)
SELECT status, count(*), origin
FROM session_queue
WHERE
queue_id = ?
@ -633,6 +671,7 @@ class SqliteSessionQueue(SessionQueueBase):
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
except Exception:
self.__conn.rollback()
raise
@ -641,6 +680,7 @@ class SqliteSessionQueue(SessionQueueBase):
return BatchStatus(
batch_id=batch_id,
origin=origin,
queue_id=queue_id,
pending=counts.get("pending", 0),
in_progress=counts.get("in_progress", 0),

View File

@ -17,6 +17,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -51,6 +52,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_12(app_config=config))
migrator.register_migration(build_migration_13())
migrator.register_migration(build_migration_14())
migrator.register_migration(build_migration_15())
migrator.run_migrations()
return db

View File

@ -0,0 +1,31 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration15Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._add_origin_col(cursor)
def _add_origin_col(self, cursor: sqlite3.Cursor) -> None:
"""
- Adds `origin` column to the session queue table.
"""
cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;")
def build_migration_15() -> Migration:
"""
Build the migration from database version 14 to 15.
This migration does the following:
- Adds `origin` column to the session queue table.
"""
migration_15 = Migration(
from_version=14,
to_version=15,
callback=Migration15Callback(),
)
return migration_15

View File

@ -32,6 +32,7 @@ class PresetType(str, Enum, metaclass=MetaEnum):
class StylePresetChanges(BaseModel, extra="forbid"):
name: Optional[str] = Field(default=None, description="The style preset's new name.")
preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.")
type: Optional[PresetType] = Field(description="The updated type of the style preset")
class StylePresetWithoutId(BaseModel):

View File

@ -0,0 +1,266 @@
{
"name": "FLUX Text to Image",
"author": "InvokeAI",
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
"version": "1.0.0",
"contact": "",
"tags": "text2image, flux",
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
"exposedFields": [
{
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"fieldName": "model"
},
{
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"fieldName": "prompt"
},
{
"nodeId": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"fieldName": "num_steps"
},
{
"nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"fieldName": "t5_encoder"
}
],
"meta": {
"version": "3.0.0",
"category": "default"
},
"nodes": [
{
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"type": "invocation",
"data": {
"id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"type": "flux_model_loader",
"version": "1.0.3",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"inputs": {
"model": {
"name": "model",
"label": "Model (Starter Models can be found in Model Manager)",
"value": {
"key": "f04a7a2f-c74d-4538-8d5e-879a53501662",
"hash": "random:4875da7a9508444ffa706f61961c260d0c6729f6181a86b31fad06df1277b850",
"name": "FLUX Dev (Quantized)",
"base": "flux",
"type": "main"
}
},
"t5_encoder": {
"name": "t5_encoder",
"label": "T 5 Encoder (Starter Models can be found in Model Manager)",
"value": {
"key": "20dcd9ec-5fbb-4012-8401-049e707da5e5",
"hash": "random:f986be43ff3502169e4adbdcee158afb0e0a65a1edc4cab16ae59963630cfd8f",
"name": "t5_bnb_int8_quantized_encoder",
"base": "any",
"type": "t5_encoder"
}
}
}
},
"position": {
"x": 337.09365228062825,
"y": 40.63469521079861
}
},
{
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"type": "invocation",
"data": {
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"type": "flux_text_encoder",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"inputs": {
"clip": {
"name": "clip",
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": ""
},
"t5_max_seq_len": {
"name": "t5_max_seq_len",
"label": "T5 Max Seq Len",
"value": 256
},
"prompt": {
"name": "prompt",
"label": "",
"value": "a cat"
}
}
},
"position": {
"x": 824.1970602278849,
"y": 146.98251001061735
}
},
{
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
"type": "invocation",
"data": {
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
"type": "rand_int",
"version": "1.0.1",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"inputs": {
"low": {
"name": "low",
"label": "",
"value": 0
},
"high": {
"name": "high",
"label": "",
"value": 2147483647
}
}
},
"position": {
"x": 822.9899179655476,
"y": 360.9657214885052
}
},
{
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"type": "invocation",
"data": {
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"type": "flux_text_to_image",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": false,
"useCache": true,
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"transformer": {
"name": "transformer",
"label": ""
},
"vae": {
"name": "vae",
"label": ""
},
"positive_text_conditioning": {
"name": "positive_text_conditioning",
"label": ""
},
"width": {
"name": "width",
"label": "",
"value": 1024
},
"height": {
"name": "height",
"label": "",
"value": 1024
},
"num_steps": {
"name": "num_steps",
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
"value": 30
},
"guidance": {
"name": "guidance",
"label": "",
"value": 4
},
"seed": {
"name": "seed",
"label": "",
"value": 0
}
}
},
"position": {
"x": 1216.3900791301849,
"y": 5.500841807102248
}
}
],
"edges": [
{
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33amax_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
"type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "max_seq_len",
"targetHandle": "t5_max_seq_len"
},
{
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33avae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
"type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33atransformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
"type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33at5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
"type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33aclip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
"type": "default",
"source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
"type": "default",
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "conditioning",
"targetHandle": "positive_text_conditioning"
},
{
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-159bdf1b-79e7-4174-b86e-d40e646964c8seed",
"type": "default",
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "value",
"targetHandle": "seed"
}
]
}

View File

@ -0,0 +1,32 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import torch
from einops import rearrange
from torch import Tensor
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@ -0,0 +1,117 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
@dataclass
class FluxParams:
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params: FluxParams):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img

View File

@ -0,0 +1,310 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
@dataclass
class AutoEncoderParams:
resolution: int
in_channels: int
ch: int
out_ch: int
ch_mult: list[int]
num_res_blocks: int
z_channels: int
scale_factor: float
shift_factor: float
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = torch.nn.functional.silu(h)
h = self.conv1(h)
h = self.norm2(h)
h = torch.nn.functional.silu(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
# no asymmetric padding in torch conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: Tensor):
pad = (0, 1, 0, 1)
x = nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor):
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
block_in = self.ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor) -> Tensor:
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = torch.nn.functional.silu(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z: Tensor) -> Tensor:
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = torch.nn.functional.silu(h)
h = self.conv_out(h)
return h
class DiagonalGaussian(nn.Module):
def __init__(self, sample: bool = True, chunk_dim: int = 1):
super().__init__()
self.sample = sample
self.chunk_dim = chunk_dim
def forward(self, z: Tensor) -> Tensor:
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
else:
return mean
class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams):
super().__init__()
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.decoder = Decoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.reg = DiagonalGaussian()
self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor
def encode(self, x: Tensor) -> Tensor:
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z
def decode(self, z: Tensor) -> Tensor:
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)
def forward(self, x: Tensor) -> Tensor:
return self.decode(self.encode(x))

View File

@ -0,0 +1,33 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from torch import Tensor, nn
from transformers import PreTrainedModel, PreTrainedTokenizer
class HFEncoder(nn.Module):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
super().__init__()
self.max_length = max_length
self.is_clip = is_clip
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
self.tokenizer = tokenizer
self.hf_module = encoder
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str]) -> Tensor:
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
return outputs[self.output_key]

View File

@ -0,0 +1,253 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import math
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
from invokeai.backend.flux.math import attention, rope
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: Tensor) -> Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * self.scale
class QKNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
)
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float | None = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x

View File

@ -0,0 +1,176 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import math
from typing import Callable
import torch
from einops import rearrange, repeat
from torch import Tensor
from tqdm import tqdm
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.conditioner import HFEncoder
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
rand_device = "cpu"
rand_dtype = torch.float16
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[], None],
guidance: float = 4.0,
):
dtype = model.txt_in.bias.dtype
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
img = img.to(dtype=dtype)
img_ids = img_ids.to(dtype=dtype)
txt = txt.to(dtype=dtype)
txt_ids = txt_ids.to(dtype=dtype)
vec = vec.to(dtype=dtype)
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
step_callback()
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)
def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert an input image in latent space to patches for diffusion.
This implementation was extracted from:
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
Returns:
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
"""
bs, c, h, w = latent_img.shape
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
# Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
return img, img_ids

View File

@ -0,0 +1,71 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
from typing import Dict, Literal
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
@dataclass
class ModelSpec:
params: FluxParams
ae_params: AutoEncoderParams
ckpt_path: str | None
ae_path: str | None
repo_id: str | None
repo_flow: str | None
repo_ae: str | None
max_seq_lengths: Dict[str, Literal[256, 512]] = {
"flux-dev": 512,
"flux-schnell": 256,
}
ae_params = {
"flux": AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)
}
params = {
"flux-dev": FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
"flux-schnell": FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
}

View File

@ -52,6 +52,7 @@ class BaseModelType(str, Enum):
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
# Kandinsky2_1 = "kandinsky-2.1"
@ -66,7 +67,9 @@ class ModelType(str, Enum):
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
CLIPEmbed = "clip_embed"
T2IAdapter = "t2i_adapter"
T5Encoder = "t5_encoder"
SpandrelImageToImage = "spandrel_image_to_image"
@ -74,6 +77,7 @@ class SubModelType(str, Enum):
"""Submodel type."""
UNet = "unet"
Transformer = "transformer"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
@ -104,6 +108,9 @@ class ModelFormat(str, Enum):
EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder"
InvokeAI = "invokeai"
T5Encoder = "t5_encoder"
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
BnbQuantizednf4b = "bnb_quantized_nf4b"
class SchedulerPredictionType(str, Enum):
@ -186,7 +193,9 @@ class ModelConfigBase(BaseModel):
class CheckpointConfigBase(ModelConfigBase):
"""Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
)
config_path: str = Field(description="path to the checkpoint model config file")
converted_at: Optional[float] = Field(
description="When this model was last converted to diffusers", default_factory=time.time
@ -205,6 +214,26 @@ class LoRAConfigBase(ModelConfigBase):
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
class T5EncoderConfigBase(ModelConfigBase):
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
class T5EncoderConfig(T5EncoderConfigBase):
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase):
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.BnbQuantizedLlmInt8b.value}")
class LoRALyCORISConfig(LoRAConfigBase):
"""Model config for LoRA/Lycoris models."""
@ -229,7 +258,6 @@ class VAECheckpointConfig(CheckpointConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
@ -268,7 +296,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase)
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
@ -317,6 +344,21 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
"""Model config for main checkpoint models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.format = ModelFormat.BnbQuantizednf4b
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
"""Model config for main diffusers models."""
@ -350,6 +392,17 @@ class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
"""Model config for Clip Embeddings."""
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
"""Model config for CLIPVision."""
@ -408,12 +461,15 @@ AnyModelConfig = Annotated[
Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
@ -421,6 +477,7 @@ AnyModelConfig = Annotated[
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]

View File

@ -0,0 +1,234 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Flux model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
import accelerate
import torch
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.util import ae_params, params
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
CLIPEmbedDiffusersConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
T5EncoderBnbQuantizedLlmInt8bConfig,
T5EncoderConfig,
VAECheckpointConfig,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.util.silence_warnings import SilenceWarnings
try:
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
bnb_available = True
except ImportError:
bnb_available = False
app_config = get_config()
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class FluxVAELoader(ModelLoader):
"""Class to load VAE models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, VAECheckpointConfig):
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
model_path = Path(config.path)
with SilenceWarnings():
model = AutoEncoder(ae_params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
model.to(dtype=self._torch_dtype)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
class ClipCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CLIPEmbedDiffusersConfig):
raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer:
return CLIPTokenizer.from_pretrained(Path(config.path) / "tokenizer")
case SubModelType.TextEncoder:
return CLIPTextModel.from_pretrained(Path(config.path) / "text_encoder")
raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b)
class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, T5EncoderBnbQuantizedLlmInt8bConfig):
raise ValueError("Only T5EncoderBnbQuantizedLlmInt8bConfig models are currently supported here.")
if not bnb_available:
raise ImportError(
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
)
match submodel_type:
case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
te2_model_path = Path(config.path) / "text_encoder_2"
model_config = AutoConfig.from_pretrained(te2_model_path)
with accelerate.init_empty_weights():
model = AutoModelForTextEncoding.from_config(model_config)
model = quantize_model_llm_int8(model, modules_to_not_convert=set())
state_dict_path = te2_model_path / "bnb_llm_int8_model.safetensors"
state_dict = load_file(state_dict_path)
self._load_state_dict_into_t5(model, state_dict)
return model
raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
@classmethod
def _load_state_dict_into_t5(cls, model: T5EncoderModel, state_dict: dict[str, torch.Tensor]):
# There is a shared reference to a single weight tensor in the model.
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
# be present in the state_dict.
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
assert len(unexpected_keys) == 0
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
# Assert that the layers we expect to be shared are actually shared.
assert model.encoder.embed_tokens.weight is model.shared.weight
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
class T5EncoderCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, T5EncoderConfig):
raise ValueError("Only T5EncoderConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2")
raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
class FluxCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config)
raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
def _load_from_singlefile(
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainCheckpointConfig)
model_path = Path(config.path)
with SilenceWarnings():
model = Flux(params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config)
raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
def _load_from_singlefile(
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
if not bnb_available:
raise ImportError(
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
)
model_path = Path(config.path)
with SilenceWarnings():
with accelerate.init_empty_weights():
model = Flux(params[config.config_path])
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
return model

View File

@ -78,7 +78,12 @@ class GenericDiffusersLoader(ModelLoader):
# TO DO: Add exception handling
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
if module in ["diffusers", "transformers"]:
if module in [
"diffusers",
"transformers",
"invokeai.backend.quantization.fast_quantized_transformers_model",
"invokeai.backend.quantization.fast_quantized_diffusion_model",
]:
res_type = sys.modules[module]
else:
res_type = sys.modules["diffusers"].pipelines

View File

@ -36,8 +36,18 @@ VARIANT_TO_IN_CHANNEL_MAP = {
}
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint
)
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models."""

View File

@ -9,7 +9,7 @@ from typing import Optional
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
@ -50,6 +50,17 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
),
):
return model.calc_size()
elif isinstance(
model,
(
T5TokenizerFast,
T5Tokenizer,
),
):
# HACK(ryand): len(model) just returns the vocabulary size, so this is blatantly wrong. It should be small
# relative to the text encoder that it's used with, so shouldn't matter too much, but we should fix this at some
# point.
return len(model)
else:
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
# supported model types.

View File

@ -95,6 +95,7 @@ class ModelProbe(object):
}
CLASS2TYPE = {
"FluxPipeline": ModelType.Main,
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
@ -106,6 +107,7 @@ class ModelProbe(object):
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter,
"CLIPModel": ModelType.CLIPEmbed,
}
@classmethod
@ -161,7 +163,7 @@ class ModelProbe(object):
fields["description"] = (
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
fields["default_settings"] = fields.get("default_settings")
@ -176,10 +178,10 @@ class ModelProbe(object):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
# additional fields needed for main and controlnet models
if (
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
and fields["format"] is ModelFormat.Checkpoint
):
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
ModelFormat.Checkpoint,
ModelFormat.BnbQuantizednf4b,
]:
ckpt_config_path = cls._get_checkpoint_config_path(
model_path,
model_type=fields["type"],
@ -222,7 +224,8 @@ class ModelProbe(object):
ckpt = ckpt.get("state_dict", ckpt)
for key in [str(k) for k in ckpt.keys()]:
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")):
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.", "double_blocks.")):
# Keys starting with double_blocks are associated with Flux models
return ModelType.Main
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
return ModelType.VAE
@ -321,10 +324,27 @@ class ModelProbe(object):
return possible_conf.absolute()
if model_type is ModelType.Main:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
config_file = f"stable-diffusion/{config_file}"
if base_type == BaseModelType.Flux:
# TODO: Decide between dev/schnell
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint
if "guidance_in.out_layer.weight" in state_dict:
# For flux, this is a key in invokeai.backend.flux.util.params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
config_file = "flux-dev"
else:
# For flux, this is a key in invokeai.backend.flux.util.params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
config_file = "flux-schnell"
else:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
config_file = f"stable-diffusion/{config_file}"
elif model_type is ModelType.ControlNet:
config_file = (
"controlnet/cldm_v15.yaml"
@ -333,7 +353,13 @@ class ModelProbe(object):
)
elif model_type is ModelType.VAE:
config_file = (
"stable-diffusion/v1-inference.yaml"
# For flux, this is a key in invokeai.backend.flux.util.ae_params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
"flux"
if base_type is BaseModelType.Flux
else "stable-diffusion/v1-inference.yaml"
if base_type is BaseModelType.StableDiffusion1
else "stable-diffusion/sd_xl_base.yaml"
if base_type is BaseModelType.StableDiffusionXL
@ -416,11 +442,15 @@ class CheckpointProbeBase(ProbeBase):
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
def get_format(self) -> ModelFormat:
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
if "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict:
return ModelFormat.BnbQuantizednf4b
return ModelFormat("checkpoint")
def get_variant_type(self) -> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
if model_type != ModelType.Main:
base_type = self.get_base_type()
if model_type != ModelType.Main or base_type == BaseModelType.Flux:
return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
@ -440,6 +470,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
if "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
return BaseModelType.Flux
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
@ -482,6 +514,7 @@ class VaeCheckpointProbe(CheckpointProbeBase):
(r"xl", BaseModelType.StableDiffusionXL),
(r"sd2", BaseModelType.StableDiffusion2),
(r"vae", BaseModelType.StableDiffusion1),
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
]:
if re.search(regexp, self.model_path.name, re.IGNORECASE):
return basetype
@ -713,6 +746,11 @@ class TextualInversionFolderProbe(FolderProbeBase):
return TextualInversionCheckpointProbe(path).get_base_type()
class T5EncoderFolderProbe(FolderProbeBase):
def get_format(self) -> ModelFormat:
return ModelFormat.T5Encoder
class ONNXFolderProbe(PipelineFolderProbe):
def get_base_type(self) -> BaseModelType:
# Due to the way the installer is set up, the configuration file for safetensors
@ -805,6 +843,11 @@ class CLIPVisionFolderProbe(FolderProbeBase):
return BaseModelType.Any
class CLIPEmbedFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
class SpandrelImageToImageFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
@ -835,8 +878,10 @@ ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPEmbed, CLIPEmbedFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)

View File

@ -2,7 +2,7 @@ from typing import Optional
from pydantic import BaseModel
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
class StarterModelWithoutDependencies(BaseModel):
@ -11,6 +11,7 @@ class StarterModelWithoutDependencies(BaseModel):
name: str
base: BaseModelType
type: ModelType
format: Optional[ModelFormat] = None
is_installed: bool = False
@ -51,10 +52,76 @@ cyberrealistic_negative = StarterModel(
type=ModelType.TextualInversion,
)
t5_base_encoder = StarterModel(
name="t5_base_encoder",
base=BaseModelType.Any,
source="InvokeAI/t5-v1_1-xxl::bfloat16",
description="T5-XXL text encoder (used in FLUX pipelines). ~8GB",
type=ModelType.T5Encoder,
)
t5_8b_quantized_encoder = StarterModel(
name="t5_bnb_int8_quantized_encoder",
base=BaseModelType.Any,
source="InvokeAI/t5-v1_1-xxl::bnb_llm_int8",
description="T5-XXL text encoder with bitsandbytes LLM.int8() quantization (used in FLUX pipelines). ~5GB",
type=ModelType.T5Encoder,
format=ModelFormat.BnbQuantizedLlmInt8b,
)
clip_l_encoder = StarterModel(
name="clip-vit-large-patch14",
base=BaseModelType.Any,
source="InvokeAI/clip-vit-large-patch14-text-encoder::bfloat16",
description="CLIP-L text encoder (used in FLUX pipelines). ~250MB",
type=ModelType.CLIPEmbed,
)
flux_vae = StarterModel(
name="FLUX.1-schnell_ae",
base=BaseModelType.Flux,
source="black-forest-labs/FLUX.1-schnell::ae.safetensors",
description="FLUX VAE compatible with both schnell and dev variants.",
type=ModelType.VAE,
)
# List of starter models, displayed on the frontend.
# The order/sort of this list is not changed by the frontend - set it how you want it here.
STARTER_MODELS: list[StarterModel] = [
# region: Main
StarterModel(
name="FLUX Schnell (Quantized)",
base=BaseModelType.Flux,
source="InvokeAI/flux_schnell::transformer/bnb_nf4/flux1-schnell-bnb_nf4.safetensors",
description="FLUX schnell transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
),
StarterModel(
name="FLUX Dev (Quantized)",
base=BaseModelType.Flux,
source="InvokeAI/flux_dev::transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors",
description="FLUX dev transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
),
StarterModel(
name="FLUX Schnell",
base=BaseModelType.Flux,
source="InvokeAI/flux_schnell::transformer/base/flux1-schnell.safetensors",
description="FLUX schnell transformer in bfloat16. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
),
StarterModel(
name="FLUX Dev",
base=BaseModelType.Flux,
source="InvokeAI/flux_dev::transformer/base/flux1-dev.safetensors",
description="FLUX dev transformer in bfloat16. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
),
StarterModel(
name="CyberRealistic v4.1",
base=BaseModelType.StableDiffusion1,
@ -125,6 +192,7 @@ STARTER_MODELS: list[StarterModel] = [
# endregion
# region VAE
sdxl_fp16_vae_fix,
flux_vae,
# endregion
# region LoRA
StarterModel(
@ -450,6 +518,11 @@ STARTER_MODELS: list[StarterModel] = [
type=ModelType.SpandrelImageToImage,
),
# endregion
# region TextEncoders
t5_base_encoder,
t5_8b_quantized_encoder,
clip_l_encoder,
# endregion
]
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"

View File

@ -54,6 +54,7 @@ def filter_files(
"lora_weights.safetensors",
"weights.pb",
"onnx_data",
"spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`.
)
):
paths.append(file)
@ -62,13 +63,13 @@ def filter_files(
# downloading random checkpoints that might also be in the repo. However there is no guarantee
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
# will adhere to this naming convention, so this is an area to be careful of.
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
paths.append(file)
# limit search to subfolder if requested
if subfolder:
subfolder = root / subfolder
paths = [x for x in paths if x.parent == Path(subfolder)]
paths = [x for x in paths if Path(subfolder) in x.parents]
# _filter_by_variant uniquifies the paths and returns a set
return sorted(_filter_by_variant(paths, variant))
@ -97,7 +98,9 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
if variant == ModelRepoVariant.Flax:
result.add(path)
elif path.suffix in [".json", ".txt"]:
# Note: '.model' was added to support:
# https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/768d12a373ed5cc9ef9a9dea7504dc09fcc14842/tokenizer_2/spiece.model
elif path.suffix in [".json", ".txt", ".model"]:
result.add(path)
elif variant in [
@ -140,6 +143,23 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
continue
for candidate_list in subfolder_weights.values():
# Check if at least one of the files has the explicit fp16 variant.
at_least_one_fp16 = False
for candidate in candidate_list:
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
at_least_one_fp16 = True
break
if not at_least_one_fp16:
# If none of the candidates in this candidate_list have the explicit fp16 variant label, then this
# candidate_list probably doesn't adhere to the variant naming convention that we expected. In this case,
# we'll simply keep all the candidates. An example of a model that hits this case is
# `black-forest-labs/FLUX.1-schnell` (as of commit 012d2fd).
for candidate in candidate_list:
result.add(candidate.path)
# The candidate_list seems to have the expected variant naming convention. We'll select the highest scoring
# candidate.
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
if highest_score_candidate:
result.add(highest_score_candidate.path)

View File

@ -0,0 +1,125 @@
import bitsandbytes as bnb
import torch
# This file contains utils for working with models that use bitsandbytes LLM.int8() quantization.
# The utils in this file are partially inspired by:
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
# cleaner by re-implementing bnb.nn.Linear8bitLt with proper use of buffers and less magic. But, for now, we try to
# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
class InvokeInt8Params(bnb.nn.Int8Params):
"""We override cuda() to avoid re-quantizing the weights in the following cases:
- We loaded quantized weights from a state_dict on the cpu, and then moved the model to the gpu.
- We are moving the model back-and-forth between the cpu and gpu.
"""
def cuda(self, device):
if self.has_fp16_weights:
return super().cuda(device)
elif self.CB is not None and self.SCB is not None:
self.data = self.data.cuda()
self.CB = self.data
self.SCB = self.SCB.cuda()
else:
# we store the 8-bit rows-major weight
# we convert this weight to the turning/ampere weight during the first inference pass
B = self.data.contiguous().half().cuda(device)
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
del CBt
del SCBt
self.data = CB
self.CB = CB
self.SCB = SCB
return self
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
def _load_from_state_dict(
self,
state_dict: dict[str, torch.Tensor],
prefix: str,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
weight = state_dict.pop(prefix + "weight")
bias = state_dict.pop(prefix + "bias", None)
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
scb = state_dict.pop(prefix + "SCB", None)
# weight_format is unused, but we pop it so we can validate that there are no unexpected keys.
_weight_format = state_dict.pop(prefix + "weight_format", None)
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
# rather than raising an exception to correctly implement this API.
assert len(state_dict) == 0
if scb is not None:
# We are loading a pre-quantized state dict.
self.weight = InvokeInt8Params(
data=weight,
requires_grad=self.weight.requires_grad,
has_fp16_weights=False,
# Note: After quantization, CB is the same as weight.
CB=weight,
SCB=scb,
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
else:
# We are loading a non-quantized state dict.
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
self.weight = InvokeInt8Params(
data=weight,
requires_grad=self.weight.requires_grad,
has_fp16_weights=False,
CB=None,
SCB=None,
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
def _convert_linear_layers_to_llm_8bit(
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
) -> None:
"""Convert all linear layers in the module to bnb.nn.Linear8bitLt layers."""
for name, child in module.named_children():
fullname = f"{prefix}.{name}" if prefix else name
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
has_bias = child.bias is not None
replacement = InvokeLinear8bitLt(
child.in_features,
child.out_features,
bias=has_bias,
has_fp16_weights=False,
threshold=outlier_threshold,
)
replacement.weight.data = child.weight.data
if has_bias:
replacement.bias.data = child.bias.data
replacement.requires_grad_(False)
module.__setattr__(name, replacement)
else:
_convert_linear_layers_to_llm_8bit(
child, ignore_modules, outlier_threshold=outlier_threshold, prefix=fullname
)
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0):
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
_convert_linear_layers_to_llm_8bit(
module=model, ignore_modules=modules_to_not_convert, outlier_threshold=outlier_threshold
)
return model

View File

@ -0,0 +1,156 @@
import bitsandbytes as bnb
import torch
# This file contains utils for working with models that use bitsandbytes NF4 quantization.
# The utils in this file are partially inspired by:
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
# cleaner by re-implementing bnb.nn.LinearNF4 with proper use of buffers and less magic. But, for now, we try to stick
# close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
class InvokeLinearNF4(bnb.nn.LinearNF4):
"""A class that extends `bnb.nn.LinearNF4` to add the following functionality:
- Ability to load Linear NF4 layers from a pre-quantized state_dict.
- Ability to load Linear NF4 layers from a state_dict when the model is on the "meta" device.
"""
def _load_from_state_dict(
self,
state_dict: dict[str, torch.Tensor],
prefix: str,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`:
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71
"""
weight = state_dict.pop(prefix + "weight")
bias = state_dict.pop(prefix + "bias", None)
# We expect the remaining keys to be quant_state keys.
quant_state_sd = state_dict
# During serialization, the quant_state is stored as subkeys of "weight." (See
# `bnb.nn.LinearNF4._save_to_state_dict()`). We validate that they at least have the correct prefix.
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
# rather than raising an exception to correctly implement this API.
assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys())
if len(quant_state_sd) > 0:
# We are loading a pre-quantized state dict.
self.weight = bnb.nn.Params4bit.from_prequantized(
data=weight, quantized_stats=quant_state_sd, device=weight.device
)
self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False)
else:
# We are loading a non-quantized state dict.
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
self.weight = bnb.nn.Params4bit(
data=weight,
requires_grad=self.weight.requires_grad,
compress_statistics=self.weight.compress_statistics,
quant_type=self.weight.quant_type,
quant_storage=self.weight.quant_storage,
module=self,
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
def _replace_param(
param: torch.nn.Parameter | bnb.nn.Params4bit,
data: torch.Tensor,
) -> torch.nn.Parameter:
"""A helper function to replace the data of a model parameter with new data in a way that allows replacing params on
the "meta" device.
Supports both `torch.nn.Parameter` and `bnb.nn.Params4bit` parameters.
"""
if param.device.type == "meta":
# Doing `param.data = data` raises a RuntimeError if param.data was on the "meta" device, so we need to
# re-create the param instead of overwriting the data.
if isinstance(param, bnb.nn.Params4bit):
return bnb.nn.Params4bit(
data,
requires_grad=data.requires_grad,
quant_state=param.quant_state,
compress_statistics=param.compress_statistics,
quant_type=param.quant_type,
)
return torch.nn.Parameter(data, requires_grad=data.requires_grad)
param.data = data
return param
def _convert_linear_layers_to_nf4(
module: torch.nn.Module,
ignore_modules: set[str],
compute_dtype: torch.dtype,
compress_statistics: bool = False,
prefix: str = "",
) -> None:
"""Convert all linear layers in the model to NF4 quantized linear layers.
Args:
module: All linear layers in this module will be converted.
ignore_modules: A set of module prefixes to ignore when converting linear layers.
compute_dtype: The dtype to use for computation in the quantized linear layers.
compress_statistics: Whether to enable nested quantization (aka double quantization) where the quantization
constants from the first quantization are quantized again.
prefix: The prefix of the current module in the model. Used to call this function recursively.
"""
for name, child in module.named_children():
fullname = f"{prefix}.{name}" if prefix else name
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
has_bias = child.bias is not None
replacement = InvokeLinearNF4(
child.in_features,
child.out_features,
bias=has_bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
)
if has_bias:
replacement.bias = _replace_param(replacement.bias, child.bias.data)
replacement.weight = _replace_param(replacement.weight, child.weight.data)
replacement.requires_grad_(False)
module.__setattr__(name, replacement)
else:
_convert_linear_layers_to_nf4(child, ignore_modules, compute_dtype=compute_dtype, prefix=fullname)
def quantize_model_nf4(model: torch.nn.Module, modules_to_not_convert: set[str], compute_dtype: torch.dtype):
"""Apply bitsandbytes nf4 quantization to the model.
You likely want to call this function inside a `accelerate.init_empty_weights()` context.
Example usage:
```
# Initialize the model from a config on the meta device.
with accelerate.init_empty_weights():
model = ModelClass.from_config(...)
# Add NF4 quantization linear layers to the model - still on the meta device.
with accelerate.init_empty_weights():
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.float16)
# Load a state_dict into the model. (Could be either a prequantized or non-quantized state_dict.)
model.load_state_dict(state_dict, strict=True, assign=True)
# Move the model to the "cuda" device. If the model was non-quantized, this is where the weight quantization takes
# place.
model.to("cuda")
```
"""
_convert_linear_layers_to_nf4(module=model, ignore_modules=modules_to_not_convert, compute_dtype=compute_dtype)
return model

View File

@ -0,0 +1,79 @@
from pathlib import Path
import accelerate
from safetensors.torch import load_file, save_file
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
def main():
"""A script for quantizing a FLUX transformer model using the bitsandbytes LLM.int8() quantization method.
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
etc.) are hardcoded and would need to be modified for other use cases.
"""
# Load the FLUX transformer model onto the meta device.
model_path = Path(
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
)
with log_time("Intialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
p = params["flux-schnell"]
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():
model = Flux(p)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
modules_to_not_convert: set[str] = set()
model_int8_path = model_path.parent / "bnb_llm_int8.safetensors"
if model_int8_path.exists():
# The quantized model already exists, load it and return it.
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
sd = load_file(model_int8_path)
model.load_state_dict(sd, strict=True, assign=True)
with log_time("Move model to cuda"):
model = model.to("cuda")
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
else:
# The quantized model does not exist, quantize the model and save it.
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
state_dict = load_file(model_path)
# TODO(ryand): Cast the state_dict to the appropriate dtype?
model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda and quantize"):
model = model.to("cuda")
with log_time("Save quantized model"):
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
save_file(model.state_dict(), model_int8_path)
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
assert isinstance(model, Flux)
return model
if __name__ == "__main__":
main()

View File

@ -0,0 +1,96 @@
import time
from contextlib import contextmanager
from pathlib import Path
import accelerate
import torch
from safetensors.torch import load_file, save_file
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
@contextmanager
def log_time(name: str):
"""Helper context manager to log the time taken by a block of code."""
start = time.time()
try:
yield None
finally:
end = time.time()
print(f"'{name}' took {end - start:.4f} secs")
def main():
"""A script for quantizing a FLUX transformer model using the bitsandbytes NF4 quantization method.
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
etc.) are hardcoded and would need to be modified for other use cases.
"""
model_path = Path(
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
)
# inference_dtype = torch.bfloat16
with log_time("Intialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
p = params["flux-schnell"]
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():
model = Flux(p)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
modules_to_not_convert: set[str] = set()
model_nf4_path = model_path.parent / "bnb_nf4.safetensors"
if model_nf4_path.exists():
# The quantized model already exists, load it and return it.
print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...")
# Replace the linear layers with NF4 quantized linear layers (still on the meta device).
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
model = quantize_model_nf4(
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
)
with log_time("Load state dict into model"):
state_dict = load_file(model_nf4_path)
model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda"):
model = model.to("cuda")
print(f"Successfully loaded pre-quantized model from '{model_nf4_path}'.")
else:
# The quantized model does not exist, quantize the model and save it.
print(f"No pre-quantized model found at '{model_nf4_path}'. Quantizing the model...")
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
model = quantize_model_nf4(
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
)
with log_time("Load state dict into model"):
state_dict = load_file(model_path)
# TODO(ryand): Cast the state_dict to the appropriate dtype?
model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda and quantize"):
model = model.to("cuda")
with log_time("Save quantized model"):
model_nf4_path.parent.mkdir(parents=True, exist_ok=True)
save_file(model.state_dict(), model_nf4_path)
print(f"Successfully quantized and saved model to '{model_nf4_path}'.")
assert isinstance(model, Flux)
return model
if __name__ == "__main__":
main()

View File

@ -0,0 +1,92 @@
from pathlib import Path
import accelerate
from safetensors.torch import load_file, save_file
from transformers import AutoConfig, AutoModelForTextEncoding, T5EncoderModel
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
def load_state_dict_into_t5(model: T5EncoderModel, state_dict: dict):
# There is a shared reference to a single weight tensor in the model.
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
# be present in the state_dict.
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
assert len(unexpected_keys) == 0
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
# Assert that the layers we expect to be shared are actually shared.
assert model.encoder.embed_tokens.weight is model.shared.weight
def main():
"""A script for quantizing a T5 text encoder model using the bitsandbytes LLM.int8() quantization method.
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
etc.) are hardcoded and would need to be modified for other use cases.
"""
model_path = Path("/data/misc/text_encoder_2")
with log_time("Intialize T5 on meta device"):
model_config = AutoConfig.from_pretrained(model_path)
with accelerate.init_empty_weights():
model = AutoModelForTextEncoding.from_config(model_config)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
modules_to_not_convert: set[str] = set()
model_int8_path = model_path / "bnb_llm_int8.safetensors"
if model_int8_path.exists():
# The quantized model already exists, load it and return it.
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
sd = load_file(model_int8_path)
load_state_dict_into_t5(model, sd)
with log_time("Move model to cuda"):
model = model.to("cuda")
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
else:
# The quantized model does not exist, quantize the model and save it.
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
# Load sharded state dict.
files = list(model_path.glob("*.safetensors"))
state_dict = {}
for file in files:
sd = load_file(file)
state_dict.update(sd)
load_state_dict_into_t5(model, state_dict)
with log_time("Move model to cuda and quantize"):
model = model.to("cuda")
with log_time("Save quantized model"):
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
state_dict = model.state_dict()
state_dict.pop("encoder.embed_tokens.weight")
save_file(state_dict, model_int8_path)
# This handling of shared weights could also be achieved with save_model(...), but then we'd lose control
# over which keys are kept. And, the corresponding load_model(...) function does not support assign=True.
# save_model(model, model_int8_path)
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
assert isinstance(model, T5EncoderModel)
return model
if __name__ == "__main__":
main()

View File

@ -25,11 +25,6 @@ class BasicConditioningInfo:
return self
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo]
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
"""SDXL text conditioning information produced by Compel."""
@ -43,6 +38,17 @@ class SDXLConditioningInfo(BasicConditioningInfo):
return super().to(device=device, dtype=dtype)
@dataclass
class FLUXConditioningInfo:
clip_embeds: torch.Tensor
t5_embeds: torch.Tensor
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
@dataclass
class IPAdapterConditioningInfo:
cond_image_prompt_embeds: torch.Tensor

View File

@ -12,6 +12,10 @@ module.exports = {
'i18next/no-literal-string': 'error',
// https://eslint.org/docs/latest/rules/no-console
'no-console': 'error',
// https://eslint.org/docs/latest/rules/no-promise-executor-return
'no-promise-executor-return': 'error',
// https://eslint.org/docs/latest/rules/require-await
'require-await': 'error',
},
overrides: [
/**

View File

@ -1,5 +1,5 @@
import { PropsWithChildren, memo, useEffect } from 'react';
import { modelChanged } from '../src/features/parameters/store/generationSlice';
import { modelChanged } from '../src/features/controlLayers/store/paramsSlice';
import { useAppDispatch } from '../src/app/store/storeHooks';
import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
/**
@ -10,7 +10,9 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
const dispatch = useAppDispatch();
useGlobalModifiersInit();
useEffect(() => {
dispatch(modelChanged({ key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' }));
dispatch(
modelChanged({ model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' } })
);
}, []);
return props.children;

View File

@ -9,6 +9,8 @@ const config: KnipConfig = {
'src/services/api/schema.ts',
'src/features/nodes/types/v1/**',
'src/features/nodes/types/v2/**',
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
'src/features/controlLayers/konva/util.ts',
],
ignoreBinaries: ['only-allow'],
paths: {

View File

@ -24,7 +24,7 @@
"build": "pnpm run lint && vite build",
"typegen": "node scripts/typegen.js",
"preview": "vite preview",
"lint:knip": "knip",
"lint:knip": "knip --tags=-knipignore",
"lint:dpdm": "dpdm --no-warning --no-tree --transform --exit-code circular:1 src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",
"lint:prettier": "prettier --check .",
@ -52,18 +52,19 @@
}
},
"dependencies": {
"@chakra-ui/react-use-size": "^2.1.0",
"@dagrejs/dagre": "^1.1.3",
"@dagrejs/graphlib": "^2.2.3",
"@dnd-kit/core": "^6.1.0",
"@dnd-kit/sortable": "^8.0.0",
"@dnd-kit/utilities": "^3.2.2",
"@fontsource-variable/inter": "^5.0.20",
"@invoke-ai/ui-library": "^0.0.29",
"@invoke-ai/ui-library": "^0.0.32",
"@nanostores/react": "^0.7.3",
"@reduxjs/toolkit": "2.2.3",
"@roarr/browser-log-writer": "^1.3.0",
"async-mutex": "^0.5.0",
"chakra-react-select": "^4.9.1",
"cmdk": "^1.0.0",
"compare-versions": "^6.1.1",
"dateformat": "^5.0.3",
"fracturedjsonjs": "^4.0.2",
@ -74,6 +75,8 @@
"jsondiffpatch": "^0.6.0",
"konva": "^9.3.14",
"lodash-es": "^4.17.21",
"lru-cache": "^11.0.0",
"nanoid": "^5.0.7",
"nanostores": "^0.11.2",
"new-github-issue-url": "^1.0.0",
"overlayscrollbars": "^2.10.0",
@ -88,10 +91,8 @@
"react-hotkeys-hook": "4.5.0",
"react-i18next": "^14.1.3",
"react-icons": "^5.2.1",
"react-konva": "^18.2.10",
"react-redux": "9.1.2",
"react-resizable-panels": "^2.0.23",
"react-select": "5.8.0",
"react-use": "^17.5.1",
"react-virtuoso": "^4.9.0",
"reactflow": "^11.11.4",
@ -102,9 +103,9 @@
"roarr": "^7.21.1",
"serialize-error": "^11.0.3",
"socket.io-client": "^4.7.5",
"stable-hash": "^0.0.4",
"use-debounce": "^10.0.2",
"use-device-pixel-ratio": "^1.1.2",
"use-image": "^1.1.1",
"uuid": "^10.0.0",
"zod": "^3.23.8",
"zod-validation-error": "^3.3.1"

File diff suppressed because it is too large Load Diff

View File

@ -80,6 +80,7 @@
"aboutDesc": "Using Invoke for work? Check out:",
"aboutHeading": "Own Your Creative Power",
"accept": "Accept",
"apply": "Apply",
"add": "Add",
"advanced": "Advanced",
"ai": "ai",
@ -115,6 +116,7 @@
"githubLabel": "Github",
"goTo": "Go to",
"hotkeysLabel": "Hotkeys",
"loadingImage": "Loading Image",
"imageFailedToLoad": "Unable to Load Image",
"img2img": "Image To Image",
"inpaint": "inpaint",
@ -325,6 +327,10 @@
"canceled": "Canceled",
"completedIn": "Completed in",
"batch": "Batch",
"origin": "Origin",
"originCanvas": "Canvas",
"originWorkflows": "Workflows",
"originOther": "Other",
"batchFieldValues": "Batch Field Values",
"item": "Item",
"session": "Session",
@ -784,6 +790,7 @@
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
"source": "Source",
"starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"syncModels": "Sync Models",
"textualInversions": "Textual Inversions",
"triggerPhrases": "Trigger Phrases",
@ -1095,7 +1102,6 @@
"confirmOnDelete": "Confirm On Delete",
"developer": "Developer",
"displayInProgress": "Display Progress Images",
"enableImageDebugging": "Enable Image Debugging",
"enableInformationalPopovers": "Enable Informational Popovers",
"informationalPopoversDisabled": "Informational Popovers Disabled",
"informationalPopoversDisabledDesc": "Informational popovers have been disabled. Enable them in Settings.",
@ -1562,7 +1568,7 @@
"copyToClipboard": "Copy to Clipboard",
"cursorPosition": "Cursor Position",
"darkenOutsideSelection": "Darken Outside Selection",
"discardAll": "Discard All",
"discardAll": "Discard All & Cancel Pending Generations",
"discardCurrent": "Discard Current",
"downloadAsImage": "Download As Image",
"enableMask": "Enable Mask",
@ -1640,41 +1646,126 @@
"storeNotInitialized": "Store is not initialized"
},
"controlLayers": {
"deleteAll": "Delete All",
"clearHistory": "Clear History",
"generateMode": "Generate",
"generateModeDesc": "Create individual images. Generated images are added directly to the gallery.",
"composeMode": "Compose",
"composeModeDesc": "Compose your work iterative. Generated images are added back to the canvas.",
"autoSave": "Auto-save to Gallery",
"resetCanvas": "Reset Canvas",
"resetAll": "Reset All",
"clearCaches": "Clear Caches",
"recalculateRects": "Recalculate Rects",
"clipToBbox": "Clip Strokes to Bbox",
"addLayer": "Add Layer",
"duplicate": "Duplicate",
"moveToFront": "Move to Front",
"moveToBack": "Move to Back",
"moveForward": "Move Forward",
"moveBackward": "Move Backward",
"brushSize": "Brush Size",
"width": "Width",
"zoom": "Zoom",
"resetView": "Reset View",
"controlLayers": "Control Layers",
"globalMaskOpacity": "Global Mask Opacity",
"autoNegative": "Auto Negative",
"enableAutoNegative": "Enable Auto Negative",
"disableAutoNegative": "Disable Auto Negative",
"deletePrompt": "Delete Prompt",
"resetRegion": "Reset Region",
"debugLayers": "Debug Layers",
"rectangle": "Rectangle",
"maskPreviewColor": "Mask Preview Color",
"maskFill": "Mask Fill",
"addPositivePrompt": "Add $t(common.positivePrompt)",
"addNegativePrompt": "Add $t(common.negativePrompt)",
"addIPAdapter": "Add $t(common.ipAdapter)",
"regionalGuidance": "Regional Guidance",
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",
"raster": "Raster",
"rasterLayer_one": "Raster Layer",
"controlLayer_one": "Control Layer",
"inpaintMask_one": "Inpaint Mask",
"regionalGuidance_one": "Regional Guidance",
"ipAdapter_one": "IP Adapter",
"rasterLayer_other": "Raster Layers",
"controlLayer_other": "Control Layers",
"inpaintMask_other": "Inpaint Masks",
"regionalGuidance_other": "Regional Guidance",
"ipAdapter_other": "IP Adapters",
"opacity": "Opacity",
"regionalGuidance_withCount_hidden": "Regional Guidance ({{count}} hidden)",
"controlAdapters_withCount_hidden": "Control Adapters ({{count}} hidden)",
"controlLayers_withCount_hidden": "Control Layers ({{count}} hidden)",
"rasterLayers_withCount_hidden": "Raster Layers ({{count}} hidden)",
"ipAdapters_withCount_hidden": "IP Adapters ({{count}} hidden)",
"inpaintMasks_withCount_hidden": "Inpaint Masks ({{count}} hidden)",
"regionalGuidance_withCount_visible": "Regional Guidance ({{count}})",
"controlAdapters_withCount_visible": "Control Adapters ({{count}})",
"controlLayers_withCount_visible": "Control Layers ({{count}})",
"rasterLayers_withCount_visible": "Raster Layers ({{count}})",
"ipAdapters_withCount_visible": "IP Adapters ({{count}})",
"inpaintMasks_withCount_visible": "Inpaint Masks ({{count}})",
"globalControlAdapter": "Global $t(controlnet.controlAdapter_one)",
"globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
"globalIPAdapter": "Global $t(common.ipAdapter)",
"globalIPAdapterLayer": "Global $t(common.ipAdapter) $t(unifiedCanvas.layer)",
"globalInitialImage": "Global Initial Image",
"globalInitialImageLayer": "$t(controlLayers.globalInitialImage) $t(unifiedCanvas.layer)",
"layer": "Layer",
"opacityFilter": "Opacity Filter",
"clearProcessor": "Clear Processor",
"resetProcessor": "Reset Processor to Defaults",
"noLayersAdded": "No Layers Added",
"layers_one": "Layer",
"layers_other": "Layers"
"layers_other": "Layers",
"objects_zero": "empty",
"objects_one": "{{count}} object",
"objects_other": "{{count}} objects",
"convertToControlLayer": "Convert to Control Layer",
"convertToRasterLayer": "Convert to Raster Layer",
"transparency": "Transparency",
"enableTransparencyEffect": "Enable Transparency Effect",
"disableTransparencyEffect": "Disable Transparency Effect",
"hidingType": "Hiding {{type}}",
"showingType": "Showing {{type}}",
"dynamicGrid": "Dynamic Grid",
"logDebugInfo": "Log Debug Info",
"locked": "Locked",
"unlocked": "Unlocked",
"deleteSelected": "Delete Selected",
"deleteAll": "Delete All",
"flipHorizontal": "Flip Horizontal",
"flipVertical": "Flip Vertical",
"fill": {
"fillStyle": "Fill Style",
"solid": "Solid",
"grid": "Grid",
"crosshatch": "Crosshatch",
"vertical": "Vertical",
"horizontal": "Horizontal",
"diagonal": "Diagonal"
},
"tool": {
"brush": "Brush",
"eraser": "Eraser",
"rectangle": "Rectangle",
"bbox": "Bbox",
"move": "Move",
"view": "View",
"transform": "Transform",
"colorPicker": "Color Picker"
},
"filter": {
"filter": "Filter",
"filters": "Filters",
"filterType": "Filter Type",
"preview": "Preview",
"apply": "Apply",
"cancel": "Cancel"
}
},
"upscaling": {
"upscale": "Upscale",
"creativity": "Creativity",
"exceedsMaxSize": "Upscale settings exceed max size limit",
"exceedsMaxSizeDetails": "Max upscale limit is {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixels. Please try a smaller image or decrease your scale selection.",
@ -1723,6 +1814,7 @@
"positivePrompt": "Positive Prompt",
"preview": "Preview",
"private": "Private",
"promptTemplateCleared": "Prompt Template Cleared",
"searchByName": "Search by name",
"shared": "Shared",
"sharedTemplates": "Shared Templates",
@ -1758,5 +1850,30 @@
"upscaling": "Upscaling",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
}
},
"system": {
"enableLogging": "Enable Logging",
"logLevel": {
"logLevel": "Log Level",
"trace": "Trace",
"debug": "Debug",
"info": "Info",
"warn": "Warn",
"error": "Error",
"fatal": "Fatal"
},
"logNamespaces": {
"logNamespaces": "Log Namespaces",
"gallery": "Gallery",
"models": "Models",
"config": "Config",
"canvas": "Canvas",
"generation": "Generation",
"workflows": "Workflows",
"system": "System",
"events": "Events",
"queue": "Queue",
"metadata": "Metadata"
}
}
}

View File

@ -929,7 +929,7 @@
"missingInvocationTemplate": "Modello di invocazione mancante",
"missingFieldTemplate": "Modello di campo mancante",
"singleFieldType": "{{name}} (Singola)",
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino delle impostazioni predefinite",
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino ai valori predefiniti",
"boardAccessError": "Impossibile trovare la bacheca {{board_id}}, ripristino ai valori predefiniti",
"modelAccessError": "Impossibile trovare il modello {{key}}, ripristino ai valori predefiniti"
},
@ -1782,7 +1782,13 @@
"updatePromptTemplate": "Aggiorna il modello di prompt",
"type": "Tipo",
"promptTemplatesDesc2": "Utilizza la stringa segnaposto <Pre>{{placeholder}}</Pre> per specificare dove inserire il tuo prompt nel modello.",
"importTemplates": "Importa modelli di prompt",
"importTemplatesDesc": "Il formato deve essere un CSV con colonne 'name' e 'prompt' o 'positive_prompt' e 'negative_prompt' incluse, oppure un file JSON con chiavi 'name' e 'prompt' o 'positive_prompt' e 'negative_prompt"
"importTemplates": "Importa modelli di prompt (CSV/JSON)",
"exportDownloaded": "Esportazione completata",
"exportFailed": "Impossibile generare e scaricare il file CSV",
"exportPromptTemplates": "Esporta i miei modelli di prompt (CSV)",
"positivePromptColumn": "'prompt' o 'positive_prompt'",
"noTemplates": "Nessun modello",
"acceptedColumnsKeys": "Colonne/chiavi accettate:",
"templateActions": "Azioni modello"
}
}

View File

@ -91,7 +91,8 @@
"enabled": "Включено",
"disabled": "Отключено",
"comparingDesc": "Сравнение двух изображений",
"comparing": "Сравнение"
"comparing": "Сравнение",
"dontShowMeThese": "Не показывай мне это"
},
"gallery": {
"galleryImageSize": "Размер изображений",
@ -153,7 +154,11 @@
"showArchivedBoards": "Показать архивированные доски",
"searchImages": "Поиск по метаданным",
"displayBoardSearch": "Отобразить поиск досок",
"displaySearch": "Отобразить поиск"
"displaySearch": "Отобразить поиск",
"exitBoardSearch": "Выйти из поиска досок",
"go": "Перейти",
"exitSearch": "Выйти из поиска",
"jump": "Пыгнуть"
},
"hotkeys": {
"keyboardShortcuts": "Горячие клавиши",
@ -376,6 +381,10 @@
"toggleViewer": {
"title": "Переключить просмотр изображений",
"desc": "Переключение между средством просмотра изображений и рабочей областью для текущей вкладки."
},
"postProcess": {
"desc": "Обработайте текущее изображение с помощью выбранной модели постобработки",
"title": "Обработать изображение"
}
},
"modelManager": {
@ -589,7 +598,10 @@
"infillColorValue": "Цвет заливки",
"globalSettings": "Глобальные настройки",
"globalNegativePromptPlaceholder": "Глобальный негативный запрос",
"globalPositivePromptPlaceholder": "Глобальный запрос"
"globalPositivePromptPlaceholder": "Глобальный запрос",
"postProcessing": "Постобработка (Shift + U)",
"processImage": "Обработка изображения",
"sendToUpscale": "Отправить на увеличение"
},
"settings": {
"models": "Модели",
@ -623,7 +635,9 @@
"intermediatesCleared_many": "Очищено {{count}} промежуточных",
"clearIntermediatesDesc1": "Очистка промежуточных элементов приведет к сбросу состояния Canvas и ControlNet.",
"intermediatesClearedFailed": "Проблема очистки промежуточных",
"reloadingIn": "Перезагрузка через"
"reloadingIn": "Перезагрузка через",
"informationalPopoversDisabled": "Информационные всплывающие окна отключены",
"informationalPopoversDisabledDesc": "Информационные всплывающие окна были отключены. Включите их в Настройках."
},
"toast": {
"uploadFailed": "Загрузка не удалась",
@ -694,7 +708,9 @@
"sessionRef": "Сессия: {{sessionId}}",
"outOfMemoryError": "Ошибка нехватки памяти",
"outOfMemoryErrorDesc": "Ваши текущие настройки генерации превышают возможности системы. Пожалуйста, измените настройки и повторите попытку.",
"somethingWentWrong": "Что-то пошло не так"
"somethingWentWrong": "Что-то пошло не так",
"importFailed": "Импорт неудачен",
"importSuccessful": "Импорт успешен"
},
"tooltip": {
"feature": {
@ -1017,7 +1033,8 @@
"composition": "Только композиция",
"hed": "HED",
"beginEndStepPercentShort": "Начало/конец %",
"setControlImageDimensionsForce": "Скопируйте размер в Ш/В (игнорируйте модель)"
"setControlImageDimensionsForce": "Скопируйте размер в Ш/В (игнорируйте модель)",
"depthAnythingSmallV2": "Small V2"
},
"boards": {
"autoAddBoard": "Авто добавление Доски",
@ -1042,7 +1059,7 @@
"downloadBoard": "Скачать доску",
"deleteBoard": "Удалить доску",
"deleteBoardAndImages": "Удалить доску и изображения",
"deletedBoardsCannotbeRestored": "Удаленные доски не подлежат восстановлению",
"deletedBoardsCannotbeRestored": "Удаленные доски не могут быть восстановлены. Выбор «Удалить только доску» переведет изображения в состояние без категории.",
"assetsWithCount_one": "{{count}} ассет",
"assetsWithCount_few": "{{count}} ассета",
"assetsWithCount_many": "{{count}} ассетов",
@ -1057,7 +1074,11 @@
"boards": "Доски",
"addPrivateBoard": "Добавить личную доску",
"private": "Личные доски",
"shared": "Общие доски"
"shared": "Общие доски",
"hideBoards": "Скрыть доски",
"viewBoards": "Просмотреть доски",
"noBoards": "Нет досок {{boardType}}",
"deletedPrivateBoardsCannotbeRestored": "Удаленные доски не могут быть восстановлены. Выбор «Удалить только доску» переведет изображения в приватное состояние без категории для создателя изображения."
},
"dynamicPrompts": {
"seedBehaviour": {
@ -1417,6 +1438,30 @@
"paragraphs": [
"Метод, с помощью которого применяется текущий IP-адаптер."
]
},
"structure": {
"paragraphs": [
"Структура контролирует, насколько точно выходное изображение будет соответствовать макету оригинала. Низкая структура допускает значительные изменения, в то время как высокая структура строго сохраняет исходную композицию и макет."
],
"heading": "Структура"
},
"scale": {
"paragraphs": [
"Масштаб управляет размером выходного изображения и основывается на кратном разрешении входного изображения. Например, при увеличении в 2 раза изображения 1024x1024 на выходе получится 2048 x 2048."
],
"heading": "Масштаб"
},
"creativity": {
"paragraphs": [
"Креативность контролирует степень свободы, предоставляемой модели при добавлении деталей. При низкой креативности модель остается близкой к оригинальному изображению, в то время как высокая креативность позволяет вносить больше изменений. При использовании подсказки высокая креативность увеличивает влияние подсказки."
],
"heading": "Креативность"
},
"upscaleModel": {
"heading": "Модель увеличения",
"paragraphs": [
"Модель увеличения масштаба масштабирует изображение до выходного размера перед добавлением деталей. Можно использовать любую поддерживаемую модель масштабирования, но некоторые из них специализированы для различных видов изображений, например фотографий или линейных рисунков."
]
}
},
"metadata": {
@ -1693,7 +1738,78 @@
"canvasTab": "$t(ui.tabs.canvas) $t(common.tab)",
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
"queue": "Очередь"
"queue": "Очередь",
"upscaling": "Увеличение",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
}
},
"upscaling": {
"exceedsMaxSize": "Параметры масштабирования превышают максимальный размер",
"exceedsMaxSizeDetails": "Максимальный предел масштабирования составляет {{maxUpscaleDimension}}x{{maxUpscaleDimension}} пикселей. Пожалуйста, попробуйте использовать меньшее изображение или уменьшите масштаб.",
"structure": "Структура",
"missingTileControlNetModel": "Не установлены подходящие модели ControlNet",
"missingUpscaleInitialImage": "Отсутствует увеличиваемое изображение",
"missingUpscaleModel": "Отсутствует увеличивающая модель",
"creativity": "Креативность",
"upscaleModel": "Модель увеличения",
"scale": "Масштаб",
"mainModelDesc": "Основная модель (архитектура SD1.5 или SDXL)",
"upscaleModelDesc": "Модель увеличения (img2img)",
"postProcessingModel": "Модель постобработки",
"tileControlNetModelDesc": "Модель ControlNet для выбранной архитектуры основной модели",
"missingModelsWarning": "Зайдите в <LinkComponent>Менеджер моделей</LinkComponent> чтоб установить необходимые модели:",
"postProcessingMissingModelWarning": "Посетите <LinkComponent>Менеджер моделей</LinkComponent>, чтобы установить модель постобработки (img2img)."
},
"stylePresets": {
"noMatchingTemplates": "Нет подходящих шаблонов",
"promptTemplatesDesc1": "Шаблоны подсказок добавляют текст к подсказкам, которые вы пишете в окне подсказок.",
"sharedTemplates": "Общие шаблоны",
"templateDeleted": "Шаблон запроса удален",
"toggleViewMode": "Переключить режим просмотра",
"type": "Тип",
"unableToDeleteTemplate": "Не получилось удалить шаблон запроса",
"viewModeTooltip": "Вот как будет выглядеть ваш запрос с выбранным шаблоном. Чтобы его отредактировать, щелкните в любом месте текстового поля.",
"viewList": "Просмотреть список шаблонов",
"active": "Активно",
"choosePromptTemplate": "Выберите шаблон запроса",
"defaultTemplates": "Стандартные шаблоны",
"deleteImage": "Удалить изображение",
"deleteTemplate": "Удалить шаблон",
"deleteTemplate2": "Вы уверены, что хотите удалить этот шаблон? Это нельзя отменить.",
"editTemplate": "Редактировать шаблон",
"exportPromptTemplates": "Экспорт моих шаблонов запроса (CSV)",
"exportDownloaded": "Экспорт скачан",
"exportFailed": "Невозможно сгенерировать и загрузить CSV",
"flatten": "Объединить выбранный шаблон с текущим запросом",
"acceptedColumnsKeys": "Принимаемые столбцы/ключи:",
"positivePromptColumn": "'prompt' или 'positive_prompt'",
"insertPlaceholder": "Вставить заполнитель",
"name": "Имя",
"negativePrompt": "Негативный запрос",
"promptTemplatesDesc3": "Если вы не используете заполнитель, шаблон будет добавлен в конец запроса.",
"positivePrompt": "Позитивный запрос",
"preview": "Предпросмотр",
"private": "Приватный",
"templateActions": "Действия с шаблоном",
"updatePromptTemplate": "Обновить шаблон запроса",
"uploadImage": "Загрузить изображение",
"useForTemplate": "Использовать для шаблона запроса",
"clearTemplateSelection": "Очистить выбор шаблона",
"copyTemplate": "Копировать шаблон",
"createPromptTemplate": "Создать шаблон запроса",
"importTemplates": "Импортировать шаблоны запроса (CSV/JSON)",
"nameColumn": "'name'",
"negativePromptColumn": "'negative_prompt'",
"myTemplates": "Мои шаблоны",
"noTemplates": "Нет шаблонов",
"promptTemplatesDesc2": "Используйте строку-заполнитель <Pre>{{placeholder}}</Pre>, чтобы указать место, куда должен быть включен ваш запрос в шаблоне.",
"searchByName": "Поиск по имени",
"shared": "Общий"
},
"upsell": {
"inviteTeammates": "Пригласите членов команды",
"professional": "Профессионал",
"professionalUpsell": "Доступно в профессиональной версии Invoke. Нажмите здесь или посетите invoke.com/pricing для получения более подробной информации.",
"shareAccess": "Поделиться доступом"
}
}

View File

@ -38,7 +38,7 @@ async function generateTypes(schema) {
process.stdout.write(`\nOK!\r\n`);
}
async function main() {
function main() {
const encoding = 'utf-8';
if (process.stdin.isTTY) {

View File

@ -6,6 +6,7 @@ import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/ap
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { PartialAppConfig } from 'app/types/invokeai';
import ImageUploadOverlay from 'common/components/ImageUploadOverlay';
import { useScopeFocusWatcher } from 'common/hooks/interactionScopes';
import { useClearStorage } from 'common/hooks/useClearStorage';
import { useFullscreenDropzone } from 'common/hooks/useFullscreenDropzone';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
@ -13,12 +14,15 @@ import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardMo
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
import RefreshAfterResetModal from 'features/system/components/SettingsModal/RefreshAfterResetModal';
import SettingsModal from 'features/system/components/SettingsModal/SettingsModal';
import { configChanged } from 'features/system/store/configSlice';
import { languageSelector } from 'features/system/store/systemSelectors';
import InvokeTabs from 'features/ui/components/InvokeTabs';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import { selectLanguage } from 'features/system/store/systemSelectors';
import { AppContent } from 'features/ui/components/AppContent';
import { setActiveTab } from 'features/ui/store/uiSlice';
import type { TabName } from 'features/ui/store/uiTypes';
import { useGetAndLoadLibraryWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadLibraryWorkflow';
import { AnimatePresence } from 'framer-motion';
import i18n from 'i18n';
@ -39,11 +43,11 @@ interface Props {
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
selectedWorkflowId?: string;
destination?: InvokeTabName | undefined;
destination?: TabName | undefined;
}
const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, destination }: Props) => {
const language = useAppSelector(languageSelector);
const language = useAppSelector(selectLanguage);
const logger = useLogger('system');
const dispatch = useAppDispatch();
const clearStorage = useClearStorage();
@ -93,6 +97,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
useStarterModelsToast();
useSyncQueueStatus();
useScopeFocusWatcher();
return (
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
@ -105,7 +110,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
{...dropzone.getRootProps()}
>
<input {...dropzone.getInputProps()} />
<InvokeTabs />
<AppContent />
<AnimatePresence>
{dropzone.isDragActive && isHandlingUpload && (
<ImageUploadOverlay dropzone={dropzone} setIsHandlingUpload={setIsHandlingUpload} />
@ -116,7 +121,10 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage, selectedWorkflowId, desti
<ChangeBoardModal />
<DynamicPromptsModal />
<StylePresetModal />
<ClearQueueConfirmationsAlertDialog />
<PreselectedImage selectedImage={selectedImage} />
<SettingsModal />
<RefreshAfterResetModal />
</ErrorBoundary>
);
};

View File

@ -1,5 +1,7 @@
import { Button, Flex, Heading, Image, Link, Text } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import newGithubIssueUrl from 'new-github-issue-url';
import InvokeLogoYellow from 'public/assets/images/invoke-symbol-ylw-lrg.svg';
@ -13,9 +15,11 @@ type Props = {
resetErrorBoundary: () => void;
};
const selectIsLocal = createSelector(selectConfigSlice, (config) => config.isLocal);
const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
const { t } = useTranslation();
const isLocal = useAppSelector((s) => s.config.isLocal);
const isLocal = useAppSelector(selectIsLocal);
const handleCopy = useCallback(() => {
const text = JSON.stringify(serializeError(error), null, 2);

View File

@ -19,7 +19,7 @@ import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import AppDndContext from 'features/dnd/components/AppDndContext';
import type { WorkflowCategory } from 'features/nodes/types/workflow';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import type { TabName } from 'features/ui/store/uiTypes';
import type { PropsWithChildren, ReactNode } from 'react';
import React, { lazy, memo, useEffect, useMemo } from 'react';
import { Provider } from 'react-redux';
@ -45,7 +45,7 @@ interface Props extends PropsWithChildren {
action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters';
};
selectedWorkflowId?: string;
destination?: InvokeTabName;
destination?: TabName;
customStarUi?: CustomStarUi;
socketOptions?: Partial<ManagerOptions & SocketOptions>;
isDebugging?: boolean;

View File

@ -2,7 +2,7 @@ import { useStore } from '@nanostores/react';
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $isDebugging } from 'app/store/nanostores/isDebugging';
import { useAppDispatch } from 'app/store/storeHooks';
import { useAppStore } from 'app/store/nanostores/store';
import type { MapStore } from 'nanostores';
import { atom, map } from 'nanostores';
import { useEffect, useMemo } from 'react';
@ -18,14 +18,19 @@ declare global {
}
}
export type AppSocket = Socket<ServerToClientEvents, ClientToServerEvents>;
export const $socket = atom<AppSocket | null>(null);
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
const $isSocketInitialized = atom<boolean>(false);
export const $isConnected = atom<boolean>(false);
/**
* Initializes the socket.io connection and sets up event listeners.
*/
export const useSocketIO = () => {
const dispatch = useAppDispatch();
const { dispatch, getState } = useAppStore();
const baseUrl = useStore($baseUrl);
const authToken = useStore($authToken);
const addlSocketOptions = useStore($socketOptions);
@ -61,8 +66,9 @@ export const useSocketIO = () => {
return;
}
const socket: Socket<ServerToClientEvents, ClientToServerEvents> = io(socketUrl, socketOptions);
setEventListeners({ dispatch, socket });
const socket: AppSocket = io(socketUrl, socketOptions);
$socket.set(socket);
setEventListeners({ socket, dispatch, getState, setIsConnected: $isConnected.set });
socket.connect();
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
@ -84,5 +90,5 @@ export const useSocketIO = () => {
socket.disconnect();
$isSocketInitialized.set(false);
};
}, [dispatch, socketOptions, socketUrl]);
}, [dispatch, getState, socketOptions, socketUrl]);
};

View File

@ -15,21 +15,21 @@ export const BASE_CONTEXT = {};
export const $logger = atom<Logger>(Roarr.child(BASE_CONTEXT));
export type LoggerNamespace =
| 'images'
| 'models'
| 'config'
| 'canvas'
| 'generation'
| 'nodes'
| 'system'
| 'socketio'
| 'session'
| 'queue'
| 'dnd'
| 'controlLayers';
export const zLogNamespace = z.enum([
'canvas',
'config',
'events',
'gallery',
'generation',
'metadata',
'models',
'system',
'queue',
'workflows',
]);
export type LogNamespace = z.infer<typeof zLogNamespace>;
export const logger = (namespace: LoggerNamespace) => $logger.get().child({ namespace });
export const logger = (namespace: LogNamespace) => $logger.get().child({ namespace });
export const zLogLevel = z.enum(['trace', 'debug', 'info', 'warn', 'error', 'fatal']);
export type LogLevel = z.infer<typeof zLogLevel>;

View File

@ -1,29 +1,41 @@
import { createLogWriter } from '@roarr/browser-log-writer';
import { useAppSelector } from 'app/store/storeHooks';
import {
selectSystemLogIsEnabled,
selectSystemLogLevel,
selectSystemLogNamespaces,
} from 'features/system/store/systemSlice';
import { useEffect, useMemo } from 'react';
import { ROARR, Roarr } from 'roarr';
import type { LoggerNamespace } from './logger';
import type { LogNamespace } from './logger';
import { $logger, BASE_CONTEXT, LOG_LEVEL_MAP, logger } from './logger';
export const useLogger = (namespace: LoggerNamespace) => {
const consoleLogLevel = useAppSelector((s) => s.system.consoleLogLevel);
const shouldLogToConsole = useAppSelector((s) => s.system.shouldLogToConsole);
export const useLogger = (namespace: LogNamespace) => {
const logLevel = useAppSelector(selectSystemLogLevel);
const logNamespaces = useAppSelector(selectSystemLogNamespaces);
const logIsEnabled = useAppSelector(selectSystemLogIsEnabled);
// The provided Roarr browser log writer uses localStorage to config logging to console
useEffect(() => {
if (shouldLogToConsole) {
if (logIsEnabled) {
// Enable console log output
localStorage.setItem('ROARR_LOG', 'true');
// Use a filter to show only logs of the given level
localStorage.setItem('ROARR_FILTER', `context.logLevel:>=${LOG_LEVEL_MAP[consoleLogLevel]}`);
let filter = `context.logLevel:>=${LOG_LEVEL_MAP[logLevel]}`;
if (logNamespaces.length > 0) {
filter += ` AND (${logNamespaces.map((ns) => `context.namespace:${ns}`).join(' OR ')})`;
} else {
filter += ' AND context.namespace:undefined';
}
localStorage.setItem('ROARR_FILTER', filter);
} else {
// Disable console log output
localStorage.setItem('ROARR_LOG', 'false');
}
ROARR.write = createLogWriter();
}, [consoleLogLevel, shouldLogToConsole]);
}, [logLevel, logIsEnabled, logNamespaces]);
// Update the module-scoped logger context as needed
useEffect(() => {

View File

@ -1,7 +1,7 @@
import { createAction } from '@reduxjs/toolkit';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import type { TabName } from 'features/ui/store/uiTypes';
export const enqueueRequested = createAction<{
tabName: InvokeTabName;
tabName: TabName;
prepend: boolean;
}>('app/enqueueRequested');

View File

@ -1,2 +1,3 @@
export const STORAGE_PREFIX = '@@invokeai-';
export const EMPTY_ARRAY = [];
export const EMPTY_OBJECT = {};

View File

@ -1,5 +1,6 @@
import { createDraftSafeSelectorCreator, createSelectorCreator, lruMemoize } from '@reduxjs/toolkit';
import type { GetSelectorsOptions } from '@reduxjs/toolkit/dist/entities/state_selectors';
import type { RootState } from 'app/store/store';
import { isEqual } from 'lodash-es';
/**
@ -19,3 +20,5 @@ export const getSelectorsOptions: GetSelectorsOptions = {
argsMemoize: lruMemoize,
}),
};
export const createMemoizedAppSelector = createMemoizedSelector.withTypes<RootState>();

View File

@ -1,5 +1,4 @@
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { PersistError, RehydrateError } from 'redux-remember';
import { serializeError } from 'serialize-error';
@ -41,6 +40,6 @@ export const errorHandler = (err: PersistError | RehydrateError) => {
} else if (err instanceof RehydrateError) {
log.error({ error: serializeError(err) }, 'Problem rehydrating state');
} else {
log.error({ error: parseify(err) }, 'Problem in persistence layer');
log.error({ error: serializeError(err) }, 'Problem in persistence layer');
}
};

View File

@ -1,9 +1,7 @@
import type { UnknownAction } from '@reduxjs/toolkit';
import { deepClone } from 'common/util/deepClone';
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { Graph } from 'services/api/types';
import { socketGeneratorProgress } from 'services/events/actions';
export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
if (isAnyGraphBuilt(action)) {
@ -24,13 +22,5 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
};
}
if (socketGeneratorProgress.match(action)) {
const sanitized = deepClone(action);
if (sanitized.payload.data.progress_image) {
sanitized.payload.data.progress_image.dataURL = '<Progress image omitted>';
}
return sanitized;
}
return action;
};

View File

@ -1,7 +1,7 @@
import type { TypedStartListening } from '@reduxjs/toolkit';
import { createListenerMiddleware } from '@reduxjs/toolkit';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { addCommitStagingAreaImageListener } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
import { addStagingListeners } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
@ -9,17 +9,6 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addCanvasCopiedToClipboardListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard';
import { addCanvasDownloadedAsImageListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage';
import { addCanvasImageToControlNetListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet';
import { addCanvasMaskSavedToGalleryListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMaskSavedToGallery';
import { addCanvasMaskToControlNetListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet';
import { addCanvasMergedListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasMerged';
import { addCanvasSavedToGalleryListener } from 'app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery';
import { addControlAdapterPreprocessor } from 'app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor';
import { addControlNetAutoProcessListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess';
import { addControlNetImageProcessedListener } from 'app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed';
import { addEnqueueRequestedCanvasListener } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas';
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
@ -37,16 +26,7 @@ import { addModelSelectedListener } from 'app/store/middleware/listenerMiddlewar
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store';
@ -83,7 +63,6 @@ addGalleryImageClickedListener(startAppListening);
addGalleryOffsetChangedListener(startAppListening);
// User Invoked
addEnqueueRequestedCanvasListener(startAppListening);
addEnqueueRequestedNodes(startAppListening);
addEnqueueRequestedLinear(startAppListening);
addEnqueueRequestedUpscale(startAppListening);
@ -91,31 +70,22 @@ addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
// Canvas actions
addCanvasSavedToGalleryListener(startAppListening);
addCanvasMaskSavedToGalleryListener(startAppListening);
addCanvasImageToControlNetListener(startAppListening);
addCanvasMaskToControlNetListener(startAppListening);
addCanvasDownloadedAsImageListener(startAppListening);
addCanvasCopiedToClipboardListener(startAppListening);
addCanvasMergedListener(startAppListening);
addStagingAreaImageSavedListener(startAppListening);
addCommitStagingAreaImageListener(startAppListening);
// addCanvasSavedToGalleryListener(startAppListening);
// addCanvasMaskSavedToGalleryListener(startAppListening);
// addCanvasImageToControlNetListener(startAppListening);
// addCanvasMaskToControlNetListener(startAppListening);
// addCanvasDownloadedAsImageListener(startAppListening);
// addCanvasCopiedToClipboardListener(startAppListening);
// addCanvasMergedListener(startAppListening);
// addStagingAreaImageSavedListener(startAppListening);
// addCommitStagingAreaImageListener(startAppListening);
addStagingListeners(startAppListening);
// Socket.IO
addGeneratorProgressEventListener(startAppListening);
addInvocationCompleteEventListener(startAppListening);
addInvocationErrorEventListener(startAppListening);
addInvocationStartedEventListener(startAppListening);
addSocketConnectedEventListener(startAppListening);
addSocketDisconnectedEventListener(startAppListening);
addModelLoadEventListener(startAppListening);
addModelInstallEventListener(startAppListening);
addSocketQueueItemStatusChangedEventListener(startAppListening);
addBulkDownloadListeners(startAppListening);
// ControlNet
addControlNetImageProcessedListener(startAppListening);
addControlNetAutoProcessListener(startAppListening);
// Gallery bulk download
addBulkDownloadListeners(startAppListening);
// Boards
addImageAddedToBoardFulfilledListener(startAppListening);
@ -148,4 +118,4 @@ addAdHocPostProcessingRequestedListener(startAppListening);
addDynamicPromptsListener(startAppListening);
addSetDefaultSettingsListener(startAppListening);
addControlAdapterPreprocessor(startAppListening);
// addControlAdapterPreprocessor(startAppListening);

View File

@ -1,21 +1,21 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import type { SerializableObject } from 'common/types';
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO } from 'services/api/types';
const log = logger('queue');
export const adHocPostProcessingRequested = createAction<{ imageDTO: ImageDTO }>(`upscaling/postProcessingRequested`);
export const addAdHocPostProcessingRequestedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: adHocPostProcessingRequested,
effect: async (action, { dispatch, getState }) => {
const log = logger('session');
const { imageDTO } = action.payload;
const state = getState();
@ -39,9 +39,9 @@ export const addAdHocPostProcessingRequestedListener = (startAppListening: AppSt
const enqueueResult = await req.unwrap();
req.reset();
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
log.debug({ enqueueResult } as SerializableObject, t('queue.graphQueued'));
} catch (error) {
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
log.error({ enqueueBatchArg } as SerializableObject, t('queue.graphFailedToQueue'));
if (error instanceof Object && 'status' in error && error.status === 403) {
return;

View File

@ -23,7 +23,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
*/
startAppListening({
matcher: matchAnyBoardDeleted,
effect: async (action, { dispatch, getState }) => {
effect: (action, { dispatch, getState }) => {
const state = getState();
const deletedBoardId = action.meta.arg.originalArgs;
const { autoAddBoardId, selectedBoardId } = state.gallery;
@ -44,7 +44,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
// If we archived a board, it may end up hidden. If it's selected or the auto-add board, we should reset those.
startAppListening({
matcher: boardsApi.endpoints.updateBoard.matchFulfilled,
effect: async (action, { dispatch, getState }) => {
effect: (action, { dispatch, getState }) => {
const state = getState();
const { shouldShowArchivedBoards } = state.gallery;
@ -61,7 +61,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
// When we hide archived boards, if the selected or the auto-add board is archived, we should reset those.
startAppListening({
actionCreator: shouldShowArchivedBoardsChanged,
effect: async (action, { dispatch, getState }) => {
effect: (action, { dispatch, getState }) => {
const shouldShowArchivedBoards = action.payload;
// We only need to take action if we have just hidden archived boards.
@ -100,7 +100,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
*/
startAppListening({
matcher: boardsApi.endpoints.listAllBoards.matchFulfilled,
effect: async (action, { dispatch, getState }) => {
effect: (action, { dispatch, getState }) => {
const boards = action.payload;
const state = getState();
const { selectedBoardId, autoAddBoardId } = state.gallery;

View File

@ -1,33 +1,37 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import {
canvasBatchIdsReset,
commitStagingAreaImage,
discardStagedImages,
resetCanvas,
setInitialCanvasImage,
} from 'features/canvas/store/canvasSlice';
sessionStagingAreaImageAccepted,
sessionStagingAreaReset,
} from 'features/controlLayers/store/canvasSessionSlice';
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import { $lastCanvasProgressEvent } from 'services/events/setEventListeners';
import { assert } from 'tsafe';
const matcher = isAnyOf(commitStagingAreaImage, discardStagedImages, resetCanvas, setInitialCanvasImage);
const log = logger('canvas');
export const addCommitStagingAreaImageListener = (startAppListening: AppStartListening) => {
export const addStagingListeners = (startAppListening: AppStartListening) => {
startAppListening({
matcher,
effect: async (_, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const { batchIds } = state.canvas;
actionCreator: sessionStagingAreaReset,
effect: async (_, { dispatch }) => {
try {
const req = dispatch(
queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: batchIds }, { fixedCacheKey: 'cancelByBatchIds' })
queueApi.endpoints.cancelByBatchOrigin.initiate(
{ origin: 'canvas' },
{ fixedCacheKey: 'cancelByBatchOrigin' }
)
);
const { canceled } = await req.unwrap();
req.reset();
$lastCanvasProgressEvent.set(null);
if (canceled > 0) {
log.debug(`Canceled ${canceled} canvas batches`);
toast({
@ -36,7 +40,6 @@ export const addCommitStagingAreaImageListener = (startAppListening: AppStartLis
status: 'success',
});
}
dispatch(canvasBatchIdsReset());
} catch {
log.error('Failed to cancel canvas batches');
toast({
@ -47,4 +50,26 @@ export const addCommitStagingAreaImageListener = (startAppListening: AppStartLis
}
},
});
startAppListening({
actionCreator: sessionStagingAreaImageAccepted,
effect: (action, api) => {
const { index } = action.payload;
const state = api.getState();
const stagingAreaImage = state.canvasSession.stagedImages[index];
assert(stagingAreaImage, 'No staged image found to accept');
const { x, y } = selectCanvasSlice(state).bbox.rect;
const { imageDTO, offsetX, offsetY } = stagingAreaImage;
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasRasterLayerState> = {
position: { x: x + offsetX, y: y + offsetY },
objects: [imageObject],
};
api.dispatch(rasterLayerAdded({ overrides, isSelected: true }));
api.dispatch(sessionStagingAreaReset());
},
});
};

View File

@ -4,7 +4,7 @@ import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
export const addAnyEnqueuedListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
effect: async (_, { dispatch, getState }) => {
effect: (_, { dispatch, getState }) => {
const { data } = selectQueueStatus(getState());
if (!data || data.processor.is_started) {

View File

@ -1,14 +1,14 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { setInfillMethod } from 'features/parameters/store/generationSlice';
import { setInfillMethod } from 'features/controlLayers/store/paramsSlice';
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo';
export const addAppConfigReceivedListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
effect: (action, { getState, dispatch }) => {
const { infill_methods = [], nsfw_methods = [], watermarking_methods = [] } = action.payload;
const infillMethod = getState().generation.infillMethod;
const infillMethod = getState().params.infillMethod;
if (!infill_methods.includes(infillMethod)) {
// if there is no infill method, set it to the first one

View File

@ -6,7 +6,7 @@ export const appStarted = createAction('app/appStarted');
export const addAppStartedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: appStarted,
effect: async (action, { unsubscribe, cancelActiveListeners }) => {
effect: (action, { unsubscribe, cancelActiveListeners }) => {
// this should only run once
cancelActiveListeners();
unsubscribe();

View File

@ -1,27 +1,30 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import type { SerializableObject } from 'common/types';
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { truncate, upperFirst } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
const log = logger('queue');
export const addBatchEnqueuedListener = (startAppListening: AppStartListening) => {
// success
startAppListening({
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
effect: async (action) => {
const response = action.payload;
effect: (action) => {
const enqueueResult = action.payload;
const arg = action.meta.arg.originalArgs;
logger('queue').debug({ enqueueResult: parseify(response) }, 'Batch enqueued');
log.debug({ enqueueResult } as SerializableObject, 'Batch enqueued');
toast({
id: 'QUEUE_BATCH_SUCCEEDED',
title: t('queue.batchQueued'),
status: 'success',
description: t('queue.batchQueuedDesc', {
count: response.enqueued,
count: enqueueResult.enqueued,
direction: arg.prepend ? t('queue.front') : t('queue.back'),
}),
});
@ -31,9 +34,9 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
// error
startAppListening({
matcher: queueApi.endpoints.enqueueBatch.matchRejected,
effect: async (action) => {
effect: (action) => {
const response = action.payload;
const arg = action.meta.arg.originalArgs;
const batchConfig = action.meta.arg.originalArgs;
if (!response) {
toast({
@ -42,7 +45,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
status: 'error',
description: t('common.unknownError'),
});
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));
log.error({ batchConfig } as SerializableObject, t('queue.batchFailedToQueue'));
return;
}
@ -68,7 +71,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
description: t('common.unknownError'),
});
}
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));
log.error({ batchConfig, error: serializeError(response) } as SerializableObject, t('queue.batchFailedToQueue'));
},
});
};

View File

@ -1,47 +1,31 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice';
import { allLayersDeleted } from 'features/controlLayers/store/controlLayersSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { imagesApi } from 'services/api/endpoints/images';
export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.deleteBoardAndImages.matchFulfilled,
effect: async (action, { dispatch, getState }) => {
effect: (action, { dispatch, getState }) => {
const { deleted_images } = action.payload;
// Remove all deleted images from the UI
let wasCanvasReset = false;
let wasNodeEditorReset = false;
let wereControlAdaptersReset = false;
let wereControlLayersReset = false;
const { canvas, nodes, controlAdapters, controlLayers } = getState();
const state = getState();
const nodes = selectNodesSlice(state);
const canvas = selectCanvasSlice(state);
deleted_images.forEach((image_name) => {
const imageUsage = getImageUsage(canvas, nodes.present, controlAdapters, controlLayers.present, image_name);
if (imageUsage.isCanvasImage && !wasCanvasReset) {
dispatch(resetCanvas());
wasCanvasReset = true;
}
const imageUsage = getImageUsage(nodes, canvas, image_name);
if (imageUsage.isNodesImage && !wasNodeEditorReset) {
dispatch(nodeEditorReset());
wasNodeEditorReset = true;
}
if (imageUsage.isControlImage && !wereControlAdaptersReset) {
dispatch(controlAdaptersReset());
wereControlAdaptersReset = true;
}
if (imageUsage.isControlLayerImage && !wereControlLayersReset) {
dispatch(allLayersDeleted());
wereControlLayersReset = true;
}
});
},
});

View File

@ -1,21 +1,15 @@
import { ExternalLink } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import {
socketBulkDownloadComplete,
socketBulkDownloadError,
socketBulkDownloadStarted,
} from 'services/events/actions';
const log = logger('images');
const log = logger('gallery');
export const addBulkDownloadListeners = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.bulkDownloadImages.matchFulfilled,
effect: async (action) => {
effect: (action) => {
log.debug(action.payload, 'Bulk download requested');
// If we have an item name, we are processing the bulk download locally and should use it as the toast id to
@ -33,7 +27,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
startAppListening({
matcher: imagesApi.endpoints.bulkDownloadImages.matchRejected,
effect: async () => {
effect: () => {
log.debug('Bulk download request failed');
// There isn't any toast to update if we get this event.
@ -44,55 +38,4 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
});
},
});
startAppListening({
actionCreator: socketBulkDownloadStarted,
effect: async (action) => {
// This should always happen immediately after the bulk download request, so we don't need to show a toast here.
log.debug(action.payload.data, 'Bulk download preparation started');
},
});
startAppListening({
actionCreator: socketBulkDownloadComplete,
effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation completed');
const { bulk_download_item_name } = action.payload.data;
// TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first
const url = `/api/v1/images/download/${bulk_download_item_name}`;
toast({
id: bulk_download_item_name,
title: t('gallery.bulkDownloadReady', 'Download ready'),
status: 'success',
description: (
<ExternalLink
label={t('gallery.clickToDownload', 'Click here to download')}
href={url}
download={bulk_download_item_name}
/>
),
duration: null,
});
},
});
startAppListening({
actionCreator: socketBulkDownloadError,
effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation failed');
const { bulk_download_item_name } = action.payload.data;
toast({
id: bulk_download_item_name,
title: t('gallery.bulkDownloadFailed'),
status: 'error',
description: action.payload.data.error,
duration: null,
});
},
});
};

View File

@ -1,38 +0,0 @@
import { $logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasCopiedToClipboard } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasCopiedToClipboard,
effect: async (action, { getState }) => {
const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' });
const state = getState();
try {
const blob = getBaseLayerBlob(state);
copyBlobToClipboard(blob);
} catch (err) {
moduleLog.error(String(err));
toast({
id: 'CANVAS_COPY_FAILED',
title: t('toast.problemCopyingCanvas'),
description: t('toast.problemCopyingCanvasDesc'),
status: 'error',
});
return;
}
toast({
id: 'CANVAS_COPY_SUCCEEDED',
title: t('toast.canvasCopiedClipboard'),
status: 'success',
});
},
});
};

View File

@ -1,34 +0,0 @@
import { $logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasDownloadedAsImage } from 'features/canvas/store/actions';
import { downloadBlob } from 'features/canvas/util/downloadBlob';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasDownloadedAsImage,
effect: async (action, { getState }) => {
const moduleLog = $logger.get().child({ namespace: 'canvasSavedToGalleryListener' });
const state = getState();
let blob;
try {
blob = await getBaseLayerBlob(state);
} catch (err) {
moduleLog.error(String(err));
toast({
id: 'CANVAS_DOWNLOAD_FAILED',
title: t('toast.problemDownloadingCanvas'),
description: t('toast.problemDownloadingCanvasDesc'),
status: 'error',
});
return;
}
downloadBlob(blob, 'canvas.png');
toast({ id: 'CANVAS_DOWNLOAD_SUCCEEDED', title: t('toast.canvasDownloaded'), status: 'success' });
},
});
};

View File

@ -1,60 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasImageToControlAdapter } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
export const addCanvasImageToControlNetListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasImageToControlAdapter,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const { id } = action.payload;
let blob: Blob;
try {
blob = await getBaseLayerBlob(state, true);
} catch (err) {
log.error(String(err));
toast({
id: 'PROBLEM_SAVING_CANVAS',
title: t('toast.problemSavingCanvas'),
description: t('toast.problemSavingCanvasDesc'),
status: 'error',
});
return;
}
const { autoAddBoardId } = state.gallery;
const imageDTO = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([blob], 'savedCanvas.png', {
type: 'image/png',
}),
image_category: 'control',
is_intermediate: true,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: false,
postUploadAction: {
type: 'TOAST',
title: t('toast.canvasSentControlnetAssets'),
},
})
).unwrap();
const { image_name } = imageDTO;
dispatch(
controlAdapterImageChanged({
id,
controlImage: image_name,
})
);
},
});
};

View File

@ -1,60 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasMaskSavedToGallery } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
export const addCanvasMaskSavedToGalleryListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasMaskSavedToGallery,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const canvasBlobsAndImageData = await getCanvasData(
state.canvas.layerState,
state.canvas.boundingBoxCoordinates,
state.canvas.boundingBoxDimensions,
state.canvas.isMaskEnabled,
state.canvas.shouldPreserveMaskedArea
);
if (!canvasBlobsAndImageData) {
return;
}
const { maskBlob } = canvasBlobsAndImageData;
if (!maskBlob) {
log.error('Problem getting mask layer blob');
toast({
id: 'PROBLEM_SAVING_MASK',
title: t('toast.problemSavingMask'),
description: t('toast.problemSavingMaskDesc'),
status: 'error',
});
return;
}
const { autoAddBoardId } = state.gallery;
dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([maskBlob], 'canvasMaskImage.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: true,
postUploadAction: {
type: 'TOAST',
title: t('toast.maskSavedAssets'),
},
})
);
},
});
};

View File

@ -1,70 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasMaskToControlAdapter } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
export const addCanvasMaskToControlNetListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasMaskToControlAdapter,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const { id } = action.payload;
const canvasBlobsAndImageData = await getCanvasData(
state.canvas.layerState,
state.canvas.boundingBoxCoordinates,
state.canvas.boundingBoxDimensions,
state.canvas.isMaskEnabled,
state.canvas.shouldPreserveMaskedArea
);
if (!canvasBlobsAndImageData) {
return;
}
const { maskBlob } = canvasBlobsAndImageData;
if (!maskBlob) {
log.error('Problem getting mask layer blob');
toast({
id: 'PROBLEM_IMPORTING_MASK',
title: t('toast.problemImportingMask'),
description: t('toast.problemImportingMaskDesc'),
status: 'error',
});
return;
}
const { autoAddBoardId } = state.gallery;
const imageDTO = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([maskBlob], 'canvasMaskImage.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: true,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: false,
postUploadAction: {
type: 'TOAST',
title: t('toast.maskSentControlnetAssets'),
},
})
).unwrap();
const { image_name } = imageDTO;
dispatch(
controlAdapterImageChanged({
id,
controlImage: image_name,
})
);
},
});
};

View File

@ -1,73 +0,0 @@
import { $logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasMerged } from 'features/canvas/store/actions';
import { $canvasBaseLayer } from 'features/canvas/store/canvasNanostore';
import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
import { getFullBaseLayerBlob } from 'features/canvas/util/getFullBaseLayerBlob';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
export const addCanvasMergedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasMerged,
effect: async (action, { dispatch }) => {
const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' });
const blob = await getFullBaseLayerBlob();
if (!blob) {
moduleLog.error('Problem getting base layer blob');
toast({
id: 'PROBLEM_MERGING_CANVAS',
title: t('toast.problemMergingCanvas'),
description: t('toast.problemMergingCanvasDesc'),
status: 'error',
});
return;
}
const canvasBaseLayer = $canvasBaseLayer.get();
if (!canvasBaseLayer) {
moduleLog.error('Problem getting canvas base layer');
toast({
id: 'PROBLEM_MERGING_CANVAS',
title: t('toast.problemMergingCanvas'),
description: t('toast.problemMergingCanvasDesc'),
status: 'error',
});
return;
}
const baseLayerRect = canvasBaseLayer.getClientRect({
relativeTo: canvasBaseLayer.getParent() ?? undefined,
});
const imageDTO = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([blob], 'mergedCanvas.png', {
type: 'image/png',
}),
image_category: 'general',
is_intermediate: true,
postUploadAction: {
type: 'TOAST',
title: t('toast.canvasMerged'),
},
})
).unwrap();
// TODO: I can't figure out how to do the type narrowing in the `take()` so just brute forcing it here
const { image_name } = imageDTO;
dispatch(
setMergedCanvas({
kind: 'image',
layer: 'base',
imageName: image_name,
...baseLayerRect,
})
);
},
});
};

View File

@ -1,53 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { canvasSavedToGallery } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: canvasSavedToGallery,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
let blob;
try {
blob = await getBaseLayerBlob(state);
} catch (err) {
log.error(String(err));
toast({
id: 'CANVAS_SAVE_FAILED',
title: t('toast.problemSavingCanvas'),
description: t('toast.problemSavingCanvasDesc'),
status: 'error',
});
return;
}
const { autoAddBoardId } = state.gallery;
dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([blob], 'savedCanvas.png', {
type: 'image/png',
}),
image_category: 'general',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: true,
postUploadAction: {
type: 'TOAST',
title: t('toast.canvasSavedGallery'),
},
metadata: {
_canvas_objects: parseify(state.canvas.layerState.objects),
},
})
);
},
});
};

View File

@ -1,194 +0,0 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch } from 'app/store/store';
import { parseify } from 'common/util/serialize';
import {
caLayerImageChanged,
caLayerModelChanged,
caLayerProcessedImageChanged,
caLayerProcessorConfigChanged,
caLayerProcessorPendingBatchIdChanged,
caLayerRecalled,
isControlAdapterLayer,
} from 'features/controlLayers/store/controlLayersSlice';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { isEqual } from 'lodash-es';
import { getImageDTO } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions';
import { assert } from 'tsafe';
const matcher = isAnyOf(
caLayerImageChanged,
caLayerProcessedImageChanged,
caLayerProcessorConfigChanged,
caLayerModelChanged,
caLayerRecalled
);
const DEBOUNCE_MS = 300;
const log = logger('session');
/**
* Simple helper to cancel a batch and reset the pending batch ID
*/
const cancelProcessorBatch = async (dispatch: AppDispatch, layerId: string, batchId: string) => {
const req = dispatch(queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: [batchId] }));
log.trace({ batchId }, 'Cancelling existing preprocessor batch');
try {
await req.unwrap();
} catch {
// no-op
} finally {
req.reset();
// Always reset the pending batch ID - the cancel req could fail if the batch doesn't exist
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: null }));
}
};
export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => {
startAppListening({
matcher,
effect: async (action, { dispatch, getState, getOriginalState, cancelActiveListeners, delay, take, signal }) => {
const layerId = caLayerRecalled.match(action) ? action.payload.id : action.payload.layerId;
const state = getState();
const originalState = getOriginalState();
// Cancel any in-progress instances of this listener
cancelActiveListeners();
log.trace('Control Layer CA auto-process triggered');
// Delay before starting actual work
await delay(DEBOUNCE_MS);
const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId);
if (!layer) {
return;
}
// We should only process if the processor settings or image have changed
const originalLayer = originalState.controlLayers.present.layers
.filter(isControlAdapterLayer)
.find((l) => l.id === layerId);
const originalImage = originalLayer?.controlAdapter.image;
const originalConfig = originalLayer?.controlAdapter.processorConfig;
const image = layer.controlAdapter.image;
const processedImage = layer.controlAdapter.processedImage;
const config = layer.controlAdapter.processorConfig;
if (isEqual(config, originalConfig) && isEqual(image, originalImage) && processedImage) {
// Neither config nor image have changed, we can bail
return;
}
if (!image || !config) {
// - If we have no image, we have nothing to process
// - If we have no processor config, we have nothing to process
// Clear the processed image and bail
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null }));
return;
}
// At this point, the user has stopped fiddling with the processor settings and there is a processor selected.
// If there is a pending processor batch, cancel it.
if (layer.controlAdapter.processorPendingBatchId) {
cancelProcessorBatch(dispatch, layerId, layer.controlAdapter.processorPendingBatchId);
}
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config as never);
const enqueueBatchArg: BatchConfig = {
prepend: true,
batch: {
graph: {
nodes: {
[processorNode.id]: {
...processorNode,
// Control images are always intermediate - do not save to gallery
is_intermediate: true,
},
},
edges: [],
},
runs: 1,
},
};
// Kick off the processor batch
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
fixedCacheKey: 'enqueueBatch',
})
);
try {
const enqueueResult = await req.unwrap();
// TODO(psyche): Update the pydantic models, pretty sure we will _always_ have a batch_id here, but the model says it's optional
assert(enqueueResult.batch.batch_id, 'Batch ID not returned from queue');
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: enqueueResult.batch.batch_id }));
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
// Wait for the processor node to complete
const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) &&
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
action.payload.data.invocation_source_id === processorNode.id
);
// We still have to check the output type
assert(
invocationCompleteAction.payload.data.result.type === 'image_output',
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
);
const { image_name } = invocationCompleteAction.payload.data.result.image;
const imageDTO = await getImageDTO(image_name);
assert(imageDTO, "Failed to fetch processor output's image DTO");
// Whew! We made it. Update the layer with the processed image
log.debug({ layerId, imageDTO }, 'ControlNet image processed');
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO }));
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: null }));
} catch (error) {
if (signal.aborted) {
// The listener was canceled - we need to cancel the pending processor batch, if there is one (could have changed by now).
const pendingBatchId = getState()
.controlLayers.present.layers.filter(isControlAdapterLayer)
.find((l) => l.id === layerId)?.controlAdapter.processorPendingBatchId;
if (pendingBatchId) {
cancelProcessorBatch(dispatch, layerId, pendingBatchId);
}
log.trace('Control Adapter preprocessor cancelled');
} else {
// Some other error condition...
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
if (error instanceof Object) {
if ('data' in error && 'status' in error) {
if (error.status === 403) {
dispatch(caLayerImageChanged({ layerId, imageDTO: null }));
return;
}
}
}
toast({
id: 'GRAPH_QUEUE_FAILED',
title: t('queue.graphFailedToQueue'),
status: 'error',
});
}
} finally {
req.reset();
}
},
});
};

View File

@ -1,85 +0,0 @@
import type { AnyListenerPredicate } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { RootState } from 'app/store/store';
import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions';
import {
controlAdapterAutoConfigToggled,
controlAdapterImageChanged,
controlAdapterModelChanged,
controlAdapterProcessorParamsChanged,
controlAdapterProcessortTypeChanged,
selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
type AnyControlAdapterParamChangeAction =
| ReturnType<typeof controlAdapterProcessorParamsChanged>
| ReturnType<typeof controlAdapterModelChanged>
| ReturnType<typeof controlAdapterImageChanged>
| ReturnType<typeof controlAdapterProcessortTypeChanged>
| ReturnType<typeof controlAdapterAutoConfigToggled>;
const predicate: AnyListenerPredicate<RootState> = (action, state, prevState) => {
const isActionMatched =
controlAdapterProcessorParamsChanged.match(action) ||
controlAdapterModelChanged.match(action) ||
controlAdapterImageChanged.match(action) ||
controlAdapterProcessortTypeChanged.match(action) ||
controlAdapterAutoConfigToggled.match(action);
if (!isActionMatched) {
return false;
}
const { id } = action.payload;
const prevCA = selectControlAdapterById(prevState.controlAdapters, id);
const ca = selectControlAdapterById(state.controlAdapters, id);
if (!prevCA || !isControlNetOrT2IAdapter(prevCA) || !ca || !isControlNetOrT2IAdapter(ca)) {
return false;
}
if (controlAdapterAutoConfigToggled.match(action)) {
// do not process if the user just disabled auto-config
if (prevCA.shouldAutoConfig === true) {
return false;
}
}
const { controlImage, processorType, shouldAutoConfig } = ca;
if (controlAdapterModelChanged.match(action) && !shouldAutoConfig) {
// do not process if the action is a model change but the processor settings are dirty
return false;
}
const isProcessorSelected = processorType !== 'none';
const hasControlImage = Boolean(controlImage);
return isProcessorSelected && hasControlImage;
};
const DEBOUNCE_MS = 300;
/**
* Listener that automatically processes a ControlNet image when its processor parameters are changed.
*
* The network request is debounced.
*/
export const addControlNetAutoProcessListener = (startAppListening: AppStartListening) => {
startAppListening({
predicate,
effect: async (action, { dispatch, cancelActiveListeners, delay }) => {
const log = logger('session');
const { id } = (action as AnyControlAdapterParamChangeAction).payload;
// Cancel any in-progress instances of this listener
cancelActiveListeners();
log.trace('ControlNet auto-process triggered');
// Delay before starting actual work
await delay(DEBOUNCE_MS);
dispatch(controlAdapterImageProcessed({ id }));
},
});
};

View File

@ -1,118 +0,0 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions';
import {
controlAdapterImageChanged,
controlAdapterProcessedImageChanged,
pendingControlImagesCleared,
selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions';
export const addControlNetImageProcessedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: controlAdapterImageProcessed,
effect: async (action, { dispatch, getState, take }) => {
const log = logger('session');
const { id } = action.payload;
const ca = selectControlAdapterById(getState().controlAdapters, id);
if (!ca?.controlImage || !isControlNetOrT2IAdapter(ca)) {
log.error('Unable to process ControlNet image');
return;
}
if (ca.processorType === 'none' || ca.processorNode.type === 'none') {
return;
}
// ControlNet one-off procressing graph is just the processor node, no edges.
// Also we need to grab the image.
const nodeId = ca.processorNode.id;
const enqueueBatchArg: BatchConfig = {
prepend: true,
batch: {
graph: {
nodes: {
[ca.processorNode.id]: {
...ca.processorNode,
is_intermediate: true,
use_cache: false,
image: { image_name: ca.controlImage },
},
},
edges: [],
},
runs: 1,
},
};
try {
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
fixedCacheKey: 'enqueueBatch',
})
);
const enqueueResult = await req.unwrap();
req.reset();
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) &&
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
action.payload.data.invocation_source_id === nodeId
);
// We still have to check the output type
if (invocationCompleteAction.payload.data.result.type === 'image_output') {
const { image_name } = invocationCompleteAction.payload.data.result.image;
// Wait for the ImageDTO to be received
const [{ payload }] = await take(
(action) =>
imagesApi.endpoints.getImageDTO.matchFulfilled(action) && action.payload.image_name === image_name
);
const processedControlImage = payload as ImageDTO;
log.debug({ controlNetId: action.payload, processedControlImage }, 'ControlNet image processed');
// Update the processed image in the store
dispatch(
controlAdapterProcessedImageChanged({
id,
processedControlImage: processedControlImage.image_name,
})
);
}
} catch (error) {
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
if (error instanceof Object) {
if ('data' in error && 'status' in error) {
if (error.status === 403) {
dispatch(pendingControlImagesCleared());
dispatch(controlAdapterImageChanged({ id, controlImage: null }));
return;
}
}
}
toast({
id: 'GRAPH_QUEUE_FAILED',
title: t('queue.graphFailedToQueue'),
status: 'error',
});
}
},
});
};

View File

@ -1,144 +0,0 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import { parseify } from 'common/util/serialize';
import { canvasBatchIdAdded, stagingAreaInitialized } from 'features/canvas/store/canvasSlice';
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildCanvasGraph } from 'features/nodes/util/graph/canvas/buildCanvasGraph';
import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { ImageDTO } from 'services/api/types';
/**
* This listener is responsible invoking the canvas. This involves a number of steps:
*
* 1. Generate image blobs from the canvas layers
* 2. Determine the generation mode from the layers (txt2img, img2img, inpaint)
* 3. Build the canvas graph
* 4. Create the session with the graph
* 5. Upload the init image if necessary
* 6. Upload the mask image if necessary
* 7. Update the init and mask images with the session ID
* 8. Initialize the staging area if not yet initialized
* 9. Dispatch the sessionReadyToInvoke action to invoke the session
*/
export const addEnqueueRequestedCanvasListener = (startAppListening: AppStartListening) => {
startAppListening({
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
enqueueRequested.match(action) && action.payload.tabName === 'canvas',
effect: async (action, { getState, dispatch }) => {
const log = logger('queue');
const { prepend } = action.payload;
const state = getState();
const { layerState, boundingBoxCoordinates, boundingBoxDimensions, isMaskEnabled, shouldPreserveMaskedArea } =
state.canvas;
// Build canvas blobs
const canvasBlobsAndImageData = await getCanvasData(
layerState,
boundingBoxCoordinates,
boundingBoxDimensions,
isMaskEnabled,
shouldPreserveMaskedArea
);
if (!canvasBlobsAndImageData) {
log.error('Unable to create canvas data');
return;
}
const { baseBlob, baseImageData, maskBlob, maskImageData } = canvasBlobsAndImageData;
// Determine the generation mode
const generationMode = getCanvasGenerationMode(baseImageData, maskImageData);
if (state.system.enableImageDebugging) {
const baseDataURL = await blobToDataURL(baseBlob);
const maskDataURL = await blobToDataURL(maskBlob);
openBase64ImageInTab([
{ base64: maskDataURL, caption: 'mask b64' },
{ base64: baseDataURL, caption: 'image b64' },
]);
}
log.debug(`Generation mode: ${generationMode}`);
// Temp placeholders for the init and mask images
let canvasInitImage: ImageDTO | undefined;
let canvasMaskImage: ImageDTO | undefined;
// For img2img and inpaint/outpaint, we need to upload the init images
if (['img2img', 'inpaint', 'outpaint'].includes(generationMode)) {
// upload the image, saving the request id
canvasInitImage = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([baseBlob], 'canvasInitImage.png', {
type: 'image/png',
}),
image_category: 'general',
is_intermediate: true,
})
).unwrap();
}
// For inpaint/outpaint, we also need to upload the mask layer
if (['inpaint', 'outpaint'].includes(generationMode)) {
// upload the image, saving the request id
canvasMaskImage = await dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([maskBlob], 'canvasMaskImage.png', {
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: true,
})
).unwrap();
}
const graph = await buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
log.debug({ graph: parseify(graph) }, `Canvas graph built`);
// currently this action is just listened to for logging
dispatch(canvasGraphBuilt(graph));
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
try {
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
const enqueueResult = await req.unwrap();
req.reset();
const batchId = enqueueResult.batch.batch_id as string; // we know the is a string, backend provides it
// Prep the canvas staging area if it is not yet initialized
if (!state.canvas.layerState.stagingArea.boundingBox) {
dispatch(
stagingAreaInitialized({
boundingBox: {
...state.canvas.boundingBoxCoordinates,
...state.canvas.boundingBoxDimensions,
},
})
);
}
// Associate the session with the canvas session ID
dispatch(canvasBatchIdAdded(batchId));
} catch {
// no-op
}
},
});
};

View File

@ -1,10 +1,21 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import type { SerializableObject } from 'common/types';
import type { Result } from 'common/util/result';
import { isErr, withResult, withResultAsync } from 'common/util/result';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import { sessionStagingAreaReset, sessionStartedStaging } from 'features/controlLayers/store/canvasSessionSlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildGenerationTabGraph } from 'features/nodes/util/graph/generation/buildGenerationTabGraph';
import { buildGenerationTabSDXLGraph } from 'features/nodes/util/graph/generation/buildGenerationTabSDXLGraph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
import type { Invocation } from 'services/api/types';
import { assert } from 'tsafe';
const log = logger('generation');
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
startAppListening({
@ -12,33 +23,77 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
enqueueRequested.match(action) && action.payload.tabName === 'generation',
effect: async (action, { getState, dispatch }) => {
const state = getState();
const { shouldShowProgressInViewer } = state.ui;
const model = state.generation.model;
const model = state.params.model;
const { prepend } = action.payload;
let graph;
const manager = $canvasManager.get();
assert(manager, 'No model found in state');
if (model?.base === 'sdxl') {
graph = await buildGenerationTabSDXLGraph(state);
} else {
graph = await buildGenerationTabGraph(state);
let didStartStaging = false;
if (!state.canvasSession.isStaging && state.canvasSession.mode === 'compose') {
dispatch(sessionStartedStaging());
didStartStaging = true;
}
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
const abortStaging = () => {
if (didStartStaging && getState().canvasSession.isStaging) {
dispatch(sessionStagingAreaReset());
}
};
let buildGraphResult: Result<
{ g: Graph; noise: Invocation<'noise'>; posCond: Invocation<'compel' | 'sdxl_compel_prompt'> },
Error
>;
assert(model, 'No model found in state');
const base = model.base;
switch (base) {
case 'sdxl':
buildGraphResult = await withResultAsync(() => buildSDXLGraph(state, manager));
break;
case 'sd-1':
case `sd-2`:
buildGraphResult = await withResultAsync(() => buildSD1Graph(state, manager));
break;
default:
assert(false, `No graph builders for base ${base}`);
}
if (isErr(buildGraphResult)) {
log.error({ error: serializeError(buildGraphResult.error) }, 'Failed to build graph');
abortStaging();
return;
}
const { g, noise, posCond } = buildGraphResult.value;
const prepareBatchResult = withResult(() => prepareLinearUIBatch(state, g, prepend, noise, posCond));
if (isErr(prepareBatchResult)) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
abortStaging();
return;
}
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
queueApi.endpoints.enqueueBatch.initiate(prepareBatchResult.value, {
fixedCacheKey: 'enqueueBatch',
})
);
try {
await req.unwrap();
if (shouldShowProgressInViewer) {
dispatch(isImageViewerOpenChanged(true));
}
} finally {
req.reset();
req.reset();
const enqueueResult = await withResultAsync(() => req.unwrap());
if (isErr(enqueueResult)) {
log.error({ error: serializeError(enqueueResult.error) }, 'Failed to enqueue batch');
abortStaging();
return;
}
log.debug({ batchConfig: prepareBatchResult.value } as SerializableObject, 'Enqueued batch');
},
});
};

View File

@ -1,5 +1,6 @@
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
import { queueApi } from 'services/api/endpoints/queue';
@ -11,12 +12,12 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
enqueueRequested.match(action) && action.payload.tabName === 'workflows',
effect: async (action, { getState, dispatch }) => {
const state = getState();
const { nodes, edges } = state.nodes.present;
const nodes = selectNodesSlice(state);
const workflow = state.workflow;
const graph = buildNodesGraph(state.nodes.present);
const graph = buildNodesGraph(nodes);
const builtWorkflow = buildWorkflowWithValidation({
nodes,
edges,
nodes: nodes.nodes,
edges: nodes.edges,
workflow,
});
@ -29,7 +30,8 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
batch: {
graph,
workflow: builtWorkflow,
runs: state.generation.iterations,
runs: state.params.iterations,
origin: 'workflows',
},
prepend: action.payload.prepend,
};

View File

@ -14,9 +14,9 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
const { shouldShowProgressInViewer } = state.ui;
const { prepend } = action.payload;
const graph = await buildMultidiffusionUpscaleGraph(state);
const { g, noise, posCond } = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
const batchConfig = prepareLinearUIBatch(state, g, prepend, noise, posCond);
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {

View File

@ -27,7 +27,7 @@ export const galleryImageClicked = createAction<{
export const addGalleryImageClickedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: galleryImageClicked,
effect: async (action, { dispatch, getState }) => {
effect: (action, { dispatch, getState }) => {
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
const state = getState();
const queryArgs = selectListImagesQueryArgs(state);

View File

@ -1,24 +1,27 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { SerializableObject } from 'common/types';
import { parseify } from 'common/util/serialize';
import { $templates } from 'features/nodes/store/nodesSlice';
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
import { size } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { appInfoApi } from 'services/api/endpoints/appInfo';
const log = logger('system');
export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: appInfoApi.endpoints.getOpenAPISchema.matchFulfilled,
effect: (action, { getState }) => {
const log = logger('system');
const schemaJSON = action.payload;
log.debug({ schemaJSON: parseify(schemaJSON) }, 'Received OpenAPI schema');
log.debug({ schemaJSON: parseify(schemaJSON) } as SerializableObject, 'Received OpenAPI schema');
const { nodesAllowlist, nodesDenylist } = getState().config;
const nodeTemplates = parseSchema(schemaJSON, nodesAllowlist, nodesDenylist);
log.debug({ nodeTemplates: parseify(nodeTemplates) }, `Built ${size(nodeTemplates)} node templates`);
log.debug({ nodeTemplates } as SerializableObject, `Built ${size(nodeTemplates)} node templates`);
$templates.set(nodeTemplates);
},
@ -30,8 +33,7 @@ export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening
// If action.meta.condition === true, the request was canceled/skipped because another request was in flight or
// the value was already in the cache. We don't want to log these errors.
if (!action.meta.condition) {
const log = logger('system');
log.error({ error: parseify(action.error) }, 'Problem retrieving OpenAPI Schema');
log.error({ error: serializeError(action.error) }, 'Problem retrieving OpenAPI Schema');
}
},
});

View File

@ -2,15 +2,13 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { imagesApi } from 'services/api/endpoints/images';
const log = logger('gallery');
export const addImageAddedToBoardFulfilledListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.addImageToBoard.matchFulfilled,
effect: (action) => {
const log = logger('images');
const { board_id, imageDTO } = action.meta.arg.originalArgs;
// TODO: update listImages cache for this board
log.debug({ board_id, imageDTO }, 'Image added to board');
},
});
@ -18,9 +16,7 @@ export const addImageAddedToBoardFulfilledListener = (startAppListening: AppStar
startAppListening({
matcher: imagesApi.endpoints.addImageToBoard.matchRejected,
effect: (action) => {
const log = logger('images');
const { board_id, imageDTO } = action.meta.arg.originalArgs;
log.debug({ board_id, imageDTO }, 'Problem adding image to board');
},
});

View File

@ -1,20 +1,9 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import {
controlAdapterImageChanged,
controlAdapterProcessedImageChanged,
selectControlAdapterAll,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import {
isControlAdapterLayer,
isInitialImageLayer,
isIPAdapterLayer,
isRegionalGuidanceLayer,
layerDeleted,
} from 'features/controlLayers/store/controlLayersSlice';
import { entityDeleted, ipaImageChanged } from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
@ -26,6 +15,10 @@ import { forEach, intersectionBy } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
const log = logger('gallery');
//TODO(psyche): handle image deletion (canvas sessions?)
// Some utils to delete images from different parts of the app
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
state.nodes.present.nodes.forEach((node) => {
@ -47,52 +40,37 @@ const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: Im
});
};
const deleteControlAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
forEach(selectControlAdapterAll(state.controlAdapters), (ca) => {
if (
ca.controlImage === imageDTO.image_name ||
(isControlNetOrT2IAdapter(ca) && ca.processedControlImage === imageDTO.image_name)
) {
dispatch(
controlAdapterImageChanged({
id: ca.id,
controlImage: null,
})
);
dispatch(
controlAdapterProcessedImageChanged({
id: ca.id,
processedControlImage: null,
})
);
// const deleteControlAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
// state.canvas.present.controlAdapters.entities.forEach(({ id, imageObject, processedImageObject }) => {
// if (
// imageObject?.image.image_name === imageDTO.image_name ||
// processedImageObject?.image.image_name === imageDTO.image_name
// ) {
// dispatch(caImageChanged({ id, imageDTO: null }));
// dispatch(caProcessedImageChanged({ id, imageDTO: null }));
// }
// });
// };
const deleteIPAdapterImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
selectCanvasSlice(state).ipAdapters.entities.forEach((entity) => {
if (entity.ipAdapter.image?.image_name === imageDTO.image_name) {
dispatch(ipaImageChanged({ entityIdentifier: getEntityIdentifier(entity), imageDTO: null }));
}
});
};
const deleteControlLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
state.controlLayers.present.layers.forEach((l) => {
if (isRegionalGuidanceLayer(l)) {
if (l.ipAdapters.some((ipa) => ipa.image?.name === imageDTO.image_name)) {
dispatch(layerDeleted(l.id));
const deleteLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
selectCanvasSlice(state).rasterLayers.entities.forEach(({ id, objects }) => {
let shouldDelete = false;
for (const obj of objects) {
if (obj.type === 'image' && obj.image.image_name === imageDTO.image_name) {
shouldDelete = true;
break;
}
}
if (isControlAdapterLayer(l)) {
if (
l.controlAdapter.image?.name === imageDTO.image_name ||
l.controlAdapter.processedImage?.name === imageDTO.image_name
) {
dispatch(layerDeleted(l.id));
}
}
if (isIPAdapterLayer(l)) {
if (l.ipAdapter.image?.name === imageDTO.image_name) {
dispatch(layerDeleted(l.id));
}
}
if (isInitialImageLayer(l)) {
if (l.image?.name === imageDTO.image_name) {
dispatch(layerDeleted(l.id));
}
if (shouldDelete) {
dispatch(entityDeleted({ entityIdentifier: { id, type: 'raster_layer' } }));
}
});
};
@ -145,14 +123,10 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
}
}
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
if (imageUsage.isCanvasImage) {
dispatch(resetCanvas());
}
deleteControlAdapterImages(state, dispatch, imageDTO);
deleteNodesImages(state, dispatch, imageDTO);
deleteControlLayerImages(state, dispatch, imageDTO);
// deleteControlAdapterImages(state, dispatch, imageDTO);
deleteIPAdapterImages(state, dispatch, imageDTO);
deleteLayerImages(state, dispatch, imageDTO);
} catch {
// no-op
} finally {
@ -189,14 +163,11 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
if (imagesUsage.some((i) => i.isCanvasImage)) {
dispatch(resetCanvas());
}
imageDTOs.forEach((imageDTO) => {
deleteControlAdapterImages(state, dispatch, imageDTO);
deleteNodesImages(state, dispatch, imageDTO);
deleteControlLayerImages(state, dispatch, imageDTO);
// deleteControlAdapterImages(state, dispatch, imageDTO);
deleteIPAdapterImages(state, dispatch, imageDTO);
deleteLayerImages(state, dispatch, imageDTO);
});
} catch {
// no-op
@ -220,7 +191,6 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
startAppListening({
matcher: imagesApi.endpoints.deleteImage.matchFulfilled,
effect: (action) => {
const log = logger('images');
log.debug({ imageDTO: action.meta.arg.originalArgs }, 'Image deleted');
},
});
@ -228,7 +198,6 @@ export const addImageDeletionListeners = (startAppListening: AppStartListening)
startAppListening({
matcher: imagesApi.endpoints.deleteImage.matchRejected,
effect: (action) => {
const log = logger('images');
log.debug({ imageDTO: action.meta.arg.originalArgs }, 'Unable to delete image');
},
});

View File

@ -1,28 +1,19 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {
controlAdapterImageChanged,
controlAdapterIsEnabledChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import {
caLayerImageChanged,
iiLayerImageChanged,
ipaLayerImageChanged,
rgLayerIPAdapterImageChanged,
} from 'features/controlLayers/store/controlLayersSlice';
controlLayerAdded,
ipaImageChanged,
rasterLayerAdded,
rgIPAdapterImageChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { CanvasControlLayerState, CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import {
imageSelected,
imageToCompareChanged,
isImageViewerOpenChanged,
selectionChanged,
} from 'features/gallery/store/gallerySlice';
import { imageToCompareChanged, isImageViewerOpenChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { imagesApi } from 'services/api/endpoints/images';
@ -31,11 +22,12 @@ export const dndDropped = createAction<{
activeData: TypesafeDraggableData;
}>('dnd/dndDropped');
const log = logger('system');
export const addImageDroppedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: dndDropped,
effect: async (action, { dispatch, getState }) => {
const log = logger('dnd');
effect: (action, { dispatch, getState }) => {
const { activeData, overData } = action.payload;
if (!isValidDrop(overData, activeData)) {
return;
@ -46,80 +38,22 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
} else if (activeData.payloadType === 'GALLERY_SELECTION') {
log.debug({ activeData, overData }, `Images (${getState().gallery.selection.length}) dropped`);
} else if (activeData.payloadType === 'NODE_FIELD') {
log.debug({ activeData: parseify(activeData), overData: parseify(overData) }, 'Node field dropped');
log.debug({ activeData, overData }, 'Node field dropped');
} else {
log.debug({ activeData, overData }, `Unknown payload dropped`);
}
/**
* Image dropped on current image
*/
if (
overData.actionType === 'SET_CURRENT_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(imageSelected(activeData.payload.imageDTO));
dispatch(isImageViewerOpenChanged(true));
return;
}
/**
* Image dropped on ControlNet
*/
if (
overData.actionType === 'SET_CONTROL_ADAPTER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { id } = overData.context;
dispatch(
controlAdapterImageChanged({
id,
controlImage: activeData.payload.imageDTO.image_name,
})
);
dispatch(
controlAdapterIsEnabledChanged({
id,
isEnabled: true,
})
);
return;
}
/**
* Image dropped on Control Adapter Layer
*/
if (
overData.actionType === 'SET_CA_LAYER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { layerId } = overData.context;
dispatch(
caLayerImageChanged({
layerId,
imageDTO: activeData.payload.imageDTO,
})
);
return;
}
/**
* Image dropped on IP Adapter Layer
*/
if (
overData.actionType === 'SET_IPA_LAYER_IMAGE' &&
overData.actionType === 'SET_IPA_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { layerId } = overData.context;
const { id } = overData.context;
dispatch(
ipaLayerImageChanged({
layerId,
imageDTO: activeData.payload.imageDTO,
})
ipaImageChanged({ entityIdentifier: { id, type: 'ip_adapter' }, imageDTO: activeData.payload.imageDTO })
);
return;
}
@ -128,14 +62,14 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
* Image dropped on RG Layer IP Adapter
*/
if (
overData.actionType === 'SET_RG_LAYER_IP_ADAPTER_IMAGE' &&
overData.actionType === 'SET_RG_IP_ADAPTER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { layerId, ipAdapterId } = overData.context;
const { id, ipAdapterId } = overData.context;
dispatch(
rgLayerIPAdapterImageChanged({
layerId,
rgIPAdapterImageChanged({
entityIdentifier: { id, type: 'regional_guidance' },
ipAdapterId,
imageDTO: activeData.payload.imageDTO,
})
@ -144,32 +78,38 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
}
/**
* Image dropped on II Layer Image
* Image dropped on Raster layer
*/
if (
overData.actionType === 'SET_II_LAYER_IMAGE' &&
overData.actionType === 'ADD_RASTER_LAYER_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { layerId } = overData.context;
dispatch(
iiLayerImageChanged({
layerId,
imageDTO: activeData.payload.imageDTO,
})
);
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
const overrides: Partial<CanvasRasterLayerState> = {
objects: [imageObject],
position: { x, y },
};
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
return;
}
/**
* Image dropped on Canvas
* Image dropped on Raster layer
*/
if (
overData.actionType === 'SET_CANVAS_INITIAL_IMAGE' &&
overData.actionType === 'ADD_CONTROL_LAYER_FROM_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(setInitialCanvasImage(activeData.payload.imageDTO, selectOptimalDimension(getState())));
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(getState()).bbox.rect;
const overrides: Partial<CanvasControlLayerState> = {
objects: [imageObject],
position: { x, y },
};
dispatch(controlLayerAdded({ overrides, isSelected: true }));
return;
}

View File

@ -2,13 +2,13 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { imagesApi } from 'services/api/endpoints/images';
const log = logger('gallery');
export const addImageRemovedFromBoardFulfilledListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.removeImageFromBoard.matchFulfilled,
effect: (action) => {
const log = logger('images');
const imageDTO = action.meta.arg.originalArgs;
log.debug({ imageDTO }, 'Image removed from board');
},
});
@ -16,9 +16,7 @@ export const addImageRemovedFromBoardFulfilledListener = (startAppListening: App
startAppListening({
matcher: imagesApi.endpoints.removeImageFromBoard.matchRejected,
effect: (action) => {
const log = logger('images');
const imageDTO = action.meta.arg.originalArgs;
log.debug({ imageDTO }, 'Problem removing image from board');
},
});

View File

@ -6,16 +6,17 @@ import { imagesToDeleteSelected, isModalOpenChanged } from 'features/deleteImage
export const addImageToDeleteSelectedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: imagesToDeleteSelected,
effect: async (action, { dispatch, getState }) => {
effect: (action, { dispatch, getState }) => {
const imageDTOs = action.payload;
const state = getState();
const { shouldConfirmOnDelete } = state.system;
const imagesUsage = selectImageUsage(getState());
const isImageInUse =
imagesUsage.some((i) => i.isCanvasImage) ||
imagesUsage.some((i) => i.isControlImage) ||
imagesUsage.some((i) => i.isNodesImage);
imagesUsage.some((i) => i.isLayerImage) ||
imagesUsage.some((i) => i.isControlAdapterImage) ||
imagesUsage.some((i) => i.isIPAdapterImage) ||
imagesUsage.some((i) => i.isLayerImage);
if (shouldConfirmOnDelete || isImageInUse) {
dispatch(isModalOpenChanged(true));

View File

@ -1,19 +1,8 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {
controlAdapterImageChanged,
controlAdapterIsEnabledChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import {
caLayerImageChanged,
iiLayerImageChanged,
ipaLayerImageChanged,
rgLayerIPAdapterImageChanged,
} from 'features/controlLayers/store/controlLayersSlice';
import { ipaImageChanged, rgIPAdapterImageChanged } from 'features/controlLayers/store/canvasSlice';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
@ -21,11 +10,12 @@ import { omit } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
const log = logger('gallery');
export const addImageUploadedFulfilledListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.uploadImage.matchFulfilled,
effect: (action, { dispatch, getState }) => {
const log = logger('images');
const imageDTO = action.payload;
const state = getState();
const { autoAddBoardId } = state.gallery;
@ -81,15 +71,6 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
return;
}
if (postUploadAction?.type === 'SET_CANVAS_INITIAL_IMAGE') {
dispatch(setInitialCanvasImage(imageDTO, selectOptimalDimension(state)));
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setAsCanvasInitialImage'),
});
return;
}
if (postUploadAction?.type === 'SET_UPSCALE_INITIAL_IMAGE') {
dispatch(upscaleInitialImageChanged(imageDTO));
toast({
@ -99,70 +80,33 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
return;
}
if (postUploadAction?.type === 'SET_CONTROL_ADAPTER_IMAGE') {
// if (postUploadAction?.type === 'SET_CA_IMAGE') {
// const { id } = postUploadAction;
// dispatch(caImageChanged({ id, imageDTO }));
// toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
// return;
// }
if (postUploadAction?.type === 'SET_IPA_IMAGE') {
const { id } = postUploadAction;
dispatch(
controlAdapterIsEnabledChanged({
id,
isEnabled: true,
})
);
dispatch(
controlAdapterImageChanged({
id,
controlImage: imageDTO.image_name,
})
);
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
});
dispatch(ipaImageChanged({ entityIdentifier: { id, type: 'ip_adapter' }, imageDTO }));
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
return;
}
if (postUploadAction?.type === 'SET_CA_LAYER_IMAGE') {
const { layerId } = postUploadAction;
dispatch(caLayerImageChanged({ layerId, imageDTO }));
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
});
}
if (postUploadAction?.type === 'SET_IPA_LAYER_IMAGE') {
const { layerId } = postUploadAction;
dispatch(ipaLayerImageChanged({ layerId, imageDTO }));
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
});
}
if (postUploadAction?.type === 'SET_RG_LAYER_IP_ADAPTER_IMAGE') {
const { layerId, ipAdapterId } = postUploadAction;
dispatch(rgLayerIPAdapterImageChanged({ layerId, ipAdapterId, imageDTO }));
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
});
}
if (postUploadAction?.type === 'SET_II_LAYER_IMAGE') {
const { layerId } = postUploadAction;
dispatch(iiLayerImageChanged({ layerId, imageDTO }));
toast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setControlImage'),
});
if (postUploadAction?.type === 'SET_RG_IP_ADAPTER_IMAGE') {
const { id, ipAdapterId } = postUploadAction;
dispatch(
rgIPAdapterImageChanged({ entityIdentifier: { id, type: 'regional_guidance' }, ipAdapterId, imageDTO })
);
toast({ ...DEFAULT_UPLOADED_TOAST, description: t('toast.setControlImage') });
return;
}
if (postUploadAction?.type === 'SET_NODES_IMAGE') {
const { nodeId, fieldName } = postUploadAction;
dispatch(fieldImageValueChanged({ nodeId, fieldName, value: imageDTO }));
toast({
...DEFAULT_UPLOADED_TOAST,
description: `${t('toast.setNodeField')} ${fieldName}`,
});
toast({ ...DEFAULT_UPLOADED_TOAST, description: `${t('toast.setNodeField')} ${fieldName}` });
return;
}
},
@ -171,7 +115,6 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
startAppListening({
matcher: imagesApi.endpoints.uploadImage.matchRejected,
effect: (action) => {
const log = logger('images');
const sanitizedData = {
arg: {
...omit(action.meta.arg.originalArgs, ['file', 'postUploadAction']),

View File

@ -6,7 +6,7 @@ import type { ImageDTO } from 'services/api/types';
export const addImagesStarredListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.starImages.matchFulfilled,
effect: async (action, { dispatch, getState }) => {
effect: (action, { dispatch, getState }) => {
const { updated_image_names: starredImages } = action.payload;
const state = getState();

View File

@ -6,7 +6,7 @@ import type { ImageDTO } from 'services/api/types';
export const addImagesUnstarredListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.unstarImages.matchFulfilled,
effect: async (action, { dispatch, getState }) => {
effect: (action, { dispatch, getState }) => {
const { updated_image_names: unstarredImages } = action.payload;
const state = getState();

View File

@ -1,23 +1,18 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import {
controlAdapterIsEnabledChanged,
selectControlAdapterAll,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { modelSelected } from 'features/parameters/store/actions';
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { forEach } from 'lodash-es';
const log = logger('models');
export const addModelSelectedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: modelSelected,
effect: (action, { getState, dispatch }) => {
const log = logger('models');
const state = getState();
const result = zParameterModel.safeParse(action.payload);
@ -29,34 +24,36 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
const newModel = result.data;
const newBaseModel = newModel.base;
const didBaseModelChange = state.generation.model?.base !== newBaseModel;
const didBaseModelChange = state.params.model?.base !== newBaseModel;
if (didBaseModelChange) {
// we may need to reset some incompatible submodels
let modelsCleared = 0;
// handle incompatible loras
forEach(state.lora.loras, (lora, id) => {
state.loras.loras.forEach((lora) => {
if (lora.model.base !== newBaseModel) {
dispatch(loraRemoved(id));
dispatch(loraDeleted({ id: lora.id }));
modelsCleared += 1;
}
});
// handle incompatible vae
const { vae } = state.generation;
const { vae } = state.params;
if (vae && vae.base !== newBaseModel) {
dispatch(vaeSelected(null));
modelsCleared += 1;
}
// handle incompatible controlnets
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
if (ca.model?.base !== newBaseModel) {
dispatch(controlAdapterIsEnabledChanged({ id: ca.id, isEnabled: false }));
modelsCleared += 1;
}
});
// state.canvas.present.controlAdapters.entities.forEach((ca) => {
// if (ca.model?.base !== newBaseModel) {
// modelsCleared += 1;
// if (ca.isEnabled) {
// dispatch(entityIsEnabledToggled({ entityIdentifier: { id: ca.id, type: 'control_adapter' } }));
// }
// }
// });
if (modelsCleared > 0) {
toast({
@ -70,7 +67,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
}
}
dispatch(modelChanged(newModel, state.generation.model));
dispatch(modelChanged({ model: newModel, previousModel: state.params.model }));
},
});
};

View File

@ -1,36 +1,42 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import type { JSONObject } from 'common/types';
import type { SerializableObject } from 'common/types';
import {
controlAdapterModelCleared,
selectControlAdapterAll,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { heightChanged, widthChanged } from 'features/controlLayers/store/controlLayersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
bboxHeightChanged,
bboxWidthChanged,
controlLayerModelChanged,
ipaModelChanged,
rgIPAdapterModelChanged,
} from 'features/controlLayers/store/canvasSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, refinerModelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { calculateNewSize } from 'features/parameters/components/DocumentSize/calculateNewSize';
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import { forEach } from 'lodash-es';
import type { Logger } from 'roarr';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import {
isControlNetOrT2IAdapterModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isRefinerMainModelModelConfig,
isSpandrelImageToImageModelConfig,
isVAEModelConfig,
} from 'services/api/types';
const log = logger('models');
export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
startAppListening({
predicate: modelsApi.endpoints.getModelConfigs.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
effect: (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models');
log.info({ models: action.payload.entities }, `Models loaded (${action.payload.ids.length})`);
const state = getState();
@ -43,6 +49,7 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
handleLoRAModels(models, state, dispatch, log);
handleControlAdapterModels(models, state, dispatch, log);
handleSpandrelImageToImageModels(models, state, dispatch, log);
handleIPAdapterModels(models, state, dispatch, log);
},
});
};
@ -51,15 +58,15 @@ type ModelHandler = (
models: AnyModelConfig[],
state: RootState,
dispatch: AppDispatch,
log: Logger<JSONObject>
log: Logger<SerializableObject>
) => undefined;
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
const currentModel = state.generation.model;
const currentModel = state.params.model;
const mainModels = models.filter(isNonRefinerMainModelConfig);
if (mainModels.length === 0) {
// No models loaded at all
dispatch(modelChanged(null));
dispatch(modelChanged({ model: null }));
return;
}
@ -74,25 +81,16 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
if (defaultModelInList) {
const result = zParameterModel.safeParse(defaultModelInList);
if (result.success) {
dispatch(modelChanged(defaultModelInList, currentModel));
dispatch(modelChanged({ model: defaultModelInList, previousModel: currentModel }));
const { bbox } = selectCanvasSlice(state);
const optimalDimension = getOptimalDimension(defaultModelInList);
if (
getIsSizeOptimal(
state.controlLayers.present.size.width,
state.controlLayers.present.size.height,
optimalDimension
)
) {
if (getIsSizeOptimal(bbox.rect.width, bbox.rect.height, optimalDimension)) {
return;
}
const { width, height } = calculateNewSize(
state.controlLayers.present.size.aspectRatio.value,
optimalDimension * optimalDimension
);
const { width, height } = calculateNewSize(bbox.aspectRatio.value, optimalDimension * optimalDimension);
dispatch(widthChanged({ width }));
dispatch(heightChanged({ height }));
dispatch(bboxWidthChanged({ width }));
dispatch(bboxHeightChanged({ height }));
return;
}
}
@ -104,11 +102,11 @@ const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
return;
}
dispatch(modelChanged(result.data, currentModel));
dispatch(modelChanged({ model: result.data, previousModel: currentModel }));
};
const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
const currentRefinerModel = state.sdxl.refinerModel;
const currentRefinerModel = state.params.refinerModel;
const refinerModels = models.filter(isRefinerMainModelModelConfig);
if (models.length === 0) {
// No models loaded at all
@ -127,7 +125,7 @@ const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
};
const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
const currentVae = state.generation.vae;
const currentVae = state.params.vae;
if (currentVae === null) {
// null is a valid VAE! it means "use the default with the main model"
@ -160,28 +158,47 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
};
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
const loras = state.lora.loras;
forEach(loras, (lora, id) => {
const isLoRAAvailable = models.some((m) => m.key === lora.model.key);
const loraModels = models.filter(isLoRAModelConfig);
state.loras.loras.forEach((lora) => {
const isLoRAAvailable = loraModels.some((m) => m.key === lora.model.key);
if (isLoRAAvailable) {
return;
}
dispatch(loraRemoved(id));
dispatch(loraDeleted({ id: lora.id }));
});
};
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
const isModelAvailable = models.some((m) => m.key === ca.model?.key);
const caModels = models.filter(isControlNetOrT2IAdapterModelConfig);
selectCanvasSlice(state).controlLayers.entities.forEach((entity) => {
const isModelAvailable = caModels.some((m) => m.key === entity.controlAdapter.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlLayerModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
});
};
dispatch(controlAdapterModelCleared({ id: ca.id }));
const handleIPAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
const ipaModels = models.filter(isIPAdapterModelConfig);
selectCanvasSlice(state).ipAdapters.entities.forEach((entity) => {
const isModelAvailable = ipaModels.some((m) => m.key === entity.ipAdapter.model?.key);
if (isModelAvailable) {
return;
}
dispatch(ipaModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
});
selectCanvasSlice(state).regions.entities.forEach((entity) => {
entity.ipAdapters.forEach(({ id: ipAdapterId, model }) => {
const isModelAvailable = ipaModels.some((m) => m.key === model?.key);
if (isModelAvailable) {
return;
}
dispatch(
rgIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), ipAdapterId, modelConfig: null })
);
});
});
};

View File

@ -1,6 +1,6 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { positivePromptChanged } from 'features/controlLayers/store/controlLayersSlice';
import { positivePromptChanged } from 'features/controlLayers/store/paramsSlice';
import {
combinatorialToggled,
isErrorChanged,
@ -13,8 +13,9 @@ import {
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
import { getPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils';
import { activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { stylePresetsApi } from 'services/api/endpoints/stylePresets';
import { utilitiesApi } from 'services/api/endpoints/utilities';
import { socketConnected } from 'services/events/actions';
import { socketConnected } from 'services/events/setEventListeners';
const matcher = isAnyOf(
positivePromptChanged,
@ -22,7 +23,8 @@ const matcher = isAnyOf(
maxPromptsChanged,
maxPromptsReset,
socketConnected,
activeStylePresetIdChanged
activeStylePresetIdChanged,
stylePresetsApi.endpoints.listStylePresets.matchFulfilled
);
export const addDynamicPromptsListener = (startAppListening: AppStartListening) => {

Some files were not shown because too many files have changed in this diff Show More