Compare commits

..

400 Commits

Author SHA1 Message Date
87261bdbc9 FLUX memory management improvements (#6791)
## Summary

This PR contains several improvements to memory management for FLUX
workflows.

It is now possible to achieve better FLUX model caching performance, but
this still requires users to manually configure their `ram`/`vram`
settings. E.g. a `vram` setting of 16.0 should allow for all quantized
FLUX models to be kept in memory on the GPU.

Changes:
- Check the size of a model on disk and free the requisite space in the
model cache before loading it. (This behaviour existed previously, but
was removed in https://github.com/invoke-ai/InvokeAI/pull/6072/files.
The removal did not seem to be intentional).
- Removed the hack to free 24GB of space in the cache before loading the
FLUX model.
- Split the T5 embedding and CLIP embedding steps into separate
functions so that the two models don't both have to be held in RAM at
the same time.
- Fix a bug in `InvokeLinear8bitLt` that was causing some tensors to be
left on the GPU when the model was offloaded to the CPU. (This class is
getting very messy due to the non-standard state_dict handling in
`bnb.nn.Linear8bitLt`. )
- Tidy up some dtype handling in FluxTextToImageInvocation to avoid
situations where we hold references to two copies of the same tensor
unnecessarily.
- (minor) Misc cleanup of ModelCache: improve docs and remove unused
vars.

Future:
We should revisit our default ram/vram configs. The current defaults are
very conservative, and users could see major performance improvements
from tuning these values.

## QA Instructions

I tested the FLUX workflow with the following configurations and
verified that the cache hit rates and memory usage matched the expected
behaviour:
- `ram = 16` and `vram = 16`
- `ram = 16` and `vram = 1`
- `ram = 1` and `vram = 1`

Note that the changes in this PR are not isolated to FLUX. Since we now
check the size of models on disk, we may see slight changes in model
cache offload patterns for other models as well.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-08-29 15:17:45 -04:00
4e4b6c6dbc Tidy variable management and dtype handling in FluxTextToImageInvocation. 2024-08-29 19:08:18 +00:00
5e8cf9fb6a Remove hack to clear cache from the FluxTextToImageInvocation. We now clear the cache based on the on-disk model size. 2024-08-29 19:08:18 +00:00
c738fe051f Split T5 encoding and CLIP encoding into separate functions to ensure that all model references are locally-scoped so that the two models don't have to be help in memory at the same time. 2024-08-29 19:08:18 +00:00
29fe1533f2 Fix bug in InvokeLinear8bitLt that was causing old state information to persist after loading from a state dict. This manifested as state tensors being left on the GPU even when a model had been offloaded to the CPU cache. 2024-08-29 19:08:18 +00:00
77090070bd Check the size of a model on disk and make room for it in the cache before loading it. 2024-08-29 19:08:18 +00:00
6ba9b1b6b0 Tidy up GIG -> GB and remove unused GIG constant. 2024-08-29 19:08:18 +00:00
c578b8df1e Improve ModelCache docs. 2024-08-29 19:08:18 +00:00
cad9a41433 Remove unused MOdelCache.exists(...) function. 2024-08-29 19:08:18 +00:00
5fefb3b0f4 Remove unused param from ModelCache. 2024-08-29 19:08:18 +00:00
5284a870b0 Remove unused constructor params from ModelCache. 2024-08-29 19:08:18 +00:00
e064377c05 Remove default model cache sizes from model_cache_default.py. These defaults were misleading, because the config defaults take precedence over them. 2024-08-29 19:08:18 +00:00
3e569c8312 feat(ui): add fields for CLIP embed models and Flux VAE models in workflows 2024-08-29 11:52:51 -04:00
16825ee6e9 feat(nodes): bump version of flux model node, update default workflow 2024-08-29 11:52:51 -04:00
3f5340fa53 feat(nodes): add submodels as inputs to FLUX main model node instead of hardcoded names 2024-08-29 11:52:51 -04:00
f2a1a39b33 Add selectedStylePreset to app parameters (#6787)
## Summary
- Add selectedStylePreset to app parameters
<!--A description of the changes in this PR. Include the kind of change
(fix, feature, docs, etc), the "why" and the "how". Screenshots or
videos are useful for frontend changes.-->

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-08-28 10:53:07 -04:00
326de55d3e remove api changes and only preselect style preset 2024-08-28 09:53:29 -04:00
b2df909570 added selectedStylePreset to preload presets when app loads 2024-08-28 09:50:44 -04:00
026ac36b06 Revert "added selectedStylePreset to preload presets when app loads"
This reverts commit e97fd85904.
2024-08-28 09:44:08 -04:00
92125e5fd2 bug fixes 2024-08-27 16:13:38 -04:00
c0c139da88 formatting ruff 2024-08-27 15:46:51 -04:00
404ad6a7fd cleanup 2024-08-27 15:42:42 -04:00
fc39086fb4 call stylePresetSelected 2024-08-27 15:34:31 -04:00
cd215700fe added route for selecting style preset 2024-08-27 15:34:07 -04:00
e97fd85904 added selectedStylePreset to preload presets when app loads 2024-08-27 15:33:24 -04:00
0a263fa5b1 chore: bump version to v4.2.9rc1 2024-08-27 12:09:27 -04:00
fae3836a8d fix CLIP 2024-08-27 10:29:10 -04:00
b3d2eb4178 add translations for new model types in MM, remove clip vision from filter since its not displayed in list 2024-08-27 10:29:10 -04:00
576f1cbb75 build: remove broken scripts
These two scripts are broken and can cause data loss. Remove them.

They are not in the launcher script, but _are_ available to users in the terminal/file browser.

Hopefully, when we removing them here, `pip` will delete them on next installation of the package...
2024-08-27 22:01:45 +10:00
50085b40bb Update starter model size estimates. 2024-08-26 20:17:50 -04:00
cff382715a default workflow: add steps to exposed fields, add more notes 2024-08-26 20:17:50 -04:00
54d54d1bf2 Run ruff 2024-08-26 20:17:50 -04:00
e84ea68282 remove prompt 2024-08-26 20:17:50 -04:00
160dd36782 update default workflow for flux 2024-08-26 20:17:50 -04:00
65bb46bcca Rename params for flux and flux vae, add comments explaining use of the config_path in model config 2024-08-26 20:17:50 -04:00
2d185fb766 Run ruff 2024-08-26 20:17:50 -04:00
2ba9b02932 Fix type error in tsc 2024-08-26 20:17:50 -04:00
849da67cc7 Remove no longer used code in the flux denoise function 2024-08-26 20:17:50 -04:00
3ea6c9666e Remove in progress images until we're able to make the valuable 2024-08-26 20:17:50 -04:00
cf633e4ef2 Only install starter models if not already installed 2024-08-26 20:17:50 -04:00
bbf934d980 Remove outdated TODO. 2024-08-26 20:17:50 -04:00
620f733110 ruff format 2024-08-26 20:17:50 -04:00
67928609a3 Downgrade accelerate and huggingface-hub deps to original versions. 2024-08-26 20:17:50 -04:00
5f15afb7db Remove flux repo dependency 2024-08-26 20:17:50 -04:00
635d2f480d ruff 2024-08-26 20:17:50 -04:00
70c278c810 Remove dependency on flux config files 2024-08-26 20:17:50 -04:00
56b9906e2e Setup scaffolding for in progress images and add ability to cancel the flux node 2024-08-26 20:17:50 -04:00
a808ce81fd Replace swish() with torch.nn.functional.silu(h). They are functionally equivalent, but in my test VAE deconding was ~8% faster after the change. 2024-08-26 20:17:50 -04:00
83f82c5ddf Switch the CLIP-L start model to use our hosted version - which is much smaller. 2024-08-26 20:17:50 -04:00
101de8c25d Update t5 encoder formats to accurately reflect the quantization strategy and data type 2024-08-26 20:17:50 -04:00
3339a4baf0 Downgrade revert torch version after removing optimum-qanto, and other minor version-related fixes. 2024-08-26 20:17:50 -04:00
dff4a88baa Move quantization scripts to a scripts/ subdir. 2024-08-26 20:17:50 -04:00
a21f6c4964 Update docs for T5 quantization script. 2024-08-26 20:17:50 -04:00
97562504b7 Remove all references to optimum-quanto and downgrade diffusers. 2024-08-26 20:17:50 -04:00
75d8ac378c Update the T5 8-bit quantized starter model to use the BnB LLM.int8() variant. 2024-08-26 20:17:50 -04:00
b9dd354e2b Fixes to the T5XXL quantization script. 2024-08-26 20:17:50 -04:00
33c2fbd201 Add script for quantizing a T5 model. 2024-08-26 20:17:50 -04:00
5063be92bf Switch flux to using its own conditioning field 2024-08-26 20:17:50 -04:00
1047584b3e Only import bnb quantize file if bitsandbytes is installed 2024-08-26 20:17:50 -04:00
6764dcfdaa Load and unload clip/t5 encoders and run inference separately in text encoding 2024-08-26 20:17:50 -04:00
012864ceb1 Update macos test vm to macOS-14 2024-08-26 20:17:50 -04:00
a0bf20bcee Run FLUX VAE decoding in the user's preferred dtype rather than float32. Tested, and seems to work well at float16. 2024-08-26 20:17:50 -04:00
14ab339b33 Move prepare_latent_image_patches(...) to sampling.py with all of the related FLUX inference code. 2024-08-26 20:17:50 -04:00
25c91efbb6 Rename field positive_prompt -> prompt. 2024-08-26 20:17:50 -04:00
1c1f2c6664 Add comment about incorrect T5 Tokenizer size calculation. 2024-08-26 20:17:50 -04:00
d7c22b3bf7 Tidy is_schnell detection logic. 2024-08-26 20:17:50 -04:00
185f2a395f Make FLUX get_noise(...) consistent across devices/dtypes. 2024-08-26 20:17:50 -04:00
0c5649491e Mark FLUX nodes as prototypes. 2024-08-26 20:17:50 -04:00
94aba5892a Attribute black-forest-labs/flux for much of the flux code 2024-08-26 20:17:50 -04:00
ef093dde29 Don't install bitsandbytes on macOS 2024-08-26 20:17:50 -04:00
34451e5f27 added FLUX dev to starter models 2024-08-26 20:17:50 -04:00
1f9bdd1a9a Undo changes to the v2 dir of frontend types 2024-08-26 20:17:50 -04:00
c27d59baf7 Run ruff 2024-08-26 20:17:50 -04:00
f130ddec7c Remove automatic install of models during flux model loader, remove no longer used import function on context 2024-08-26 20:17:50 -04:00
a0a259eef1 Fix max_seq_len field description. 2024-08-26 20:17:50 -04:00
b66f19d4d1 Add docs to the quantization scripts. 2024-08-26 20:17:50 -04:00
4105a78b83 Update load_flux_model_bnb_llm_int8.py to work with a single-file FLUX transformer checkpoint. 2024-08-26 20:17:50 -04:00
19a68afb3a Fix bug in InvokeInt8Params that was causing it to use double the necessary VRAM. 2024-08-26 20:17:50 -04:00
fd68a2475b add better workflow name 2024-08-26 20:17:50 -04:00
28ff7ba830 add better workflow description 2024-08-26 20:17:50 -04:00
5d0b248fdb fix(worker) fix T5 type 2024-08-26 20:17:50 -04:00
01a4e0f6ef update default workflow 2024-08-26 20:17:50 -04:00
91e0731506 fix schema 2024-08-26 20:17:50 -04:00
d1f904d41f tsc and lint fix 2024-08-26 20:17:50 -04:00
269388c9f4 feat(ui): create new field for t5 encoder models in nodes 2024-08-26 20:17:50 -04:00
b8486379ce fix(ui): pass base/type when installing models, add flux formats to MM badges 2024-08-26 20:17:50 -04:00
400eb94d3b fix(ui): only exclude flux main models from linear UI dropdown, not model manager list 2024-08-26 20:17:50 -04:00
e210c96485 add FLUX schnell starter models and submodels as dependenices or adhoc download options 2024-08-26 20:17:50 -04:00
5f567f41f4 add case for clip embed models in probe 2024-08-26 20:17:50 -04:00
5fed573a29 update flux_model_loader node to take a T5 encoder from node field instead of hardcoded list, assume all models have been downloaded 2024-08-26 20:17:50 -04:00
cfac7c8189 Move requantize.py to the quatnization/ dir. 2024-08-26 20:17:50 -04:00
1787de6836 Add docs to the requantize(...) function explaining why it was copied from optimum-quanto. 2024-08-26 20:17:50 -04:00
ac96f187bd Remove duplicate log_time(...) function. 2024-08-26 20:17:50 -04:00
72398350b4 More flux loader cleanup 2024-08-26 20:17:50 -04:00
df9445c351 Various styling and exception type updates 2024-08-26 20:17:50 -04:00
87b7a2e39b Switch inheritance class of flux model loaders 2024-08-26 20:17:50 -04:00
f7e46622a1 Update doc string for import_local_model and remove access_token since it's only usable for local file paths 2024-08-26 20:17:50 -04:00
71f18353a9 Address minor review comments. 2024-08-26 20:17:50 -04:00
4228de707b Rename t5Encoder -> t5_encoder. 2024-08-26 20:17:50 -04:00
b6a05629ef add default workflow for flux t2i 2024-08-26 20:17:50 -04:00
fbaa820643 exclude flux models from main model dropdown 2024-08-26 20:17:50 -04:00
db2a2d5e38 Some cleanup of the tags and description of flux nodes 2024-08-26 20:17:50 -04:00
8ba6e6b1f8 Add t5 encoders and clip embeds to the model manager 2024-08-26 20:17:50 -04:00
57168d719b Fix styling/lint 2024-08-26 20:17:50 -04:00
dee6d2c98e Fix support for 8b quantized t5 encoders, update exception messages in flux loaders 2024-08-26 20:17:50 -04:00
e49105ece5 Add tqdm progress bar to FLUX denoising. 2024-08-26 20:17:50 -04:00
0c5e11f521 Fix FLUX output image clamping. And a few other minor fixes to make inference work with the full bfloat16 FLUX transformer model. 2024-08-26 20:17:50 -04:00
a63f842a13 Select dev/schnell based on state dict, use correct max seq len based on dev/schnell, and shift in inference, separate vae flux params into separate config 2024-08-26 20:17:50 -04:00
4bd7fda694 Install sub directories with folders correctly, ensure consistent dtype of tensors in flux pipeline and vae 2024-08-26 20:17:50 -04:00
81f0886d6f Working inference node with quantized bnb nf4 checkpoint 2024-08-26 20:17:50 -04:00
2eb87f3306 Remove unused param on _run_vae_decoding in flux text to image 2024-08-26 20:17:50 -04:00
723f3ab0a9 Add nf4 bnb quantized format 2024-08-26 20:17:50 -04:00
1bd90e0fd4 Run ruff, setup initial text to image node 2024-08-26 20:17:50 -04:00
436f18ff55 Add backend functions and classes for Flux implementation, Update the way flux encoders/tokenizers are loaded for prompt encoding, Update way flux vae is loaded 2024-08-26 20:17:50 -04:00
cde9696214 Some UI cleanup, regenerate schema 2024-08-26 20:17:50 -04:00
2d9042fb93 Run Ruff 2024-08-26 20:17:50 -04:00
9ed53af520 Run Ruff 2024-08-26 20:17:50 -04:00
56fda669fd Manage quantization of models within the loader 2024-08-26 20:17:50 -04:00
1d8545a76c Remove changes to v1 workflow 2024-08-26 20:17:50 -04:00
5f59a828f9 Setup flux model loading in the UI 2024-08-26 20:17:50 -04:00
1fa6bddc89 WIP on moving from diffusers to FLUX 2024-08-26 20:17:50 -04:00
d3a5ca5247 More improvements for LLM.int8() - not fully tested. 2024-08-26 20:17:50 -04:00
f01f56a98e LLM.int8() quantization is working, but still some rough edges to solve. 2024-08-26 20:17:50 -04:00
99b0f79784 Clean up NF4 implementation. 2024-08-26 20:17:50 -04:00
e1eb104345 NF4 inference working 2024-08-26 20:17:50 -04:00
5c2f95ef50 NF4 loading working... I think. 2024-08-26 20:17:50 -04:00
b63df9bab9 wip 2024-08-26 20:17:50 -04:00
a52c899c6d Split a FluxTextEncoderInvocation out from the FluxTextToImageInvocation. This has the advantage that we benfit from automatic caching when the prompt isn't changed. 2024-08-26 20:17:50 -04:00
eeabb7ebe5 Make quantized loading fast for both T5XXL and FLUX transformer. 2024-08-26 20:17:50 -04:00
8b1cef978c Make quantized loading fast. 2024-08-26 20:17:50 -04:00
152da482cd WIP - experimentation 2024-08-26 20:17:50 -04:00
3cf0365a35 Make float16 inference work with FLUX on 24GB GPU. 2024-08-26 20:17:50 -04:00
5870742bb9 Add support for 8-bit quantizatino of the FLUX T5XXL text encoder. 2024-08-26 20:17:50 -04:00
01d8c62c57 Make 8-bit quantization save/reload work for the FLUX transformer. Reload is still very slow with the current optimum.quanto implementation. 2024-08-26 20:17:50 -04:00
55a242b2d6 Minor improvements to FLUX workflow. 2024-08-26 20:17:50 -04:00
45263b339f Got FLUX schnell working with 8-bit quantization. Still lots of rough edges to clean up. 2024-08-26 20:17:50 -04:00
3319491861 Use the FluxPipeline.encode_prompt() api rather than trying to run the two text encoders separately. 2024-08-26 20:17:50 -04:00
e687afac90 Add sentencepiece dependency for the T5 tokenizer. 2024-08-26 20:17:50 -04:00
b39031ea53 First draft of FluxTextToImageInvocation. 2024-08-26 20:17:50 -04:00
0b77511271 Update HF download logic to work for black-forest-labs/FLUX.1-schnell. 2024-08-26 20:17:50 -04:00
c99cd989c1 Update imports for compatibility with bumped diffusers version. 2024-08-26 20:17:50 -04:00
317fdadb21 Bump diffusers version to include FLUX support. 2024-08-26 20:17:50 -04:00
4e294f9e3e disable export button if no non-default presets 2024-08-26 09:23:15 -04:00
526e0f30a0 Added support for bounding boxes in the Invocation API
Adding built-in bounding boxes as a core type would help developers of nodes that include bounding box support.
2024-08-26 08:03:30 +10:00
231e5ec94a chore: bump version v4.2.8post1 2024-08-23 06:55:30 +10:00
e5bb6f9693 lint fix 2024-08-23 06:46:19 +10:00
da7dee44c6 fix(ui): use empty string fallback if unable to parse prompts when creating style preset from existing image 2024-08-23 06:46:19 +10:00
83144f4fe3 fix(docs): follow-up docker readme fixes 2024-08-22 11:19:07 -04:00
c451f52ea3 chore(ui): lint 2024-08-22 21:00:09 +10:00
8a2c78f2e1 fix(ui): dynamic prompts not recalculating when deleting or updating a style preset
The root cause was the active style preset not being reset when it was deleted, or no longer present in the list of style presets.

- Add extra reducer to `stylePresetSlice` to reset the active preset if it is deleted or otherwise unavailable
- Update the dynamic prompts listener to trigger on delete/update/list of style presets
2024-08-22 21:00:09 +10:00
bcc78bde9b chore: bump version to v4.2.8 2024-08-22 21:00:09 +10:00
054bb6fe0a translationBot(ui): update translation (Russian)
Currently translated at 100.0% (1367 of 1367 strings)

Co-authored-by: Васянатор <ilabulanov339@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/
Translation: InvokeAI/Web UI
2024-08-22 13:09:56 +10:00
4f4aa6d92e translationBot(ui): update translation (Italian)
Currently translated at 98.4% (1346 of 1367 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (1346 of 1367 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-08-22 13:09:56 +10:00
eac51ac6f5 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2024-08-22 13:09:56 +10:00
9f349a7c0a fix(ui): do not constrain width of hide/show boards button
lets translations display fully
2024-08-22 11:36:07 +10:00
918afa5b15 fix(ui): show more of current board name 2024-08-22 11:36:07 +10:00
eb1113f95c feat(ui): add translation string for "Upscale" 2024-08-22 11:36:07 +10:00
4f4ba7b462 tidy(ui): clean up ActiveStylePreset markup 2024-08-21 09:06:41 +10:00
2298be0e6b fix(ui): error handling if unable to convert image URL to blob 2024-08-21 09:06:41 +10:00
63494dfca7 remove extra slash in exports path 2024-08-21 09:06:41 +10:00
36a1d39454 fix(ui): handle badge styling when template name is long 2024-08-21 09:06:41 +10:00
a6f6d5c400 fix(ui): add loading state to button when creating or updating a style preset 2024-08-21 09:06:41 +10:00
e85f221aca fix(ui): clear prompt template when prompts are recalled 2024-08-21 09:04:35 +10:00
d4797e37dc fix(ui): properly unwrap delete style preset API request so that error is caught 2024-08-19 16:12:39 -04:00
3e7923d072 fix(api): allow updating of type for style preset 2024-08-19 16:12:39 -04:00
a85d69ce3d tidy(ui): getViewModeChunks.tsx -> .ts 2024-08-19 08:25:39 +10:00
96db006c99 fix(ui): edge case with getViewModeChunks 2024-08-19 08:25:39 +10:00
8ca57d03d8 tests(ui): add tests for getViewModeChunks 2024-08-19 08:25:39 +10:00
6c404ce5f8 fix(ui): prompt template preset preview out of order 2024-08-19 08:25:39 +10:00
584e07182b fix(ui): use translations for style preset strings 2024-08-17 21:27:53 +10:00
f787e9acf6 chore: bump version v4.2.8rc2 2024-08-16 21:47:06 +10:00
5a24b89e54 fix(app): include style preset defaults in build 2024-08-16 21:47:06 +10:00
9b482e2a4f chore: bump version to v4.2.8rc1 2024-08-16 10:53:19 +10:00
Max
df4dbe2d57 Fix invoke.sh not detecting symlinks
When invoke.sh is executed using a symlink with a working directory outside of InvokeAI's root directory, it will fail.

invoke.sh attempts to cd into the correct directory at the start of the script, but will cd into the directory of the symlink instead. This commit fixes that.
2024-08-16 10:40:59 +10:00
713bd11177 feat(ui, api): prompt template export (#6745)
## Summary

Adds option to download all prompt templates to a CSV

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-08-16 10:38:50 +10:00
182571df4b Merge branch 'main' into maryhipp/export-presets 2024-08-16 10:17:07 +10:00
29bfe492b6 ui: translations update from weblate (#6746)
Translations update from [Hosted Weblate](https://hosted.weblate.org)
for [InvokeAI/Web
UI](https://hosted.weblate.org/projects/invokeai/web-ui/).



Current translation status:

![Weblate translation
status](https://hosted.weblate.org/widget/invokeai/web-ui/horizontal-auto.svg)
2024-08-16 10:16:51 +10:00
3fb4e3050c feat(ui): focus in textarea after inserting placeholder 2024-08-16 10:14:25 +10:00
39c7ec3cd9 feat(ui): per type fallbacks for templates 2024-08-16 10:11:43 +10:00
26bfbdec7f feat(ui): use buttons instead of menu for preset import/export 2024-08-16 09:58:19 +10:00
7a3eaa8da9 feat(api): save file as prompt_templates.csv 2024-08-16 09:51:46 +10:00
599db7296f export only user style presets 2024-08-15 16:07:32 -04:00
042aab4295 translationBot(ui): update translation (Italian)
Currently translated at 98.6% (1340 of 1359 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-08-15 20:44:02 +02:00
24f298283f clean up, add context menu to import/download templates 2024-08-15 12:39:55 -04:00
68dac6349d Merge remote-tracking branch 'origin/main' into maryhipp/export-presets 2024-08-15 11:21:56 -04:00
b675fc19e8 feat: add base prop for selectedWorkflow to allow loading a workflow on launch (#6742)
## Summary
added a base prop for selectedWorkflow to allow loading a workflow on
launch

<!--A description of the changes in this PR. Include the kind of change
(fix, feature, docs, etc), the "why" and the "how". Screenshots or
videos are useful for frontend changes.-->

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions
can test by loading InvokeAIUI with a selectedWorkflow prop of the
workflow ID
<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-08-15 10:52:23 -04:00
659019cfd6 Merge branch 'main' into chainchompa/preselect-workflows 2024-08-15 10:40:44 -04:00
dcd61e1f82 pin ruff version in python check gha 2024-08-15 09:47:49 -04:00
f5c99b1488 exclude jupyter notebooks from ruff 2024-08-15 09:47:49 -04:00
810be3e1d4 update import directions to include JSON 2024-08-15 09:47:49 -04:00
60d754d1df feat(api): tidy style presets import logic
- Extract parsing into utility function
- Log import errors
- Forbid extra properties on the imported data
2024-08-15 09:47:49 -04:00
bd07c86db9 feat(ui): make style preset menu trigger look like button 2024-08-15 09:47:49 -04:00
bcbf8b6bd8 feat(ui): revert to using {prompt} for prompt template placeholder 2024-08-15 09:47:49 -04:00
356661459b feat(api): support JSON for preset imports
This allows us to support Fooocus format presets.
2024-08-15 09:47:49 -04:00
deb917825e feat(api): use pydantic validation during style preset import
- Enforce name is present and not an empty string
- Provide empty string as default for positive and negative prompt
- Add `positive_prompt` as validation alias for `prompt` field
- Strip whitespace automatically
- Create `TypeAdapter` to validate the whole list in one go
2024-08-15 09:47:49 -04:00
15415c6d85 feat(ui): use dropzone for style preset upload
Easier to accept multiple file types and supper drag and drop in the future.
2024-08-15 09:47:49 -04:00
76b0380b5f feat(ui): create component to upload CSV of style presets to import 2024-08-15 09:47:49 -04:00
2d58754789 feat(api): add endpoint to take a CSV, parse it, validate it, and create many style preset entries 2024-08-15 09:47:49 -04:00
9cdf1f599c Merge branch 'main' into chainchompa/preselect-workflows 2024-08-15 09:25:19 -04:00
268be97ba0 remove ref, make options optional for useGetLoadWorkflow 2024-08-15 09:18:41 -04:00
a9014673a0 wip export 2024-08-15 09:00:11 -04:00
d36c43a10f ui: translations update from weblate (#6727)
Translations update from [Hosted Weblate](https://hosted.weblate.org)
for [InvokeAI/Web
UI](https://hosted.weblate.org/projects/invokeai/web-ui/).



Current translation status:

![Weblate translation
status](https://hosted.weblate.org/widget/invokeai/web-ui/horizontal-auto.svg)
2024-08-15 08:48:03 +10:00
54a5c4e482 translationBot(ui): update translation (Chinese (Simplified))
Currently translated at 98.1% (1296 of 1320 strings)

Co-authored-by: Phrixus2023 <920414016@qq.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/zh_Hans/
Translation: InvokeAI/Web UI
2024-08-15 00:46:01 +02:00
5e09a244e3 translationBot(ui): update translation (Italian)
Currently translated at 98.5% (1336 of 1355 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.5% (1302 of 1321 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.6% (1302 of 1320 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-08-15 00:46:01 +02:00
88648dca1a change selectedWorkflow to selectedWorkflowId 2024-08-14 11:22:37 -04:00
8840df2b00 Merge branch 'main' into chainchompa/preselect-workflows 2024-08-14 09:02:12 -04:00
af159acbdf cleanup 2024-08-14 08:58:38 -04:00
471719bbbe add base prop for selectedWorkflow to allow loading a workflow on launch 2024-08-14 08:47:02 -04:00
b126f2ffd5 feat(ui, api): prompt templates (#6729)
## Summary

Adds prompt templates to the UI. Demo video is attached.
* added default prompt templates to seed database on startup (these
cannot be edited or deleted by users via the UI)
* can create fresh prompt template, create from an image in gallery that
has prompt metadata, or copy an existing prompt template and modify
* if a template is active, can view what your prompt will be invoked as
by switching to "view mode"



https://github.com/user-attachments/assets/32d84e0c-b04c-48da-bae5-aa6eb685d209



## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-08-14 12:49:31 +10:00
9938f12ef0 Merge branch 'main' into maryhipp/style-presets 2024-08-14 12:33:30 +10:00
982c266073 tidy: remove extra characters in prompt templates 2024-08-14 12:31:57 +10:00
5c37391883 fix(ui): do not show [prompt] in preset preview 2024-08-14 12:29:05 +10:00
ddeafc6833 fix(ui): minimize layout shift when overlaying preset prompt preview 2024-08-14 12:24:57 +10:00
41b2d5d013 fix(ui): prompt preview not working preset starts with [prompt] 2024-08-14 12:21:38 +10:00
29d6f48901 fix(ui): prompt shows thru prompt label text 2024-08-14 12:01:49 +10:00
d5c9f4e47f chore(ui): revert framer-motion upgrade
`framer-motion` 11 breaks a lot of stuff in profoundly unintuitive ways, holy crap. UI lib rolled back its dep, pulling in latest version of that
2024-08-14 06:12:00 +10:00
24d73387d8 build(ui): fix chakra deps
We had multiple versions of @emotion/react, stemming from an extraneous dependency on @chakra-ui/react. Removed the extraneosu dep
2024-08-14 06:12:00 +10:00
e0d3927265 feat: add flag for allowPrivateStylePresets that shows a type field when creating a style preset 2024-08-13 14:08:54 -04:00
e5f7c2a9b7 add type safety / validation to form data payloads and allow type to be passed through api 2024-08-13 13:00:31 -04:00
b0760710d5 add the rest of default style presets, update image service to return default images correctly by name, add tooltip popover to images in UI 2024-08-13 11:33:15 -04:00
764accc921 update config docstring 2024-08-12 15:17:40 -04:00
6a01fce9c1 fix payloads for stringified data 2024-08-12 15:16:22 -04:00
9c732ac3b1 Merge remote-tracking branch 'origin/main' into maryhipp/style-presets 2024-08-12 14:53:45 -04:00
b70891c661 update descriptoin of placeholder in modal 2024-08-12 13:37:04 -04:00
4dbf851741 ui: add labels to prompt boxes 2024-08-12 13:33:39 -04:00
6c927a9fd4 move mdoal state into nanostore 2024-08-12 12:46:02 -04:00
096f001634 ui: add ability to copy template 2024-08-12 12:32:31 -04:00
4837e578b2 api: update dir path for style preset images, update payload for create/update formdata 2024-08-12 12:00:14 -04:00
1e547ef912 UI more pr feedback 2024-08-12 11:59:25 -04:00
f6b8970bd1 fix(app): create reference to events task to prevent accidental GC
This wasn't a problem, but it's advised in the official docs so I've done it.
2024-08-12 07:49:58 +10:00
29325a7214 fix(app): use asyncio queue and existing event loop for events
Around the time we (I) implemented pydantic events, I noticed a short pause between progress images every 4 or 5 steps when generating with SDXL. It didn't happen with SD1.5, but I did notice that with SD1.5, we'd get 4 or 5 progress events simultaneously. I'd expect one event every ~25ms, matching my it/s with SD1.5. Mysterious!

Digging in, I found an issue is related to our use of a synchronous queue for events. When the event queue is empty, we must call `asyncio.sleep` before checking again. We were sleeping for 100ms.

Said another way, every time we clear the event queue, we have to wait 100ms before another event can be dispatched, even if it is put on the queue immediately after we start waiting. In practice, this means our events get buffered into batches, dispatched once every 100ms.

This explains why I was getting batches of 4 or 5 SD1.5 progress events at once, but not the intermittent SDXL delay.

But this 100ms wait has another effect when the events are put on the queue in intervals that don't perfectly line up with the 100ms wait. This is most noticeable when the time between events is >100ms, and can add up to 100ms delay before the event is dispatched.

For example, say the queue is empty and we start a 100ms wait. Then, immediately after - like 0.01ms later - we push an event on to the queue. We still need to wait another 99.9ms before that event will be dispatched. That's the SDXL delay.

The easy fix is to reduce the sleep to something like 0.01 seconds, but this feels kinda dirty. Can't we just wait on the queue and dispatch every event immediately? Not with the normal synchronous queue - but we can with `asyncio.Queue`.

I switched the events queue to use `asyncio.Queue` (as seen in this commit), which lets us asynchronous wait on the queue in a loop.

Unfortunately, I ran into another issue - events now felt like their timing was inconsistent, but in a different way than with the 100ms sleep. The time between pushing events on the queue and dispatching them was not consistently ~0ms as I'd expect - it was highly variable from ~0ms up to ~100ms.

This is resolved by passing the asyncio loop directly into the events service and using its methods to create the task and interact with the queue. I don't fully understand why this resolved the issue, because either way we are interacting with the same event loop (as shown by `asyncio.get_running_loop()`). I suppose there's some scheduling magic happening.
2024-08-12 07:49:58 +10:00
8ecf72838d fix(api): image downloads with correct filename
Closes #6730
2024-08-10 09:53:56 -04:00
c3ab8a6aa8 chore(ui): bump rest of deps 2024-08-10 07:45:23 -04:00
1931aa3e70 chore(ui): typegen 2024-08-10 07:45:23 -04:00
d3d8055055 feat(ui): update typegen script 2024-08-10 07:45:23 -04:00
476b0a0403 chore(ui): bump openapi-typescript 2024-08-10 07:45:23 -04:00
f66584713c fix(api): sort OpenAPI schema properties for InvocationOutputMap
This makes the schema output deterministic!
2024-08-10 07:45:23 -04:00
33624fc2fa fix(api): duplicate operation id for get_image_full
There's a FastAPI bug that results in the OpenAPI spec outputting the same operation id for each operation when specifying multiple HTTP methods.

- Discussion: https://github.com/tiangolo/fastapi/discussions/8449
- Pending PR to fix: https://github.com/tiangolo/fastapi/pull/10694

In our case, we have a `get_image_full` endpoint that handles GET and HEAD.

This results in an invalid OpenAPI schema. A workaround is to use two route decorators for the operation handler. This works as expected - HEAD requests get the header, and GET requests get the resource. And the OpenAPI schema is valid.
2024-08-10 07:45:23 -04:00
41c3e73a3c fix tests 2024-08-09 16:31:42 -04:00
97553a7de2 API/DB updates per PR feedback 2024-08-09 16:27:37 -04:00
12ba15bfa9 UI updates per PR feedback 2024-08-09 16:00:13 -04:00
09d1e190e7 show warning for maxUpscaleDimension if model tab is disabled 2024-08-09 14:07:55 -04:00
8eb5d08499 missed translation 2024-08-08 16:01:16 -04:00
9be6acde7d require name to submit style preset 2024-08-08 15:53:21 -04:00
5f83bb0069 update config docstring 2024-08-08 15:20:43 -04:00
b138882abc fix tests? 2024-08-08 15:18:32 -04:00
0cd7cdb52e remove send2trash 2024-08-08 15:13:36 -04:00
1d8b7e2bcf ruff 2024-08-08 15:08:45 -04:00
6461f4758d lint fix 2024-08-08 15:07:58 -04:00
3189ab6863 get dynamic prompts working 2024-08-08 15:07:23 -04:00
3f9a674d4b seed default presets and handle them in UI 2024-08-08 15:02:41 -04:00
587f59b25b focus on prompt textarea when exiting view mode by clicking 2024-08-08 14:38:50 -04:00
4952eada87 ruff format 2024-08-08 14:22:40 -04:00
581029ebaa ruff 2024-08-08 14:21:37 -04:00
42d68780de lint 2024-08-08 14:19:33 -04:00
28032a2f80 more cleanup 2024-08-08 14:18:05 -04:00
e381e021e9 knip lint 2024-08-08 14:00:17 -04:00
641af64f93 regnerate schema 2024-08-08 13:58:25 -04:00
a7b83c8b5b Merge remote-tracking branch 'origin/main' into maryhipp/style-presets 2024-08-08 13:56:59 -04:00
4cc41e0188 translations and lint fix 2024-08-08 13:56:37 -04:00
442fc02429 resize images to 100x100 for style preset images 2024-08-08 12:56:55 -04:00
9a4d075074 fix path for style_preset_images, fix png type when converting blobs to files, built view mode components 2024-08-08 12:31:20 -04:00
17ff8196cb Remove tmp code 2024-08-07 22:06:05 -04:00
68f993998a Add support for norm layer 2024-08-07 22:06:05 -04:00
7da6120b39 Fix LoKR refactor bug 2024-08-07 22:06:05 -04:00
6cd40965c4 Depth Anything V2 (#6674)
- Updated the previous DepthAnything manual implementation to use the
`transformers` implementation instead. So we can get upstream features.
- Plugged in the DepthAnything models to be handled by Invoke's Model
Manager.
- `small_v2` model will use DepthAnythingV2. This has been added as a
new model option and is now also the default in the Linear UI.


![opera_TxRhmbFole](https://github.com/user-attachments/assets/2a25abe3-ba0b-4f97-b75a-2ce5fd6246e6)


# Merge

Review and merge.
2024-08-07 20:26:58 +05:30
408a1d6dbb Merge branch 'main' into depth_anything_v2 2024-08-07 10:45:56 -04:00
0b0abfbe8f clean up image implementation 2024-08-07 10:36:38 -04:00
cc96dcf0ed style preset images 2024-08-07 09:58:27 -04:00
2604fd9fde a whole bunch of stuff 2024-08-06 15:31:13 -04:00
140670d00e translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2024-08-06 17:54:47 +10:00
70233fae5d translationBot(ui): update translation (Chinese (Simplified))
Currently translated at 98.1% (1296 of 1321 strings)

Co-authored-by: Phrixus2023 <920414016@qq.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/zh_Hans/
Translation: InvokeAI/Web UI
2024-08-06 17:54:47 +10:00
6f457a6c4c translationBot(ui): update translation (German)
Currently translated at 65.1% (860 of 1321 strings)

Co-authored-by: Alexander Eichhorn <pfannkuchensack@einfach-doof.de>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2024-08-06 17:54:47 +10:00
B N
5c319f5356 translationBot(ui): update translation (German)
Currently translated at 64.8% (857 of 1321 strings)

Co-authored-by: B N <berndnieschalk@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2024-08-06 17:54:47 +10:00
991a04f090 translationBot(ui): update translation (Italian)
Currently translated at 98.6% (1303 of 1321 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.6% (1302 of 1320 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.6% (1294 of 1312 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2024-08-06 17:54:47 +10:00
c39fa75113 docs(ui): add comment in useIsTooLargeToUpscale 2024-08-06 11:49:35 +10:00
f7863e17ce docs(ui): add docstring for maxUpscaleDimension 2024-08-06 11:49:35 +10:00
7c526390ed fix(ui): compare upscaledPixels vs square of max dimension 2024-08-06 11:49:35 +10:00
2cff20f87a update translations, change config value to be dimension instead of total pixels 2024-08-06 11:49:35 +10:00
90ec757802 lint 2024-08-06 11:49:35 +10:00
4b85dfcefe (ui): restore optioanl limit on upcsale output resolution 2024-08-06 11:49:35 +10:00
21deefdc41 (ui): add image resolution badge to initial upscale image 2024-08-06 11:49:35 +10:00
857d74bbfe wip apply and calculate prompt with interpolation 2024-08-05 19:11:48 -04:00
fd7a635777 (ui) the most basic crud ui: view list of presets, create a new preset, edit/delete existing presets 2024-08-05 15:48:23 -04:00
af9110e964 fix prompt concat logic 2024-08-05 13:42:28 -04:00
a61209206b remove custom SDXL prompts component 2024-08-05 13:40:46 -04:00
e05cc62e5f add style presets API layer to UI 2024-08-05 13:37:07 -04:00
4d4f921a4e build: exclude matplotlib 3.9.1
There was a problem w/ this release on windows and the builds were pulled from pypi. When installing invoke on windows, pip attempts to build from source, but most (all?) systems won't have the prerequisites for this and installs fail.

This also affects GH actions.

The simple fix is to exclude version 3.9.1 from our deps.

For more information, see https://github.com/matplotlib/matplotlib/issues/28551
2024-08-05 08:38:44 +10:00
98db8f395b feat(app): clean up DiskImageStorage types 2024-08-04 09:43:20 +10:00
f465a956a3 feat(ui): remove "images can be restored" messages 2024-08-04 09:43:20 +10:00
9edb02d7ef build: remove send2trash dependency 2024-08-04 09:43:20 +10:00
6c4cf58a31 feat(app): delete model_images instead of using send2trash 2024-08-04 09:43:20 +10:00
08993c0d29 feat(app): delete images instead of using send2trash
Closes #6709
2024-08-04 09:43:20 +10:00
4f8a4b0f22 Merge branch 'main' into depth_anything_v2 2024-08-03 00:38:57 +05:30
a743f3c9b5 fix: implement model to func for depth anything 2024-08-03 00:37:17 +05:30
217fe40d99 feat(api): add style_presets router, make sure all CRUD is working, add is_default 2024-08-02 12:29:54 -04:00
b76bf50b93 feat(db,api): create new table for style presets, build out record storage service for style presets 2024-08-01 22:20:11 -04:00
571ba87e13 fix(ui): include upscale metadata for SDXL multidiffusion 2024-08-01 21:30:42 -04:00
f27b6e2b44 Add Grounded SAM support (text prompt image segmentation) (#6701)
## Summary

This PR enables Grounded SAM workflows
(https://arxiv.org/pdf/2401.14159) via the following:
- `GroundingDinoInvocation` for running a Grounding DINO model.
- `SegmentAnythingModelInvocation` for running a SAM model.
- `MaskTensorToImageInvocation` for convenient visualization.

Other notes:
- Uses the transformers implementation of Grounding DINO and SAM.
- The new models are treated as 'utility models' meaning that they are
not visible in the Models tab, and are downloaded automatically the
first time that they are used.

<img width="874" alt="image"
src="https://github.com/user-attachments/assets/1cbaa97d-0e27-4943-86b1-dc7327ba8675">

## Example

Input image

![be10ec0c-20a8-4ac7-840e-d1a05fffdb6a](https://github.com/user-attachments/assets/bf21572c-635d-4703-b4ab-7aba658a9671)

Prompt: "wheels", all other configs default
Result:

![2221c44e-64e6-4b18-b4cb-610514b7a554](https://github.com/user-attachments/assets/344b91f4-7f4a-4b70-8e2e-3b4a0e55176d)

## Related Issues / Discussions

Thanks to @blessedcoolant for the initial draft here:
https://github.com/invoke-ai/InvokeAI/pull/6678

## QA Instructions

Manual tests:
- [ ] Test that default settings work well.
- [ ] Test with / without apply_polygon_refinement
- [ ] Test mask_filter options
- [ ] Test detection_threshold values
- [ ] Test RGB input image
- [ ] Test RGBA input image
- [ ] Test grayscale input image
- [ ] Smoke test that an empty mask is returned when 0 objects are
detected
- [ ] Test on CPU
- [ ] Test on MPS (Works on Mac OS, but had to force both models to run
on CPU instead of MPS)

Performance:
- Peak GPU memory utilization with both Grounding DINO and SAM models
loaded is ~4.5GB. (The models do not need to be loaded at the same time,
so could be offloaded by the MM if needed.)
- On an RTX4090, with the models already cached, node execution takes
~0.6 secs.
- On my CPU, with the models cached, node execution takes ~10secs.

## Merge Plan

No special instructions.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
2024-08-01 20:40:18 +02:00
981475a624 Merge branch 'main' into ryan/grounded-sam 2024-08-01 20:30:35 +02:00
27ac61a4fb Expose all model options in the GroundingDinoInvocation and the SegmentAnythingInvocation. 2024-08-01 14:23:32 -04:00
675ffc2757 Remove BoundingBoxInvocation field name overrides. 2024-08-01 14:05:44 -04:00
44b21f10f1 Add a pydantic model_validator to BoundingBoxField to check the validity of the coords. 2024-08-01 14:00:57 -04:00
c6d49e8b1f Shorten SegmentAnythingInvocation and GroundingDinoInvocatino docstrings, since they are used as the invocation descriptions in the UI. 2024-08-01 10:17:42 -04:00
e6a512aa86 (minor) Tweak order of mask operations. 2024-08-01 10:12:24 -04:00
c3a6a6fb22 Rename SegmentAnythingModelInvocation -> SegmentAnythingInvocation. 2024-08-01 10:00:36 -04:00
b9dc3460ba Rename SegmentAnythingModel -> SegmentAnythingPipeline. 2024-08-01 09:57:47 -04:00
63581ec980 (minor) Add None check to fix static type checking error. 2024-08-01 09:51:53 -04:00
08b1feeed7 add base prop for destination to direct users to different tabs on initial load (#6706)
## Summary
- we want a way to load the studio while being directed to a specific
tab, introduced a destination prop to achieve that
<!--A description of the changes in this PR. Include the kind of change
(fix, feature, docs, etc), the "why" and the "how". Screenshots or
videos are useful for frontend changes.-->

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-07-31 19:25:36 -04:00
f5cfdcf32d feat: Add BoundingBox Primitive Node 2024-08-01 04:09:08 +05:30
e78fb428f0 simplify destination prop handling 2024-07-31 18:06:22 -04:00
31e270e32c add base prop for destination to direct users to different tabs 2024-07-31 17:20:51 -04:00
b5832768dc Return a MaskOutput from SegmentAnythingModelInvocation. And add a MaskTensorToImageInvocation. 2024-07-31 17:16:14 -04:00
4ce64b69cb Modular backend - LoRA/LyCORIS (#6667)
## Summary

Code for lora patching from #6577.
Additionally made it the way, that lora can patch not only `weight`, but
also `bias`, because saw some loras which doing it.

## Related Issues / Discussions

#6606 

https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d

## QA Instructions

Run with and without set `USE_MODULAR_DENOISE` environment.

## Merge Plan

Replace old lora patcher with new after review done.
If you think that there should be some kind of tests - feel free to add.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-07-31 21:31:31 +02:00
5a9173f766 Merge branch 'main' into stalker-modular_lora 2024-07-31 15:13:22 -04:00
0bb7ed44f6 Add some docs to OriginalWeightsStorage and fix type hints. 2024-07-31 15:08:24 -04:00
332bc9da5b fix: Update depth anything node default to v2 2024-07-31 23:52:29 +05:30
08def3da95 fix: Update canvas depth anything processor default to v2 2024-07-31 23:50:13 +05:30
daf899f9c4 fix: Move the manual image resizing out of the depth anything pipeline 2024-07-31 23:38:12 +05:30
13fb2d1f49 fix: Add Depth Anything V2 as a new option
It is also now the default in the UI replacing Depth Anything V1 small
2024-07-31 23:29:43 +05:30
95dde802ea fix: assert the return depth map to be a PIL image 2024-07-31 23:22:01 +05:30
fca119773b Split invokeai/backend/image_util/segment_anything/ dir into grounding_dino/ and segment_anything/ 2024-07-31 12:28:47 -04:00
0193267a53 Split GroundedSamInvocation into GroundingDinoInvocation and SegmentAnythingModelInvocation. 2024-07-31 12:20:23 -04:00
b4cf78a95d fix: make DA Pipeline a subclass of RawModel 2024-07-31 21:14:49 +05:30
73386826d6 Make GroundingDinoPipeline and SegmentAnythingModel subclasses of RawModel for type checking purposes. 2024-07-31 10:25:34 -04:00
9f448fecb7 Move invokeai/backend/grounded_sam -> invokeai/backend/image_util/grounded_sam 2024-07-31 10:00:30 -04:00
bcd1483a14 Re-order GroundedSAMInvocation._to_numpy_masks(...) to do slightly more work on the GPU. 2024-07-31 09:51:14 -04:00
e206890e25 Use staticmethods rather than inner functions for the Grounding DINO and SAM model loaders. 2024-07-31 09:28:52 -04:00
0a7048f650 (minor) Simplify GroundedSAMInvocation._merge_masks(...). 2024-07-31 08:58:51 -04:00
e8ecf5e155 (minor) Move apply_polygon_refinement condition up a layer. 2024-07-31 08:50:56 -04:00
33e8604b57 Make Grounding DINO DetectionResult a Pydantic model. 2024-07-31 08:47:00 -04:00
cec7399366 (minor) Use a new variable name to satisfy type checks. 2024-07-31 08:27:01 -04:00
bdae81e429 (minor) Simplify GroundedSAMInvocation._filter_detections() 2024-07-31 08:25:19 -04:00
67c32f3d6c Fix typo: zip(..., strict=True) 2024-07-31 08:15:28 -04:00
94d64b8a78 Fix gradient mask values range (#6688)
## Summary

Gradient mask node outputs mask tensor with values in range [-1, 1],
which unexpected range for mask.
It handled in denoise node the way it translates to [0, 2] mask, which
looks even more wrongly)
From discussion with @dunkeroni I understand him as he thought that
negative values will be treated same as 0, so clamping values not change
intended node logic.

## Related Issues / Discussions

#6643 

## QA Instructions

\-

## Merge Plan

\-

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-07-31 06:37:32 +05:30
fa3c0c81b3 Merge branch 'main' into stalker7779/fix_gradient_mask 2024-07-31 06:30:44 +05:30
66547b99c1 Add more karras schedulers (#6695)
## Summary

Add karras variants of `deis`, `unipc`, `kdpm2` and `kdpm_2_a`
schedulers.
Also added `dpmpp_3` schedulers, but `dpmpp_3s` currently bugged, so
added only 3m:
https://github.com/huggingface/diffusers/issues/9007

## Related Issues / Discussions

\-

## QA Instructions

\-

## Merge Plan

~@psychedelicious We need to decide what to do with schedulers order, as
it looks a bit broken:~

![image](https://github.com/user-attachments/assets/e41674af-d87c-4432-8014-c90bd86965a6)

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-07-31 06:09:26 +05:30
328e58be4c Merge branch 'main' into stalker7779/new_karras_schedulers 2024-07-31 05:56:13 +05:30
18f89ed5ed fix: Make DepthAnything work with Invoke's Model Management 2024-07-31 03:57:54 +05:30
5701c79fab Prevent Grounding DINO and Segment Anything from being moved to MPS - they don't work on MPS devices. 2024-07-30 23:04:15 +02:00
2da9f913f3 Add detection_result.py - was forgotten in a prior commit 2024-07-30 16:04:29 -04:00
6b10b59abe Make GroundedSAMInvocation work with any input image mode (RGB, RGBA, grayscale). 2024-07-30 15:55:57 -04:00
918f77bce0 Move some logic from GroundedSAMInvocation to the backend classes. 2024-07-30 15:34:33 -04:00
f170697ebe Merge branch 'main' into depth_anything_v2 2024-07-31 00:53:32 +05:30
556c6a1d84 fix: Update DepthAnything to use the transformers implementation 2024-07-31 00:51:55 +05:30
aca2a2fa13 Add mask_filter and detection_threshold options to the GroundedSAMInvocation. 2024-07-30 14:22:40 -04:00
ff6398f7d8 Add a GroundedSamInvocation for image segmentation from a text prompt (Grounding DINO + Segment Anything Model). 2024-07-30 11:12:26 -04:00
cf996472b9 Suggested changes
Co-Authored-By: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2024-07-30 04:50:56 +03:00
156d14c349 Run api regen 2024-07-30 04:05:21 +03:00
86f705bf48 Optimize weights handling 2024-07-30 03:39:01 +03:00
1fd9631f2d Comments fix
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-30 00:39:50 +03:00
2227a2357f Suggested changes + simplify weights logic in patching
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-30 00:34:37 +03:00
58e7ab157d Ruff format 2024-07-29 22:59:17 +03:00
8d16fa6a49 Remove dpmpp_3s schedulers as it bugged now 2024-07-29 22:55:45 +03:00
55e810efa3 Add dpmpp_3 schedulers 2024-07-29 22:52:15 +03:00
2755316021 update delete board modal to be more descriptive (#6690)
## Summary

<!--A description of the changes in this PR. Include the kind of change
(fix, feature, docs, etc), the "why" and the "how". Screenshots or
videos are useful for frontend changes.-->

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-07-29 13:43:17 -04:00
6525f18610 Merge branch 'main' into chainchompa/board-delete-info 2024-07-29 12:52:36 -04:00
2ad13ac7eb Modular backend - inpaint (#6643)
## Summary

Code for inpainting and inpaint models handling from
https://github.com/invoke-ai/InvokeAI/pull/6577.
Separated in 2 extensions as discussed briefly before, so wait for
discussion about such implementation.

## Related Issues / Discussions

#6606

https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d

## QA Instructions

Run with and without set `USE_MODULAR_DENOISE` environment.
Try and compare outputs between backends in cases:
- Normal generation on inpaint model
- Inpainting on inpaint model
- Inpainting on normal model

## Merge Plan

Nope.
If you think that there should be some kind of tests - feel free to add.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-07-29 10:27:25 -04:00
693a3eaff5 Merge branch 'main' into stalker-modular_inpaint-2 2024-07-29 10:14:45 -04:00
ffca792d5b edited copy for deleted boards message 2024-07-29 09:46:08 -04:00
86a92bb6b5 Add more karras schedulers 2024-07-29 15:14:34 +03:00
171a4e6d80 fix(ui): race condition when deleting a board and resetting selected/auto-add
We were checking the selected and auto-add board ids against the query cache to see if they still exist. If not, we reset.

This only works if the query cache is updated by the time we do the check - race condition!

We already have the board id from the query args, so there's no need to check the query cache - just compare the deleted board ID directly.

Previously this file's several listeners were all in a single one and I had adapted/split its logic up a bit wonkily, introducing these problems.
2024-07-29 11:36:03 +10:00
e3a75a8adf fix(ui): fix logic to reset selected/auto-add boards when toggling show archived boards
The logic was incorrect in two ways:
1. We only ran the logic if we _enable_ showing archived boards. It should be run we we _disable_ showing archived boards.
2. If we couldn't find the selected board in the query cache, we didn't do the reset. This is wrong - if the board isn't in the query cache, we _should_ do the reset. This inverted logic makes more sense before the fix for issue 1.
2024-07-29 11:36:03 +10:00
ee7503ce13 Modular backend - T2I Adapter (#6662)
## Summary

T2I Adapter code from #6577.

## Related Issues / Discussions

#6606 

https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d

## QA Instructions

Run with and without set `USE_MODULAR_DENOISE` environment.

## Merge Plan

Nope.
If you think that there should be some kind of tests - feel free to add.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-07-28 15:52:04 -04:00
8500bac3ca Use logger for warning 2024-07-28 22:51:52 +03:00
310719eb4c Merge branch 'main' into stalker-modular_t2i_adapter 2024-07-28 15:30:00 -04:00
e8e24822ec Modular backend - Seamless (#6651)
## Summary

Seamless code from #6577.

## Related Issues / Discussions

#6606 

https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d

## QA Instructions

Run with and without set `USE_MODULAR_DENOISE` environment.

## Merge Plan

Nope.
If you think that there should be some kind of tests - feel free to add.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-07-28 13:57:38 -04:00
c57a7afb87 Merge branch 'main' into stalker7779/modular_seamless 2024-07-28 13:49:43 -04:00
84d028898c Revert wrong comment copy 2024-07-27 13:20:58 +03:00
ed0174fbc6 Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-27 13:18:28 +03:00
9e582563eb Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-27 04:25:15 +03:00
faa88f72bf Make lora as separate extensions 2024-07-27 02:39:53 +03:00
0d69a31df0 Merge branch 'main' into chainchompa/board-delete-info 2024-07-26 14:03:18 -04:00
5b84e117b2 Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-26 20:51:12 +03:00
eb257d2d28 update delete board modal to be more descriptive 2024-07-26 13:34:25 -04:00
5810cee6c9 Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-26 19:47:28 +03:00
eef88d1f83 Update gradient mask node version 2024-07-26 19:33:41 +03:00
78f6850fc0 Fix gradient mask values range 2024-07-26 19:28:00 +03:00
bd8890be11 Revert "Fix create gradient mask node output"
This reverts commit 9d1fcba415.
2024-07-26 19:24:46 +03:00
adf1a977ea Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-26 19:22:26 +03:00
e5d9ca013e fix: use v1 models for large and base versions 2024-07-25 17:24:12 +05:30
4166c756ce wip: depth_anything_v2 init lint fixes 2024-07-25 14:41:22 +05:30
4f0dfbd34d wip: depth_anything_v2 initial implementation 2024-07-25 13:53:06 +05:30
46c632e7cc Change layer detection keys according to LyCORIS repository 2024-07-25 02:10:47 +03:00
653f63ae71 Add layer keys check 2024-07-25 02:03:08 +03:00
8a9e2f57a4 Handle bias in full/diff lora layer 2024-07-25 02:02:37 +03:00
31949ed2f2 Refactor code a bit 2024-07-25 02:00:30 +03:00
0ccb304b8b Ruff format 2024-07-24 16:01:29 +03:00
ab0bfa709a Handle loras in modular denoise 2024-07-24 05:07:29 +03:00
6af659b1da Handle t2i adapter in modular denoise 2024-07-24 02:55:33 +03:00
416d29fb83 Ruff format 2024-07-24 01:17:28 +03:00
19c00241c6 Use non-inverted mask generally(except inpaint model handling) 2024-07-24 00:59:13 +03:00
c323a760a5 Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-23 23:34:28 +03:00
9d1fcba415 Fix create gradient mask node output 2024-07-23 23:29:28 +03:00
ca21996a97 Remove old seamless class 2024-07-23 18:04:33 +03:00
62aa064e56 Handle seamless in modular denoise 2024-07-23 18:03:59 +03:00
87eb018380 Revert debug change 2024-07-22 23:49:20 +03:00
5003e5d763 Same changes as in other PRs, add check for running inpainting on inpaint model without source image
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
2024-07-22 23:47:39 +03:00
58f3072b91 Handle inpainting on normal models 2024-07-21 22:17:29 +03:00
9e7b470189 Handle inpaint models 2024-07-21 20:45:55 +03:00
233 changed files with 30017 additions and 21249 deletions

View File

@ -62,7 +62,7 @@ jobs:
- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: pip install ruff
run: pip install ruff==0.6.0
shell: bash
- name: ruff check

View File

@ -60,7 +60,7 @@ jobs:
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- platform: macos-default
os: macOS-12
os: macOS-14
github-env: $GITHUB_ENV
- platform: windows-cpu
os: windows-2022

View File

@ -1,20 +1,22 @@
# Invoke in Docker
- Ensure that Docker can use the GPU on your system
- This documentation assumes Linux, but should work similarly under Windows with WSL2
First things first:
- Ensure that Docker can use your [NVIDIA][nvidia docker docs] or [AMD][amd docker docs] GPU.
- This document assumes a Linux system, but should work similarly under Windows with WSL2.
- We don't recommend running Invoke in Docker on macOS at this time. It works, but very slowly.
## Quickstart :lightning:
## Quickstart
No `docker compose`, no persistence, just a simple one-liner using the official images:
No `docker compose`, no persistence, single command, using the official images:
**CUDA:**
**CUDA (NVIDIA GPU):**
```bash
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
```
**ROCm:**
**ROCm (AMD GPU):**
```bash
docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invoke-ai/invokeai:main-rocm
@ -22,12 +24,20 @@ docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invok
Open `http://localhost:9090` in your browser once the container finishes booting, install some models, and generate away!
> [!TIP]
> To persist your data (including downloaded models) outside of the container, add a `--volume/-v` flag to the above command, e.g.: `docker run --volume /some/local/path:/invokeai <...the rest of the command>`
### Data persistence
To persist your generated images and downloaded models outside of the container, add a `--volume/-v` flag to the above command, e.g.:
```bash
docker run --volume /some/local/path:/invokeai {...etc...}
```
`/some/local/path/invokeai` will contain all your data.
It can *usually* be reused between different installs of Invoke. Tread with caution and read the release notes!
## Customize the container
We ship the `run.sh` script, which is a convenient wrapper around `docker compose` for cases where custom image build args are needed. Alternatively, the familiar `docker compose` commands work just as well.
The included `run.sh` script is a convenience wrapper around `docker compose`. It can be helpful for passing additional build arguments to `docker compose`. Alternatively, the familiar `docker compose` commands work just as well.
```bash
cd docker
@ -38,11 +48,14 @@ cp .env.sample .env
It will take a few minutes to build the image the first time. Once the application starts up, open `http://localhost:9090` in your browser to invoke!
>[!TIP]
>When using the `run.sh` script, the container will continue running after Ctrl+C. To shut it down, use the `docker compose down` command.
## Docker setup in detail
#### Linux
1. Ensure builkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
1. Ensure buildkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
2. Install the `docker compose` plugin using your package manager, or follow a [tutorial](https://docs.docker.com/compose/install/linux/#install-using-the-repository).
- The deprecated `docker-compose` (hyphenated) CLI probably won't work. Update to a recent version.
3. Ensure docker daemon is able to access the GPU.
@ -98,25 +111,7 @@ GPU_DRIVER=cuda
Any environment variables supported by InvokeAI can be set here. See the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
## Even More Customizing!
---
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.
### Reconfigure the runtime directory
Can be used to download additional models from the supported model list
In conjunction with `INVOKEAI_ROOT` can be also used to initialize a runtime directory
```yaml
command:
- invokeai-configure
- --yes
```
Or install models:
```yaml
command:
- invokeai-model-install
```
[nvidia docker docs]: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
[amd docker docs]: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html

View File

@ -17,7 +17,7 @@
set -eu
# Ensure we're in the correct folder in case user's CWD is somewhere else
scriptdir=$(dirname "$0")
scriptdir=$(dirname $(readlink -f "$0"))
cd "$scriptdir"
. .venv/bin/activate

View File

@ -1,5 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
from logging import Logger
import torch
@ -31,6 +32,8 @@ from invokeai.app.services.session_processor.session_processor_default import (
)
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
from invokeai.app.services.urls.urls_default import LocalUrlService
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@ -63,7 +66,12 @@ class ApiDependencies:
invoker: Invoker
@staticmethod
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
def initialize(
config: InvokeAIAppConfig,
event_handler_id: int,
loop: asyncio.AbstractEventLoop,
logger: Logger = logger,
) -> None:
logger.info(f"InvokeAI version {__version__}")
logger.info(f"Root directory = {str(config.root_path)}")
@ -74,6 +82,7 @@ class ApiDependencies:
image_files = DiskImageFileStorage(f"{output_folder}/images")
model_images_folder = config.models_path
style_presets_folder = config.style_presets_path
db = init_db(config=config, logger=logger, image_files=image_files)
@ -84,7 +93,7 @@ class ApiDependencies:
board_images = BoardImagesService()
board_records = SqliteBoardRecordStorage(db=db)
boards = BoardService()
events = FastAPIEventService(event_handler_id)
events = FastAPIEventService(event_handler_id, loop=loop)
bulk_download = BulkDownloadService()
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
@ -109,6 +118,8 @@ class ApiDependencies:
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
workflow_records = SqliteWorkflowRecordsStorage(db=db)
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
services = InvocationServices(
board_image_records=board_image_records,
@ -134,6 +145,8 @@ class ApiDependencies:
workflow_records=workflow_records,
tensors=tensors,
conditioning=conditioning,
style_preset_records=style_preset_records,
style_preset_image_files=style_preset_image_files,
)
ApiDependencies.invoker = Invoker(services)

View File

@ -218,9 +218,8 @@ async def get_image_workflow(
raise HTTPException(status_code=404)
@images_router.api_route(
@images_router.get(
"/i/{image_name}/full",
methods=["GET", "HEAD"],
operation_id="get_image_full",
response_class=Response,
responses={
@ -231,6 +230,18 @@ async def get_image_workflow(
404: {"description": "Image not found"},
},
)
@images_router.head(
"/i/{image_name}/full",
operation_id="get_image_full_head",
response_class=Response,
responses={
200: {
"description": "Return the full-resolution image",
"content": {"image/png": {}},
},
404: {"description": "Image not found"},
},
)
async def get_image_full(
image_name: str = Path(description="The name of full-resolution image file to get"),
) -> Response:
@ -242,6 +253,7 @@ async def get_image_full(
content = f.read()
response = Response(content, media_type="image/png")
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
response.headers["Content-Disposition"] = f'inline; filename="{image_name}"'
return response
except Exception:
raise HTTPException(status_code=404)

View File

@ -0,0 +1,274 @@
import csv
import io
import json
import traceback
from typing import Optional
import pydantic
from fastapi import APIRouter, File, Form, HTTPException, Path, Response, UploadFile
from fastapi.responses import FileResponse
from PIL import Image
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException
from invokeai.app.services.style_preset_records.style_preset_records_common import (
InvalidPresetImportDataError,
PresetData,
PresetType,
StylePresetChanges,
StylePresetNotFoundError,
StylePresetRecordWithImage,
StylePresetWithoutId,
UnsupportedFileTypeError,
parse_presets_from_file,
)
class StylePresetFormData(BaseModel):
name: str = Field(description="Preset name")
positive_prompt: str = Field(description="Positive prompt")
negative_prompt: str = Field(description="Negative prompt")
type: PresetType = Field(description="Preset type")
style_presets_router = APIRouter(prefix="/v1/style_presets", tags=["style_presets"])
@style_presets_router.get(
"/i/{style_preset_id}",
operation_id="get_style_preset",
responses={
200: {"model": StylePresetRecordWithImage},
},
)
async def get_style_preset(
style_preset_id: str = Path(description="The style preset to get"),
) -> StylePresetRecordWithImage:
"""Gets a style preset"""
try:
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
style_preset = ApiDependencies.invoker.services.style_preset_records.get(style_preset_id)
return StylePresetRecordWithImage(image=image, **style_preset.model_dump())
except StylePresetNotFoundError:
raise HTTPException(status_code=404, detail="Style preset not found")
@style_presets_router.patch(
"/i/{style_preset_id}",
operation_id="update_style_preset",
responses={
200: {"model": StylePresetRecordWithImage},
},
)
async def update_style_preset(
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
style_preset_id: str = Path(description="The id of the style preset to update"),
data: str = Form(description="The data of the style preset to update"),
) -> StylePresetRecordWithImage:
"""Updates a style preset"""
if image is not None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
try:
ApiDependencies.invoker.services.style_preset_image_files.save(style_preset_id, pil_image)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
else:
try:
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
except StylePresetImageFileNotFoundException:
pass
try:
parsed_data = json.loads(data)
validated_data = StylePresetFormData(**parsed_data)
name = validated_data.name
type = validated_data.type
positive_prompt = validated_data.positive_prompt
negative_prompt = validated_data.negative_prompt
except pydantic.ValidationError:
raise HTTPException(status_code=400, detail="Invalid preset data")
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
changes = StylePresetChanges(name=name, preset_data=preset_data, type=type)
style_preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
style_preset = ApiDependencies.invoker.services.style_preset_records.update(
style_preset_id=style_preset_id, changes=changes
)
return StylePresetRecordWithImage(image=style_preset_image, **style_preset.model_dump())
@style_presets_router.delete(
"/i/{style_preset_id}",
operation_id="delete_style_preset",
)
async def delete_style_preset(
style_preset_id: str = Path(description="The style preset to delete"),
) -> None:
"""Deletes a style preset"""
try:
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
except StylePresetImageFileNotFoundException:
pass
ApiDependencies.invoker.services.style_preset_records.delete(style_preset_id)
@style_presets_router.post(
"/",
operation_id="create_style_preset",
responses={
200: {"model": StylePresetRecordWithImage},
},
)
async def create_style_preset(
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
data: str = Form(description="The data of the style preset to create"),
) -> StylePresetRecordWithImage:
"""Creates a style preset"""
try:
parsed_data = json.loads(data)
validated_data = StylePresetFormData(**parsed_data)
name = validated_data.name
type = validated_data.type
positive_prompt = validated_data.positive_prompt
negative_prompt = validated_data.negative_prompt
except pydantic.ValidationError:
raise HTTPException(status_code=400, detail="Invalid preset data")
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
style_preset = StylePresetWithoutId(name=name, preset_data=preset_data, type=type)
new_style_preset = ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset)
if image is not None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
try:
ApiDependencies.invoker.services.style_preset_image_files.save(new_style_preset.id, pil_image)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(new_style_preset.id)
return StylePresetRecordWithImage(image=preset_image, **new_style_preset.model_dump())
@style_presets_router.get(
"/",
operation_id="list_style_presets",
responses={
200: {"model": list[StylePresetRecordWithImage]},
},
)
async def list_style_presets() -> list[StylePresetRecordWithImage]:
"""Gets a page of style presets"""
style_presets_with_image: list[StylePresetRecordWithImage] = []
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many()
for preset in style_presets:
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(preset.id)
style_preset_with_image = StylePresetRecordWithImage(image=image, **preset.model_dump())
style_presets_with_image.append(style_preset_with_image)
return style_presets_with_image
@style_presets_router.get(
"/i/{style_preset_id}/image",
operation_id="get_style_preset_image",
responses={
200: {
"description": "The style preset image was fetched successfully",
},
400: {"description": "Bad request"},
404: {"description": "The style preset image could not be found"},
},
status_code=200,
)
async def get_style_preset_image(
style_preset_id: str = Path(description="The id of the style preset image to get"),
) -> FileResponse:
"""Gets an image file that previews the model"""
try:
path = ApiDependencies.invoker.services.style_preset_image_files.get_path(style_preset_id)
response = FileResponse(
path,
media_type="image/png",
filename=style_preset_id + ".png",
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
raise HTTPException(status_code=404)
@style_presets_router.get(
"/export",
operation_id="export_style_presets",
responses={200: {"content": {"text/csv": {}}, "description": "A CSV file with the requested data."}},
status_code=200,
)
async def export_style_presets():
# Create an in-memory stream to store the CSV data
output = io.StringIO()
writer = csv.writer(output)
# Write the header
writer.writerow(["name", "prompt", "negative_prompt"])
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many(type=PresetType.User)
for preset in style_presets:
writer.writerow([preset.name, preset.preset_data.positive_prompt, preset.preset_data.negative_prompt])
csv_data = output.getvalue()
output.close()
return Response(
content=csv_data,
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=prompt_templates.csv"},
)
@style_presets_router.post(
"/import",
operation_id="import_style_presets",
)
async def import_style_presets(file: UploadFile = File(description="The file to import")):
try:
style_presets = await parse_presets_from_file(file)
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
except InvalidPresetImportDataError as e:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=400, detail=str(e))
except UnsupportedFileTypeError as e:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail=str(e))

View File

@ -30,6 +30,7 @@ from invokeai.app.api.routers import (
images,
model_manager,
session_queue,
style_presets,
utilities,
workflows,
)
@ -55,11 +56,13 @@ mimetypes.add_type("text/css", ".css")
torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")
loop = asyncio.new_event_loop()
@asynccontextmanager
async def lifespan(app: FastAPI):
# Add startup event to load dependencies
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, loop=loop, logger=logger)
yield
# Shut down threads
ApiDependencies.shutdown()
@ -106,6 +109,7 @@ app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
app.include_router(style_presets.style_presets_router, prefix="/api")
app.openapi = get_openapi_func(app)
@ -184,8 +188,6 @@ def invoke_api() -> None:
check_cudnn(logger)
# Start our own event loop for eventing usage
loop = asyncio.new_event_loop()
config = uvicorn.Config(
app=app,
host=app_config.host,

View File

@ -80,12 +80,12 @@ class CompelInvocation(BaseInvocation):
with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(
text_encoder,
loras=_lora_loader(),
model_state_dict=model_state_dict,
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
@ -175,13 +175,13 @@ class SDXLPromptInvocationBase:
with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (state_dict, text_encoder),
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora(
text_encoder,
loras=_lora_loader(),
prefix=lora_prefix,
model_state_dict=state_dict,
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),

View File

@ -21,6 +21,8 @@ from controlnet_aux import (
from controlnet_aux.util import HWC3, ade_palette
from PIL import Image
from pydantic import BaseModel, Field, field_validator, model_validator
from transformers import pipeline
from transformers.pipelines import DepthEstimationPipeline
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@ -44,13 +46,12 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.util.devices import TorchDevice
class ControlField(BaseModel):
@ -592,7 +593,14 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
return color_map
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
DEPTH_ANYTHING_MODELS = {
"large": "LiheYoung/depth-anything-large-hf",
"base": "LiheYoung/depth-anything-base-hf",
"small": "LiheYoung/depth-anything-small-hf",
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
}
@invocation(
@ -600,28 +608,33 @@ DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
title="Depth Anything Processor",
tags=["controlnet", "depth", "depth anything"],
category="controlnet",
version="1.1.2",
version="1.1.3",
)
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a depth map based on the Depth Anything algorithm"""
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
default="small", description="The size of the depth model to use"
default="small_v2", description="The size of the depth model to use"
)
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
def loader(model_path: Path):
return DepthAnythingDetector.load_model(
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
)
def load_depth_anything(model_path: Path):
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
return DepthAnythingPipeline(depth_anything_pipeline)
with self._context.models.load_remote_model(
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
) as model:
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
return processed_image
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
) as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
depth_map = depth_anything_detector.generate_depth(image)
# Resizing to user target specified size
new_height = int(image.size[1] * (self.resolution / image.size[0]))
depth_map = depth_map.resize((self.resolution, new_height))
return depth_map
@invocation(

View File

@ -39,7 +39,7 @@ class GradientMaskOutput(BaseInvocationOutput):
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
version="1.1.0",
version="1.2.0",
)
class CreateGradientMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
@ -93,6 +93,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
blur_tensor[blur_tensor < 0] = 0.0
threshold = 1 - self.minimum_denoise

View File

@ -37,9 +37,9 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion import PipelineIntermediateState
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
@ -60,8 +60,13 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionB
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@ -498,6 +503,33 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
)
@staticmethod
def parse_t2i_adapter_field(
exit_stack: ExitStack,
context: InvocationContext,
t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
ext_manager: ExtensionsManager,
) -> None:
if t2i_adapters is None:
return
# Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField.
if isinstance(t2i_adapters, T2IAdapterField):
t2i_adapters = [t2i_adapters]
for t2i_adapter_field in t2i_adapters:
ext_manager.add_extension(
T2IAdapterExt(
node_context=context,
model_id=t2i_adapter_field.t2i_adapter_model,
image=context.images.get_pil(t2i_adapter_field.image.image_name),
weight=t2i_adapter_field.weight,
begin_step_percent=t2i_adapter_field.begin_step_percent,
end_step_percent=t2i_adapter_field.end_step_percent,
resize_mode=t2i_adapter_field.resize_mode,
)
)
def prep_ip_adapter_image_prompts(
self,
context: InvocationContext,
@ -707,7 +739,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
else:
masked_latents = torch.where(mask < 0.5, 0.0, latents)
return 1 - mask, masked_latents, self.denoise_mask.gradient
return mask, masked_latents, self.denoise_mask.gradient
@staticmethod
def prepare_noise_and_latents(
@ -765,10 +797,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
dtype = TorchDevice.choose_torch_dtype()
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)
_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
@ -801,21 +829,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
denoising_end=self.denoising_end,
)
denoise_ctx = DenoiseContext(
inputs=DenoiseInputs(
orig_latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
attention_processor_cls=CustomAttnProcessor2_0,
),
unet=None,
scheduler=scheduler,
)
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
@ -833,6 +846,50 @@ class DenoiseLatentsInvocation(BaseInvocation):
if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
### lora
if self.unet.loras:
for lora_field in self.unet.loras:
ext_manager.add_extension(
LoRAExt(
node_context=context,
model_id=lora_field.lora,
weight=lora_field.weight,
)
)
### seamless
if self.unet.seamless_axes:
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
### inpaint
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
# NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
# use the ModelVariantType config. During testing, there was a report of a user with models that had an
# incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be
# prevalent, we will have to revisit how we initialize the inpainting extensions.
if unet_config.variant == ModelVariantType.Inpaint:
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
elif mask is not None:
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
# Initialize context for modular denoise
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)
denoise_ctx = DenoiseContext(
inputs=DenoiseInputs(
orig_latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
attention_processor_cls=CustomAttnProcessor2_0,
),
unet=None,
scheduler=scheduler,
)
# context for loading additional models
with ExitStack() as exit_stack:
# later should be smth like:
@ -840,6 +897,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
# ext_manager.add_extension(ext)
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
@ -871,6 +929,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
# At this point, the mask ranges from 0 (leave unchanged) to 1 (inpaint).
# We invert the mask here for compatibility with the old backend implementation.
if mask is not None:
mask = 1 - mask
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
# below. Investigate whether this is appropriate.
@ -913,14 +975,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
unet_info.model_on_device() as (model_state_dict, unet),
unet_info.model_on_device() as (cached_weights, unet),
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
set_seamless(unet, self.unet.seamless_axes), # FIXME
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(
unet,
loras=_lora_loader(),
model_state_dict=model_state_dict,
cached_weights=cached_weights,
),
):
assert isinstance(unet, UNet2DConditionModel)

View File

@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Callable, Optional, Tuple
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
from pydantic.fields import _Unset
from pydantic_core import PydanticUndefined
@ -40,14 +40,18 @@ class UIType(str, Enum, metaclass=MetaEnum):
# region Model Field Types
MainModel = "MainModelField"
FluxMainModel = "FluxMainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
VAEModel = "VAEModelField"
FluxVAEModel = "FluxVAEModelField"
LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField"
T5EncoderModel = "T5EncoderModelField"
CLIPEmbedModel = "CLIPEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
# endregion
@ -125,13 +129,17 @@ class FieldDescriptions:
negative_cond = "Negative conditioning tensor"
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
t5_encoder = "T5 tokenizer and text encoder"
clip_embed_model = "CLIP Embed loader"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
@ -231,6 +239,12 @@ class ColorField(BaseModel):
return (self.r, self.g, self.b, self.a)
class FluxConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
@ -242,6 +256,31 @@ class ConditioningField(BaseModel):
)
class BoundingBoxField(BaseModel):
"""A bounding box primitive value."""
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
score: Optional[float] = Field(
default=None,
ge=0.0,
le=1.0,
description="The score associated with the bounding box. In the range [0, 1]. This value is typically set "
"when the bounding box was produced by a detector and has an associated confidence score.",
)
@model_validator(mode="after")
def check_coords(self):
if self.x_min > self.x_max:
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
if self.y_min > self.y_max:
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
return self
class MetadataField(RootModel[dict[str, Any]]):
"""
Pydantic model for metadata with custom root of type dict[str, Any].

View File

@ -0,0 +1,92 @@
from typing import Literal
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
@invocation(
"flux_text_encoder",
title="FLUX Text Encoding",
tags=["prompt", "conditioning", "flux"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxTextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a flux image."""
clip: CLIPField = InputField(
title="CLIP",
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5_encoder: T5EncoderField = InputField(
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
t5_max_seq_len: Literal[256, 512] = InputField(
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
)
prompt: str = InputField(description="Text prompt to encode.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
# Note: The T5 and CLIP encoding are done in separate functions to ensure that all model references are locally
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
t5_embeddings = self._t5_encode(context)
clip_embeddings = self._clip_encode(context)
conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
)
conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput.build(conditioning_name)
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, T5Tokenizer)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
prompt_embeds = t5_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
prompt = [self.prompt]
with (
clip_text_encoder_info as clip_text_encoder,
clip_tokenizer_info as clip_tokenizer,
):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
pooled_prompt_embeds = clip_encoder(prompt)
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return pooled_prompt_embeds

View File

@ -0,0 +1,169 @@
import torch
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux_text_to_image",
title="FLUX Text to Image",
tags=["image", "flux"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Text-to-image generation using a FLUX model."""
transformer: TransformerField = InputField(
description=FieldDescriptions.flux_model,
input=Input.Connection,
title="Transformer",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
positive_text_conditioning: FluxConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50."
)
guidance: float = InputField(
default=4.0,
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = self._run_diffusion(context)
image = self._run_vae_decoding(context, latents)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)
def _run_diffusion(
self,
context: InvocationContext,
):
inference_dtype = torch.bfloat16
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=inference_dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
transformer_info = context.models.load(self.transformer.transformer)
# Prepare input noise.
x = get_noise(
num_samples=1,
height=self.height,
width=self.width,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
seed=self.seed,
)
x, img_ids = prepare_latent_img_patches(x)
is_schnell = "schnell" in transformer_info.config.config_path
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=x.shape[1],
shift=not is_schnell,
)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
with transformer_info as transformer:
assert isinstance(transformer, Flux)
def step_callback() -> None:
if context.util.is_canceled():
raise CanceledException
# TODO: Make this look like the image before re-enabling
# latent_image = unpack(img.float(), self.height, self.width)
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
# # Create a new tensor of the required shape [255, 255, 3]
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
# # Convert to a NumPy array and then to a PIL Image
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
# (width, height) = image.size
# width *= 8
# height *= 8
# dataURL = image_to_dataURL(image, image_format="JPEG")
# # TODO: move this whole function to invocation context to properly reference these variables
# context._services.events.emit_invocation_denoise_progress(
# context._data.queue_item,
# context._data.invocation,
# state,
# ProgressImage(dataURL=dataURL, width=width, height=height),
# )
x = denoise(
model=transformer,
img=x,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,
vec=clip_embeddings,
timesteps=timesteps,
step_callback=step_callback,
guidance=self.guidance,
)
x = unpack(x.float(), self.height, self.width)
return x
def _run_vae_decoding(
self,
context: InvocationContext,
latents: torch.Tensor,
) -> Image.Image:
vae_info = context.models.load(self.vae.vae)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
img = vae.decode(latents)
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
return img_pil

View File

@ -0,0 +1,100 @@
from pathlib import Path
from typing import Literal
import torch
from PIL import Image
from transformers import pipeline
from transformers.pipelines import ZeroShotObjectDetectionPipeline
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
GroundingDinoModelKey = Literal["grounding-dino-tiny", "grounding-dino-base"]
GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = {
"grounding-dino-tiny": "IDEA-Research/grounding-dino-tiny",
"grounding-dino-base": "IDEA-Research/grounding-dino-base",
}
@invocation(
"grounding_dino",
title="Grounding DINO (Text Prompt Object Detection)",
tags=["prompt", "object detection"],
category="image",
version="1.0.0",
)
class GroundingDinoInvocation(BaseInvocation):
"""Runs a Grounding DINO model. Performs zero-shot bounding-box object detection from a text prompt."""
# Reference:
# - https://arxiv.org/pdf/2303.05499
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
model: GroundingDinoModelKey = InputField(description="The Grounding DINO model to use.")
prompt: str = InputField(description="The prompt describing the object to segment.")
image: ImageField = InputField(description="The image to segment.")
detection_threshold: float = InputField(
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be returned.",
ge=0.0,
le=1.0,
default=0.3,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BoundingBoxCollectionOutput:
# The model expects a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
detections = self._detect(
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
)
# Convert detections to BoundingBoxCollectionOutput.
bounding_boxes: list[BoundingBoxField] = []
for detection in detections:
bounding_boxes.append(
BoundingBoxField(
x_min=detection.box.xmin,
x_max=detection.box.xmax,
y_min=detection.box.ymin,
y_max=detection.box.ymax,
score=detection.score,
)
)
return BoundingBoxCollectionOutput(collection=bounding_boxes)
@staticmethod
def _load_grounding_dino(model_path: Path):
grounding_dino_pipeline = pipeline(
model=str(model_path),
task="zero-shot-object-detection",
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
return GroundingDinoPipeline(grounding_dino_pipeline)
def _detect(
self,
context: InvocationContext,
image: Image.Image,
labels: list[str],
threshold: float = 0.3,
) -> list[DetectionResult]:
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
# actually makes a difference.
labels = [label if label.endswith(".") else label + "." for label in labels]
with context.models.load_remote_model(
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
) as detector:
assert isinstance(detector, GroundingDinoPipeline)
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)

View File

@ -24,7 +24,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion import set_seamless
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
@ -59,7 +59,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if self.fp32:

View File

@ -1,9 +1,10 @@
import numpy as np
import torch
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
from invokeai.app.invocations.primitives import MaskOutput
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
@invocation(
@ -118,3 +119,27 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
height=mask.shape[1],
width=mask.shape[2],
)
@invocation(
"tensor_mask_to_image",
title="Tensor Mask to Image",
tags=["mask"],
category="mask",
version="1.0.0",
)
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Convert a mask tensor to an image."""
mask: TensorField = InputField(description="The mask tensor to convert.")
def invoke(self, context: InvocationContext) -> ImageOutput:
mask = context.tensors.load(self.mask.tensor_name)
# Ensure that the mask is binary.
if mask.dtype != torch.bool:
mask = mask > 0.5
mask_np = (mask.float() * 255).byte().cpu().numpy()
mask_pil = Image.fromarray(mask_np, mode="L")
image_dto = context.images.save(image=mask_pil)
return ImageOutput.build(image_dto)

View File

@ -1,5 +1,5 @@
import copy
from typing import List, Optional
from typing import List, Literal, Optional
from pydantic import BaseModel, Field
@ -13,7 +13,14 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
CheckpointConfigBase,
ModelType,
SubModelType,
)
class ModelIdentifierField(BaseModel):
@ -60,6 +67,15 @@ class CLIPField(BaseModel):
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
class T5EncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@ -122,6 +138,78 @@ class ModelIdentifierInvocation(BaseInvocation):
return ModelIdentifierOutput(model=self.model)
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
title="Max Seq Length",
)
@invocation(
"flux_model_loader",
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.4",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP Embed",
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)
@invocation(
"main_model_loader",
title="Main Model",

View File

@ -7,10 +7,12 @@ import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
BoundingBoxField,
ColorField,
ConditioningField,
DenoiseMaskField,
FieldDescriptions,
FluxConditioningField,
ImageField,
Input,
InputField,
@ -413,6 +415,17 @@ class MaskOutput(BaseInvocationOutput):
height: int = OutputField(description="The height of the mask in pixels.")
@invocation_output("flux_conditioning_output")
class FluxConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""
conditioning: FluxConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "FluxConditioningOutput":
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""
@ -469,3 +482,42 @@ class ConditioningCollectionInvocation(BaseInvocation):
# endregion
# region BoundingBox
@invocation_output("bounding_box_output")
class BoundingBoxOutput(BaseInvocationOutput):
"""Base class for nodes that output a single bounding box"""
bounding_box: BoundingBoxField = OutputField(description="The output bounding box.")
@invocation_output("bounding_box_collection_output")
class BoundingBoxCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of bounding boxes"""
collection: list[BoundingBoxField] = OutputField(description="The output bounding boxes.", title="Bounding Boxes")
@invocation(
"bounding_box",
title="Bounding Box",
tags=["primitives", "segmentation", "collection", "bounding box"],
category="primitives",
version="1.0.0",
)
class BoundingBoxInvocation(BaseInvocation):
"""Create a bounding box manually by supplying box coordinates"""
x_min: int = InputField(default=0, description="x-coordinate of the bounding box's top left vertex")
y_min: int = InputField(default=0, description="y-coordinate of the bounding box's top left vertex")
x_max: int = InputField(default=0, description="x-coordinate of the bounding box's bottom right vertex")
y_max: int = InputField(default=0, description="y-coordinate of the bounding box's bottom right vertex")
def invoke(self, context: InvocationContext) -> BoundingBoxOutput:
bounding_box = BoundingBoxField(x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max)
return BoundingBoxOutput(bounding_box=bounding_box)
# endregion

View File

@ -0,0 +1,161 @@
from pathlib import Path
from typing import Literal
import numpy as np
import torch
from PIL import Image
from transformers import AutoModelForMaskGeneration, AutoProcessor
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
from invokeai.app.invocations.primitives import MaskOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"]
SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
"segment-anything-base": "facebook/sam-vit-base",
"segment-anything-large": "facebook/sam-vit-large",
"segment-anything-huge": "facebook/sam-vit-huge",
}
@invocation(
"segment_anything",
title="Segment Anything",
tags=["prompt", "segmentation"],
category="segmentation",
version="1.0.0",
)
class SegmentAnythingInvocation(BaseInvocation):
"""Runs a Segment Anything Model."""
# Reference:
# - https://arxiv.org/pdf/2304.02643
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
image: ImageField = InputField(description="The image to segment.")
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
apply_polygon_refinement: bool = InputField(
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
default=True,
)
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
description="The filtering to apply to the detected masks before merging them into a final output.",
default="all",
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> MaskOutput:
# The models expect a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
if len(self.bounding_boxes) == 0:
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
else:
masks = self._segment(context=context, image=image_pil)
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
# masks contains bool values, so we merge them via max-reduce.
combined_mask, _ = torch.stack(masks).max(dim=0)
mask_tensor_name = context.tensors.save(combined_mask)
height, width = combined_mask.shape
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
@staticmethod
def _load_sam_model(model_path: Path):
sam_model = AutoModelForMaskGeneration.from_pretrained(
model_path,
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(sam_model, SamModel)
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
assert isinstance(sam_processor, SamProcessor)
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
def _segment(
self,
context: InvocationContext,
image: Image.Image,
) -> list[torch.Tensor]:
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
# Convert the bounding boxes to the SAM input format.
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
with (
context.models.load_remote_model(
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
) as sam_pipeline,
):
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
masks = self._process_masks(masks)
if self.apply_polygon_refinement:
masks = self._apply_polygon_refinement(masks)
return masks
def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]:
"""Convert the tensor output from the Segment Anything model from a tensor of shape
[num_masks, channels, height, width] to a list of tensors of shape [height, width].
"""
assert masks.dtype == torch.bool
# [num_masks, channels, height, width] -> [num_masks, height, width]
masks, _ = masks.max(dim=1)
# Split the first dimension into a list of masks.
return list(masks.cpu().unbind(dim=0))
def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
"""Apply polygon refinement to the masks.
Convert each mask to a polygon, then back to a mask. This has the following effect:
- Smooth the edges of the mask slightly.
- Ensure that each mask consists of a single closed polygon
- Removes small mask pieces.
- Removes holes from the mask.
"""
# Convert tensor masks to np masks.
np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
# Apply polygon refinement.
for idx, mask in enumerate(np_masks):
shape = mask.shape
assert len(shape) == 2 # Assert length to satisfy type checker.
polygon = mask_to_polygon(mask)
mask = polygon_to_mask(polygon, shape)
np_masks[idx] = mask
# Convert np masks back to tensor masks.
masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
return masks
def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
"""Filter the detected masks based on the specified mask filter."""
assert len(masks) == len(bounding_boxes)
if self.mask_filter == "all":
return masks
elif self.mask_filter == "largest":
# Find the largest mask.
return [max(masks, key=lambda x: float(x.sum()))]
elif self.mask_filter == "highest_box_score":
# Find the index of the bounding box with the highest score.
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
# reasonable fallback since the expected score range is [0.0, 1.0].
max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
return [masks[max_score_idx]]
else:
raise ValueError(f"Invalid mask filter: {self.mask_filter}")

View File

@ -91,6 +91,7 @@ class InvokeAIAppConfig(BaseSettings):
db_dir: Path to InvokeAI databases directory.
outputs_dir: Path to directory for outputs.
custom_nodes_dir: Path to directory for custom nodes.
style_presets_dir: Path to directory for style presets.
log_handlers: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
@ -153,6 +154,7 @@ class InvokeAIAppConfig(BaseSettings):
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.")
style_presets_dir: Path = Field(default=Path("style_presets"), description="Path to directory for style presets.")
# LOGGING
log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".')
@ -300,6 +302,11 @@ class InvokeAIAppConfig(BaseSettings):
"""Path to the models directory, resolved to an absolute path.."""
return self._resolve(self.models_dir)
@property
def style_presets_path(self) -> Path:
"""Path to the style presets directory, resolved to an absolute path.."""
return self._resolve(self.style_presets_dir)
@property
def convert_cache_path(self) -> Path:
"""Path to the converted cache models directory, resolved to an absolute path.."""

View File

@ -1,46 +1,44 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from fastapi_events.dispatcher import dispatch
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import (
EventBase,
)
from invokeai.app.services.events.events_common import EventBase
class FastAPIEventService(EventServiceBase):
def __init__(self, event_handler_id: int) -> None:
def __init__(self, event_handler_id: int, loop: asyncio.AbstractEventLoop) -> None:
self.event_handler_id = event_handler_id
self._queue = Queue[EventBase | None]()
self._queue = asyncio.Queue[EventBase | None]()
self._stop_event = threading.Event()
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
self._loop = loop
# We need to store a reference to the task so it doesn't get GC'd
# See: https://docs.python.org/3/library/asyncio-task.html#creating-tasks
self._background_tasks: set[asyncio.Task[None]] = set()
task = self._loop.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.remove)
super().__init__()
def stop(self, *args, **kwargs):
self._stop_event.set()
self._queue.put(None)
self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
def dispatch(self, event: EventBase) -> None:
self._queue.put(event)
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
async def _dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self._queue.get(block=False)
event = await self._queue.get()
if not event: # Probably stopping
continue
# Leave the payloads as live pydantic models
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@ -1,11 +1,10 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from pathlib import Path
from queue import Queue
from typing import Dict, Optional, Union
from typing import Optional, Union
from PIL import Image, PngImagePlugin
from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.image_files.image_files_common import (
@ -20,18 +19,12 @@ from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk"""
__output_folder: Path
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[Path, PILImageType]
__max_cache_size: int
__invoker: Invoker
def __init__(self, output_folder: Union[str, Path]):
self.__cache = {}
self.__cache_ids = Queue()
self.__cache: dict[Path, PILImageType] = {}
self.__cache_ids = Queue[Path]()
self.__max_cache_size = 10 # TODO: get this from config
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__thumbnails_folder = self.__output_folder / "thumbnails"
# Validate required output folders at launch
self.__validate_storage_folders()
@ -103,7 +96,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image_path = self.get_path(image_name)
if image_path.exists():
send2trash(image_path)
image_path.unlink()
if image_path in self.__cache:
del self.__cache[image_path]
@ -111,7 +104,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
thumbnail_path = self.get_path(thumbnail_name, True)
if thumbnail_path.exists():
send2trash(thumbnail_path)
thumbnail_path.unlink()
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
except Exception as e:

View File

@ -4,6 +4,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
if TYPE_CHECKING:
from logging import Logger
@ -61,6 +63,8 @@ class InvocationServices:
workflow_records: "WorkflowRecordsStorageBase",
tensors: "ObjectSerializerBase[torch.Tensor]",
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
style_preset_records: "StylePresetRecordsStorageBase",
style_preset_image_files: "StylePresetImageFileStorageBase",
):
self.board_images = board_images
self.board_image_records = board_image_records
@ -85,3 +89,5 @@ class InvocationServices:
self.workflow_records = workflow_records
self.tensors = tensors
self.conditioning = conditioning
self.style_preset_records = style_preset_records
self.style_preset_image_files = style_preset_image_files

View File

@ -2,7 +2,6 @@ from pathlib import Path
from PIL import Image
from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase
@ -70,7 +69,7 @@ class ModelImageFileStorageDisk(ModelImageFileStorageBase):
if not self._validate_path(path):
raise ModelImageFileNotFoundException
send2trash(path)
path.unlink()
except Exception as e:
raise ModelImageFileDeleteException from e

View File

@ -783,8 +783,9 @@ class ModelInstallService(ModelInstallServiceBase):
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
if subfolder:
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
path_to_add = Path(f"{top}_{subfolder}")
path_to_remove = top / subfolder # sdxl-turbo/vae/
subfolder_rename = subfolder.name.replace("/", "_").replace("\\", "_")
path_to_add = Path(f"{top}_{subfolder_rename}")
else:
path_to_remove = Path(".")
path_to_add = Path(".")

View File

@ -77,6 +77,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
type: Optional[ModelType] = Field(description="Type of model", default=None)
key: Optional[str] = Field(description="Database ID for this model", default=None)
hash: Optional[str] = Field(description="hash of model file", default=None)
format: Optional[str] = Field(description="format of model file", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None

View File

@ -16,6 +16,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -49,6 +50,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
migrator.register_migration(build_migration_12(app_config=config))
migrator.register_migration(build_migration_13())
migrator.register_migration(build_migration_14())
migrator.run_migrations()
return db

View File

@ -0,0 +1,61 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration14Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._create_style_presets(cursor)
def _create_style_presets(self, cursor: sqlite3.Cursor) -> None:
"""Create the table used to store style presets."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS style_presets (
id TEXT NOT NULL PRIMARY KEY,
name TEXT NOT NULL,
preset_data TEXT NOT NULL,
type TEXT NOT NULL DEFAULT "user",
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
]
# Add trigger for `updated_at`.
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS style_presets
AFTER UPDATE
ON style_presets FOR EACH ROW
BEGIN
UPDATE style_presets SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
]
# Add indexes for searchable fields
indices = [
"CREATE INDEX IF NOT EXISTS idx_style_presets_name ON style_presets(name);",
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def build_migration_14() -> Migration:
"""
Build the migration from database version 13 to 14..
This migration does the following:
- Create the table used to store style presets.
"""
migration_14 = Migration(
from_version=13,
to_version=14,
callback=Migration14Callback(),
)
return migration_14

Binary file not shown.

After

Width:  |  Height:  |  Size: 98 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 138 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 122 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 123 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 160 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 146 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 119 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 117 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 79 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 156 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 141 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 132 KiB

View File

@ -0,0 +1,33 @@
from abc import ABC, abstractmethod
from pathlib import Path
from PIL.Image import Image as PILImageType
class StylePresetImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, style_preset_id: str) -> PILImageType:
"""Retrieves a style preset image as PIL Image."""
pass
@abstractmethod
def get_path(self, style_preset_id: str) -> Path:
"""Gets the internal path to a style preset image."""
pass
@abstractmethod
def get_url(self, style_preset_id: str) -> str | None:
"""Gets the URL to fetch a style preset image."""
pass
@abstractmethod
def save(self, style_preset_id: str, image: PILImageType) -> None:
"""Saves a style preset image."""
pass
@abstractmethod
def delete(self, style_preset_id: str) -> None:
"""Deletes a style preset image."""
pass

View File

@ -0,0 +1,19 @@
class StylePresetImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message: str = "Style preset image file not found"):
super().__init__(message)
class StylePresetImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message: str = "Style preset image file not saved"):
super().__init__(message)
class StylePresetImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message: str = "Style preset image file not deleted"):
super().__init__(message)

View File

@ -0,0 +1,88 @@
from pathlib import Path
from PIL import Image
from PIL.Image import Image as PILImageType
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
from invokeai.app.services.style_preset_images.style_preset_images_common import (
StylePresetImageFileDeleteException,
StylePresetImageFileNotFoundException,
StylePresetImageFileSaveException,
)
from invokeai.app.services.style_preset_records.style_preset_records_common import PresetType
from invokeai.app.util.misc import uuid_string
from invokeai.app.util.thumbnails import make_thumbnail
class StylePresetImageFileStorageDisk(StylePresetImageFileStorageBase):
"""Stores images on disk"""
def __init__(self, style_preset_images_folder: Path):
self._style_preset_images_folder = style_preset_images_folder
self._validate_storage_folders()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def get(self, style_preset_id: str) -> PILImageType:
try:
path = self.get_path(style_preset_id)
return Image.open(path)
except FileNotFoundError as e:
raise StylePresetImageFileNotFoundException from e
def save(self, style_preset_id: str, image: PILImageType) -> None:
try:
self._validate_storage_folders()
image_path = self._style_preset_images_folder / (style_preset_id + ".webp")
thumbnail = make_thumbnail(image, 256)
thumbnail.save(image_path, format="webp")
except Exception as e:
raise StylePresetImageFileSaveException from e
def get_path(self, style_preset_id: str) -> Path:
style_preset = self._invoker.services.style_preset_records.get(style_preset_id)
if style_preset.type is PresetType.Default:
default_images_dir = Path(__file__).parent / Path("default_style_preset_images")
path = default_images_dir / (style_preset.name + ".png")
else:
path = self._style_preset_images_folder / (style_preset_id + ".webp")
return path
def get_url(self, style_preset_id: str) -> str | None:
path = self.get_path(style_preset_id)
if not self._validate_path(path):
return
url = self._invoker.services.urls.get_style_preset_image_url(style_preset_id)
# The image URL never changes, so we must add random query string to it to prevent caching
url += f"?{uuid_string()}"
return url
def delete(self, style_preset_id: str) -> None:
try:
path = self.get_path(style_preset_id)
if not self._validate_path(path):
raise StylePresetImageFileNotFoundException
path.unlink()
except StylePresetImageFileNotFoundException as e:
raise StylePresetImageFileNotFoundException from e
except Exception as e:
raise StylePresetImageFileDeleteException from e
def _validate_path(self, path: Path) -> bool:
"""Validates the path given for an image."""
return path.exists()
def _validate_storage_folders(self) -> None:
"""Checks if the required folders exist and create them if they don't"""
self._style_preset_images_folder.mkdir(parents=True, exist_ok=True)

View File

@ -0,0 +1,146 @@
[
{
"name": "Photography (General)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}. photography. f/2.8 macro photo, bokeh, photorealism",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Photography (Studio Lighting)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}, photography. f/8 photo. centered subject, studio lighting.",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Photography (Landscape)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}, landscape photograph, f/12, lifelike, highly detailed.",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Photography (Portrait)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}. photography. portraiture. catch light in eyes. one flash. rembrandt lighting. Soft box. dark shadows. High contrast. 80mm lens. F2.8.",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Photography (Black and White)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} photography. natural light. 80mm lens. F1.4. strong contrast, hard light. dark contrast. blurred background. black and white",
"negative_prompt": "painting, digital art. sketch, colour+"
}
},
{
"name": "Architectural Visualization",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt}. architectural photography, f/12, luxury, aesthetically pleasing form and function.",
"negative_prompt": "painting, digital art. sketch, blurry"
}
},
{
"name": "Concept Art (Fantasy)",
"type": "default",
"preset_data": {
"positive_prompt": "concept artwork of a {prompt}. (digital painterly art style)++, mythological, (textured 2d dry media brushpack)++, glazed brushstrokes, otherworldly. painting+, illustration+",
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
}
},
{
"name": "Concept Art (Sci-Fi)",
"type": "default",
"preset_data": {
"positive_prompt": "(concept art)++, {prompt}, (sleek futurism)++, (textured 2d dry media)++, metallic highlights, digital painting style",
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
}
},
{
"name": "Concept Art (Character)",
"type": "default",
"preset_data": {
"positive_prompt": "(character concept art)++, stylized painterly digital painting of {prompt}, (painterly, impasto. Dry brush.)++",
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
}
},
{
"name": "Concept Art (Painterly)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} oil painting. high contrast. impasto. sfumato. chiaroscuro. Palette knife.",
"negative_prompt": "photo. smooth. border. frame"
}
},
{
"name": "Environment Art",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} environment artwork, hyper-realistic digital painting style with cinematic composition, atmospheric, depth and detail, voluminous. textured dry brush 2d media",
"negative_prompt": "photo, distorted, blurry, out of focus. sketch."
}
},
{
"name": "Interior Design (Visualization)",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} interior design photo, gentle shadows, light mid-tones, dimension, mix of smooth and textured surfaces, focus on negative space and clean lines, focus",
"negative_prompt": "photo, distorted. sketch."
}
},
{
"name": "Product Rendering",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} high quality product photography, 3d rendering with key lighting, shallow depth of field, simple plain background, studio lighting.",
"negative_prompt": "blurry, sketch, messy, dirty. unfinished."
}
},
{
"name": "Sketch",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} black and white pencil drawing, off-center composition, cross-hatching for shadows, bold strokes, textured paper. sketch+++",
"negative_prompt": "blurry, photo, painting, color. messy, dirty. unfinished. frame, borders."
}
},
{
"name": "Line Art",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} Line art. bold outline. simplistic. white background. 2d",
"negative_prompt": "photo. digital art. greyscale. solid black. painting"
}
},
{
"name": "Anime",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} anime++, bold outline, cel-shaded coloring, shounen, seinen",
"negative_prompt": "(photo)+++. greyscale. solid black. painting"
}
},
{
"name": "Illustration",
"type": "default",
"preset_data": {
"positive_prompt": "{prompt} illustration, bold linework, illustrative details, vector art style, flat coloring",
"negative_prompt": "(photo)+++. greyscale. painting, black and white."
}
},
{
"name": "Vehicles",
"type": "default",
"preset_data": {
"positive_prompt": "A weird futuristic normal auto, {prompt} elegant design, nice color, nice wheels",
"negative_prompt": "sketch. digital art. greyscale. painting"
}
}
]

View File

@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
from invokeai.app.services.style_preset_records.style_preset_records_common import (
PresetType,
StylePresetChanges,
StylePresetRecordDTO,
StylePresetWithoutId,
)
class StylePresetRecordsStorageBase(ABC):
"""Base class for style preset storage services."""
@abstractmethod
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Get style preset by id."""
pass
@abstractmethod
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
"""Creates a style preset."""
pass
@abstractmethod
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
"""Creates many style presets."""
pass
@abstractmethod
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
"""Updates a style preset."""
pass
@abstractmethod
def delete(self, style_preset_id: str) -> None:
"""Deletes a style preset."""
pass
@abstractmethod
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
"""Gets many workflows."""
pass

View File

@ -0,0 +1,139 @@
import codecs
import csv
import json
from enum import Enum
from typing import Any, Optional
import pydantic
from fastapi import UploadFile
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
from invokeai.app.util.metaenum import MetaEnum
class StylePresetNotFoundError(Exception):
"""Raised when a style preset is not found"""
class PresetData(BaseModel, extra="forbid"):
positive_prompt: str = Field(description="Positive prompt")
negative_prompt: str = Field(description="Negative prompt")
PresetDataValidator = TypeAdapter(PresetData)
class PresetType(str, Enum, metaclass=MetaEnum):
User = "user"
Default = "default"
Project = "project"
class StylePresetChanges(BaseModel, extra="forbid"):
name: Optional[str] = Field(default=None, description="The style preset's new name.")
preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.")
type: Optional[PresetType] = Field(description="The updated type of the style preset")
class StylePresetWithoutId(BaseModel):
name: str = Field(description="The name of the style preset.")
preset_data: PresetData = Field(description="The preset data")
type: PresetType = Field(description="The type of style preset")
class StylePresetRecordDTO(StylePresetWithoutId):
id: str = Field(description="The style preset ID.")
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "StylePresetRecordDTO":
data["preset_data"] = PresetDataValidator.validate_json(data.get("preset_data", ""))
return StylePresetRecordDTOValidator.validate_python(data)
StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO)
class StylePresetRecordWithImage(StylePresetRecordDTO):
image: Optional[str] = Field(description="The path for image")
class StylePresetImportRow(BaseModel):
name: str = Field(min_length=1, description="The name of the preset.")
positive_prompt: str = Field(
default="",
description="The positive prompt for the preset.",
validation_alias=AliasChoices("positive_prompt", "prompt"),
)
negative_prompt: str = Field(default="", description="The negative prompt for the preset.")
model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
StylePresetImportList = list[StylePresetImportRow]
StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList)
class UnsupportedFileTypeError(ValueError):
"""Raised when an unsupported file type is encountered"""
pass
class InvalidPresetImportDataError(ValueError):
"""Raised when invalid preset import data is encountered"""
pass
async def parse_presets_from_file(file: UploadFile) -> list[StylePresetWithoutId]:
"""Parses style presets from a file. The file must be a CSV or JSON file.
If CSV, the file must have the following columns:
- name
- prompt (or positive_prompt)
- negative_prompt
If JSON, the file must be a list of objects with the following keys:
- name
- prompt (or positive_prompt)
- negative_prompt
Args:
file (UploadFile): The file to parse.
Returns:
list[StylePresetWithoutId]: The parsed style presets.
Raises:
UnsupportedFileTypeError: If the file type is not supported.
InvalidPresetImportDataError: If the data in the file is invalid.
"""
if file.content_type not in ["text/csv", "application/json"]:
raise UnsupportedFileTypeError()
if file.content_type == "text/csv":
csv_reader = csv.DictReader(codecs.iterdecode(file.file, "utf-8"))
data = list(csv_reader)
else: # file.content_type == "application/json":
json_data = await file.read()
data = json.loads(json_data)
try:
imported_presets = StylePresetImportListTypeAdapter.validate_python(data)
style_presets: list[StylePresetWithoutId] = []
for imported in imported_presets:
preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt)
style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User)
style_presets.append(style_preset)
except pydantic.ValidationError as e:
if file.content_type == "text/csv":
msg = "Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
else: # file.content_type == "application/json":
msg = "Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
raise InvalidPresetImportDataError(msg) from e
finally:
file.file.close()
return style_presets

View File

@ -0,0 +1,215 @@
import json
from pathlib import Path
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
from invokeai.app.services.style_preset_records.style_preset_records_common import (
PresetType,
StylePresetChanges,
StylePresetNotFoundError,
StylePresetRecordDTO,
StylePresetWithoutId,
)
from invokeai.app.util.misc import uuid_string
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
self._sync_default_style_presets()
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Gets a style preset by ID."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = self._cursor.fetchone()
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return StylePresetRecordDTO.from_dict(dict(row))
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
style_preset_id = uuid_string()
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
id,
name,
preset_data,
type
)
VALUES (?, ?, ?, ?);
""",
(
style_preset_id,
style_preset.name,
style_preset.preset_data.model_dump_json(),
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(style_preset_id)
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
try:
self._lock.acquire()
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
id,
name,
preset_data,
type
)
VALUES (?, ?, ?, ?);
""",
(
style_preset_id,
style_preset.name,
style_preset.preset_data.model_dump_json(),
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
try:
self._lock.acquire()
# Change the name of a style preset
if changes.name is not None:
self._cursor.execute(
"""--sql
UPDATE style_presets
SET name = ?
WHERE id = ?;
""",
(changes.name, style_preset_id),
)
# Change the preset data for a style preset
if changes.preset_data is not None:
self._cursor.execute(
"""--sql
UPDATE style_presets
SET preset_data = ?
WHERE id = ?;
""",
(changes.preset_data.model_dump_json(), style_preset_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(style_preset_id)
def delete(self, style_preset_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE from style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
try:
self._lock.acquire()
main_query = """
SELECT
*
FROM style_presets
"""
if type is not None:
main_query += "WHERE type = ? "
main_query += "ORDER BY LOWER(name) ASC"
if type is not None:
self._cursor.execute(main_query, (type,))
else:
self._cursor.execute(main_query)
rows = self._cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
return style_presets
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def _sync_default_style_presets(self) -> None:
"""Syncs default style presets to the database. Internal use only."""
# First delete all existing default style presets
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM style_presets
WHERE type = "default";
"""
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
# Next, parse and create the default style presets
with self._lock, open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
presets = json.load(file)
for preset in presets:
style_preset = StylePresetWithoutId.model_validate(preset)
self.create(style_preset)

View File

@ -13,3 +13,8 @@ class UrlServiceBase(ABC):
def get_model_image_url(self, model_key: str) -> str:
"""Gets the URL for a model image"""
pass
@abstractmethod
def get_style_preset_image_url(self, style_preset_id: str) -> str:
"""Gets the URL for a style preset image"""
pass

View File

@ -19,3 +19,6 @@ class LocalUrlService(UrlServiceBase):
def get_model_image_url(self, model_key: str) -> str:
return f"{self._base_url_v2}/models/i/{model_key}/image"
def get_style_preset_image_url(self, style_preset_id: str) -> str:
return f"{self._base_url}/style_presets/i/{style_preset_id}/image"

View File

@ -0,0 +1,260 @@
{
"name": "FLUX Text to Image",
"author": "InvokeAI",
"description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
"version": "1.0.4",
"contact": "",
"tags": "text2image, flux",
"notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.",
"exposedFields": [
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "model"
},
{
"nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"fieldName": "prompt"
},
{
"nodeId": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"fieldName": "num_steps"
},
{
"nodeId": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"fieldName": "t5_encoder_model"
}
],
"meta": {
"version": "3.0.0",
"category": "default"
},
"nodes": [
{
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "invocation",
"data": {
"id": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"type": "flux_model_loader",
"version": "1.0.4",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"inputs": {
"model": {
"name": "model",
"label": ""
},
"t5_encoder_model": {
"name": "t5_encoder_model",
"label": ""
},
"clip_embed_model": {
"name": "clip_embed_model",
"label": ""
},
"vae_model": {
"name": "vae_model",
"label": ""
}
}
},
"position": {
"x": 381.1882713063478,
"y": -95.89663532854017
}
},
{
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"type": "invocation",
"data": {
"id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"type": "flux_text_encoder",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"inputs": {
"clip": {
"name": "clip",
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": ""
},
"t5_max_seq_len": {
"name": "t5_max_seq_len",
"label": "T5 Max Seq Len",
"value": 256
},
"prompt": {
"name": "prompt",
"label": "",
"value": "a cat"
}
}
},
"position": {
"x": 824.1970602278849,
"y": 146.98251001061735
}
},
{
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
"type": "invocation",
"data": {
"id": "4754c534-a5f3-4ad0-9382-7887985e668c",
"type": "rand_int",
"version": "1.0.1",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"inputs": {
"low": {
"name": "low",
"label": "",
"value": 0
},
"high": {
"name": "high",
"label": "",
"value": 2147483647
}
}
},
"position": {
"x": 822.9899179655476,
"y": 360.9657214885052
}
},
{
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"type": "invocation",
"data": {
"id": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"type": "flux_text_to_image",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": false,
"useCache": true,
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"transformer": {
"name": "transformer",
"label": ""
},
"vae": {
"name": "vae",
"label": ""
},
"positive_text_conditioning": {
"name": "positive_text_conditioning",
"label": ""
},
"width": {
"name": "width",
"label": "",
"value": 1024
},
"height": {
"name": "height",
"label": "",
"value": 1024
},
"num_steps": {
"name": "num_steps",
"label": "Steps (Recommend 30 for Dev, 4 for Schnell)",
"value": 30
},
"guidance": {
"name": "guidance",
"label": "",
"value": 4
},
"seed": {
"name": "seed",
"label": "",
"value": 0
}
}
},
"position": {
"x": 1216.3900791301849,
"y": 5.500841807102248
}
}
],
"edges": [
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90max_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "max_seq_len",
"targetHandle": "t5_max_seq_len"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90vae-159bdf1b-79e7-4174-b86e-d40e646964c8vae",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90t5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90clip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90transformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer",
"type": "default",
"source": "f8d9d7c8-9ed7-4bd7-9e42-ab0e89bfac90",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning",
"type": "default",
"source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "conditioning",
"targetHandle": "positive_text_conditioning"
},
{
"id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-159bdf1b-79e7-4174-b86e-d40e646964c8seed",
"type": "default",
"source": "4754c534-a5f3-4ad0-9382-7887985e668c",
"target": "159bdf1b-79e7-4174-b86e-d40e646964c8",
"sourceHandle": "value",
"targetHandle": "seed"
}
]
}

View File

@ -81,7 +81,7 @@ def get_openapi_func(
# Add the output map to the schema
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": invocation_output_map_properties,
"properties": dict(sorted(invocation_output_map_properties.items())),
"required": invocation_output_map_required,
}

View File

@ -0,0 +1,32 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import torch
from einops import rearrange
from torch import Tensor
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@ -0,0 +1,117 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
@dataclass
class FluxParams:
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params: FluxParams):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img

View File

@ -0,0 +1,310 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
@dataclass
class AutoEncoderParams:
resolution: int
in_channels: int
ch: int
out_ch: int
ch_mult: list[int]
num_res_blocks: int
z_channels: int
scale_factor: float
shift_factor: float
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = torch.nn.functional.silu(h)
h = self.conv1(h)
h = self.norm2(h)
h = torch.nn.functional.silu(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
# no asymmetric padding in torch conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: Tensor):
pad = (0, 1, 0, 1)
x = nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor):
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
block_in = self.ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor) -> Tensor:
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = torch.nn.functional.silu(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z: Tensor) -> Tensor:
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = torch.nn.functional.silu(h)
h = self.conv_out(h)
return h
class DiagonalGaussian(nn.Module):
def __init__(self, sample: bool = True, chunk_dim: int = 1):
super().__init__()
self.sample = sample
self.chunk_dim = chunk_dim
def forward(self, z: Tensor) -> Tensor:
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
else:
return mean
class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams):
super().__init__()
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.decoder = Decoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.reg = DiagonalGaussian()
self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor
def encode(self, x: Tensor) -> Tensor:
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z
def decode(self, z: Tensor) -> Tensor:
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)
def forward(self, x: Tensor) -> Tensor:
return self.decode(self.encode(x))

View File

@ -0,0 +1,33 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from torch import Tensor, nn
from transformers import PreTrainedModel, PreTrainedTokenizer
class HFEncoder(nn.Module):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
super().__init__()
self.max_length = max_length
self.is_clip = is_clip
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
self.tokenizer = tokenizer
self.hf_module = encoder
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str]) -> Tensor:
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
return outputs[self.output_key]

View File

@ -0,0 +1,253 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import math
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
from invokeai.backend.flux.math import attention, rope
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: Tensor) -> Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * self.scale
class QKNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
)
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float | None = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x

View File

@ -0,0 +1,167 @@
# Initially pulled from https://github.com/black-forest-labs/flux
import math
from typing import Callable
import torch
from einops import rearrange, repeat
from torch import Tensor
from tqdm import tqdm
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.conditioner import HFEncoder
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
rand_device = "cpu"
rand_dtype = torch.float16
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[], None],
guidance: float = 4.0,
):
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
step_callback()
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)
def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert an input image in latent space to patches for diffusion.
This implementation was extracted from:
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
Returns:
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
"""
bs, c, h, w = latent_img.shape
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
# Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
return img, img_ids

View File

@ -0,0 +1,71 @@
# Initially pulled from https://github.com/black-forest-labs/flux
from dataclasses import dataclass
from typing import Dict, Literal
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
@dataclass
class ModelSpec:
params: FluxParams
ae_params: AutoEncoderParams
ckpt_path: str | None
ae_path: str | None
repo_id: str | None
repo_flow: str | None
repo_ae: str | None
max_seq_lengths: Dict[str, Literal[256, 512]] = {
"flux-dev": 512,
"flux-schnell": 256,
}
ae_params = {
"flux": AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)
}
params = {
"flux-dev": FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
"flux-schnell": FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
}

View File

@ -1,90 +0,0 @@
from pathlib import Path
from typing import Literal
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from einops import repeat
from PIL import Image
from torchvision.transforms import Compose
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.logging import InvokeAILogger
config = get_config()
logger = InvokeAILogger.get_logger(config=config)
DEPTH_ANYTHING_MODELS = {
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
}
transform = Compose(
[
Resize(
width=518,
height=518,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
]
)
class DepthAnythingDetector:
def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
self.model = model
self.device = device
@staticmethod
def load_model(
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
) -> DPT_DINOv2:
match model_size:
case "small":
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
case "base":
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
case "large":
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
model.eval()
model.to(device)
return model
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
if not self.model:
logger.warn("DepthAnything model was not loaded. Returning original image")
return image
np_image = np.array(image, dtype=np.uint8)
np_image = np_image[:, :, ::-1] / 255.0
image_height, image_width = np_image.shape[:2]
np_image = transform({"image": np_image})["image"]
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
with torch.no_grad():
depth = self.model(tensor_image)
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
depth_map = Image.fromarray(depth_map)
new_height = int(image_height * (resolution / image_width))
depth_map = depth_map.resize((resolution, new_height))
return depth_map

View File

@ -0,0 +1,31 @@
from typing import Optional
import torch
from PIL import Image
from transformers.pipelines import DepthEstimationPipeline
from invokeai.backend.raw_model import RawModel
class DepthAnythingPipeline(RawModel):
"""Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility
for Invoke's Model Management System"""
def __init__(self, pipeline: DepthEstimationPipeline) -> None:
self._pipeline = pipeline
def generate_depth(self, image: Image.Image) -> Image.Image:
depth_map = self._pipeline(image)["depth"]
assert isinstance(depth_map, Image.Image)
return depth_map
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
if device is not None and device.type not in {"cpu", "cuda"}:
device = None
self._pipeline.model.to(device=device, dtype=dtype)
self._pipeline.device = self._pipeline.model.device
def calc_size(self) -> int:
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._pipeline.model)

View File

@ -1,145 +0,0 @@
import torch.nn as nn
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
if len(in_shape) >= 4:
out_shape4 = out_shape
if expand:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
if len(in_shape) >= 4:
out_shape4 = out_shape * 8
scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
if len(in_shape) >= 4:
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
return scratch
class ResidualConvUnit(nn.Module):
"""Residual convolution module."""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups = 1
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
if self.bn:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
class FeatureFusionBlock(nn.Module):
"""Feature fusion block."""
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups = 1
self.expand = expand
out_features = features
if self.expand:
out_features = features // 2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
self.size = size
def forward(self, *xs, size=None):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
output = self.resConfUnit2(output)
if (size is None) and (self.size is None):
modifier = {"scale_factor": 2}
elif size is None:
modifier = {"size": self.size}
else:
modifier = {"size": size}
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
output = self.out_conv(output)
return output

View File

@ -1,183 +0,0 @@
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from invokeai.backend.image_util.depth_anything.model.blocks import FeatureFusionBlock, _make_scratch
torchhub_path = Path(__file__).parent.parent / "torchhub"
def _make_fusion_block(features, use_bn, size=None):
return FeatureFusionBlock(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
size=size,
)
class DPTHead(nn.Module):
def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False):
super(DPTHead, self).__init__()
self.nclass = nclass
self.use_clstoken = use_clstoken
self.projects = nn.ModuleList(
[
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channel,
kernel_size=1,
stride=1,
padding=0,
)
for out_channel in out_channels
]
)
self.resize_layers = nn.ModuleList(
[
nn.ConvTranspose2d(
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
),
nn.ConvTranspose2d(
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
),
]
)
if use_clstoken:
self.readout_projects = nn.ModuleList()
for _ in range(len(self.projects)):
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
self.scratch = _make_scratch(
out_channels,
features,
groups=1,
expand=False,
)
self.scratch.stem_transpose = None
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
head_features_1 = features
head_features_2 = 32
if nclass > 1:
self.scratch.output_conv = nn.Sequential(
nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
)
else:
self.scratch.output_conv1 = nn.Conv2d(
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
)
self.scratch.output_conv2 = nn.Sequential(
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True),
nn.Identity(),
)
def forward(self, out_features, patch_h, patch_w):
out = []
for i, x in enumerate(out_features):
if self.use_clstoken:
x, cls_token = x[0], x[1]
readout = cls_token.unsqueeze(1).expand_as(x)
x = self.readout_projects[i](torch.cat((x, readout), -1))
else:
x = x[0]
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
x = self.projects[i](x)
x = self.resize_layers[i](x)
out.append(x)
layer_1, layer_2, layer_3, layer_4 = out
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv1(path_1)
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
out = self.scratch.output_conv2(out)
return out
class DPT_DINOv2(nn.Module):
def __init__(
self,
features,
out_channels,
encoder="vitl",
use_bn=False,
use_clstoken=False,
):
super(DPT_DINOv2, self).__init__()
assert encoder in ["vits", "vitb", "vitl"]
# # in case the Internet connection is not stable, please load the DINOv2 locally
# if use_local:
# self.pretrained = torch.hub.load(
# torchhub_path / "facebookresearch_dinov2_main",
# "dinov2_{:}14".format(encoder),
# source="local",
# pretrained=False,
# )
# else:
# self.pretrained = torch.hub.load(
# "facebookresearch/dinov2",
# "dinov2_{:}14".format(encoder),
# )
self.pretrained = torch.hub.load(
"facebookresearch/dinov2",
"dinov2_{:}14".format(encoder),
)
dim = self.pretrained.blocks[0].attn.qkv.in_features
self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken)
def forward(self, x):
h, w = x.shape[-2:]
features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
patch_h, patch_w = h // 14, w // 14
depth = self.depth_head(features, patch_h, patch_w)
depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
depth = F.relu(depth)
return depth.squeeze(1)

View File

@ -1,227 +0,0 @@
import math
import cv2
import numpy as np
import torch
import torch.nn.functional as F
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
Args:
sample (dict): sample
size (tuple): image size
Returns:
tuple: new size
"""
shape = list(sample["disparity"].shape)
if shape[0] >= size[0] and shape[1] >= size[1]:
return sample
scale = [0, 0]
scale[0] = size[0] / shape[0]
scale[1] = size[1] / shape[1]
scale = max(scale)
shape[0] = math.ceil(scale * shape[0])
shape[1] = math.ceil(scale * shape[1])
# resize
sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method)
sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST)
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
tuple(shape[::-1]),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return tuple(shape)
class Resize(object):
"""Resize sample to given size (width, height)."""
def __init__(
self,
width,
height,
resize_target=True,
keep_aspect_ratio=False,
ensure_multiple_of=1,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_AREA,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller
than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self.__width = width
self.__height = height
self.__resize_target = resize_target
self.__keep_aspect_ratio = keep_aspect_ratio
self.__multiple_of = ensure_multiple_of
self.__resize_method = resize_method
self.__image_interpolation_method = image_interpolation_method
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
if max_val is not None and y > max_val:
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
if y < min_val:
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
return y
def get_size(self, width, height):
# determine new height and width
scale_height = self.__height / height
scale_width = self.__width / width
if self.__keep_aspect_ratio:
if self.__resize_method == "lower_bound":
# scale such that output size is lower bound
if scale_width > scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "upper_bound":
# scale such that output size is upper bound
if scale_width < scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "minimal":
# scale as least as possbile
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
if self.__resize_method == "lower_bound":
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
elif self.__resize_method == "upper_bound":
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
elif self.__resize_method == "minimal":
new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width)
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
return (new_width, new_height)
def __call__(self, sample):
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
# resize sample
sample["image"] = cv2.resize(
sample["image"],
(width, height),
interpolation=self.__image_interpolation_method,
)
if self.__resize_target:
if "disparity" in sample:
sample["disparity"] = cv2.resize(
sample["disparity"],
(width, height),
interpolation=cv2.INTER_NEAREST,
)
if "depth" in sample:
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
if "semseg_mask" in sample:
# sample["semseg_mask"] = cv2.resize(
# sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
# )
sample["semseg_mask"] = F.interpolate(
torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode="nearest"
).numpy()[0, 0]
if "mask" in sample:
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
(width, height),
interpolation=cv2.INTER_NEAREST,
)
# sample["mask"] = sample["mask"].astype(bool)
# print(sample['image'].shape, sample['depth'].shape)
return sample
class NormalizeImage(object):
"""Normlize image by given mean and std."""
def __init__(self, mean, std):
self.__mean = mean
self.__std = std
def __call__(self, sample):
sample["image"] = (sample["image"] - self.__mean) / self.__std
return sample
class PrepareForNet(object):
"""Prepare sample for usage as network input."""
def __init__(self):
pass
def __call__(self, sample):
image = np.transpose(sample["image"], (2, 0, 1))
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
if "mask" in sample:
sample["mask"] = sample["mask"].astype(np.float32)
sample["mask"] = np.ascontiguousarray(sample["mask"])
if "depth" in sample:
depth = sample["depth"].astype(np.float32)
sample["depth"] = np.ascontiguousarray(depth)
if "semseg_mask" in sample:
sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
return sample

View File

@ -0,0 +1,22 @@
from pydantic import BaseModel, ConfigDict
class BoundingBox(BaseModel):
"""Bounding box helper class."""
xmin: int
ymin: int
xmax: int
ymax: int
class DetectionResult(BaseModel):
"""Detection result from Grounding DINO."""
score: float
label: str
box: BoundingBox
model_config = ConfigDict(
# Allow arbitrary types for mask, since it will be a numpy array.
arbitrary_types_allowed=True
)

View File

@ -0,0 +1,37 @@
from typing import Optional
import torch
from PIL import Image
from transformers.pipelines import ZeroShotObjectDetectionPipeline
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
from invokeai.backend.raw_model import RawModel
class GroundingDinoPipeline(RawModel):
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
management system.
"""
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
self._pipeline = pipeline
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]:
results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
assert results is not None
results = [DetectionResult.model_validate(result) for result in results]
return results
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
# HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or
# CUDA.
if device is not None and device.type not in {"cpu", "cuda"}:
device = None
self._pipeline.model.to(device=device, dtype=dtype)
self._pipeline.device = self._pipeline.model.device
def calc_size(self) -> int:
# HACK(ryand): Fix the circular import issue.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._pipeline.model)

View File

@ -0,0 +1,50 @@
# This file contains utilities for Grounded-SAM mask refinement based on:
# https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
import cv2
import numpy as np
import numpy.typing as npt
def mask_to_polygon(mask: npt.NDArray[np.uint8]) -> list[tuple[int, int]]:
"""Convert a binary mask to a polygon.
Returns:
list[list[int]]: List of (x, y) coordinates representing the vertices of the polygon.
"""
# Find contours in the binary mask.
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Find the contour with the largest area.
largest_contour = max(contours, key=cv2.contourArea)
# Extract the vertices of the contour.
polygon = largest_contour.reshape(-1, 2).tolist()
return polygon
def polygon_to_mask(
polygon: list[tuple[int, int]], image_shape: tuple[int, int], fill_value: int = 1
) -> npt.NDArray[np.uint8]:
"""Convert a polygon to a segmentation mask.
Args:
polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
image_shape (tuple): Shape of the image (height, width) for the mask.
fill_value (int): Value to fill the polygon with.
Returns:
np.ndarray: Segmentation mask with the polygon filled (with value 255).
"""
# Create an empty mask.
mask = np.zeros(image_shape, dtype=np.uint8)
# Convert polygon to an array of points.
pts = np.array(polygon, dtype=np.int32)
# Fill the polygon with white color (255).
cv2.fillPoly(mask, [pts], color=(fill_value,))
return mask

View File

@ -0,0 +1,53 @@
from typing import Optional
import torch
from PIL import Image
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
from invokeai.backend.raw_model import RawModel
class SegmentAnythingPipeline(RawModel):
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):
self._sam_model = sam_model
self._sam_processor = sam_processor
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
# HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA.
if device is not None and device.type not in {"cpu", "cuda"}:
device = None
self._sam_model.to(device=device, dtype=dtype)
def calc_size(self) -> int:
# HACK(ryand): Fix the circular import issue.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._sam_model)
def segment(self, image: Image.Image, bounding_boxes: list[list[int]]) -> torch.Tensor:
"""Run the SAM model.
Args:
image (Image.Image): The image to segment.
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
[xmin, ymin, xmax, ymax].
Returns:
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
"""
# Add batch dimension of 1 to the bounding boxes.
boxes = [bounding_boxes]
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
outputs = self._sam_model(**inputs)
masks = self._sam_processor.post_process_masks(
masks=outputs.pred_masks,
original_sizes=inputs.original_sizes,
reshaped_input_sizes=inputs.reshaped_input_sizes,
)
# There should be only one batch.
assert len(masks) == 1
return masks[0]

View File

@ -3,12 +3,13 @@
import bisect
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
from safetensors.torch import load_file
from typing_extensions import Self
import invokeai.backend.util.logging as logger
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.raw_model import RawModel
@ -46,9 +47,19 @@ class LoRALayerBase:
self.rank = None # set in layer implementation
self.layer_key = layer_key
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
return self.bias
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
params = {"weight": self.get_weight(orig_module.weight)}
bias = self.get_bias(orig_module.bias)
if bias is not None:
params["bias"] = bias
return params
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
@ -60,6 +71,17 @@ class LoRALayerBase:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
unknown_keys = set(values.keys()) - all_known_keys
if unknown_keys:
logger.warning(
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
@ -76,14 +98,19 @@ class LoRALayer(LoRALayerBase):
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
else:
self.mid = None
self.mid = values.get("lora_mid.weight", None)
self.rank = self.down.shape[0]
self.check_keys(
values,
{
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
@ -125,20 +152,23 @@ class LoHALayer(LoRALayerBase):
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
if "hada_t1" in values:
self.t1: Optional[torch.Tensor] = values["hada_t1"]
else:
self.t1 = None
if "hada_t2" in values:
self.t2: Optional[torch.Tensor] = values["hada_t2"]
else:
self.t2 = None
self.t1 = values.get("hada_t1", None)
self.t2 = values.get("hada_t2", None)
self.rank = self.w1_b.shape[0]
self.check_keys(
values,
{
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"hada_t1",
"hada_t2",
},
)
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
@ -186,37 +216,45 @@ class LoKRLayer(LoRALayerBase):
):
super().__init__(layer_key, values)
if "lokr_w1" in values:
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
self.w1 = None
self.w1 = values.get("lokr_w1", None)
if self.w1 is None:
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
if "lokr_w2" in values:
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
self.w2 = None
self.w1_b = None
self.w1_a = None
self.w2 = values.get("lokr_w2", None)
if self.w2 is None:
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
if "lokr_t2" in values:
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
else:
self.t2 = None
self.w2_a = None
self.w2_b = None
if "lokr_w1_b" in values:
self.rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
self.t2 = values.get("lokr_t2", None)
if self.w1_b is not None:
self.rank = self.w1_b.shape[0]
elif self.w2_b is not None:
self.rank = self.w2_b.shape[0]
else:
self.rank = None # unscaled
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
self.check_keys(
values,
{
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1
if w1 is None:
assert self.w1_a is not None
@ -272,7 +310,9 @@ class LoKRLayer(LoRALayerBase):
class FullLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
@ -282,15 +322,12 @@ class FullLayer(LoRALayerBase):
super().__init__(layer_key, values)
self.weight = values["diff"]
if len(values.keys()) > 1:
_keys = list(values.keys())
_keys.remove("diff")
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
self.bias = values.get("diff_b", None)
self.rank = None # unscaled
self.check_keys(values, {"diff", "diff_b"})
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
@ -319,8 +356,9 @@ class IA3Layer(LoRALayerBase):
self.on_input = values["on_input"]
self.rank = None # unscaled
self.check_keys(values, {"weight", "on_input"})
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
@ -340,7 +378,39 @@ class IA3Layer(LoRALayerBase):
self.on_input = self.on_input.to(device=device, dtype=dtype)
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
class NormLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
self.weight = values["w_norm"]
self.bias = values.get("b_norm", None)
self.rank = None # unscaled
self.check_keys(values, {"w_norm", "b_norm"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
class LoRAModelRaw(RawModel): # (torch.nn.Module):
@ -458,16 +528,19 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items():
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
# lora and locon
if "lora_down.weight" in values:
if "lora_up.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
elif "hada_w1_a" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
elif "lokr_w1" in values or "lokr_w1_a" in values:
layer = LoKRLayer(layer_key, values)
# diff
@ -475,9 +548,13 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
layer = FullLayer(layer_key, values)
# ia3
elif "weight" in values and "on_input" in values:
elif "on_input" in values:
layer = IA3Layer(layer_key, values)
# norms
elif "w_norm" in values:
layer = NormLayer(layer_key, values)
else:
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")

View File

@ -52,6 +52,7 @@ class BaseModelType(str, Enum):
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
# Kandinsky2_1 = "kandinsky-2.1"
@ -66,7 +67,9 @@ class ModelType(str, Enum):
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
CLIPEmbed = "clip_embed"
T2IAdapter = "t2i_adapter"
T5Encoder = "t5_encoder"
SpandrelImageToImage = "spandrel_image_to_image"
@ -74,6 +77,7 @@ class SubModelType(str, Enum):
"""Submodel type."""
UNet = "unet"
Transformer = "transformer"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
@ -104,6 +108,9 @@ class ModelFormat(str, Enum):
EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder"
InvokeAI = "invokeai"
T5Encoder = "t5_encoder"
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
BnbQuantizednf4b = "bnb_quantized_nf4b"
class SchedulerPredictionType(str, Enum):
@ -186,7 +193,9 @@ class ModelConfigBase(BaseModel):
class CheckpointConfigBase(ModelConfigBase):
"""Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
)
config_path: str = Field(description="path to the checkpoint model config file")
converted_at: Optional[float] = Field(
description="When this model was last converted to diffusers", default_factory=time.time
@ -205,6 +214,26 @@ class LoRAConfigBase(ModelConfigBase):
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
class T5EncoderConfigBase(ModelConfigBase):
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
class T5EncoderConfig(T5EncoderConfigBase):
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase):
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.BnbQuantizedLlmInt8b.value}")
class LoRALyCORISConfig(LoRAConfigBase):
"""Model config for LoRA/Lycoris models."""
@ -229,7 +258,6 @@ class VAECheckpointConfig(CheckpointConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
@ -268,7 +296,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase)
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
@ -317,6 +344,21 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
"""Model config for main checkpoint models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.format = ModelFormat.BnbQuantizednf4b
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
"""Model config for main diffusers models."""
@ -350,6 +392,17 @@ class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
"""Model config for Clip Embeddings."""
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
"""Model config for CLIPVision."""
@ -408,12 +461,15 @@ AnyModelConfig = Annotated[
Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
@ -421,6 +477,7 @@ AnyModelConfig = Annotated[
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]

View File

@ -72,6 +72,7 @@ class ModelLoader(ModelLoaderBase):
pass
config.path = str(self._get_model_path(config))
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
loaded_model = self._load_model(config, submodel_type)
self._ram_cache.put(

View File

@ -193,15 +193,6 @@ class ModelCacheBase(ABC, Generic[T]):
"""
pass
@abstractmethod
def exists(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
) -> bool:
"""Return true if the model identified by key and submodel_type is in the cache."""
pass
@abstractmethod
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""

View File

@ -1,22 +1,6 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
# TODO: Add Stalker's proper name to copyright
"""
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
The cache returns context manager generators designed to load the
model into the GPU within the context, and unload outside the
context. Use like this:
cache = ModelCache(max_cache_size=7.5)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
do_something_in_GPU(SD1,SD2)
"""
""" """
import gc
import math
@ -40,45 +24,64 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
# amount of GPU memory to hold in reserve for use by generations (GB)
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig
GIG = 1073741824
# Size of a GB in bytes.
GB = 2**30
# Size of a MB in bytes.
MB = 2**20
class ModelCache(ModelCacheBase[AnyModel]):
"""Implementation of ModelCacheBase."""
"""A cache for managing models in memory.
The cache is based on two levels of model storage:
- execution_device: The device where most models are executed (typically "cuda", "mps", or "cpu").
- storage_device: The device where models are offloaded when not in active use (typically "cpu").
The model cache is based on the following assumptions:
- storage_device_mem_size > execution_device_mem_size
- disk_to_storage_device_transfer_time >> storage_device_to_execution_device_transfer_time
A copy of all models in the cache is always kept on the storage_device. A subset of the models also have a copy on
the execution_device.
Models are moved between the storage_device and the execution_device as necessary. Cache size limits are enforced
on both the storage_device and the execution_device. The execution_device cache uses a smallest-first offload
policy. The storage_device cache uses a least-recently-used (LRU) offload policy.
Note: Neither of these offload policies has really been compared against alternatives. It's likely that different
policies would be better, although the optimal policies are likely heavily dependent on usage patterns and HW
configuration.
The cache returns context manager generators designed to load the model into the execution device (often GPU) within
the context, and unload outside the context.
Example usage:
```
cache = ModelCache(max_cache_size=7.5, max_vram_cache_size=6.0)
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1:
do_something_on_gpu(SD1)
```
"""
def __init__(
self,
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
max_cache_size: float,
max_vram_cache_size: float,
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
):
"""
Initialize the model RAM cache.
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
:param max_cache_size: Maximum size of the storage_device cache in GBs.
:param max_vram_cache_size: Maximum size of the execution_device cache in GBs.
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded.
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
@ -86,7 +89,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""
# allow lazy offloading only when vram cache enabled
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self._precision: torch.dtype = precision
self._max_cache_size: float = max_cache_size
self._max_vram_cache_size: float = max_vram_cache_size
self._execution_device: torch.device = execution_device
@ -145,15 +147,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
total += cache_record.size
return total
def exists(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
) -> bool:
"""Return true if the model identified by key and submodel_type is in the cache."""
key = self._make_cache_key(key, submodel_type)
return key in self._cached_models
def put(
self,
key: str,
@ -203,7 +196,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
# more stats
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.cache_size = int(self._max_cache_size * GB)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
@ -231,10 +224,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
return model_key
def offload_unlocked_models(self, size_required: int) -> None:
"""Move any unused models from VRAM."""
reserved = self._max_vram_cache_size * GIG
"""Offload models from the execution_device to make room for size_required.
:param size_required: The amount of space to clear in the execution_device cache, in bytes.
"""
reserved = self._max_vram_cache_size * GB
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
self.logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
@ -245,7 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB"
)
TorchDevice.empty_cache()
@ -303,7 +299,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self.logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
f"Estimated model size: {(cache_entry.size/GB):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
@ -326,14 +322,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
f"Moving model '{cache_entry.key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f" {(cache_entry.size/GB):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % (self.cache_size() / GIG)
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB)
ram = "%4.2fG" % (self.cache_size() / GB)
in_ram_models = 0
in_vram_models = 0
@ -353,17 +349,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
)
def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size."""
# calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1
"""Make enough room in the cache to accommodate a new model of indicated size.
Note: This function deletes all of the cache's internal references to a model in order to free it. If there are
external references to the model, there's nothing that the cache can do about it, and those models will not be
garbage-collected.
"""
bytes_needed = size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
maximum_size = self.max_cache_size * GB # stored in GB, convert to bytes
current_size = self.cache_size()
if current_size + bytes_needed > maximum_size:
self.logger.debug(
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GIG):.2f} GB"
f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GB):.2f} GB"
)
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
@ -380,7 +379,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
if not cache_entry.locked:
self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1

View File

@ -0,0 +1,234 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Flux model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
import accelerate
import torch
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.util import ae_params, params
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
CLIPEmbedDiffusersConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
T5EncoderBnbQuantizedLlmInt8bConfig,
T5EncoderConfig,
VAECheckpointConfig,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.util.silence_warnings import SilenceWarnings
try:
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
bnb_available = True
except ImportError:
bnb_available = False
app_config = get_config()
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class FluxVAELoader(ModelLoader):
"""Class to load VAE models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, VAECheckpointConfig):
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
model_path = Path(config.path)
with SilenceWarnings():
model = AutoEncoder(ae_params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
model.to(dtype=self._torch_dtype)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
class ClipCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CLIPEmbedDiffusersConfig):
raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer:
return CLIPTokenizer.from_pretrained(Path(config.path) / "tokenizer")
case SubModelType.TextEncoder:
return CLIPTextModel.from_pretrained(Path(config.path) / "text_encoder")
raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b)
class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, T5EncoderBnbQuantizedLlmInt8bConfig):
raise ValueError("Only T5EncoderBnbQuantizedLlmInt8bConfig models are currently supported here.")
if not bnb_available:
raise ImportError(
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
)
match submodel_type:
case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
te2_model_path = Path(config.path) / "text_encoder_2"
model_config = AutoConfig.from_pretrained(te2_model_path)
with accelerate.init_empty_weights():
model = AutoModelForTextEncoding.from_config(model_config)
model = quantize_model_llm_int8(model, modules_to_not_convert=set())
state_dict_path = te2_model_path / "bnb_llm_int8_model.safetensors"
state_dict = load_file(state_dict_path)
self._load_state_dict_into_t5(model, state_dict)
return model
raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
@classmethod
def _load_state_dict_into_t5(cls, model: T5EncoderModel, state_dict: dict[str, torch.Tensor]):
# There is a shared reference to a single weight tensor in the model.
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
# be present in the state_dict.
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
assert len(unexpected_keys) == 0
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
# Assert that the layers we expect to be shared are actually shared.
assert model.encoder.embed_tokens.weight is model.shared.weight
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
class T5EncoderCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, T5EncoderConfig):
raise ValueError("Only T5EncoderConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer2:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2")
raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
class FluxCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config)
raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
def _load_from_singlefile(
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainCheckpointConfig)
model_path = Path(config.path)
with SilenceWarnings():
model = Flux(params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b)
class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
"""Class to load main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config)
raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
def _load_from_singlefile(
self,
config: AnyModelConfig,
) -> AnyModel:
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
if not bnb_available:
raise ImportError(
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
)
model_path = Path(config.path)
with SilenceWarnings():
with accelerate.init_empty_weights():
model = Flux(params[config.config_path])
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
return model

View File

@ -78,7 +78,12 @@ class GenericDiffusersLoader(ModelLoader):
# TO DO: Add exception handling
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
if module in ["diffusers", "transformers"]:
if module in [
"diffusers",
"transformers",
"invokeai.backend.quantization.fast_quantized_transformers_model",
"invokeai.backend.quantization.fast_quantized_diffusion_model",
]:
res_type = sys.modules[module]
else:
res_type = sys.modules["diffusers"].pipelines

View File

@ -36,8 +36,18 @@ VARIANT_TO_IN_CHANNEL_MAP = {
}
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint
)
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models."""

View File

@ -9,8 +9,11 @@ from typing import Optional
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager.config import AnyModel
@ -34,8 +37,30 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
elif isinstance(model, CLIPTokenizer):
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
return 0
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)):
elif isinstance(
model,
(
TextualInversionModelRaw,
IPAdapter,
LoRAModelRaw,
SpandrelImageToImageModel,
GroundingDinoPipeline,
SegmentAnythingPipeline,
DepthAnythingPipeline,
),
):
return model.calc_size()
elif isinstance(
model,
(
T5TokenizerFast,
T5Tokenizer,
),
):
# HACK(ryand): len(model) just returns the vocabulary size, so this is blatantly wrong. It should be small
# relative to the text encoder that it's used with, so shouldn't matter too much, but we should fix this at some
# point.
return len(model)
else:
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
# supported model types.

View File

@ -95,6 +95,7 @@ class ModelProbe(object):
}
CLASS2TYPE = {
"FluxPipeline": ModelType.Main,
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
@ -106,6 +107,7 @@ class ModelProbe(object):
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter,
"CLIPModel": ModelType.CLIPEmbed,
}
@classmethod
@ -161,7 +163,7 @@ class ModelProbe(object):
fields["description"] = (
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
fields["default_settings"] = fields.get("default_settings")
@ -176,10 +178,10 @@ class ModelProbe(object):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
# additional fields needed for main and controlnet models
if (
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
and fields["format"] is ModelFormat.Checkpoint
):
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [
ModelFormat.Checkpoint,
ModelFormat.BnbQuantizednf4b,
]:
ckpt_config_path = cls._get_checkpoint_config_path(
model_path,
model_type=fields["type"],
@ -222,7 +224,8 @@ class ModelProbe(object):
ckpt = ckpt.get("state_dict", ckpt)
for key in [str(k) for k in ckpt.keys()]:
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")):
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.", "double_blocks.")):
# Keys starting with double_blocks are associated with Flux models
return ModelType.Main
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
return ModelType.VAE
@ -321,10 +324,27 @@ class ModelProbe(object):
return possible_conf.absolute()
if model_type is ModelType.Main:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
config_file = f"stable-diffusion/{config_file}"
if base_type == BaseModelType.Flux:
# TODO: Decide between dev/schnell
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint
if "guidance_in.out_layer.weight" in state_dict:
# For flux, this is a key in invokeai.backend.flux.util.params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
config_file = "flux-dev"
else:
# For flux, this is a key in invokeai.backend.flux.util.params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
config_file = "flux-schnell"
else:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
config_file = f"stable-diffusion/{config_file}"
elif model_type is ModelType.ControlNet:
config_file = (
"controlnet/cldm_v15.yaml"
@ -333,7 +353,13 @@ class ModelProbe(object):
)
elif model_type is ModelType.VAE:
config_file = (
"stable-diffusion/v1-inference.yaml"
# For flux, this is a key in invokeai.backend.flux.util.ae_params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
"flux"
if base_type is BaseModelType.Flux
else "stable-diffusion/v1-inference.yaml"
if base_type is BaseModelType.StableDiffusion1
else "stable-diffusion/sd_xl_base.yaml"
if base_type is BaseModelType.StableDiffusionXL
@ -416,11 +442,15 @@ class CheckpointProbeBase(ProbeBase):
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
def get_format(self) -> ModelFormat:
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
if "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict:
return ModelFormat.BnbQuantizednf4b
return ModelFormat("checkpoint")
def get_variant_type(self) -> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
if model_type != ModelType.Main:
base_type = self.get_base_type()
if model_type != ModelType.Main or base_type == BaseModelType.Flux:
return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
@ -440,6 +470,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
if "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
return BaseModelType.Flux
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
@ -482,6 +514,7 @@ class VaeCheckpointProbe(CheckpointProbeBase):
(r"xl", BaseModelType.StableDiffusionXL),
(r"sd2", BaseModelType.StableDiffusion2),
(r"vae", BaseModelType.StableDiffusion1),
(r"FLUX.1-schnell_ae", BaseModelType.Flux),
]:
if re.search(regexp, self.model_path.name, re.IGNORECASE):
return basetype
@ -713,6 +746,11 @@ class TextualInversionFolderProbe(FolderProbeBase):
return TextualInversionCheckpointProbe(path).get_base_type()
class T5EncoderFolderProbe(FolderProbeBase):
def get_format(self) -> ModelFormat:
return ModelFormat.T5Encoder
class ONNXFolderProbe(PipelineFolderProbe):
def get_base_type(self) -> BaseModelType:
# Due to the way the installer is set up, the configuration file for safetensors
@ -805,6 +843,11 @@ class CLIPVisionFolderProbe(FolderProbeBase):
return BaseModelType.Any
class CLIPEmbedFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
class SpandrelImageToImageFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
@ -835,8 +878,10 @@ ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPEmbed, CLIPEmbedFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)

View File

@ -2,7 +2,7 @@ from typing import Optional
from pydantic import BaseModel
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
class StarterModelWithoutDependencies(BaseModel):
@ -11,6 +11,7 @@ class StarterModelWithoutDependencies(BaseModel):
name: str
base: BaseModelType
type: ModelType
format: Optional[ModelFormat] = None
is_installed: bool = False
@ -51,10 +52,76 @@ cyberrealistic_negative = StarterModel(
type=ModelType.TextualInversion,
)
t5_base_encoder = StarterModel(
name="t5_base_encoder",
base=BaseModelType.Any,
source="InvokeAI/t5-v1_1-xxl::bfloat16",
description="T5-XXL text encoder (used in FLUX pipelines). ~8GB",
type=ModelType.T5Encoder,
)
t5_8b_quantized_encoder = StarterModel(
name="t5_bnb_int8_quantized_encoder",
base=BaseModelType.Any,
source="InvokeAI/t5-v1_1-xxl::bnb_llm_int8",
description="T5-XXL text encoder with bitsandbytes LLM.int8() quantization (used in FLUX pipelines). ~5GB",
type=ModelType.T5Encoder,
format=ModelFormat.BnbQuantizedLlmInt8b,
)
clip_l_encoder = StarterModel(
name="clip-vit-large-patch14",
base=BaseModelType.Any,
source="InvokeAI/clip-vit-large-patch14-text-encoder::bfloat16",
description="CLIP-L text encoder (used in FLUX pipelines). ~250MB",
type=ModelType.CLIPEmbed,
)
flux_vae = StarterModel(
name="FLUX.1-schnell_ae",
base=BaseModelType.Flux,
source="black-forest-labs/FLUX.1-schnell::ae.safetensors",
description="FLUX VAE compatible with both schnell and dev variants.",
type=ModelType.VAE,
)
# List of starter models, displayed on the frontend.
# The order/sort of this list is not changed by the frontend - set it how you want it here.
STARTER_MODELS: list[StarterModel] = [
# region: Main
StarterModel(
name="FLUX Schnell (Quantized)",
base=BaseModelType.Flux,
source="InvokeAI/flux_schnell::transformer/bnb_nf4/flux1-schnell-bnb_nf4.safetensors",
description="FLUX schnell transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
),
StarterModel(
name="FLUX Dev (Quantized)",
base=BaseModelType.Flux,
source="InvokeAI/flux_dev::transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors",
description="FLUX dev transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
),
StarterModel(
name="FLUX Schnell",
base=BaseModelType.Flux,
source="InvokeAI/flux_schnell::transformer/base/flux1-schnell.safetensors",
description="FLUX schnell transformer in bfloat16. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
),
StarterModel(
name="FLUX Dev",
base=BaseModelType.Flux,
source="InvokeAI/flux_dev::transformer/base/flux1-dev.safetensors",
description="FLUX dev transformer in bfloat16. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
),
StarterModel(
name="CyberRealistic v4.1",
base=BaseModelType.StableDiffusion1,
@ -125,6 +192,7 @@ STARTER_MODELS: list[StarterModel] = [
# endregion
# region VAE
sdxl_fp16_vae_fix,
flux_vae,
# endregion
# region LoRA
StarterModel(
@ -450,6 +518,11 @@ STARTER_MODELS: list[StarterModel] = [
type=ModelType.SpandrelImageToImage,
),
# endregion
# region TextEncoders
t5_base_encoder,
t5_8b_quantized_encoder,
clip_l_encoder,
# endregion
]
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"

View File

@ -54,6 +54,7 @@ def filter_files(
"lora_weights.safetensors",
"weights.pb",
"onnx_data",
"spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`.
)
):
paths.append(file)
@ -62,13 +63,13 @@ def filter_files(
# downloading random checkpoints that might also be in the repo. However there is no guarantee
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
# will adhere to this naming convention, so this is an area to be careful of.
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
paths.append(file)
# limit search to subfolder if requested
if subfolder:
subfolder = root / subfolder
paths = [x for x in paths if x.parent == Path(subfolder)]
paths = [x for x in paths if Path(subfolder) in x.parents]
# _filter_by_variant uniquifies the paths and returns a set
return sorted(_filter_by_variant(paths, variant))
@ -97,7 +98,9 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
if variant == ModelRepoVariant.Flax:
result.add(path)
elif path.suffix in [".json", ".txt"]:
# Note: '.model' was added to support:
# https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/768d12a373ed5cc9ef9a9dea7504dc09fcc14842/tokenizer_2/spiece.model
elif path.suffix in [".json", ".txt", ".model"]:
result.add(path)
elif variant in [
@ -140,6 +143,23 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
continue
for candidate_list in subfolder_weights.values():
# Check if at least one of the files has the explicit fp16 variant.
at_least_one_fp16 = False
for candidate in candidate_list:
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
at_least_one_fp16 = True
break
if not at_least_one_fp16:
# If none of the candidates in this candidate_list have the explicit fp16 variant label, then this
# candidate_list probably doesn't adhere to the variant naming convention that we expected. In this case,
# we'll simply keep all the candidates. An example of a model that hits this case is
# `black-forest-labs/FLUX.1-schnell` (as of commit 012d2fd).
for candidate in candidate_list:
result.add(candidate.path)
# The candidate_list seems to have the expected variant naming convention. We'll select the highest scoring
# candidate.
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
if highest_score_candidate:
result.add(highest_score_candidate.path)

View File

@ -17,8 +17,9 @@ from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
"""
loras = [
@ -85,13 +86,13 @@ class ModelPatcher:
cls,
unet: UNet2DConditionModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
with cls.apply_lora(
unet,
loras=loras,
prefix="lora_unet_",
model_state_dict=model_state_dict,
cached_weights=cached_weights,
):
yield
@ -101,9 +102,9 @@ class ModelPatcher:
cls,
text_encoder: CLIPTextModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
yield
@classmethod
@ -113,7 +114,7 @@ class ModelPatcher:
model: AnyModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str,
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[None, None, None]:
"""
Apply one or more LoRAs to a model.
@ -121,66 +122,26 @@ class ModelPatcher:
:param model: The model to patch.
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
"""
original_weights = {}
original_weights = OriginalWeightsStorage(cached_weights)
try:
with torch.no_grad():
for lora, lora_weight in loras:
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
for lora_model, lora_weight in loras:
LoRAExt.patch_model(
model=model,
prefix=prefix,
lora=lora_model,
lora_weight=lora_weight,
original_weights=original_weights,
)
del lora_model
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
# should be improved in the following ways:
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
# LoRA model is applied.
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
assert isinstance(model, torch.nn.Module)
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
if module_key not in original_weights:
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
original_weights[module_key] = model_state_dict[module_key + ".weight"]
else:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=TorchDevice.CPU_DEVICE)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
assert hasattr(layer_weight, "reshape")
layer_weight = layer_weight.reshape(module.weight.shape)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
module.weight += layer_weight.to(dtype=dtype)
yield # wait for context manager exit
yield
finally:
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight)
for param_key, weight in original_weights.get_changed_weights():
model.get_parameter(param_key).copy_(weight)
@classmethod
@contextmanager

View File

@ -0,0 +1,135 @@
import bitsandbytes as bnb
import torch
# This file contains utils for working with models that use bitsandbytes LLM.int8() quantization.
# The utils in this file are partially inspired by:
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
# cleaner by re-implementing bnb.nn.Linear8bitLt with proper use of buffers and less magic. But, for now, we try to
# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
class InvokeInt8Params(bnb.nn.Int8Params):
"""We override cuda() to avoid re-quantizing the weights in the following cases:
- We loaded quantized weights from a state_dict on the cpu, and then moved the model to the gpu.
- We are moving the model back-and-forth between the cpu and gpu.
"""
def cuda(self, device):
if self.has_fp16_weights:
return super().cuda(device)
elif self.CB is not None and self.SCB is not None:
self.data = self.data.cuda()
self.CB = self.data
self.SCB = self.SCB.cuda()
else:
# we store the 8-bit rows-major weight
# we convert this weight to the turning/ampere weight during the first inference pass
B = self.data.contiguous().half().cuda(device)
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
del CBt
del SCBt
self.data = CB
self.CB = CB
self.SCB = SCB
return self
class InvokeLinear8bitLt(bnb.nn.Linear8bitLt):
def _load_from_state_dict(
self,
state_dict: dict[str, torch.Tensor],
prefix: str,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
weight = state_dict.pop(prefix + "weight")
bias = state_dict.pop(prefix + "bias", None)
# See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format.
scb = state_dict.pop(prefix + "SCB", None)
# Currently, we only support weight_format=0.
weight_format = state_dict.pop(prefix + "weight_format", None)
assert weight_format == 0
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
# rather than raising an exception to correctly implement this API.
assert len(state_dict) == 0
if scb is not None:
# We are loading a pre-quantized state dict.
self.weight = InvokeInt8Params(
data=weight,
requires_grad=self.weight.requires_grad,
has_fp16_weights=False,
# Note: After quantization, CB is the same as weight.
CB=weight,
SCB=scb,
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
else:
# We are loading a non-quantized state dict.
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
self.weight = InvokeInt8Params(
data=weight,
requires_grad=self.weight.requires_grad,
has_fp16_weights=False,
CB=None,
SCB=None,
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
# Reset the state. The persisted fields are based on the initialization behaviour in
# `bnb.nn.Linear8bitLt.__init__()`.
new_state = bnb.MatmulLtState()
new_state.threshold = self.state.threshold
new_state.has_fp16_weights = False
new_state.use_pool = self.state.use_pool
self.state = new_state
def _convert_linear_layers_to_llm_8bit(
module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = ""
) -> None:
"""Convert all linear layers in the module to bnb.nn.Linear8bitLt layers."""
for name, child in module.named_children():
fullname = f"{prefix}.{name}" if prefix else name
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
has_bias = child.bias is not None
replacement = InvokeLinear8bitLt(
child.in_features,
child.out_features,
bias=has_bias,
has_fp16_weights=False,
threshold=outlier_threshold,
)
replacement.weight.data = child.weight.data
if has_bias:
replacement.bias.data = child.bias.data
replacement.requires_grad_(False)
module.__setattr__(name, replacement)
else:
_convert_linear_layers_to_llm_8bit(
child, ignore_modules, outlier_threshold=outlier_threshold, prefix=fullname
)
def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0):
"""Apply bitsandbytes LLM.8bit() quantization to the model."""
_convert_linear_layers_to_llm_8bit(
module=model, ignore_modules=modules_to_not_convert, outlier_threshold=outlier_threshold
)
return model

View File

@ -0,0 +1,156 @@
import bitsandbytes as bnb
import torch
# This file contains utils for working with models that use bitsandbytes NF4 quantization.
# The utils in this file are partially inspired by:
# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py
# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much
# cleaner by re-implementing bnb.nn.LinearNF4 with proper use of buffers and less magic. But, for now, we try to stick
# close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes.
class InvokeLinearNF4(bnb.nn.LinearNF4):
"""A class that extends `bnb.nn.LinearNF4` to add the following functionality:
- Ability to load Linear NF4 layers from a pre-quantized state_dict.
- Ability to load Linear NF4 layers from a state_dict when the model is on the "meta" device.
"""
def _load_from_state_dict(
self,
state_dict: dict[str, torch.Tensor],
prefix: str,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`:
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71
"""
weight = state_dict.pop(prefix + "weight")
bias = state_dict.pop(prefix + "bias", None)
# We expect the remaining keys to be quant_state keys.
quant_state_sd = state_dict
# During serialization, the quant_state is stored as subkeys of "weight." (See
# `bnb.nn.LinearNF4._save_to_state_dict()`). We validate that they at least have the correct prefix.
# TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs`
# rather than raising an exception to correctly implement this API.
assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys())
if len(quant_state_sd) > 0:
# We are loading a pre-quantized state dict.
self.weight = bnb.nn.Params4bit.from_prequantized(
data=weight, quantized_stats=quant_state_sd, device=weight.device
)
self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False)
else:
# We are loading a non-quantized state dict.
# We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to
# load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta"
# device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()`
# implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new
# `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done.
self.weight = bnb.nn.Params4bit(
data=weight,
requires_grad=self.weight.requires_grad,
compress_statistics=self.weight.compress_statistics,
quant_type=self.weight.quant_type,
quant_storage=self.weight.quant_storage,
module=self,
)
self.bias = bias if bias is None else torch.nn.Parameter(bias)
def _replace_param(
param: torch.nn.Parameter | bnb.nn.Params4bit,
data: torch.Tensor,
) -> torch.nn.Parameter:
"""A helper function to replace the data of a model parameter with new data in a way that allows replacing params on
the "meta" device.
Supports both `torch.nn.Parameter` and `bnb.nn.Params4bit` parameters.
"""
if param.device.type == "meta":
# Doing `param.data = data` raises a RuntimeError if param.data was on the "meta" device, so we need to
# re-create the param instead of overwriting the data.
if isinstance(param, bnb.nn.Params4bit):
return bnb.nn.Params4bit(
data,
requires_grad=data.requires_grad,
quant_state=param.quant_state,
compress_statistics=param.compress_statistics,
quant_type=param.quant_type,
)
return torch.nn.Parameter(data, requires_grad=data.requires_grad)
param.data = data
return param
def _convert_linear_layers_to_nf4(
module: torch.nn.Module,
ignore_modules: set[str],
compute_dtype: torch.dtype,
compress_statistics: bool = False,
prefix: str = "",
) -> None:
"""Convert all linear layers in the model to NF4 quantized linear layers.
Args:
module: All linear layers in this module will be converted.
ignore_modules: A set of module prefixes to ignore when converting linear layers.
compute_dtype: The dtype to use for computation in the quantized linear layers.
compress_statistics: Whether to enable nested quantization (aka double quantization) where the quantization
constants from the first quantization are quantized again.
prefix: The prefix of the current module in the model. Used to call this function recursively.
"""
for name, child in module.named_children():
fullname = f"{prefix}.{name}" if prefix else name
if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules):
has_bias = child.bias is not None
replacement = InvokeLinearNF4(
child.in_features,
child.out_features,
bias=has_bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
)
if has_bias:
replacement.bias = _replace_param(replacement.bias, child.bias.data)
replacement.weight = _replace_param(replacement.weight, child.weight.data)
replacement.requires_grad_(False)
module.__setattr__(name, replacement)
else:
_convert_linear_layers_to_nf4(child, ignore_modules, compute_dtype=compute_dtype, prefix=fullname)
def quantize_model_nf4(model: torch.nn.Module, modules_to_not_convert: set[str], compute_dtype: torch.dtype):
"""Apply bitsandbytes nf4 quantization to the model.
You likely want to call this function inside a `accelerate.init_empty_weights()` context.
Example usage:
```
# Initialize the model from a config on the meta device.
with accelerate.init_empty_weights():
model = ModelClass.from_config(...)
# Add NF4 quantization linear layers to the model - still on the meta device.
with accelerate.init_empty_weights():
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.float16)
# Load a state_dict into the model. (Could be either a prequantized or non-quantized state_dict.)
model.load_state_dict(state_dict, strict=True, assign=True)
# Move the model to the "cuda" device. If the model was non-quantized, this is where the weight quantization takes
# place.
model.to("cuda")
```
"""
_convert_linear_layers_to_nf4(module=model, ignore_modules=modules_to_not_convert, compute_dtype=compute_dtype)
return model

View File

@ -0,0 +1,79 @@
from pathlib import Path
import accelerate
from safetensors.torch import load_file, save_file
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
def main():
"""A script for quantizing a FLUX transformer model using the bitsandbytes LLM.int8() quantization method.
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
etc.) are hardcoded and would need to be modified for other use cases.
"""
# Load the FLUX transformer model onto the meta device.
model_path = Path(
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
)
with log_time("Intialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
p = params["flux-schnell"]
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():
model = Flux(p)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
modules_to_not_convert: set[str] = set()
model_int8_path = model_path.parent / "bnb_llm_int8.safetensors"
if model_int8_path.exists():
# The quantized model already exists, load it and return it.
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
sd = load_file(model_int8_path)
model.load_state_dict(sd, strict=True, assign=True)
with log_time("Move model to cuda"):
model = model.to("cuda")
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
else:
# The quantized model does not exist, quantize the model and save it.
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
state_dict = load_file(model_path)
# TODO(ryand): Cast the state_dict to the appropriate dtype?
model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda and quantize"):
model = model.to("cuda")
with log_time("Save quantized model"):
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
save_file(model.state_dict(), model_int8_path)
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
assert isinstance(model, Flux)
return model
if __name__ == "__main__":
main()

View File

@ -0,0 +1,96 @@
import time
from contextlib import contextmanager
from pathlib import Path
import accelerate
import torch
from safetensors.torch import load_file, save_file
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
@contextmanager
def log_time(name: str):
"""Helper context manager to log the time taken by a block of code."""
start = time.time()
try:
yield None
finally:
end = time.time()
print(f"'{name}' took {end - start:.4f} secs")
def main():
"""A script for quantizing a FLUX transformer model using the bitsandbytes NF4 quantization method.
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
etc.) are hardcoded and would need to be modified for other use cases.
"""
model_path = Path(
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
)
# inference_dtype = torch.bfloat16
with log_time("Intialize FLUX transformer on meta device"):
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
p = params["flux-schnell"]
# Initialize the model on the "meta" device.
with accelerate.init_empty_weights():
model = Flux(p)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
modules_to_not_convert: set[str] = set()
model_nf4_path = model_path.parent / "bnb_nf4.safetensors"
if model_nf4_path.exists():
# The quantized model already exists, load it and return it.
print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...")
# Replace the linear layers with NF4 quantized linear layers (still on the meta device).
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
model = quantize_model_nf4(
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
)
with log_time("Load state dict into model"):
state_dict = load_file(model_nf4_path)
model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda"):
model = model.to("cuda")
print(f"Successfully loaded pre-quantized model from '{model_nf4_path}'.")
else:
# The quantized model does not exist, quantize the model and save it.
print(f"No pre-quantized model found at '{model_nf4_path}'. Quantizing the model...")
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
model = quantize_model_nf4(
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
)
with log_time("Load state dict into model"):
state_dict = load_file(model_path)
# TODO(ryand): Cast the state_dict to the appropriate dtype?
model.load_state_dict(state_dict, strict=True, assign=True)
with log_time("Move model to cuda and quantize"):
model = model.to("cuda")
with log_time("Save quantized model"):
model_nf4_path.parent.mkdir(parents=True, exist_ok=True)
save_file(model.state_dict(), model_nf4_path)
print(f"Successfully quantized and saved model to '{model_nf4_path}'.")
assert isinstance(model, Flux)
return model
if __name__ == "__main__":
main()

View File

@ -0,0 +1,92 @@
from pathlib import Path
import accelerate
from safetensors.torch import load_file, save_file
from transformers import AutoConfig, AutoModelForTextEncoding, T5EncoderModel
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
def load_state_dict_into_t5(model: T5EncoderModel, state_dict: dict):
# There is a shared reference to a single weight tensor in the model.
# Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should
# be present in the state_dict.
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True)
assert len(unexpected_keys) == 0
assert set(missing_keys) == {"encoder.embed_tokens.weight"}
# Assert that the layers we expect to be shared are actually shared.
assert model.encoder.embed_tokens.weight is model.shared.weight
def main():
"""A script for quantizing a T5 text encoder model using the bitsandbytes LLM.int8() quantization method.
This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert,
etc.) are hardcoded and would need to be modified for other use cases.
"""
model_path = Path("/data/misc/text_encoder_2")
with log_time("Intialize T5 on meta device"):
model_config = AutoConfig.from_pretrained(model_path)
with accelerate.init_empty_weights():
model = AutoModelForTextEncoding.from_config(model_config)
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
modules_to_not_convert: set[str] = set()
model_int8_path = model_path / "bnb_llm_int8.safetensors"
if model_int8_path.exists():
# The quantized model already exists, load it and return it.
print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...")
# Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device).
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
sd = load_file(model_int8_path)
load_state_dict_into_t5(model, sd)
with log_time("Move model to cuda"):
model = model.to("cuda")
print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.")
else:
# The quantized model does not exist, quantize the model and save it.
print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...")
with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights():
model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert)
with log_time("Load state dict into model"):
# Load sharded state dict.
files = list(model_path.glob("*.safetensors"))
state_dict = {}
for file in files:
sd = load_file(file)
state_dict.update(sd)
load_state_dict_into_t5(model, state_dict)
with log_time("Move model to cuda and quantize"):
model = model.to("cuda")
with log_time("Save quantized model"):
model_int8_path.parent.mkdir(parents=True, exist_ok=True)
state_dict = model.state_dict()
state_dict.pop("encoder.embed_tokens.weight")
save_file(state_dict, model_int8_path)
# This handling of shared weights could also be achieved with save_model(...), but then we'd lose control
# over which keys are kept. And, the corresponding load_model(...) function does not support assign=True.
# save_model(model, model_int8_path)
print(f"Successfully quantized and saved model to '{model_int8_path}'.")
assert isinstance(model, T5EncoderModel)
return model
if __name__ == "__main__":
main()

View File

@ -7,11 +7,9 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import ( # noqa: F401
StableDiffusionGeneratorPipeline,
)
from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401
from invokeai.backend.stable_diffusion.seamless import set_seamless # noqa: F401
__all__ = [
"PipelineIntermediateState",
"StableDiffusionGeneratorPipeline",
"InvokeAIDiffuserComponent",
"set_seamless",
]

View File

@ -25,11 +25,6 @@ class BasicConditioningInfo:
return self
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo]
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
"""SDXL text conditioning information produced by Compel."""
@ -43,6 +38,22 @@ class SDXLConditioningInfo(BasicConditioningInfo):
return super().to(device=device, dtype=dtype)
@dataclass
class FLUXConditioningInfo:
clip_embeds: torch.Tensor
t5_embeds: torch.Tensor
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype)
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
return self
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
@dataclass
class IPAdapterConditioningInfo:
cond_image_prompt_embeds: torch.Tensor

Some files were not shown because too many files have changed in this diff Show More