Compare commits

...

124 Commits

Author SHA1 Message Date
9afb5d6ace Update version to 3.0.2post1 2023-08-12 19:49:33 -04:00
50177b8ed9 Update frontend JS files 2023-08-12 19:49:33 -04:00
a67d8376c7 fix missed spot for autoAddBoardId none 2023-08-12 18:07:01 +10:00
0b11f309ca instead of crashing when a corrupted model is detected, warn and move on 2023-08-11 15:05:14 -04:00
6a8eb392b2 Add support for loading SDXL LoRA weights in diffusers format. 2023-08-11 14:40:22 -04:00
824ca92760 fix maximum python version instructions 2023-08-11 13:49:39 -04:00
80fd4c2176 undo lint changes 2023-08-11 14:26:09 +10:00
3b6e425e17 fix error detail in toast 2023-08-11 14:26:09 +10:00
50415450d8 invalidate board total when images deleted, only run date range logic if board has less than 20 images 2023-08-11 14:26:09 +10:00
06296896a9 Update invokeai version 2023-08-10 22:23:41 -04:00
a7399aca0c Add new JS files for 3.0.2 build 2023-08-10 22:23:41 -04:00
d1ea8b1e98 Two changes to command-line scripts (#4235)
During install testing I discovered two small problems in the
command-line scripts. These are fixed.

## What type of PR is this? (check all applicable)

- [X Bug Fix

## Have you discussed this change with the InvokeAI team?
- [X] Yes
- 
      
## Have you updated all relevant documentation?
- [X] Yes


## Description

- installer - use correct entry point for invokeai-configure
- model merge script - prevent error when `--root` not provided
2023-08-10 21:11:45 -04:00
f851ad7ba0 Two changes to command-line scripts
- installer - use correct entry point for invokeai-configure
- model merge script - prevent error when `--root` not provided
2023-08-10 20:59:22 -04:00
591838a84b Add support for LyCORIS IA3 format (#4234)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [x] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [x] No

      
## Have you updated all relevant documentation?
- [ ] Yes
- [x] No


## Description
Add support for LyCORIS IA3 format

## Related Tickets & Documents
- Closes #4229 

## Added/updated tests?

- [ ] Yes
- [x] No
2023-08-11 03:30:35 +03:00
c0c2ab3dcf Format by black 2023-08-11 03:20:56 +03:00
56023bc725 Add support for LyCORIS IA3 format 2023-08-11 02:08:08 +03:00
2ef6a8995b Temporary force set vae to same precision as unet 2023-08-10 18:01:58 -04:00
d0fee93aac round slider values to nice numbers 2023-08-10 18:00:45 -04:00
1bfe9835cf clip cache settings to permissible values; remove redundant imports in install __init__ file 2023-08-10 18:00:45 -04:00
8e7eae6cc7 Probe LoRAs that do not have the text encoder (#4181)
## What type of PR is this? (check all applicable)

- [X] Bug Fix

## Have you discussed this change with the InvokeAI team?
- [X] No - minor fix

      
## Have you updated all relevant documentation?
- [X] Yes

## Description

It turns out that some LoRAs do not have the text encoder model, and
this was causing the code that distinguishes the model base type during
model import to reject them as having an unknown base model. This PR
enables detection of these cases.
2023-08-10 17:50:20 -04:00
f6522c8971 Merge branch 'main' into fix/detect-more-loras 2023-08-10 17:33:16 -04:00
a969707e45 prevent vae: '' from crashing model 2023-08-10 17:33:04 -04:00
6c8e898f09 Update scripts/verify_checkpoint_template.py
Co-authored-by: Eugene Brodsky <ebr@users.noreply.github.com>
2023-08-10 16:00:33 -04:00
7bad9bcf53 update dependencies and docs to cu118 2023-08-10 15:19:12 -04:00
d42b45116f fix(ui): fix lora sort (#4222)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [s] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because:

      

## Description

was sorting with disabled at top of list instead of bottom

fixes #4217

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #4217

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

![image](https://github.com/invoke-ai/InvokeAI/assets/4822129/dd895b86-05de-4303-8674-9b181037abaa)
2023-08-10 21:04:28 +12:00
d4812bbc8d Merge branch 'main' into fix/ui/fix-lora-sort 2023-08-10 19:00:26 +10:00
3cd05cf6bf fix(ui): fix lora sort
was sorting with disabled at top of list instead of bottom

fixes #4217
2023-08-10 15:31:29 +10:00
2564301aeb fix(ui): fix canvas model switching (#4221)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because:

## Description

There was no check at all to see if the canvas had a valid model already
selected. The first model in the list was selected every time.

Now, we check if its valid. If not, we go through the logic to try and
pick the first valid model.

If there are no valid models, or there was a problem listing models, the
model selection is cleared.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->


- Closes #4125

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

- Go to Canvas tab
- Select a model other than the first one in the list
- Go to a different tab
- Go back to Canvas tab
- The model should be the same as you selected
2023-08-10 17:29:41 +12:00
da0efeaa7f fix(ui): fix canvas model switching
There was no check at all to see if the canvas had a valid model already selected. The first model in the list was selected every time.

Now, we check if its valid. If not, we go through the logic to try and pick the first valid model.

If there are no valid models, or there was a problem listing models, the model selection is cleared.
2023-08-10 15:20:37 +10:00
49cce1eec6 feat: add app_version to image metadata 2023-08-10 14:22:39 +10:00
c8fbaf54b6 Add self.min, not self.max 2023-08-10 09:59:22 +10:00
f86d388786 refactor(diffusers_pipeline): remove unused pipeline methods 🚮 (#4175) 2023-08-09 15:19:27 -07:00
cd2c688562 Merge branch 'main' into refactor/remove_unused_pipeline_methods 2023-08-09 17:26:09 -04:00
2d29ac6f0d Add techjedi's image import script (#4171)
## What type of PR is this? (check all applicable)

- [X ] Feature

## Have you discussed this change with the InvokeAI team?
- [X] Yes

## Have you updated all relevant documentation?
- [X] Yes

## Description

This PR adds the `invokeai-import-images` script, which imports a
directory of 2.*.* -generated images into the current InvokeAI root
directory, preserving and converting their metadata. The script also
handles 3.* images.

Many thanks to @techjedi for writing this. This version differs from the
original in two minor respects:

1. It is installed as an `invokeai-import-images` command.
2. The prompts for image and database paths use file completion provided
by the `prompt_toolkit` library.
## To Test

1. Activate the virtual environment for the destination root to import
INTO
2. Run `invokeai-import-images`
3. Follow the prompts

## Related Tickets & Documents

This is a frequently-requested feature on Discord, but I couldn't find
an Issue.

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [X] No : but should in the future
2023-08-09 13:17:08 -04:00
2c2b731386 fix typo 2023-08-09 13:08:59 -04:00
2f68a1a76c use Stalker's simplified LoRA vector-length detection code 2023-08-09 09:21:29 -04:00
930e7bc754 Merge branch 'main' into feat/image-import-script 2023-08-09 08:54:56 -04:00
7d4ace962a Merge branch 'main' into fix/detect-more-loras 2023-08-09 08:48:27 -04:00
06842f8e0a Update to 3.0.2rc1 2023-08-09 00:29:43 -04:00
c82da330db Pin safetensors to 0.3.1
Safetensors 0.3.2 does not ship an ARM64 wheel so install on macOS fails
2023-08-09 00:29:43 -04:00
628df4ec98 Add updated frontend html file 2023-08-09 00:29:43 -04:00
16b956616f Update version to 3.0.2 2023-08-09 00:29:43 -04:00
604cc17a3a Yarn build JS files 2023-08-09 00:29:43 -04:00
37c9b85549 Add slider for VRAM cache in configure script (#4133)
## What type of PR is this? (check all applicable)

- [X ] Feature

## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [X] No - will be in release notes

## Description

On CUDA systems, this PR adds a new slider to the install-time configure
script for adjusting the VRAM cache and suggests a good starting value
based on the user's max VRAM (this is subject to verification).

On non-CUDA systems this slider is suppressed.

Please test on both CUDA and non-CUDA systems using:
```
invokeai-configure --root ~/invokeai-main/ --skip-sd --skip-support
```

To see and test the default values, move `invokeai.yaml` out of the way
before running.

**Note added 8 August 2023**

This PR also fixes the configure and model install scripts so that if
the window is too small to fit the user interface, the user will be
prompted to interactively resize the window and/or change font size
(with the option to give up). This will prevent `npyscreen` from
generating its horrible tracebacks.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-08-09 12:27:54 +10:00
8b39b67ec7 Merge branch 'main' into feat/select-vram-in-config 2023-08-09 12:17:27 +10:00
a933977861 Pick correct config file for sdxl models (#4191)
## What type of PR is this? (check all applicable)

- [X] Bug Fix

## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [X Yes
- [ ] No


## Description

If `models.yaml` is cleared out for some reason, the model manager will
repopulate it by scanning `models`. However, this would fail with a
pydantic validation error if any SDXL checkpoint models were present
because the lack of logic to pick the correct configuration file. This
has now been added.
2023-08-09 11:16:48 +10:00
dfb41d8461 Merge branch 'main' into bugfix/autodetect-sdxl-ckpt-config 2023-08-09 03:57:44 +03:00
4d5169e16d Merge branch 'main' into feat/select-vram-in-config 2023-08-08 13:50:02 -04:00
f56f19710d allow user to interactively resize screen before UI runs 2023-08-08 12:27:25 -04:00
e77400ab62 remove deprecated options from config 2023-08-08 08:33:30 -07:00
13347f6aec blackified 2023-08-08 08:33:30 -07:00
a9bf387e5e turned on Pydantic validate_assignment 2023-08-08 08:33:30 -07:00
8258c87a9f refrain from writing deprecated legacy options to invokeai.yaml 2023-08-08 08:33:30 -07:00
1b1b399fd0 Fix crash when attempting to update a model (#4192)
## What type of PR is this? (check all applicable)

- [X] Bug Fix


## Have you discussed this change with the InvokeAI team?
- [X No, because small fix

      
## Have you updated all relevant documentation?
- [X] Yes

## Description

A logic bug was introduced in PR #4109 that caused Web-based model
updates to fail with a pydantic validation error. This corrects the
problem.

## Related Tickets & Documents

PR #4109
2023-08-08 10:54:27 -04:00
a8d3e078c0 Merge branch 'main' into fix/detect-more-loras 2023-08-08 10:42:45 -04:00
6ed7ba57dd Merge branch 'main' into bugfix/fix-model-updates 2023-08-08 09:05:25 -04:00
2b3b77a276 api(images): allow HEAD request on image/full (#4193) 2023-08-08 00:08:48 -07:00
8b8ec68b30 Merge branch 'main' into feat/image_http_head 2023-08-08 00:02:48 -07:00
e20af5aef0 feat(ui): add LoRA support to SDXL linear UI
new graph modifier `addSDXLLoRasToGraph()` handles adding LoRA to the SDXL t2i and i2i graphs.
2023-08-08 15:02:00 +10:00
57e8ec9488 chore(ui): lint/format 2023-08-08 12:53:47 +10:00
734a9e4271 invalidate board total when images deleted, only run date range logic if board has less than 20 images 2023-08-08 12:53:47 +10:00
fe924daee3 add option to disable multiselect 2023-08-08 12:53:47 +10:00
750f09fbed blackify 2023-08-07 21:01:59 -04:00
4df581811e add template verification script 2023-08-07 21:01:48 -04:00
eb70bc2ae4 add scripts to create model templates and check whether they match 2023-08-07 21:00:47 -04:00
809705c30d api(images): allow HEAD request on image/full 2023-08-07 15:11:47 -07:00
f0918edf98 improve error reporting on unrecognized lora models 2023-08-07 16:38:58 -04:00
a846d82fa1 Add techedi code to avoid rendering prompt/seed with null
- Added techjedi github and real names
2023-08-07 16:29:46 -04:00
22f7cf0638 add stalker's complicated but effective code for finding token vector length in LoRAs 2023-08-07 16:19:57 -04:00
25c669b1d6 Merge remote-tracking branch 'origin/main' into refactor/remove_unused_pipeline_methods 2023-08-07 13:03:10 -07:00
4367061b19 fix(ModelManager): fix overridden VAE with relative path (#4059) 2023-08-07 12:57:32 -07:00
0fd13d3604 Merge branch 'main' into feat/select-vram-in-config 2023-08-07 15:51:59 -04:00
72a3e776b2 fix logic error introduced in PR 4109 2023-08-07 15:38:22 -04:00
af044007d5 pick correct config file for sdxl models 2023-08-07 15:19:49 -04:00
f272a44feb Merge branch 'main' into refactor/model_manager_instantiate 2023-08-07 10:59:28 -07:00
8469d3e95a chore: black 2023-08-07 10:05:52 +10:00
ae17d01e1d Fix hue adjustment (#4182)
* Fix hue adjustment

Hue adjustment wasn't working correctly because color channels got swapped. This has now been fixed and we're using PIL rather than cv2 to do the RGBA->HSV->RGBA conversion. The range of hue adjustment is also the more typical 0..360 degrees.
2023-08-06 23:23:51 +00:00
f3d3316558 probe LoRAs that do not have the text encoder 2023-08-06 16:00:53 -04:00
5a6cefb0ea add backslash to end of incomplete windows paths 2023-08-06 12:34:35 -04:00
1a6f5f0860 use backslash on Windows systems for autoadded delimiter 2023-08-06 12:29:31 -04:00
5bfd6cb66f Merge remote-tracking branch 'origin/main' into refactor/model_manager_instantiate
# Conflicts:
#	invokeai/backend/model_management/model_manager.py
2023-08-05 22:02:28 -07:00
59caff7ff0 refactor(diffusers_pipeline): remove unused img2img wrappers 🚮
invokeai.app no longer needs this as a single method, as it builds on latents2latents instead.
2023-08-05 21:50:52 -07:00
6487e7d906 refactor(diffusers_pipeline): remove unused ModelGroup 🚮
orphaned since #3550 removed the LazilyLoadedModelGroup code, probably unused since ModelCache took over responsibility for sequential offload somewhere around #3335.
2023-08-05 21:50:52 -07:00
77033eabd3 refactor(diffusers_pipeline): remove unused precision 🚮 2023-08-05 21:50:52 -07:00
b80abdd101 refactor(diffusers_pipeline): remove unused image_from_embeddings 🚮 2023-08-05 21:50:52 -07:00
006d782cc8 refactor(diffusers_pipeline): tidy imports 🚮 2023-08-05 21:50:52 -07:00
7f4c387080 test(model_management): factor out name strings 2023-08-05 15:46:46 -07:00
80876bbbd1 Merge remote-tracking branch 'origin/refactor/model_manager_instantiate' into refactor/model_manager_instantiate 2023-08-05 15:25:05 -07:00
7a4ff4c089 Merge branch 'main' into refactor/model_manager_instantiate 2023-08-05 15:23:38 -07:00
44bf308192 test(model_management): add a couple tests for _get_model_path 2023-08-05 15:22:23 -07:00
12e51c84ae blackified 2023-08-05 14:26:16 -07:00
b2eb83deff add docs 2023-08-05 14:26:16 -07:00
0ccc3b509e add techjedi's import script, with some filecompletion tweaks 2023-08-05 14:26:16 -07:00
4043a4c21c blackified 2023-08-05 12:44:58 -04:00
c8ceb96091 add docs 2023-08-05 12:26:52 -04:00
83f75750a9 add techjedi's import script, with some filecompletion tweaks 2023-08-05 12:19:24 -04:00
65ed224bfc Merge branch 'main' into refactor/model_manager_instantiate 2023-08-04 21:34:38 -07:00
da96a41103 Merge branch 'main' into feat/select-vram-in-config 2023-08-05 12:11:50 +10:00
b10cf20eb1 Merge branch 'main' into refactor/model_manager_instantiate
# Conflicts:
#	invokeai/backend/model_management/model_manager.py
2023-08-04 18:28:18 -07:00
1deca89fde Merge branch 'main' into feat/select-vram-in-config 2023-08-03 19:27:58 -04:00
91ebf9f76e Merge branch 'main' into refactor/model_manager_instantiate 2023-08-02 19:01:21 -07:00
02d2cc758d Merge branch 'main' into refactor/model_manager_instantiate 2023-08-02 17:11:23 -07:00
6bc21984c6 Merge branch 'main' into feat/select-vram-in-config 2023-08-02 19:12:43 -04:00
ec48779080 blackify 2023-08-02 14:28:19 -04:00
bc20fe4cb5 Merge branch 'main' into feat/select-vram-in-config 2023-08-02 14:27:17 -04:00
5de42be4a6 reduce VRAM cache default; take max RAM from system 2023-08-02 14:27:13 -04:00
29ac252501 blackify 2023-08-02 09:44:06 -04:00
880727436c fix default vram cache size calculation 2023-08-02 09:43:52 -04:00
77c5c18542 add slider for VRAM cache 2023-08-02 09:11:24 -04:00
1f9e984b0d Merge branch 'main' into refactor/model_manager_instantiate 2023-08-01 16:49:39 -07:00
5998509888 Merge branch 'main' into refactor/model_manager_instantiate 2023-08-01 11:09:43 -07:00
bacdf985f1 doc(model_manager): docstrings 2023-07-31 09:16:32 -07:00
e3519052ae Merge remote-tracking branch 'origin/main' into refactor/model_manager_instantiate 2023-07-31 08:46:09 -07:00
adfd1e52f4 refactor(model_manager): avoid copy/paste logic 2023-07-30 11:53:12 -07:00
0e48c98330 Merge remote-tracking branch 'origin/main' into refactor/model_manager_instantiate
# Conflicts:
#	invokeai/backend/model_management/model_manager.py
2023-07-30 11:33:13 -07:00
ff1c40747e lint: formatting 2023-07-29 20:02:31 -07:00
dbfd1bcb5e Merge branch 'main' into refactor/model_manager_instantiate 2023-07-29 19:53:21 -07:00
ccceb32a85 lint: formatting 2023-07-29 11:50:04 -07:00
21617e60e1 Merge remote-tracking branch 'origin/main' into refactor/model_manager_instantiate 2023-07-29 08:21:26 -07:00
86b8b69e88 internal(ModelManager): add instantiate method 2023-07-28 22:30:25 -07:00
bc9a5038fd refactor(ModelManager): factor out get_model_path 2023-07-28 22:29:36 -07:00
b163ae6a4d refactor(ModelManager): factor out get_model_config 2023-07-28 21:30:20 -07:00
dca685ac25 refactor(ModelManager): refactor rescan-on-miss to exists() method 2023-07-28 21:11:00 -07:00
e70bedba7d refactor(ModelManager): factor out _get_implementation method 2023-07-28 21:03:27 -07:00
65 changed files with 2239 additions and 1098 deletions

View File

@ -161,7 +161,7 @@ the command `npm install -g yarn` if needed)
_For Windows/Linux with an NVIDIA GPU:_
```terminal
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
```
_For Linux with an AMD GPU:_
@ -306,13 +306,30 @@ InvokeAI. The second will prepare the 2.3 directory for use with 3.0.
You may now launch the WebUI in the usual way, by selecting option [1]
from the launcher script
#### Migration Caveats
#### Migrating Images
The migration script will migrate your invokeai settings and models,
including textual inversion models, LoRAs and merges that you may have
installed previously. However it does **not** migrate the generated
images stored in your 2.3-format outputs directory. You will need to
manually import selected images into the 3.0 gallery via drag-and-drop.
images stored in your 2.3-format outputs directory. To do this, you
need to run an additional step:
1. From a working InvokeAI 3.0 root directory, start the launcher and
enter menu option [8] to open the "developer's console".
2. At the developer's console command line, type the command:
```bash
invokeai-import-images
```
3. This will lead you through the process of confirming the desired
source and destination for the imported images. The images will
appear in the gallery board of your choice, and contain the
original prompt, model name, and other parameters used to generate
the image.
(Many kudos to **techjedi** for contributing this script.)
## Hardware Requirements

View File

@ -264,7 +264,7 @@ experimental versions later.
you can create several levels of subfolders and drop your models into
whichever ones you want.
- ***Autoimport FolderLICENSE***
- ***LICENSE***
At the bottom of the screen you will see a checkbox for accepting
the CreativeML Responsible AI Licenses. You need to accept the license
@ -471,7 +471,7 @@ Then type the following commands:
=== "NVIDIA System"
```bash
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu117
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu118
pip install xformers
```

View File

@ -148,7 +148,7 @@ manager, please follow these steps:
=== "CUDA (NVidia)"
```bash
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
```
=== "ROCm (AMD)"
@ -312,7 +312,7 @@ installation protocol (important!)
=== "CUDA (NVidia)"
```bash
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
```
=== "ROCm (AMD)"
@ -356,7 +356,7 @@ you can do so using this unsupported recipe:
mkdir ~/invokeai
conda create -n invokeai python=3.10
conda activate invokeai
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
invokeai-configure --root ~/invokeai
invokeai --root ~/invokeai --web
```

View File

@ -34,11 +34,11 @@ directly from NVIDIA. **Do not try to install Ubuntu's
nvidia-cuda-toolkit package. It is out of date and will cause
conflicts among the NVIDIA driver and binaries.**
Go to [CUDA Toolkit 11.7
Downloads](https://developer.nvidia.com/cuda-11-7-0-download-archive),
and use the target selection wizard to choose your operating system,
hardware platform, and preferred installation method (e.g. "local"
versus "network").
Go to [CUDA Toolkit
Downloads](https://developer.nvidia.com/cuda-downloads), and use the
target selection wizard to choose your operating system, hardware
platform, and preferred installation method (e.g. "local" versus
"network").
This will provide you with a downloadable install file or, depending
on your choices, a recipe for downloading and running a install shell
@ -61,7 +61,7 @@ Runtime Site](https://developer.nvidia.com/nvidia-container-runtime)
When installing torch and torchvision manually with `pip`, remember to provide
the argument `--extra-index-url
https://download.pytorch.org/whl/cu117` as described in the [Manual
https://download.pytorch.org/whl/cu118` as described in the [Manual
Installation Guide](020_INSTALL_MANUAL.md).
## :simple-amd: ROCm

View File

@ -28,18 +28,21 @@ command line, then just be sure to activate it's virtual environment.
Then run the following three commands:
```sh
pip install xformers==0.0.16rc425
pip install triton
pip install xformers~=0.0.19
pip install triton # WON'T WORK ON WINDOWS
python -m xformers.info output
```
The first command installs `xformers`, the second installs the
`triton` training accelerator, and the third prints out the `xformers`
installation status. If all goes well, you'll see a report like the
installation status. On Windows, please omit the `triton` package,
which is not available on that platform.
If all goes well, you'll see a report like the
following:
```sh
xFormers 0.0.16rc425
xFormers 0.0.20
memory_efficient_attention.cutlassF: available
memory_efficient_attention.cutlassB: available
memory_efficient_attention.flshattF: available
@ -48,22 +51,28 @@ memory_efficient_attention.smallkF: available
memory_efficient_attention.smallkB: available
memory_efficient_attention.tritonflashattF: available
memory_efficient_attention.tritonflashattB: available
indexing.scaled_index_addF: available
indexing.scaled_index_addB: available
indexing.index_select: available
swiglu.dual_gemm_silu: available
swiglu.gemm_fused_operand_sum: available
swiglu.fused.p.cpp: available
is_triton_available: True
is_functorch_available: False
pytorch.version: 1.13.1+cu117
pytorch.version: 2.0.1+cu118
pytorch.cuda: available
gpu.compute_capability: 8.6
gpu.name: NVIDIA RTX A2000 12GB
gpu.compute_capability: 8.9
gpu.name: NVIDIA GeForce RTX 4070
build.info: available
build.cuda_version: 1107
build.python_version: 3.10.9
build.torch_version: 1.13.1+cu117
build.cuda_version: 1108
build.python_version: 3.10.11
build.torch_version: 2.0.1+cu118
build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6
build.env.XFORMERS_BUILD_TYPE: Release
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None
build.env.NVCC_FLAGS: None
build.env.XFORMERS_PACKAGE_FROM: wheel-v0.0.16rc425
build.env.XFORMERS_PACKAGE_FROM: wheel-v0.0.20
build.nvcc_version: 11.8.89
source.privacy: open source
```
@ -83,14 +92,14 @@ installed from source. These instructions were written for a system
running Ubuntu 22.04, but other Linux distributions should be able to
adapt this recipe.
#### 1. Install CUDA Toolkit 11.7
#### 1. Install CUDA Toolkit 11.8
You will need the CUDA developer's toolkit in order to compile and
install xFormers. **Do not try to install Ubuntu's nvidia-cuda-toolkit
package.** It is out of date and will cause conflicts among the NVIDIA
driver and binaries. Instead install the CUDA Toolkit package provided
by NVIDIA itself. Go to [CUDA Toolkit 11.7
Downloads](https://developer.nvidia.com/cuda-11-7-0-download-archive)
by NVIDIA itself. Go to [CUDA Toolkit 11.8
Downloads](https://developer.nvidia.com/cuda-11-8-0-download-archive)
and use the target selection wizard to choose your platform and Linux
distribution. Select an installer type of "runfile (local)" at the
last step.
@ -101,17 +110,17 @@ example, the install script recipe for Ubuntu 22.04 running on a
x86_64 system is:
```
wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run
sudo sh cuda_11.7.0_515.43.04_linux.run
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
sudo sh cuda_11.8.0_520.61.05_linux.run
```
Rather than cut-and-paste this example, We recommend that you walk
through the toolkit wizard in order to get the most up to date
installer for your system.
#### 2. Confirm/Install pyTorch 1.13 with CUDA 11.7 support
#### 2. Confirm/Install pyTorch 2.01 with CUDA 11.8 support
If you are using InvokeAI 2.3 or higher, these will already be
If you are using InvokeAI 3.0.2 or higher, these will already be
installed. If not, you can check whether you have the needed libraries
using a quick command. Activate the invokeai virtual environment,
either by entering the "developer's console", or manually with a
@ -124,7 +133,7 @@ Then run the command:
python -c 'exec("import torch\nprint(torch.__version__)")'
```
If it prints __1.13.1+cu117__ you're good. If not, you can install the
If it prints __1.13.1+cu118__ you're good. If not, you can install the
most up to date libraries with this command:
```sh

View File

@ -348,7 +348,7 @@ class InvokeAiInstance:
introduction()
from invokeai.frontend.install import invokeai_configure
from invokeai.frontend.install.invokeai_configure import invokeai_configure
# NOTE: currently the config script does its own arg parsing! this means the command-line switches
# from the installer will also automatically propagate down to the config script.
@ -463,10 +463,10 @@ def get_torch_source() -> (Union[str, None], str):
url = "https://download.pytorch.org/whl/cpu"
if device == "cuda":
url = "https://download.pytorch.org/whl/cu117"
url = "https://download.pytorch.org/whl/cu118"
optional_modules = "[xformers,onnx-cuda]"
if device == "cuda_and_dml":
url = "https://download.pytorch.org/whl/cu117"
url = "https://download.pytorch.org/whl/cu118"
optional_modules = "[xformers,onnx-directml]"
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@ -8,16 +8,13 @@ Preparations:
to work. Instructions are given here:
https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/
NOTE: At this time we do not recommend Python 3.11. We recommend
Version 3.10.9, which has been extensively tested with InvokeAI.
Before you start the installer, please open up your system's command
line window (Terminal or Command) and type the commands:
python --version
If all is well, it will print "Python 3.X.X", where the version number
is at least 3.9.1, and less than 3.11.
is at least 3.9.*, and not higher than 3.11.*.
If this works, check the version of the Python package manager, pip:

View File

@ -1,22 +1,20 @@
import io
from typing import Optional
from PIL import Image
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, Field
from pydantic import BaseModel
from invokeai.app.invocations.metadata import ImageMetadata
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.services.models.image_record import (
ImageDTO,
ImageRecordChanges,
ImageUrlsDTO,
)
from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@ -152,8 +150,9 @@ async def get_image_metadata(
raise HTTPException(status_code=404)
@images_router.get(
@images_router.api_route(
"/i/{image_name}/full",
methods=["GET", "HEAD"],
operation_id="get_image_full",
response_class=Response,
responses={

View File

@ -104,8 +104,12 @@ async def update_model(
): # model manager moved model path during rename - don't overwrite it
info.path = new_info.get("path")
# replace empty string values with None/null to avoid phenomenon of vae: ''
info_dict = info.dict()
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict()
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(

View File

@ -1,26 +1,23 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import contextmanager, ContextDecorator
from functools import partial
from typing import Literal, Optional, get_args
import torch
from pydantic import Field
from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods
from ...backend.generator import Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.step_callback import stable_diffusion_step_callback
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
from .image import ImageOutput
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from .model import UNetField, VaeField
from .compel import ConditioningField
from contextlib import contextmanager, ExitStack, ContextDecorator
from .image import ImageOutput
from .model import UNetField, VaeField
from ..util.step_callback import stable_diffusion_step_callback
from ...backend.generator import Inpaint, InvokeAIGenerator
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())]
@ -184,6 +181,8 @@ class InpaintInvocation(BaseInvocation):
device = context.services.model_manager.mgr.cache.execution_device
dtype = context.services.model_manager.mgr.cache.precision
vae.to(dtype=unet.dtype)
pipeline = StableDiffusionGeneratorPipeline(
vae=vae,
text_encoder=None,
@ -193,8 +192,6 @@ class InpaintInvocation(BaseInvocation):
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
precision="float16" if dtype == torch.float16 else "float32",
execution_device=device,
)
yield OldModelInfo(

View File

@ -501,7 +501,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
image = context.services.images.get_pil_image(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
image_arr = image_arr * (self.max - self.min) + self.max
image_arr = image_arr * (self.max - self.min) + self.min
lerp_image = Image.fromarray(numpy.uint8(image_arr))
@ -661,27 +661,23 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
# Inputs
image: ImageField = Field(default=None, description="The image to adjust")
hue: int = Field(default=0, description="The degrees by which to rotate the hue")
hue: int = Field(default=0, description="The degrees by which to rotate the hue, 0-360")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
# Convert PIL image to OpenCV format (numpy array), note color channel
# ordering is changed from RGB to BGR
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
# Convert image to HSV color space
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
hsv_image = numpy.array(pil_image.convert("HSV"))
# Adjust the hue
hsv_image[:, :, 0] = (hsv_image[:, :, 0] + self.hue) % 180
# Convert hue from 0..360 to 0..256
hue = int(256 * ((self.hue % 360) / 360))
# Convert image back to BGR color space
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
# Increment each hue and wrap around at 255
hsv_image[:, :, 0] = (hsv_image[:, :, 0] + hue) % 256
# Convert back to PIL format and to original color mode
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
pil_image = Image.fromarray(hsv_image, mode="HSV").convert("RGBA")
image_dto = context.services.images.create(
image=pil_image,

View File

@ -5,15 +5,26 @@ from typing import List, Literal, Optional, Union
import einops
import torch
from diffusers import ControlNetModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.model_management import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
@ -24,23 +35,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.model_management import ModelPatcher
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from invokeai.app.util.controlnet_utils import prepare_control_image
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
DEFAULT_PRECISION = choose_precision(choose_torch_device())
@ -231,7 +226,6 @@ class TextToLatentsInvocation(BaseInvocation):
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
precision="float16" if unet.dtype == torch.float16 else "float32",
)
def prep_control_data(

View File

@ -2,6 +2,7 @@ from typing import Literal, Optional, Union
from pydantic import Field
from ...version import __version__
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -23,6 +24,7 @@ class LoRAMetadataField(BaseModelExcludeNull):
class CoreMetadata(BaseModelExcludeNull):
"""Core generation metadata for an image generated in InvokeAI."""
app_version: str = Field(default=__version__, description="The version of InvokeAI used to generate this image")
generation_mode: str = Field(
description="The generation mode that output this image",
)

View File

@ -24,11 +24,10 @@ InvokeAI:
sequential_guidance: false
precision: float16
max_cache_size: 6
max_vram_cache_size: 2.7
max_vram_cache_size: 0.5
always_use_cpu: false
free_gpu_mem: false
Features:
restore: true
esrgan: true
patchmatch: true
internet_available: true
@ -165,7 +164,7 @@ import pydoc
import os
import sys
from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig
from omegaconf import OmegaConf, DictConfig, ListConfig
from pathlib import Path
from pydantic import BaseSettings, Field, parse_obj_as
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
@ -173,6 +172,7 @@ from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_ty
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_MAX_VRAM = 0.5
class InvokeAISettings(BaseSettings):
@ -189,7 +189,12 @@ class InvokeAISettings(BaseSettings):
opt = parser.parse_args(argv)
for name in self.__fields__:
if name not in self._excluded():
setattr(self, name, getattr(opt, name))
value = getattr(opt, name)
if isinstance(value, ListConfig):
value = list(value)
elif isinstance(value, DictConfig):
value = dict(value)
setattr(self, name, value)
def to_yaml(self) -> str:
"""
@ -282,14 +287,10 @@ class InvokeAISettings(BaseSettings):
return [
"type",
"initconf",
"gpu_mem_reserved",
"max_loaded_models",
"version",
"from_file",
"model",
"restore",
"root",
"nsfw_checker",
]
class Config:
@ -388,15 +389,11 @@ class InvokeAIAppConfig(InvokeAISettings):
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='DEPRECATED')
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
gpu_mem_reserved : float = Field(default=2.75, ge=0, description="DEPRECATED: use max_vram_cache_size. Amount of VRAM reserved for model storage", category='DEPRECATED')
nsfw_checker : bool = Field(default=True, description="DEPRECATED: use Web settings to enable/disable", category='DEPRECATED')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
@ -414,9 +411,7 @@ class InvokeAIAppConfig(InvokeAISettings):
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert')
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
@ -426,6 +421,9 @@ class InvokeAIAppConfig(InvokeAISettings):
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
# fmt: on
class Config:
validate_assignment = True
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
"""
Update settings with contents of init file, environment, and

View File

@ -1,25 +1,11 @@
"""
invokeai.backend.generator.img2img descends from .generator
"""
from typing import Optional
import torch
from accelerate.utils import set_seed
from diffusers import logging
from ..stable_diffusion import (
ConditioningData,
PostprocessingSettings,
StableDiffusionGeneratorPipeline,
)
from .base import Generator
class Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
self.init_latent = None # by get_noise()
def get_make_image(
self,
sampler,
@ -42,51 +28,4 @@ class Img2Img(Generator):
Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it.
"""
self.perlin = perlin
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
uc, c, extra_conditioning_info = conditioning
conditioning_data = ConditioningData(
uc,
c,
cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=threshold,
warmup=warmup,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
def make_image(x_T: torch.Tensor, seed: int):
# FIXME: use x_T for initial seeded noise
# We're not at the moment because the pipeline automatically resizes init_image if
# necessary, which the x_T input might not match.
# In the meantime, reset the seed prior to generating pipeline output so we at least get the same result.
logging.set_verbosity_error() # quench safety check warnings
pipeline_output = pipeline.img2img_from_embeddings(
init_image,
strength,
steps,
conditioning_data,
noise_func=self.get_noise_like,
callback=step_callback,
seed=seed,
)
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
return make_image
def get_noise_like(self, like: torch.Tensor):
device = like.device
x = torch.randn_like(like, device=device)
if self.perlin > 0.0:
shape = like.shape
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
return x
raise NotImplementedError("replaced by invokeai.app.invocations.latent.LatentsToLatentsInvocation")

View File

@ -377,3 +377,11 @@ class Inpaint(Img2Img):
)
return corrected_result
def get_noise_like(self, like: torch.Tensor):
device = like.device
x = torch.randn_like(like, device=device)
if self.perlin > 0.0:
shape = like.shape
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
return x

View File

@ -10,15 +10,17 @@ import sys
import argparse
import io
import os
import psutil
import shutil
import textwrap
import torch
import traceback
import yaml
import warnings
from argparse import Namespace
from enum import Enum
from pathlib import Path
from shutil import get_terminal_size
from typing import get_type_hints
from urllib import request
import npyscreen
@ -44,6 +46,8 @@ from invokeai.app.services.config import (
)
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
# TO DO - Move all the frontend code into invokeai.frontend.install
from invokeai.frontend.install.widgets import (
SingleSelectColumns,
CenteredButtonPress,
@ -53,6 +57,7 @@ from invokeai.frontend.install.widgets import (
CyclingForm,
MIN_COLS,
MIN_LINES,
WindowTooSmallException,
)
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
from invokeai.backend.install.model_install_backend import (
@ -61,6 +66,7 @@ from invokeai.backend.install.model_install_backend import (
ModelInstall,
)
from invokeai.backend.model_management.model_probe import ModelType, BaseModelType
from pydantic.error_wrappers import ValidationError
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
@ -76,6 +82,13 @@ Default_config_file = config.model_conf_path
SD_Configs = config.legacy_conf_path
PRECISION_CHOICES = ["auto", "float16", "float32"]
GB = 1073741824 # GB in bytes
HAS_CUDA = torch.cuda.is_available()
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0)
MAX_VRAM /= GB
MAX_RAM = psutil.virtual_memory().total / GB
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# This is the InvokeAI initialization file, which contains command-line default values.
@ -86,6 +99,12 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
logger = InvokeAILogger.getLogger()
class DummyWidgetValue(Enum):
zero = 0
true = True
false = False
# --------------------------------------------
def postscript(errors: None):
if not any(errors):
@ -376,15 +395,47 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
max_width=80,
scroll_exit=True,
)
self.max_cache_size = self.add_widget_intelligent(
IntTitleSlider,
name="Size of the RAM cache used for fast model switching (GB)",
value=old_opts.max_cache_size,
out_of=20,
lowest=3,
begin_entry_at=6,
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="RAM cache size (GB). Make this at least large enough to hold a single full model.",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.max_cache_size = self.add_widget_intelligent(
npyscreen.Slider,
value=clip(old_opts.max_cache_size, range=(3.0, MAX_RAM), step=0.5),
out_of=round(MAX_RAM),
lowest=0.0,
step=0.5,
relx=8,
scroll_exit=True,
)
if HAS_CUDA:
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="VRAM cache size (GB). Reserving a small amount of VRAM will modestly speed up the start of image generation.",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.max_vram_cache_size = self.add_widget_intelligent(
npyscreen.Slider,
value=clip(old_opts.max_vram_cache_size, range=(0, MAX_VRAM), step=0.25),
out_of=round(MAX_VRAM * 2) / 2,
lowest=0.0,
relx=8,
step=0.25,
scroll_exit=True,
)
else:
self.max_vram_cache_size = DummyWidgetValue.zero
self.nextrely += 1
self.outdir = self.add_widget_intelligent(
FileBox,
@ -401,7 +452,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
self.autoimport_dirs = {}
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
FileBox,
name=f"Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
value=str(config.root_path / config.autoimport_dir),
select_dir=True,
must_exist=False,
@ -476,6 +527,7 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
"outdir",
"free_gpu_mem",
"max_cache_size",
"max_vram_cache_size",
"xformers_enabled",
"always_use_cpu",
]:
@ -553,6 +605,16 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
)
# -------------------------------------
def clip(value: float, range: tuple[float, float], step: float) -> float:
minimum, maximum = range
if value < minimum:
value = minimum
if value > maximum:
value = maximum
return round(value / step) * step
# -------------------------------------
def initialize_rootdir(root: Path, yes_to_all: bool = False):
logger.info("Initializing InvokeAI runtime directory")
@ -592,13 +654,13 @@ def maybe_create_models_yaml(root: Path):
# -------------------------------------
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
# parse_args() will read from init file if present
invokeai_opts = default_startup_options(initfile)
invokeai_opts.root = program_opts.root
# The third argument is needed in the Windows 11 environment to
# launch a console window running this program.
set_min_terminal_size(MIN_COLS, MIN_LINES)
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
raise WindowTooSmallException(
"Could not increase terminal size. Try running again with a larger window or smaller font size."
)
# the install-models application spawns a subprocess to install
# models, and will crash unless this is set before running.
@ -654,10 +716,13 @@ def migrate_init_file(legacy_format: Path):
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
new = InvokeAIAppConfig.get_config()
fields = list(get_type_hints(InvokeAIAppConfig).keys())
fields = [x for x, y in InvokeAIAppConfig.__fields__.items() if y.field_info.extra.get("category") != "DEPRECATED"]
for attr in fields:
if hasattr(old, attr):
setattr(new, attr, getattr(old, attr))
try:
setattr(new, attr, getattr(old, attr))
except ValidationError as e:
print(f"* Ignoring incompatible value for field {attr}:\n {str(e)}")
# a few places where the field names have changed and we have to
# manually add in the new names/values
@ -777,6 +842,7 @@ def main():
models_to_download = default_user_selections(opt)
new_init_file = config.root_path / "invokeai.yaml"
if opt.yes_to_all:
write_default_options(opt, new_init_file)
init_options = Namespace(precision="float32" if opt.full_precision else "float16")
@ -802,6 +868,8 @@ def main():
postscript(errors=errors)
if not opt.yes_to_all:
input("Press any key to continue...")
except WindowTooSmallException as e:
logger.error(str(e))
except KeyboardInterrupt:
print("\nGoodbye! Come back soon.")

View File

@ -591,7 +591,6 @@ script, which will perform a full upgrade in place.""",
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists()
if not dest_is_setup:
import invokeai.frontend.install.invokeai_configure
from invokeai.backend.install.invokeai_configure import initialize_rootdir
initialize_rootdir(dest_root, True)

View File

@ -143,7 +143,7 @@ class ModelPatcher:
# with torch.autocast(device_type="cpu"):
layer.to(dtype=torch.float32)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale
layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale
if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
@ -361,7 +361,8 @@ class ONNXModelPatcher:
layer.to(dtype=torch.float32)
layer_key = layer_key.replace(prefix, "")
layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
# TODO: rewrite to pass original tensor weight(required by ia3)
layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight
if layer_key is blended_loras:
blended_loras[layer_key] += layer_weight
else:

View File

@ -228,19 +228,19 @@ the root is the InvokeAI ROOTDIR.
"""
from __future__ import annotations
import os
import hashlib
import os
import textwrap
import yaml
import types
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Optional, List, Tuple, Union, Dict, Set, Callable, types
from shutil import rmtree, move
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable
import torch
import yaml
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from pydantic import BaseModel, Field
import invokeai.backend.util.logging as logger
@ -259,6 +259,7 @@ from .models import (
ModelNotFoundException,
InvalidModelException,
DuplicateModelException,
ModelBase,
)
# We are only starting to number the config file with release 3.
@ -361,7 +362,7 @@ class ModelManager(object):
if model_key.startswith("_"):
continue
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
model_class = self._get_implementation(base_model, model_type)
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
@ -381,18 +382,24 @@ class ModelManager(object):
# causing otherwise unreferenced models to be removed from memory
self._read_models()
def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool:
"""
Given a model name, returns True if it is a valid
identifier.
Given a model name, returns True if it is a valid identifier.
:param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model
:param rescan: if True, scan_models_directory
"""
model_key = self.create_key(model_name, base_model, model_type)
return model_key in self.models
exists = model_key in self.models
# if model not found try to find it (maybe file just pasted)
if rescan and not exists:
self.scan_models_directory(base_model=base_model, model_type=model_type)
exists = self.model_exists(model_name, base_model, model_type, rescan=False)
return exists
@classmethod
def create_key(
@ -443,39 +450,32 @@ class ModelManager(object):
:param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model
:param submode_typel: an ModelType enum indicating the portion of
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
model_class = MODEL_CLASSES[base_model][model_type]
model_key = self.create_key(model_name, base_model, model_type)
# if model not found try to find it (maybe file just pasted)
if model_key not in self.models:
self.scan_models_directory(base_model=base_model, model_type=model_type)
if model_key not in self.models:
raise ModelNotFoundException(f"Model not found - {model_key}")
if not self.model_exists(model_name, base_model, model_type, rescan=True):
raise ModelNotFoundException(f"Model not found - {model_key}")
model_config = self.models[model_key]
model_path = self.resolve_model_path(model_config.path)
model_config = self._get_model_config(base_model, model_name, model_type)
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
if is_submodel_override:
model_type = submodel_type
submodel_type = None
model_class = self._get_implementation(base_model, model_type)
if not model_path.exists():
if model_class.save_to_config:
self.models[model_key].error = ModelError.NotFound
raise Exception(f'Files for model "{model_key}" not found')
raise Exception(f'Files for model "{model_key}" not found at {model_path}')
else:
self.models.pop(model_key, None)
raise ModelNotFoundException(f"Model not found - {model_key}")
# vae/movq override
# TODO:
if submodel_type is not None and hasattr(model_config, submodel_type):
override_path = getattr(model_config, submodel_type)
if override_path:
model_path = self.resolve_path(override_path)
model_type = submodel_type
submodel_type = None
model_class = MODEL_CLASSES[base_model][model_type]
raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}')
# TODO: path
# TODO: is it accurate to use path as id
@ -513,6 +513,55 @@ class ModelManager(object):
_cache=self.cache,
)
def _get_model_path(
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
) -> (Path, bool):
"""Extract a model's filesystem path from its config.
:return: The fully qualified Path of the module (or submodule).
"""
model_path = model_config.path
is_submodel_override = False
# Does the config explicitly override the submodel?
if submodel_type is not None and hasattr(model_config, submodel_type):
submodel_path = getattr(model_config, submodel_type)
if submodel_path is not None and len(submodel_path) > 0:
model_path = getattr(model_config, submodel_type)
is_submodel_override = True
model_path = self.resolve_model_path(model_path)
return model_path, is_submodel_override
def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase:
"""Get a model's config object."""
model_key = self.create_key(model_name, base_model, model_type)
try:
model_config = self.models[model_key]
except KeyError:
raise ModelNotFoundException(f"Model not found - {model_key}")
return model_config
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
"""Get the concrete implementation class for a specific model type."""
model_class = MODEL_CLASSES[base_model][model_type]
return model_class
def _instantiate(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
) -> ModelBase:
"""Make a new instance of this model, without loading it."""
model_config = self._get_model_config(base_model, model_name, model_type)
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
# FIXME: do non-overriden submodels get the right class?
constructor = self._get_implementation(base_model, model_type)
instance = constructor(model_path, base_model, model_type)
return instance
def model_info(
self,
model_name: str,
@ -546,9 +595,10 @@ class ModelManager(object):
the combined format of the list_models() method.
"""
models = self.list_models(base_model, model_type, model_name)
if len(models) > 1:
if len(models) >= 1:
return models[0]
return None
else:
return None
def list_models(
self,
@ -660,7 +710,7 @@ class ModelManager(object):
if path := model_attributes.get("path"):
model_attributes["path"] = str(self.relative_model_path(Path(path)))
model_class = MODEL_CLASSES[base_model][model_type]
model_class = self._get_implementation(base_model, model_type)
model_config = model_class.create_config(**model_attributes)
model_key = self.create_key(model_name, base_model, model_type)
@ -851,7 +901,7 @@ class ModelManager(object):
for model_key, model_config in self.models.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
model_class = self._get_implementation(base_model, model_type)
if model_class.save_to_config:
# TODO: or exclude_unset better fits here?
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
@ -909,7 +959,7 @@ class ModelManager(object):
model_path = self.resolve_model_path(model_config.path).absolute()
if not model_path.exists():
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
model_class = self._get_implementation(cur_base_model, cur_model_type)
if model_class.save_to_config:
model_config.error = ModelError.NotFound
self.models.pop(model_key, None)
@ -925,7 +975,7 @@ class ModelManager(object):
for cur_model_type in ModelType:
if model_type is not None and cur_model_type != model_type:
continue
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
model_class = self._get_implementation(cur_base_model, cur_model_type)
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
if not models_dir.exists():
@ -941,7 +991,9 @@ class ModelManager(object):
raise DuplicateModelException(f"Model with key {model_key} added twice")
model_path = self.relative_model_path(model_path)
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
model_config: ModelConfigBase = model_class.probe_config(
str(model_path), model_base=cur_base_model
)
self.models[model_key] = model_config
new_models_found = True
except DuplicateModelException as e:

View File

@ -17,6 +17,7 @@ from .models import (
SilenceWarnings,
InvalidModelException,
)
from .util import lora_token_vector_length
from .models.base import read_checkpoint_meta
@ -315,38 +316,16 @@ class LoRACheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
# SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future
# There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be
# misclassified as SD-1
key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
if key in checkpoint and checkpoint[key].shape[0] == 320:
return BaseModelType.StableDiffusion2
key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight"
if key in checkpoint:
return BaseModelType.StableDiffusionXL
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
lora_token_vector_length = (
checkpoint[key1].shape[1]
if key1 in checkpoint
else checkpoint[key2].shape[1]
if key2 in checkpoint
else checkpoint[key3].shape[0]
if key3 in checkpoint
else None
)
if lora_token_vector_length == 768:
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif lora_token_vector_length == 1024:
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"Unknown LoRA type")
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
class TextualInversionCheckpointProbe(CheckpointProbeBase):

View File

@ -1,18 +1,21 @@
import bisect
import os
import torch
from enum import Enum
from typing import Optional, Dict, Union, Literal, Any
from pathlib import Path
from typing import Dict, Optional, Union
import torch
from safetensors.torch import load_file
from .base import (
BaseModelType,
InvalidModelException,
ModelBase,
ModelConfigBase,
BaseModelType,
ModelNotFoundException,
ModelType,
SubModelType,
classproperty,
InvalidModelException,
ModelNotFoundException,
)
@ -122,41 +125,7 @@ class LoRALayerBase:
self.rank = None # set in layer implementation
self.layer_key = layer_key
def forward(
self,
module: torch.nn.Module,
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
multiplier: float,
):
if type(module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d
extra_args = dict(
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)
else:
op = torch.nn.functional.linear
extra_args = {}
weight = self.get_weight()
bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
return (
op(
*input_h,
(weight + bias).view(module.weight.shape),
None,
**extra_args,
)
* multiplier
* scale
)
def get_weight(self):
def get_weight(self, orig_weight: torch.Tensor):
raise NotImplementedError()
def calc_size(self) -> int:
@ -197,7 +166,7 @@ class LoRALayer(LoRALayerBase):
self.rank = self.down.shape[0]
def get_weight(self):
def get_weight(self, orig_weight: torch.Tensor):
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
@ -260,7 +229,7 @@ class LoHALayer(LoRALayerBase):
self.rank = self.w1_b.shape[0]
def get_weight(self):
def get_weight(self, orig_weight: torch.Tensor):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
@ -342,7 +311,7 @@ class LoKRLayer(LoRALayerBase):
else:
self.rank = None # unscaled
def get_weight(self):
def get_weight(self, orig_weight: torch.Tensor):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
@ -410,7 +379,7 @@ class FullLayer(LoRALayerBase):
self.rank = None # unscaled
def get_weight(self):
def get_weight(self, orig_weight: torch.Tensor):
return self.weight
def calc_size(self) -> int:
@ -428,6 +397,45 @@ class FullLayer(LoRALayerBase):
self.weight = self.weight.to(device=device, dtype=dtype)
class IA3Layer(LoRALayerBase):
# weight: torch.Tensor
# on_input: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.weight = values["weight"]
self.on_input = values["on_input"]
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor):
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
return orig_weight * weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
model_size += self.on_input.nelement() * self.on_input.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module):
_name: str
@ -477,30 +485,61 @@ class LoRAModelRaw: # (torch.nn.Module):
return model_size
@classmethod
def _convert_sdxl_compvis_keys(cls, state_dict):
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict):
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict = dict()
for full_key, value in state_dict.items():
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
continue # clip same
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
position = bisect.bisect_right(stability_unet_keys, search_key)
map_key = stability_unet_keys[position - 1]
# Now, check if the map_key *actually* matches the search_key.
if search_key.startswith(map_key):
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if not full_key.startswith("lora_unet_"):
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}")
src_key = full_key.replace("lora_unet_", "")
try:
dst_key = None
while "_" in src_key:
if src_key in SDXL_UNET_COMPVIS_MAP:
dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
break
src_key = "_".join(src_key.split("_")[:-1])
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
if dst_key is None:
raise Exception(f"Unknown sdxl lora key - {full_key}")
new_key = full_key.replace(src_key, dst_key)
except:
print(SDXL_UNET_COMPVIS_MAP)
raise
new_state_dict[new_key] = value
return new_state_dict
@classmethod
@ -532,7 +571,7 @@ class LoRAModelRaw: # (torch.nn.Module):
state_dict = cls._group_state(state_dict)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_compvis_keys(state_dict)
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items():
# lora and locon
@ -547,11 +586,15 @@ class LoRAModelRaw: # (torch.nn.Module):
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)
# ia3
elif "weight" in values and "on_input" in values:
layer = IA3Layer(layer_key, values)
else:
# TODO: ia3/... format
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
@ -579,6 +622,7 @@ class LoRAModelRaw: # (torch.nn.Module):
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def make_sdxl_unet_conversion_map():
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
@ -662,7 +706,6 @@ def make_sdxl_unet_conversion_map():
return unet_conversion_map
SDXL_UNET_COMPVIS_MAP = {
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
for sd, hf in make_sdxl_unet_conversion_map()
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
}

View File

@ -1,6 +1,5 @@
import os
import json
import invokeai.backend.util.logging as logger
from enum import Enum
from pydantic import Field
from typing import Literal, Optional
@ -12,6 +11,7 @@ from .base import (
DiffusersModel,
read_checkpoint_meta,
classproperty,
InvalidModelException,
)
from omegaconf import OmegaConf
@ -65,7 +65,7 @@ class StableDiffusionXLModel(DiffusersModel):
in_channels = unet_config["in_channels"]
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
raise InvalidModelException(f"{path} is not a recognized Stable Diffusion diffusers model")
else:
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
@ -80,8 +80,10 @@ class StableDiffusionXLModel(DiffusersModel):
raise Exception("Unkown stable diffusion 2.* model format")
if ckpt_config_path is None:
# TO DO: implement picking
pass
# avoid circular import
from .stable_diffusion import _select_ckpt_config
ckpt_config_path = _select_ckpt_config(kwargs.get("model_base", BaseModelType.StableDiffusionXL), variant)
return cls.create_config(
path=path,

View File

@ -1,9 +1,14 @@
import os
import torch
import safetensors
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from typing import Optional
import safetensors
import torch
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
from invokeai.app.services.config import InvokeAIAppConfig
from .base import (
ModelBase,
ModelConfigBase,
@ -18,9 +23,6 @@ from .base import (
InvalidModelException,
ModelNotFoundException,
)
from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
class VaeModelFormat(str, Enum):
@ -80,7 +82,7 @@ class VaeModel(ModelBase):
@classmethod
def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException()
raise ModelNotFoundException(f"Does not exist as local file: {path}")
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")):

View File

@ -0,0 +1,75 @@
# Copyright (c) 2023 The InvokeAI Development Team
"""Utilities used by the Model Manager"""
def lora_token_vector_length(checkpoint: dict) -> int:
"""
Given a checkpoint in memory, return the lora token vector length
:param checkpoint: The checkpoint
"""
def _get_shape_1(key, tensor, checkpoint):
lora_token_vector_length = None
if "." not in key:
return lora_token_vector_length # wrong key format
model_key, lora_key = key.split(".", 1)
# check lora/locon
if lora_key == "lora_down.weight":
lora_token_vector_length = tensor.shape[1]
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
lora_token_vector_length = tensor.shape[1]
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
elif "lokr_" in lora_key:
if model_key + ".lokr_w1" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
elif model_key + "lokr_w1_b" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
else:
return lora_token_vector_length # unknown format
if model_key + ".lokr_w2" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
elif model_key + "lokr_w2_b" in checkpoint:
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
else:
return lora_token_vector_length # unknown format
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
elif lora_key == "diff":
lora_token_vector_length = tensor.shape[1]
# ia3 can be detected only by shape[0] in text encoder
elif lora_key == "weight" and "lora_unet_" not in model_key:
lora_token_vector_length = tensor.shape[0]
return lora_token_vector_length
lora_token_vector_length = None
lora_te1_length = None
lora_te2_length = None
for key, tensor in checkpoint.items():
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_te") and "_self_attn_" in key:
tmp_length = _get_shape_1(key, tensor, checkpoint)
if key.startswith("lora_te_"):
lora_token_vector_length = tmp_length
elif key.startswith("lora_te1_"):
lora_te1_length = tmp_length
elif key.startswith("lora_te2_"):
lora_te2_length = tmp_length
if lora_te1_length is not None and lora_te2_length is not None:
lora_token_vector_length = lora_te1_length + lora_te2_length
if lora_token_vector_length is not None:
break
return lora_token_vector_length

View File

@ -4,25 +4,21 @@ import dataclasses
import inspect
import math
import secrets
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import Field
import einops
import PIL.Image
import numpy as np
from accelerate.utils import set_seed
import einops
import psutil
import torch
import torchvision.transforms as T
from accelerate.utils import set_seed
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline,
)
@ -31,21 +27,20 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
)
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils import PIL_INTERPOLATION
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput
from pydantic import Field
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec
from invokeai.app.services.config import InvokeAIAppConfig
from ..util import CPU_DEVICE, normalize_device
from .diffusion import (
AttentionMapSaver,
InvokeAIDiffuserComponent,
PostprocessingSettings,
)
from .offloading import FullyLoadedModelGroup, ModelGroup
from ..util import normalize_device
@dataclass
@ -289,8 +284,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_model_group: ModelGroup
ID_LENGTH = 8
def __init__(
@ -303,9 +296,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
precision: str = "float32",
control_model: ControlNetModel = None,
execution_device: Optional[torch.device] = None,
):
super().__init__(
vae,
@ -330,9 +321,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# control_model=control_model,
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
self._model_group.install(*self._submodels)
self.control_model = control_model
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
@ -368,72 +356,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
else:
self.disable_attention_slicing()
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
# overridden method; types match the superclass.
if torch_device is None:
return self
self._model_group.set_device(torch.device(torch_device))
self._model_group.ready()
@property
def device(self) -> torch.device:
return self._model_group.execution_device
@property
def _submodels(self) -> Sequence[torch.nn.Module]:
module_names, _, _ = self.extract_init_dict(dict(self.config))
submodels = []
for name in module_names.keys():
if hasattr(self, name):
value = getattr(self, name)
else:
value = getattr(self.config, name)
if isinstance(value, torch.nn.Module):
submodels.append(value)
return submodels
def image_from_embeddings(
self,
latents: torch.Tensor,
num_inference_steps: int,
conditioning_data: ConditioningData,
*,
noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
) -> InvokeAIStableDiffusionPipelineOutput:
r"""
Function invoked when calling the pipeline for generation.
:param conditioning_data:
:param latents: Pre-generated un-noised latents, to be used as inputs for
image generation. Can be used to tweak the same generation with different prompts.
:param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality
image at the expense of slower inference.
:param noise: Noise to add to the latents, sampled from a Gaussian distribution.
:param callback:
:param run_id:
"""
result_latents, result_attention_map_saver = self.latents_from_embeddings(
latents,
num_inference_steps,
conditioning_data,
noise=noise,
run_id=run_id,
callback=callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(result_latents)
output = InvokeAIStableDiffusionPipelineOutput(
images=image,
nsfw_content_detected=[],
attention_map_saver=result_attention_map_saver,
)
return self.check_for_safety(output, dtype=conditioning_data.dtype)
def latents_from_embeddings(
self,
latents: torch.Tensor,
@ -450,7 +372,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device("cpu")
else:
scheduler_device = self._model_group.device_for(self.unet)
scheduler_device = self.unet.device
if timesteps is None:
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
@ -504,7 +426,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
(batch_size,),
timesteps[0],
dtype=timesteps.dtype,
device=self._model_group.device_for(self.unet),
device=self.unet.device,
)
latents = self.scheduler.add_noise(latents, noise, batched_t)
@ -700,79 +622,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
**kwargs,
).sample
def img2img_from_embeddings(
self,
init_image: Union[torch.FloatTensor, PIL.Image.Image],
strength: float,
num_inference_steps: int,
conditioning_data: ConditioningData,
*,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
noise_func=None,
seed=None,
) -> InvokeAIStableDiffusionPipelineOutput:
if isinstance(init_image, PIL.Image.Image):
init_image = image_resized_to_grid_as_tensor(init_image.convert("RGB"))
if init_image.dim() == 3:
init_image = einops.rearrange(init_image, "c h w -> 1 c h w")
# 6. Prepare latent variables
initial_latents = self.non_noised_latents_from_image(
init_image,
device=self._model_group.device_for(self.unet),
dtype=self.unet.dtype,
)
if seed is not None:
set_seed(seed)
noise = noise_func(initial_latents)
return self.img2img_from_latents_and_embeddings(
initial_latents,
num_inference_steps,
conditioning_data,
strength,
noise,
run_id,
callback,
)
def img2img_from_latents_and_embeddings(
self,
initial_latents,
num_inference_steps,
conditioning_data: ConditioningData,
strength,
noise: torch.Tensor,
run_id=None,
callback=None,
) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
result_latents, result_attention_maps = self.latents_from_embeddings(
latents=initial_latents
if strength < 1.0
else torch.zeros_like(initial_latents, device=initial_latents.device, dtype=initial_latents.dtype),
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
timesteps=timesteps,
noise=noise,
run_id=run_id,
callback=callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
with torch.inference_mode():
image = self.decode_latents(result_latents)
output = InvokeAIStableDiffusionPipelineOutput(
images=image,
nsfw_content_detected=[],
attention_map_saver=result_attention_maps,
)
return self.check_for_safety(output, dtype=conditioning_data.dtype)
def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int):
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
assert img2img_pipeline.scheduler is self.scheduler
@ -780,7 +629,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device("cpu")
else:
scheduler_device = self._model_group.device_for(self.unet)
scheduler_device = self.unet.device
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
@ -806,7 +655,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise_func=None,
seed=None,
) -> InvokeAIStableDiffusionPipelineOutput:
device = self._model_group.device_for(self.unet)
device = self.unet.device
latents_dtype = self.unet.dtype
if isinstance(init_image, PIL.Image.Image):
@ -877,42 +726,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
nsfw_content_detected=[],
attention_map_saver=result_attention_maps,
)
return self.check_for_safety(output, dtype=conditioning_data.dtype)
return output
def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
init_image = init_image.to(device=device, dtype=dtype)
with torch.inference_mode():
self._model_group.load(self.vae)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible!
init_latents = 0.18215 * init_latents
return init_latents
def check_for_safety(self, output, dtype):
with torch.inference_mode():
screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype)
screened_attention_map_saver = None
if has_nsfw_concept is None or not has_nsfw_concept:
screened_attention_map_saver = output.attention_map_saver
return InvokeAIStableDiffusionPipelineOutput(
screened_images,
has_nsfw_concept,
# block the attention maps if NSFW content is detected
attention_map_saver=screened_attention_map_saver,
)
def run_safety_checker(self, image, device=None, dtype=None):
# overriding to use the model group for device info instead of requiring the caller to know.
if self.safety_checker is not None:
device = self._model_group.device_for(self.safety_checker)
return super().run_safety_checker(image, device, dtype)
def decode_latents(self, latents):
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
self._model_group.load(self.vae)
return super().decode_latents(latents)
def debug_latents(self, latents, msg):
from invokeai.backend.image_util import debug_image

View File

@ -1,253 +0,0 @@
from __future__ import annotations
import warnings
import weakref
from abc import ABCMeta, abstractmethod
from collections.abc import MutableMapping
from typing import Callable, Union
import torch
from accelerate.utils import send_to_device
from torch.utils.hooks import RemovableHandle
OFFLOAD_DEVICE = torch.device("cpu")
class _NoModel:
"""Symbol that indicates no model is loaded.
(We can't weakref.ref(None), so this was my best idea at the time to come up with something
type-checkable.)
"""
def __bool__(self):
return False
def to(self, device: torch.device):
pass
def __repr__(self):
return "<NO MODEL>"
NO_MODEL = _NoModel()
class ModelGroup(metaclass=ABCMeta):
"""
A group of models.
The use case I had in mind when writing this is the sub-models used by a DiffusionPipeline,
e.g. its text encoder, U-net, VAE, etc.
Those models are :py:class:`diffusers.ModelMixin`, but "model" is interchangeable with
:py:class:`torch.nn.Module` here.
"""
def __init__(self, execution_device: torch.device):
self.execution_device = execution_device
@abstractmethod
def install(self, *models: torch.nn.Module):
"""Add models to this group."""
pass
@abstractmethod
def uninstall(self, models: torch.nn.Module):
"""Remove models from this group."""
pass
@abstractmethod
def uninstall_all(self):
"""Remove all models from this group."""
@abstractmethod
def load(self, model: torch.nn.Module):
"""Load this model to the execution device."""
pass
@abstractmethod
def offload_current(self):
"""Offload the current model(s) from the execution device."""
pass
@abstractmethod
def ready(self):
"""Ready this group for use."""
pass
@abstractmethod
def set_device(self, device: torch.device):
"""Change which device models from this group will execute on."""
pass
@abstractmethod
def device_for(self, model) -> torch.device:
"""Get the device the given model will execute on.
The model should already be a member of this group.
"""
pass
@abstractmethod
def __contains__(self, model):
"""Check if the model is a member of this group."""
pass
def __repr__(self) -> str:
return f"<{self.__class__.__name__} object at {id(self):x}: " f"device={self.execution_device} >"
class LazilyLoadedModelGroup(ModelGroup):
"""
Only one model from this group is loaded on the GPU at a time.
Running the forward method of a model will displace the previously-loaded model,
offloading it to CPU.
If you call other methods on the model, e.g. ``model.encode(x)`` instead of ``model(x)``,
you will need to explicitly load it with :py:method:`.load(model)`.
This implementation relies on pytorch forward-pre-hooks, and it will copy forward arguments
to the appropriate execution device, as long as they are positional arguments and not keyword
arguments. (I didn't make the rules; that's the way the pytorch 1.13 API works for hooks.)
"""
_hooks: MutableMapping[torch.nn.Module, RemovableHandle]
_current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]]
def __init__(self, execution_device: torch.device):
super().__init__(execution_device)
self._hooks = weakref.WeakKeyDictionary()
self._current_model_ref = weakref.ref(NO_MODEL)
def install(self, *models: torch.nn.Module):
for model in models:
self._hooks[model] = model.register_forward_pre_hook(self._pre_hook)
def uninstall(self, *models: torch.nn.Module):
for model in models:
hook = self._hooks.pop(model)
hook.remove()
if self.is_current_model(model):
# no longer hooked by this object, so don't claim to manage it
self.clear_current_model()
def uninstall_all(self):
self.uninstall(*self._hooks.keys())
def _pre_hook(self, module: torch.nn.Module, forward_input):
self.load(module)
if len(forward_input) == 0:
warnings.warn(
f"Hook for {module.__class__.__name__} got no input. " f"Inputs must be positional, not keywords.",
stacklevel=3,
)
return send_to_device(forward_input, self.execution_device)
def load(self, module):
if not self.is_current_model(module):
self.offload_current()
self._load(module)
def offload_current(self):
module = self._current_model_ref()
if module is not NO_MODEL:
module.to(OFFLOAD_DEVICE)
self.clear_current_model()
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}"
module = module.to(self.execution_device)
self.set_current_model(module)
return module
def is_current_model(self, model: torch.nn.Module) -> bool:
"""Is the given model the one currently loaded on the execution device?"""
return self._current_model_ref() is model
def is_empty(self):
"""Are none of this group's models loaded on the execution device?"""
return self._current_model_ref() is NO_MODEL
def set_current_model(self, value):
self._current_model_ref = weakref.ref(value)
def clear_current_model(self):
self._current_model_ref = weakref.ref(NO_MODEL)
def set_device(self, device: torch.device):
if device == self.execution_device:
return
self.execution_device = device
current = self._current_model_ref()
if current is not NO_MODEL:
current.to(device)
def device_for(self, model):
if model not in self:
raise KeyError(f"This does not manage this model {type(model).__name__}", model)
return self.execution_device # this implementation only dispatches to one device
def ready(self):
pass # always ready to load on-demand
def __contains__(self, model):
return model in self._hooks
def __repr__(self) -> str:
return (
f"<{self.__class__.__name__} object at {id(self):x}: "
f"current_model={type(self._current_model_ref()).__name__} >"
)
class FullyLoadedModelGroup(ModelGroup):
"""
A group of models without any implicit loading or unloading.
:py:meth:`.ready` loads _all_ the models to the execution device at once.
"""
_models: weakref.WeakSet
def __init__(self, execution_device: torch.device):
super().__init__(execution_device)
self._models = weakref.WeakSet()
def install(self, *models: torch.nn.Module):
for model in models:
self._models.add(model)
model.to(self.execution_device)
def uninstall(self, *models: torch.nn.Module):
for model in models:
self._models.remove(model)
def uninstall_all(self):
self.uninstall(*self._models)
def load(self, model):
model.to(self.execution_device)
def offload_current(self):
for model in self._models:
model.to(OFFLOAD_DEVICE)
def ready(self):
for model in self._models:
self.load(model)
def set_device(self, device: torch.device):
self.execution_device = device
for model in self._models:
if model.device != OFFLOAD_DEVICE:
model.to(device)
def device_for(self, model):
if model not in self:
raise KeyError("This does not manage this model f{type(model).__name__}", model)
return self.execution_device # this implementation only dispatches to one device
def __contains__(self, model):
return model in self._models

View File

@ -1,6 +1,3 @@
"""
Initialization file for invokeai.frontend.config
"""
from .invokeai_configure import main as invokeai_configure
from .invokeai_update import main as invokeai_update
from .model_install import main as invokeai_model_install

View File

@ -0,0 +1,795 @@
# Copyright (c) 2023 - The InvokeAI Team
# Primary Author: David Lovell (github @f412design, discord @techjedi)
# co-author, minor tweaks - Lincoln Stein
# pylint: disable=line-too-long
# pylint: disable=broad-exception-caught
"""Script to import images into the new database system for 3.0.0"""
import os
import datetime
import shutil
import locale
import sqlite3
import json
import glob
import re
import uuid
import yaml
import PIL
import PIL.ImageOps
import PIL.PngImagePlugin
from pathlib import Path
from prompt_toolkit import prompt
from prompt_toolkit.shortcuts import message_dialog
from prompt_toolkit.completion import PathCompleter
from prompt_toolkit.key_binding import KeyBindings
from invokeai.app.services.config import InvokeAIAppConfig
app_config = InvokeAIAppConfig.get_config()
bindings = KeyBindings()
@bindings.add("c-c")
def _(event):
raise KeyboardInterrupt
# release notes
# "Use All" with size dimensions not selectable in the UI will not load dimensions
class Config:
"""Configuration loader."""
def __init__(self):
pass
TIMESTAMP_STRING = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
INVOKE_DIRNAME = "invokeai"
YAML_FILENAME = "invokeai.yaml"
DATABASE_FILENAME = "invokeai.db"
database_path = None
database_backup_dir = None
outputs_path = None
thumbnail_path = None
def find_and_load(self):
"""find the yaml config file and load"""
root = app_config.root_path
if not self.confirm_and_load(os.path.abspath(root)):
print("\r\nSpecify custom database and outputs paths:")
self.confirm_and_load_from_user()
self.database_backup_dir = os.path.join(os.path.dirname(self.database_path), "backup")
self.thumbnail_path = os.path.join(self.outputs_path, "thumbnails")
def confirm_and_load(self, invoke_root):
"""Validates a yaml path exists, confirms the user wants to use it and loads config."""
yaml_path = os.path.join(invoke_root, self.YAML_FILENAME)
if os.path.exists(yaml_path):
db_dir, outdir = self.load_paths_from_yaml(yaml_path)
if os.path.isabs(db_dir):
database_path = os.path.join(db_dir, self.DATABASE_FILENAME)
else:
database_path = os.path.join(invoke_root, db_dir, self.DATABASE_FILENAME)
if os.path.isabs(outdir):
outputs_path = os.path.join(outdir, "images")
else:
outputs_path = os.path.join(invoke_root, outdir, "images")
db_exists = os.path.exists(database_path)
outdir_exists = os.path.exists(outputs_path)
text = f"Found {self.YAML_FILENAME} file at {yaml_path}:"
text += f"\n Database : {database_path}"
text += f"\n Outputs : {outputs_path}"
text += "\n\nUse these paths for import (yes) or choose different ones (no) [Yn]: "
if db_exists and outdir_exists:
if (prompt(text).strip() or "Y").upper().startswith("Y"):
self.database_path = database_path
self.outputs_path = outputs_path
return True
else:
return False
else:
print(" Invalid: One or more paths in this config did not exist and cannot be used.")
else:
message_dialog(
title="Path not found",
text=f"Auto-discovery of configuration failed! Could not find ({yaml_path}), Custom paths can be specified.",
).run()
return False
def confirm_and_load_from_user(self):
default = ""
while True:
database_path = os.path.expanduser(
prompt(
"Database: Specify absolute path to the database to import into: ",
completer=PathCompleter(
expanduser=True, file_filter=lambda x: Path(x).is_dir() or x.endswith((".db"))
),
default=default,
)
)
if database_path.endswith(".db") and os.path.isabs(database_path) and os.path.exists(database_path):
break
default = database_path + "/" if Path(database_path).is_dir() else database_path
default = ""
while True:
outputs_path = os.path.expanduser(
prompt(
"Outputs: Specify absolute path to outputs/images directory to import into: ",
completer=PathCompleter(expanduser=True, only_directories=True),
default=default,
)
)
if outputs_path.endswith("images") and os.path.isabs(outputs_path) and os.path.exists(outputs_path):
break
default = outputs_path + "/" if Path(outputs_path).is_dir() else outputs_path
self.database_path = database_path
self.outputs_path = outputs_path
return
def load_paths_from_yaml(self, yaml_path):
"""Load an Invoke AI yaml file and get the database and outputs paths."""
try:
with open(yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
yamlinfo = yaml.safe_load(file)
db_dir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("db_dir", None)
outdir = yamlinfo.get("InvokeAI", {}).get("Paths", {}).get("outdir", None)
return db_dir, outdir
except Exception:
print(f"Failed to load paths from yaml file! {yaml_path}!")
return None, None
class ImportStats:
"""DTO for tracking work progress."""
def __init__(self):
pass
time_start = datetime.datetime.utcnow()
count_source_files = 0
count_skipped_file_exists = 0
count_skipped_db_exists = 0
count_imported = 0
count_imported_by_version = {}
count_file_errors = 0
@staticmethod
def get_elapsed_time_string():
"""Get a friendly time string for the time elapsed since processing start."""
time_now = datetime.datetime.utcnow()
total_seconds = (time_now - ImportStats.time_start).total_seconds()
hours = int((total_seconds) / 3600)
minutes = int(((total_seconds) % 3600) / 60)
seconds = total_seconds % 60
out_str = f"{hours} hour(s) -" if hours > 0 else ""
out_str += f"{minutes} minute(s) -" if minutes > 0 else ""
out_str += f"{seconds:.2f} second(s)"
return out_str
class InvokeAIMetadata:
"""DTO for core Invoke AI generation properties parsed from metadata."""
def __init__(self):
pass
def __str__(self):
formatted_str = f"{self.generation_mode}~{self.steps}~{self.cfg_scale}~{self.model_name}~{self.scheduler}~{self.seed}~{self.width}~{self.height}~{self.rand_device}~{self.strength}~{self.init_image}"
formatted_str += f"\r\npositive_prompt: {self.positive_prompt}"
formatted_str += f"\r\nnegative_prompt: {self.negative_prompt}"
return formatted_str
generation_mode = None
steps = None
cfg_scale = None
model_name = None
scheduler = None
seed = None
width = None
height = None
rand_device = None
strength = None
init_image = None
positive_prompt = None
negative_prompt = None
imported_app_version = None
def to_json(self):
"""Convert the active instance to json format."""
prop_dict = {}
prop_dict["generation_mode"] = self.generation_mode
# dont render prompt nodes if neither are set to avoid the ui thinking it can set them
# if at least one exists, render them both, but use empty string instead of None if one of them is empty
# this allows the field that is empty to actually be cleared byt he UI instead of leaving the previous value
if self.positive_prompt or self.negative_prompt:
prop_dict["positive_prompt"] = "" if self.positive_prompt is None else self.positive_prompt
prop_dict["negative_prompt"] = "" if self.negative_prompt is None else self.negative_prompt
prop_dict["width"] = self.width
prop_dict["height"] = self.height
# only render seed if it has a value to avoid ui thinking it can set this and then error
if self.seed:
prop_dict["seed"] = self.seed
prop_dict["rand_device"] = self.rand_device
prop_dict["cfg_scale"] = self.cfg_scale
prop_dict["steps"] = self.steps
prop_dict["scheduler"] = self.scheduler
prop_dict["clip_skip"] = 0
prop_dict["model"] = {}
prop_dict["model"]["model_name"] = self.model_name
prop_dict["model"]["base_model"] = None
prop_dict["controlnets"] = []
prop_dict["loras"] = []
prop_dict["vae"] = None
prop_dict["strength"] = self.strength
prop_dict["init_image"] = self.init_image
prop_dict["positive_style_prompt"] = None
prop_dict["negative_style_prompt"] = None
prop_dict["refiner_model"] = None
prop_dict["refiner_cfg_scale"] = None
prop_dict["refiner_steps"] = None
prop_dict["refiner_scheduler"] = None
prop_dict["refiner_aesthetic_store"] = None
prop_dict["refiner_start"] = None
prop_dict["imported_app_version"] = self.imported_app_version
return json.dumps(prop_dict)
class InvokeAIMetadataParser:
"""Parses strings with json data to find Invoke AI core metadata properties."""
def __init__(self):
pass
def parse_meta_tag_dream(self, dream_string):
"""Take as input an png metadata json node for the 'dream' field variant from prior to 1.15"""
props = InvokeAIMetadata()
props.imported_app_version = "pre1.15"
seed_match = re.search("-S\\s*(\\d+)", dream_string)
if seed_match is not None:
try:
props.seed = int(seed_match[1])
except ValueError:
props.seed = None
raw_prompt = re.sub("(-S\\s*\\d+)", "", dream_string)
else:
raw_prompt = dream_string
pos_prompt, neg_prompt = self.split_prompt(raw_prompt)
props.positive_prompt = pos_prompt
props.negative_prompt = neg_prompt
return props
def parse_meta_tag_sd_metadata(self, tag_value):
"""Take as input an png metadata json node for the 'sd-metadata' field variant from 1.15 through 2.3.5 post 2"""
props = InvokeAIMetadata()
props.imported_app_version = tag_value.get("app_version")
props.model_name = tag_value.get("model_weights")
img_node = tag_value.get("image")
if img_node is not None:
props.generation_mode = img_node.get("type")
props.width = img_node.get("width")
props.height = img_node.get("height")
props.seed = img_node.get("seed")
props.rand_device = "cuda" # hardcoded since all generations pre 3.0 used cuda random noise instead of cpu
props.cfg_scale = img_node.get("cfg_scale")
props.steps = img_node.get("steps")
props.scheduler = self.map_scheduler(img_node.get("sampler"))
props.strength = img_node.get("strength")
if props.strength is None:
props.strength = img_node.get("strength_steps") # try second name for this property
props.init_image = img_node.get("init_image_path")
if props.init_image is None: # try second name for this property
props.init_image = img_node.get("init_img")
# remove the path info from init_image so if we move the init image, it will be correctly relative in the new location
if props.init_image is not None:
props.init_image = os.path.basename(props.init_image)
raw_prompt = img_node.get("prompt")
if isinstance(raw_prompt, list):
raw_prompt = raw_prompt[0].get("prompt")
props.positive_prompt, props.negative_prompt = self.split_prompt(raw_prompt)
return props
def parse_meta_tag_invokeai(self, tag_value):
"""Take as input an png metadata json node for the 'invokeai' field variant from 3.0.0 beta 1 through 5"""
props = InvokeAIMetadata()
props.imported_app_version = "3.0.0 or later"
props.generation_mode = tag_value.get("type")
if props.generation_mode is not None:
props.generation_mode = props.generation_mode.replace("t2l", "txt2img").replace("l2l", "img2img")
props.width = tag_value.get("width")
props.height = tag_value.get("height")
props.seed = tag_value.get("seed")
props.cfg_scale = tag_value.get("cfg_scale")
props.steps = tag_value.get("steps")
props.scheduler = tag_value.get("scheduler")
props.strength = tag_value.get("strength")
props.positive_prompt = tag_value.get("positive_conditioning")
props.negative_prompt = tag_value.get("negative_conditioning")
return props
def map_scheduler(self, old_scheduler):
"""Convert the legacy sampler names to matching 3.0 schedulers"""
if old_scheduler is None:
return None
match (old_scheduler):
case "ddim":
return "ddim"
case "plms":
return "pnmd"
case "k_lms":
return "lms"
case "k_dpm_2":
return "kdpm_2"
case "k_dpm_2_a":
return "kdpm_2_a"
case "dpmpp_2":
return "dpmpp_2s"
case "k_dpmpp_2":
return "dpmpp_2m"
case "k_dpmpp_2_a":
return None # invalid, in 2.3.x, selecting this sample would just fallback to last run or plms if new session
case "k_euler":
return "euler"
case "k_euler_a":
return "euler_a"
case "k_heun":
return "heun"
return None
def split_prompt(self, raw_prompt: str):
"""Split the unified prompt strings by extracting all negative prompt blocks out into the negative prompt."""
if raw_prompt is None:
return "", ""
raw_prompt_search = raw_prompt.replace("\r", "").replace("\n", "")
matches = re.findall(r"\[(.+?)\]", raw_prompt_search)
if len(matches) > 0:
negative_prompt = ""
if len(matches) == 1:
negative_prompt = matches[0].strip().strip(",")
else:
for match in matches:
negative_prompt += f"({match.strip().strip(',')})"
positive_prompt = re.sub(r"(\[.+?\])", "", raw_prompt_search).strip()
else:
positive_prompt = raw_prompt_search.strip()
negative_prompt = ""
return positive_prompt, negative_prompt
class DatabaseMapper:
"""Class to abstract database functionality."""
def __init__(self, database_path, database_backup_dir):
self.database_path = database_path
self.database_backup_dir = database_backup_dir
self.connection = None
self.cursor = None
def connect(self):
"""Open connection to the database."""
self.connection = sqlite3.connect(self.database_path)
self.cursor = self.connection.cursor()
def get_board_names(self):
"""Get a list of the current board names from the database."""
sql_get_board_name = "SELECT board_name FROM boards"
self.cursor.execute(sql_get_board_name)
rows = self.cursor.fetchall()
return [row[0] for row in rows]
def does_image_exist(self, image_name):
"""Check database if a image name already exists and return a boolean."""
sql_get_image_by_name = f"SELECT image_name FROM images WHERE image_name='{image_name}'"
self.cursor.execute(sql_get_image_by_name)
rows = self.cursor.fetchall()
return True if len(rows) > 0 else False
def add_new_image_to_database(self, filename, width, height, metadata, modified_date_string):
"""Add an image to the database."""
sql_add_image = f"""INSERT INTO images (image_name, image_origin, image_category, width, height, session_id, node_id, metadata, is_intermediate, created_at, updated_at)
VALUES ('{filename}', 'internal', 'general', {width}, {height}, null, null, '{metadata}', 0, '{modified_date_string}', '{modified_date_string}')"""
self.cursor.execute(sql_add_image)
self.connection.commit()
def get_board_id_with_create(self, board_name):
"""Get the board id for supplied name, and create the board if one does not exist."""
sql_find_board = f"SELECT board_id FROM boards WHERE board_name='{board_name}' COLLATE NOCASE"
self.cursor.execute(sql_find_board)
rows = self.cursor.fetchall()
if len(rows) > 0:
return rows[0][0]
else:
board_date_string = datetime.datetime.utcnow().date().isoformat()
new_board_id = str(uuid.uuid4())
sql_insert_board = f"INSERT INTO boards (board_id, board_name, created_at, updated_at) VALUES ('{new_board_id}', '{board_name}', '{board_date_string}', '{board_date_string}')"
self.cursor.execute(sql_insert_board)
self.connection.commit()
return new_board_id
def add_image_to_board(self, filename, board_id):
"""Add an image mapping to a board."""
add_datetime_str = datetime.datetime.utcnow().isoformat()
sql_add_image_to_board = f"""INSERT INTO board_images (board_id, image_name, created_at, updated_at)
VALUES ('{board_id}', '{filename}', '{add_datetime_str}', '{add_datetime_str}')"""
self.cursor.execute(sql_add_image_to_board)
self.connection.commit()
def disconnect(self):
"""Disconnect from the db, cleaning up connections and cursors."""
if self.cursor is not None:
self.cursor.close()
if self.connection is not None:
self.connection.close()
def backup(self, timestamp_string):
"""Take a backup of the database."""
if not os.path.exists(self.database_backup_dir):
print(f"Database backup directory {self.database_backup_dir} does not exist -> creating...", end="")
os.makedirs(self.database_backup_dir)
print("Done!")
database_backup_path = os.path.join(self.database_backup_dir, f"backup-{timestamp_string}-invokeai.db")
print(f"Making DB Backup at {database_backup_path}...", end="")
shutil.copy2(self.database_path, database_backup_path)
print("Done!")
class MediaImportProcessor:
"""Containing class for script functionality."""
def __init__(self):
pass
board_name_id_map = {}
def get_import_file_list(self):
"""Ask the user for the import folder and scan for the list of files to return."""
while True:
default = ""
while True:
import_dir = os.path.expanduser(
prompt(
"Inputs: Specify absolute path containing InvokeAI .png images to import: ",
completer=PathCompleter(expanduser=True, only_directories=True),
default=default,
)
)
if len(import_dir) > 0 and Path(import_dir).is_dir():
break
default = import_dir
recurse_directories = (
(prompt("Include files from subfolders recursively [yN]? ").strip() or "N").upper().startswith("N")
)
if recurse_directories:
is_recurse = False
matching_file_list = glob.glob(import_dir + "/*.png", recursive=False)
else:
is_recurse = True
matching_file_list = glob.glob(import_dir + "/**/*.png", recursive=True)
if len(matching_file_list) > 0:
return import_dir, is_recurse, matching_file_list
else:
print(f"The specific path {import_dir} exists, but does not contain .png files!")
def get_file_details(self, filepath):
"""Retrieve the embedded metedata fields and dimensions from an image file."""
with PIL.Image.open(filepath) as img:
img.load()
png_width, png_height = img.size
img_info = img.info
return img_info, png_width, png_height
def select_board_option(self, board_names, timestamp_string):
"""Allow the user to choose how a board is selected for imported files."""
while True:
print("\r\nOptions for board selection for imported images:")
print(f"1) Select an existing board name. (found {len(board_names)})")
print("2) Specify a board name to create/add to.")
print("3) Create/add to board named 'IMPORT'.")
print(
f"4) Create/add to board named 'IMPORT' with the current datetime string appended (.e.g IMPORT_{timestamp_string})."
)
print(
"5) Create/add to board named 'IMPORT' with a the original file app_version appended (.e.g IMPORT_2.2.5)."
)
input_option = input("Specify desired board option: ")
match (input_option):
case "1":
if len(board_names) < 1:
print("\r\nThere are no existing board names to choose from. Select another option!")
continue
board_name = self.select_item_from_list(
board_names, "board name", True, "Cancel, go back and choose a different board option."
)
if board_name is not None:
return board_name
case "2":
while True:
board_name = input("Specify new/existing board name: ")
if board_name:
return board_name
case "3":
return "IMPORT"
case "4":
return f"IMPORT_{timestamp_string}"
case "5":
return "IMPORT_APPVERSION"
def select_item_from_list(self, items, entity_name, allow_cancel, cancel_string):
"""A general function to render a list of items to select in the console, prompt the user for a selection and ensure a valid entry is selected."""
print(f"Select a {entity_name.lower()} from the following list:")
index = 1
for item in items:
print(f"{index}) {item}")
index += 1
if allow_cancel:
print(f"{index}) {cancel_string}")
while True:
try:
option_number = int(input("Specify number of selection: "))
except ValueError:
continue
if allow_cancel and option_number == index:
return None
if option_number >= 1 and option_number <= len(items):
return items[option_number - 1]
def import_image(self, filepath: str, board_name_option: str, db_mapper: DatabaseMapper, config: Config):
"""Import a single file by its path"""
parser = InvokeAIMetadataParser()
file_name = os.path.basename(filepath)
file_destination_path = os.path.join(config.outputs_path, file_name)
print("===============================================================================")
print(f"Importing {filepath}")
# check destination to see if the file was previously imported
if os.path.exists(file_destination_path):
print("File already exists in the destination, skipping!")
ImportStats.count_skipped_file_exists += 1
return
# check if file name is already referenced in the database
if db_mapper.does_image_exist(file_name):
print("A reference to a file with this name already exists in the database, skipping!")
ImportStats.count_skipped_db_exists += 1
return
# load image info and dimensions
img_info, png_width, png_height = self.get_file_details(filepath)
# parse metadata
destination_needs_meta_update = True
log_version_note = "(Unknown)"
if "invokeai_metadata" in img_info:
# for the latest, we will just re-emit the same json, no need to parse/modify
converted_field = None
latest_json_string = img_info.get("invokeai_metadata")
log_version_note = "3.0.0+"
destination_needs_meta_update = False
else:
if "sd-metadata" in img_info:
converted_field = parser.parse_meta_tag_sd_metadata(json.loads(img_info.get("sd-metadata")))
elif "invokeai" in img_info:
converted_field = parser.parse_meta_tag_invokeai(json.loads(img_info.get("invokeai")))
elif "dream" in img_info:
converted_field = parser.parse_meta_tag_dream(img_info.get("dream"))
elif "Dream" in img_info:
converted_field = parser.parse_meta_tag_dream(img_info.get("Dream"))
else:
converted_field = InvokeAIMetadata()
destination_needs_meta_update = False
print("File does not have metadata from known Invoke AI versions, add only, no update!")
# use the loaded img dimensions if the metadata didnt have them
if converted_field.width is None:
converted_field.width = png_width
if converted_field.height is None:
converted_field.height = png_height
log_version_note = converted_field.imported_app_version if converted_field else "NoVersion"
log_version_note = log_version_note or "NoVersion"
latest_json_string = converted_field.to_json()
print(f"From Invoke AI Version {log_version_note} with dimensions {png_width} x {png_height}.")
# if metadata needs update, then update metdata and copy in one shot
if destination_needs_meta_update:
print("Updating metadata while copying...", end="")
self.update_file_metadata_while_copying(
filepath, file_destination_path, "invokeai_metadata", latest_json_string
)
print("Done!")
else:
print("No metadata update necessary, copying only...", end="")
shutil.copy2(filepath, file_destination_path)
print("Done!")
# create thumbnail
print("Creating thumbnail...", end="")
thumbnail_path = os.path.join(config.thumbnail_path, os.path.splitext(file_name)[0]) + ".webp"
thumbnail_size = 256, 256
with PIL.Image.open(filepath) as source_image:
source_image.thumbnail(thumbnail_size)
source_image.save(thumbnail_path, "webp")
print("Done!")
# finalize the dynamic board name if there is an APPVERSION token in it.
if converted_field is not None:
board_name = board_name_option.replace("APPVERSION", converted_field.imported_app_version or "NoVersion")
else:
board_name = board_name_option.replace("APPVERSION", "Latest")
# maintain a map of alrady created/looked up ids to avoid DB queries
print("Finding/Creating board...", end="")
if board_name in self.board_name_id_map:
board_id = self.board_name_id_map[board_name]
else:
board_id = db_mapper.get_board_id_with_create(board_name)
self.board_name_id_map[board_name] = board_id
print("Done!")
# add image to db
print("Adding image to database......", end="")
modified_time = datetime.datetime.utcfromtimestamp(os.path.getmtime(filepath))
db_mapper.add_new_image_to_database(file_name, png_width, png_height, latest_json_string, modified_time)
print("Done!")
# add image to board
print("Adding image to board......", end="")
db_mapper.add_image_to_board(file_name, board_id)
print("Done!")
ImportStats.count_imported += 1
if log_version_note in ImportStats.count_imported_by_version:
ImportStats.count_imported_by_version[log_version_note] += 1
else:
ImportStats.count_imported_by_version[log_version_note] = 1
def update_file_metadata_while_copying(self, filepath, file_destination_path, tag_name, tag_value):
"""Perform a metadata update with save to a new destination which accomplishes a copy while updating metadata."""
with PIL.Image.open(filepath) as target_image:
existing_img_info = target_image.info
metadata = PIL.PngImagePlugin.PngInfo()
# re-add any existing invoke ai tags unless they are the one we are trying to add
for key in existing_img_info:
if key != tag_name and key in ("dream", "Dream", "sd-metadata", "invokeai", "invokeai_metadata"):
metadata.add_text(key, existing_img_info[key])
metadata.add_text(tag_name, tag_value)
target_image.save(file_destination_path, pnginfo=metadata)
def process(self):
"""Begin main processing."""
print("===============================================================================")
print("This script will import images generated by earlier versions of")
print("InvokeAI into the currently installed root directory:")
print(f" {app_config.root_path}")
print("If this is not what you want to do, type ctrl-C now to cancel.")
# load config
print("===============================================================================")
print("= Configuration & Settings")
config = Config()
config.find_and_load()
db_mapper = DatabaseMapper(config.database_path, config.database_backup_dir)
db_mapper.connect()
import_dir, is_recurse, import_file_list = self.get_import_file_list()
ImportStats.count_source_files = len(import_file_list)
board_names = db_mapper.get_board_names()
board_name_option = self.select_board_option(board_names, config.TIMESTAMP_STRING)
print("\r\n===============================================================================")
print("= Import Settings Confirmation")
print()
print(f"Database File Path : {config.database_path}")
print(f"Outputs/Images Directory : {config.outputs_path}")
print(f"Import Image Source Directory : {import_dir}")
print(f" Recurse Source SubDirectories : {'Yes' if is_recurse else 'No'}")
print(f"Count of .png file(s) found : {len(import_file_list)}")
print(f"Board name option specified : {board_name_option}")
print(f"Database backup will be taken at : {config.database_backup_dir}")
print("\r\nNotes about the import process:")
print("- Source image files will not be modified, only copied to the outputs directory.")
print("- If the same file name already exists in the destination, the file will be skipped.")
print("- If the same file name already has a record in the database, the file will be skipped.")
print("- Invoke AI metadata tags will be updated/written into the imported copy only.")
print(
"- On the imported copy, only Invoke AI known tags (latest and legacy) will be retained (dream, sd-metadata, invokeai, invokeai_metadata)"
)
print(
"- A property 'imported_app_version' will be added to metadata that can be viewed in the UI's metadata viewer."
)
print(
"- The new 3.x InvokeAI outputs folder structure is flat so recursively found source imges will all be placed into the single outputs/images folder."
)
while True:
should_continue = prompt("\nDo you wish to continue with the import [Yn] ? ").lower() or "y"
if should_continue == "n":
print("\r\nCancelling Import")
return
elif should_continue == "y":
print()
break
db_mapper.backup(config.TIMESTAMP_STRING)
print()
ImportStats.time_start = datetime.datetime.utcnow()
for filepath in import_file_list:
try:
self.import_image(filepath, board_name_option, db_mapper, config)
except sqlite3.Error as sql_ex:
print(f"A database related exception was found processing {filepath}, will continue to next file. ")
print("Exception detail:")
print(sql_ex)
ImportStats.count_file_errors += 1
except Exception as ex:
print(f"Exception processing {filepath}, will continue to next file. ")
print("Exception detail:")
print(ex)
ImportStats.count_file_errors += 1
print("\r\n===============================================================================")
print(f"= Import Complete - Elpased Time: {ImportStats.get_elapsed_time_string()}")
print()
print(f"Source File(s) : {ImportStats.count_source_files}")
print(f"Total Imported : {ImportStats.count_imported}")
print(f"Skipped b/c file already exists on disk : {ImportStats.count_skipped_file_exists}")
print(f"Skipped b/c file already exists in db : {ImportStats.count_skipped_db_exists}")
print(f"Errors during import : {ImportStats.count_file_errors}")
if ImportStats.count_imported > 0:
print("\r\nBreakdown of imported files by version:")
for key, version in ImportStats.count_imported_by_version.items():
print(f" {key:20} : {version}")
def main():
try:
processor = MediaImportProcessor()
processor.process()
except KeyboardInterrupt:
print("\r\n\r\nUser cancelled execution.")
if __name__ == "__main__":
main()

View File

@ -1,4 +1,4 @@
"""
Wrapper for invokeai.backend.configure.invokeai_configure
"""
from ...backend.install.invokeai_configure import main
from ...backend.install.invokeai_configure import main as invokeai_configure

View File

@ -28,7 +28,6 @@ from npyscreen import widget
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.install.model_install_backend import (
ModelInstallList,
InstallSelections,
ModelInstall,
SchedulerPredictionType,
@ -41,12 +40,12 @@ from invokeai.frontend.install.widgets import (
SingleSelectColumns,
TextBox,
BufferBox,
FileBox,
set_min_terminal_size,
select_stable_diffusion_config_file,
CyclingForm,
MIN_COLS,
MIN_LINES,
WindowTooSmallException,
)
from invokeai.app.services.config import InvokeAIAppConfig
@ -156,7 +155,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
BufferBox,
name="Log Messages",
editable=False,
max_height=15,
max_height=6,
)
self.nextrely += 1
@ -693,7 +692,11 @@ def select_and_download_models(opt: Namespace):
# needed to support the probe() method running under a subprocess
torch.multiprocessing.set_start_method("spawn")
set_min_terminal_size(MIN_COLS, MIN_LINES)
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
raise WindowTooSmallException(
"Could not increase terminal size. Try running again with a larger window or smaller font size."
)
installApp = AddModelApplication(opt)
try:
installApp.run()
@ -787,6 +790,8 @@ def main():
curses.echo()
curses.endwin()
logger.info("Goodbye! Come back soon.")
except WindowTooSmallException as e:
logger.error(str(e))
except widget.NotEnoughSpaceForWidget as e:
if str(e).startswith("Height of 1 allocated"):
logger.error("Insufficient vertical space for the interface. Please make your window taller and try again")

View File

@ -21,31 +21,40 @@ MIN_COLS = 130
MIN_LINES = 38
class WindowTooSmallException(Exception):
pass
# -------------------------------------
def set_terminal_size(columns: int, lines: int):
ts = get_terminal_size()
width = max(columns, ts.columns)
height = max(lines, ts.lines)
def set_terminal_size(columns: int, lines: int) -> bool:
OS = platform.uname().system
if OS == "Windows":
pass
# not working reliably - ask user to adjust the window
# _set_terminal_size_powershell(width,height)
elif OS in ["Darwin", "Linux"]:
_set_terminal_size_unix(width, height)
screen_ok = False
while not screen_ok:
ts = get_terminal_size()
width = max(columns, ts.columns)
height = max(lines, ts.lines)
# check whether it worked....
ts = get_terminal_size()
pause = False
if ts.columns < columns:
print("\033[1mThis window is too narrow for the user interface.\033[0m")
pause = True
if ts.lines < lines:
print("\033[1mThis window is too short for the user interface.\033[0m")
pause = True
if pause:
input("Maximize the window then press any key to continue..")
if OS == "Windows":
pass
# not working reliably - ask user to adjust the window
# _set_terminal_size_powershell(width,height)
elif OS in ["Darwin", "Linux"]:
_set_terminal_size_unix(width, height)
# check whether it worked....
ts = get_terminal_size()
if ts.columns < columns or ts.lines < lines:
print(
f"\033[1mThis window is too small for the interface. InvokeAI requires {columns}x{lines} (w x h) characters, but window is {ts.columns}x{ts.lines}\033[0m"
)
resp = input(
"Maximize the window and/or decrease the font size then press any key to continue. Type [Q] to give up.."
)
if resp.upper().startswith("Q"):
break
else:
screen_ok = True
return screen_ok
def _set_terminal_size_powershell(width: int, height: int):
@ -80,14 +89,14 @@ def _set_terminal_size_unix(width: int, height: int):
sys.stdout.flush()
def set_min_terminal_size(min_cols: int, min_lines: int):
def set_min_terminal_size(min_cols: int, min_lines: int) -> bool:
# make sure there's enough room for the ui
term_cols, term_lines = get_terminal_size()
if term_cols >= min_cols and term_lines >= min_lines:
return
return True
cols = max(term_cols, min_cols)
lines = max(term_lines, min_lines)
set_terminal_size(cols, lines)
return set_terminal_size(cols, lines)
class IntSlider(npyscreen.Slider):
@ -164,7 +173,7 @@ class FloatSlider(npyscreen.Slider):
class FloatTitleSlider(npyscreen.TitleText):
_entry_type = FloatSlider
_entry_type = npyscreen.Slider
class SelectColumnBase:

View File

@ -382,7 +382,8 @@ def run_cli(args: Namespace):
def main():
args = _parse_args()
config.parse_args(["--root", str(args.root_dir)])
if args.root_dir:
config.parse_args(["--root", str(args.root_dir)])
try:
if args.front_end:

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,4 @@
import{B as m,g7 as Je,A as y,a5 as Ka,g8 as Xa,af as va,aj as d,g9 as b,ga as t,gb as Ya,gc as h,gd as ua,ge as Ja,gf as Qa,aL as Za,gg as et,ad as rt,gh as at}from"./index-de589048.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./menu-11348abc.js";var za=String.raw,Ca=za`
import{B as m,g7 as Je,A as y,a5 as Ka,g8 as Xa,af as va,aj as d,g9 as b,ga as t,gb as Ya,gc as h,gd as ua,ge as Ja,gf as Qa,aL as Za,gg as et,ad as rt,gh as at}from"./index-deaa1f26.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./menu-b4489359.js";var za=String.raw,Ca=za`
:root,
:host {
--chakra-vh: 100vh;

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-de589048.js"></script>
<script type="module" crossorigin src="./assets/index-deaa1f26.js"></script>
</head>
<body dir="ltr">

View File

@ -20,7 +20,7 @@ export const addStagingAreaImageSavedListener = () => {
// we may need to add it to the autoadd board
const { autoAddBoardId } = getState().gallery;
if (autoAddBoardId) {
if (autoAddBoardId && autoAddBoardId !== 'none') {
await dispatch(
imagesApi.endpoints.addImageToBoard.initiate({
imageDTO: newImageDTO,

View File

@ -1,55 +1,58 @@
import { modelChanged } from 'features/parameters/store/generationSlice';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import {
MainModelConfigEntity,
modelsApi,
} from 'services/api/endpoints/models';
import { mainModelsAdapter, modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..';
export const addTabChangedListener = () => {
startAppListening({
actionCreator: setActiveTab,
effect: (action, { getState, dispatch }) => {
effect: async (action, { getState, dispatch }) => {
const activeTabName = action.payload;
if (activeTabName === 'unifiedCanvas') {
// grab the models from RTK Query cache
const { data } = modelsApi.endpoints.getMainModels.select(
NON_REFINER_BASE_MODELS
)(getState());
const currentBaseModel = getState().generation.model?.base_model;
if (!data) {
// no models yet, so we can't do anything
dispatch(modelChanged(null));
if (currentBaseModel && ['sd-1', 'sd-2'].includes(currentBaseModel)) {
// if we're already on a valid model, no change needed
return;
}
// need to filter out all the invalid canvas models (currently, this is just sdxl)
const validCanvasModels: MainModelConfigEntity[] = [];
try {
// just grab fresh models
const modelsRequest = dispatch(
modelsApi.endpoints.getMainModels.initiate(NON_REFINER_BASE_MODELS)
);
const models = await modelsRequest.unwrap();
// cancel this cache subscription
modelsRequest.unsubscribe();
forEach(data.entities, (entity) => {
if (!entity) {
if (!models.ids.length) {
// no valid canvas models
dispatch(modelChanged(null));
return;
}
if (['sd-1', 'sd-2'].includes(entity.base_model)) {
validCanvasModels.push(entity);
// need to filter out all the invalid canvas models (currently sdxl & refiner)
const validCanvasModels = mainModelsAdapter
.getSelectors()
.selectAll(models)
.filter((model) => ['sd-1', 'sd-2'].includes(model.base_model));
const firstValidCanvasModel = validCanvasModels[0];
if (!firstValidCanvasModel) {
// no valid canvas models
dispatch(modelChanged(null));
return;
}
});
// this could still be undefined even tho TS doesn't say so
const firstValidCanvasModel = validCanvasModels[0];
const { base_model, model_name, model_type } = firstValidCanvasModel;
if (!firstValidCanvasModel) {
// uh oh, we have no models that are valid for canvas
dispatch(modelChanged({ base_model, model_name, model_type }));
} catch {
// network request failed, bail
dispatch(modelChanged(null));
return;
}
// only store the model name and base model in redux
const { base_model, model_name, model_type } = firstValidCanvasModel;
dispatch(modelChanged({ base_model, model_name, model_type }));
}
},
});

View File

@ -96,7 +96,8 @@ export type AppFeature =
| 'consoleLogging'
| 'dynamicPrompting'
| 'batches'
| 'syncModels';
| 'syncModels'
| 'multiselect';
/**
* A disable-able Stable Diffusion feature

View File

@ -9,6 +9,7 @@ import { useListImagesQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { selectionChanged } from '../store/gallerySlice';
import { imagesSelectors } from 'services/api/util';
import { useFeatureStatus } from '../../system/hooks/useFeatureStatus';
const selector = createSelector(
[stateSelector, selectListImagesBaseQueryArgs],
@ -33,11 +34,18 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
}),
});
const isMultiSelectEnabled = useFeatureStatus('multiselect').isFeatureEnabled;
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (!imageDTO) {
return;
}
if (!isMultiSelectEnabled) {
dispatch(selectionChanged([imageDTO]));
return;
}
if (e.shiftKey) {
const rangeEndImageName = imageDTO.image_name;
const lastSelectedImage = selection[selection.length - 1]?.image_name;
@ -71,7 +79,7 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
dispatch(selectionChanged([imageDTO]));
}
},
[dispatch, imageDTO, imageDTOs, selection]
[dispatch, imageDTO, imageDTOs, selection, isMultiSelectEnabled]
);
const isSelected = useMemo(

View File

@ -31,7 +31,7 @@ const ParamLoraCollapse = () => {
}
return (
<IAICollapse label={'LoRA'} activeLabel={activeLabel}>
<IAICollapse label="LoRA" activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 2 }}>
<ParamLoRASelect />
<ParamLoraList />

View File

@ -1,3 +1,4 @@
import { Divider } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
@ -8,20 +9,21 @@ import ParamLora from './ParamLora';
const selector = createSelector(
stateSelector,
({ lora }) => {
const { loras } = lora;
return { loras };
return { lorasArray: map(lora.loras) };
},
defaultSelectorOptions
);
const ParamLoraList = () => {
const { loras } = useAppSelector(selector);
const { lorasArray } = useAppSelector(selector);
return (
<>
{map(loras, (lora) => (
<ParamLora key={lora.model_name} lora={lora} />
{lorasArray.map((lora, i) => (
<>
{i > 0 && <Divider key={`${lora.model_name}-divider`} pt={1} />}
<ParamLora key={lora.model_name} lora={lora} />
</>
))}
</>
);

View File

@ -54,12 +54,7 @@ const ParamLoRASelect = () => {
});
});
// Sort Alphabetically
data.sort((a, b) =>
a.label && b.label ? (a.label?.localeCompare(b.label) ? 1 : -1) : -1
);
return data.sort((a, b) => (a.disabled && !b.disabled ? -1 : 1));
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [loras, loraModels, currentMainModel?.base_model]);
const handleChange = useCallback(

View File

@ -9,7 +9,6 @@ import {
CLIP_SKIP,
LORA_LOADER,
MAIN_MODEL_LOADER,
ONNX_MODEL_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING,
@ -36,15 +35,11 @@ export const addLoRAsToGraph = (
| undefined;
if (loraCount > 0) {
// Remove MAIN_MODEL_LOADER unet connection to feed it to LoRAs
// Remove modelLoaderNodeId unet connection to feed it to LoRAs
graph.edges = graph.edges.filter(
(e) =>
!(
e.source.node_id === MAIN_MODEL_LOADER &&
['unet'].includes(e.source.field)
) &&
!(
e.source.node_id === ONNX_MODEL_LOADER &&
e.source.node_id === modelLoaderNodeId &&
['unet'].includes(e.source.field)
)
);

View File

@ -0,0 +1,212 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { forEach, size } from 'lodash-es';
import {
MetadataAccumulatorInvocation,
SDXLLoraLoaderInvocation,
} from 'services/api/types';
import {
LORA_LOADER,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING,
SDXL_MODEL_LOADER,
} from './constants';
export const addSDXLLoRAsToGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string,
modelLoaderNodeId: string = SDXL_MODEL_LOADER
): void => {
/**
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
* or to the inference/conditioning nodes.
*
* So we need to inject a LoRA chain into the graph.
*/
const { loras } = state.lora;
const loraCount = size(loras);
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (loraCount > 0) {
// Remove modelLoaderNodeId unet/clip/clip2 connections to feed it to LoRAs
graph.edges = graph.edges.filter(
(e) =>
!(
e.source.node_id === modelLoaderNodeId &&
['unet'].includes(e.source.field)
) &&
!(
e.source.node_id === modelLoaderNodeId &&
['clip'].includes(e.source.field)
) &&
!(
e.source.node_id === modelLoaderNodeId &&
['clip2'].includes(e.source.field)
)
);
}
// we need to remember the last lora so we can chain from it
let lastLoraNodeId = '';
let currentLoraIndex = 0;
forEach(loras, (lora) => {
const { model_name, base_model, weight } = lora;
const currentLoraNodeId = `${LORA_LOADER}_${model_name.replace('.', '_')}`;
const loraLoaderNode: SDXLLoraLoaderInvocation = {
type: 'sdxl_lora_loader',
id: currentLoraNodeId,
is_intermediate: true,
lora: { model_name, base_model },
weight,
};
// add the lora to the metadata accumulator
if (metadataAccumulator) {
metadataAccumulator.loras.push({
lora: { model_name, base_model },
weight,
});
}
// add to graph
graph.nodes[currentLoraNodeId] = loraLoaderNode;
if (currentLoraIndex === 0) {
// first lora = start the lora chain, attach directly to model loader
graph.edges.push({
source: {
node_id: modelLoaderNodeId,
field: 'unet',
},
destination: {
node_id: currentLoraNodeId,
field: 'unet',
},
});
graph.edges.push({
source: {
node_id: modelLoaderNodeId,
field: 'clip',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip',
},
});
graph.edges.push({
source: {
node_id: modelLoaderNodeId,
field: 'clip2',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip2',
},
});
} else {
// we are in the middle of the lora chain, instead connect to the previous lora
graph.edges.push({
source: {
node_id: lastLoraNodeId,
field: 'unet',
},
destination: {
node_id: currentLoraNodeId,
field: 'unet',
},
});
graph.edges.push({
source: {
node_id: lastLoraNodeId,
field: 'clip',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip',
},
});
graph.edges.push({
source: {
node_id: lastLoraNodeId,
field: 'clip2',
},
destination: {
node_id: currentLoraNodeId,
field: 'clip2',
},
});
}
if (currentLoraIndex === loraCount - 1) {
// final lora, end the lora chain - we need to connect up to inference and conditioning nodes
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'unet',
},
destination: {
node_id: baseNodeId,
field: 'unet',
},
});
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
});
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
});
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip2',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip2',
},
});
graph.edges.push({
source: {
node_id: currentLoraNodeId,
field: 'clip2',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip2',
},
});
}
// increment the lora for the next one in the chain
lastLoraNodeId = currentLoraNodeId;
currentLoraIndex += 1;
});
};

View File

@ -22,6 +22,7 @@ import {
SDXL_LATENTS_TO_LATENTS,
SDXL_MODEL_LOADER,
} from './constants';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
/**
* Builds the Image to Image tab graph.
@ -364,6 +365,8 @@ export const buildLinearSDXLImageToImageGraph = (
},
});
addSDXLLoRAsToGraph(state, graph, SDXL_LATENTS_TO_LATENTS, SDXL_MODEL_LOADER);
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_LATENTS_TO_LATENTS);

View File

@ -4,6 +4,7 @@ import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -246,6 +247,8 @@ export const buildLinearSDXLTextToImageGraph = (
},
});
addSDXLLoRAsToGraph(state, graph, SDXL_TEXT_TO_LATENTS, SDXL_MODEL_LOADER);
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_TEXT_TO_LATENTS);

View File

@ -4,6 +4,7 @@ import ProcessButtons from 'features/parameters/components/ProcessButtons/Proces
import ParamSDXLPromptArea from './ParamSDXLPromptArea';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
import SDXLImageToImageTabCoreParameters from './SDXLImageToImageTabCoreParameters';
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
const SDXLImageToImageTabParameters = () => {
return (
@ -12,6 +13,7 @@ const SDXLImageToImageTabParameters = () => {
<ProcessButtons />
<SDXLImageToImageTabCoreParameters />
<ParamSDXLRefinerCollapse />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse />
<ParamNoiseCollapse />
</>

View File

@ -4,6 +4,7 @@ import ProcessButtons from 'features/parameters/components/ProcessButtons/Proces
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
import ParamSDXLPromptArea from './ParamSDXLPromptArea';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
const SDXLTextToImageTabParameters = () => {
return (
@ -12,6 +13,7 @@ const SDXLTextToImageTabParameters = () => {
<ProcessButtons />
<TextToImageTabCoreParameters />
<ParamSDXLRefinerCollapse />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse />
<ParamNoiseCollapse />
</>

View File

@ -365,12 +365,19 @@ export const systemSlice = createSlice({
state.statusTranslationKey = 'common.statusConnected';
state.progressImage = null;
let errorDescription = undefined;
if (action.payload?.status === 422) {
errorDescription = 'Validation Error';
} else if (action.payload?.error) {
errorDescription = action.payload?.error as string;
}
state.toastQueue.push(
makeToast({
title: t('toast.serverError'),
status: 'error',
description:
action.payload?.status === 422 ? 'Validation Error' : undefined,
description: errorDescription,
})
);
});

View File

@ -4,6 +4,7 @@ import {
ASSETS_CATEGORIES,
BoardId,
IMAGE_CATEGORIES,
IMAGE_LIMIT,
} from 'features/gallery/store/types';
import { keyBy } from 'lodash';
import { ApiFullTagDescription, LIST_TAG, api } from '..';
@ -167,7 +168,14 @@ export const imagesApi = api.injectEndpoints({
},
};
},
invalidatesTags: (result, error, imageDTOs) => [],
invalidatesTags: (result, error, { imageDTOs }) => {
// for now, assume bulk delete is all on one board
const boardId = imageDTOs[0]?.board_id;
return [
{ type: 'BoardImagesTotal', id: boardId ?? 'none' },
{ type: 'BoardAssetsTotal', id: boardId ?? 'none' },
];
},
async onQueryStarted({ imageDTOs }, { dispatch, queryFulfilled }) {
/**
* Cache changes for `deleteImages`:
@ -889,18 +897,25 @@ export const imagesApi = api.injectEndpoints({
board_id,
},
}),
invalidatesTags: (result, error, { board_id }) => [
// update the destination board
{ type: 'Board', id: board_id ?? 'none' },
// update old board totals
{ type: 'BoardImagesTotal', id: board_id ?? 'none' },
{ type: 'BoardAssetsTotal', id: board_id ?? 'none' },
// update the no_board totals
{ type: 'BoardImagesTotal', id: 'none' },
{ type: 'BoardAssetsTotal', id: 'none' },
],
invalidatesTags: (result, error, { imageDTOs, board_id }) => {
//assume all images are being moved from one board for now
const oldBoardId = imageDTOs[0]?.board_id;
return [
// update the destination board
{ type: 'Board', id: board_id ?? 'none' },
// update new board totals
{ type: 'BoardImagesTotal', id: board_id ?? 'none' },
{ type: 'BoardAssetsTotal', id: board_id ?? 'none' },
// update old board totals
{ type: 'BoardImagesTotal', id: oldBoardId ?? 'none' },
{ type: 'BoardAssetsTotal', id: oldBoardId ?? 'none' },
// update the no_board totals
{ type: 'BoardImagesTotal', id: 'none' },
{ type: 'BoardAssetsTotal', id: 'none' },
];
},
async onQueryStarted(
{ board_id, imageDTOs },
{ board_id: new_board_id, imageDTOs },
{ dispatch, queryFulfilled, getState }
) {
try {
@ -920,7 +935,7 @@ export const imagesApi = api.injectEndpoints({
'getImageDTO',
image_name,
(draft) => {
draft.board_id = board_id;
draft.board_id = new_board_id;
}
)
);
@ -946,7 +961,7 @@ export const imagesApi = api.injectEndpoints({
);
const queryArgs = {
board_id,
board_id: new_board_id,
categories,
};
@ -954,23 +969,24 @@ export const imagesApi = api.injectEndpoints({
queryArgs
)(getState());
const { data: total } = IMAGE_CATEGORIES.includes(
const { data: previousTotal } = IMAGE_CATEGORIES.includes(
imageDTO.image_category
)
? boardsApi.endpoints.getBoardImagesTotal.select(
imageDTO.board_id ?? 'none'
new_board_id ?? 'none'
)(getState())
: boardsApi.endpoints.getBoardAssetsTotal.select(
imageDTO.board_id ?? 'none'
new_board_id ?? 'none'
)(getState());
const isCacheFullyPopulated =
currentCache.data && currentCache.data.ids.length >= (total ?? 0);
currentCache.data &&
currentCache.data.ids.length >= (previousTotal ?? 0);
const isInDateRange = getIsImageInDateRange(
currentCache.data,
imageDTO
);
const isInDateRange =
(previousTotal || 0) >= IMAGE_LIMIT
? getIsImageInDateRange(currentCache.data, imageDTO)
: true;
if (isCacheFullyPopulated || isInDateRange) {
// *upsert* to $cache
@ -981,7 +997,7 @@ export const imagesApi = api.injectEndpoints({
(draft) => {
imagesAdapter.upsertOne(draft, {
...imageDTO,
board_id,
board_id: new_board_id,
});
}
)
@ -1097,10 +1113,10 @@ export const imagesApi = api.injectEndpoints({
const isCacheFullyPopulated =
currentCache.data && currentCache.data.ids.length >= (total ?? 0);
const isInDateRange = getIsImageInDateRange(
currentCache.data,
imageDTO
);
const isInDateRange =
(total || 0) >= IMAGE_LIMIT
? getIsImageInDateRange(currentCache.data, imageDTO)
: true;
if (isCacheFullyPopulated || isInDateRange) {
// *upsert* to $cache
@ -1111,7 +1127,7 @@ export const imagesApi = api.injectEndpoints({
(draft) => {
imagesAdapter.upsertOne(draft, {
...imageDTO,
board_id: undefined,
board_id: 'none',
});
}
)

View File

@ -1443,7 +1443,7 @@ export type components = {
* @description The nodes in this graph
*/
nodes?: {
[key: string]: (components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
[key: string]: (components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
};
/**
* Edges
@ -1486,7 +1486,7 @@ export type components = {
* @description The results of node executions
*/
results: {
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["VaeLoaderOutput"] | components["schemas"]["MetadataAccumulatorOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
};
/**
* Errors
@ -1904,6 +1904,40 @@ export type components = {
*/
image_name: string;
};
/**
* ImageHueAdjustmentInvocation
* @description Adjusts the Hue of an image.
*/
ImageHueAdjustmentInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default img_hue_adjust
* @enum {string}
*/
type?: "img_hue_adjust";
/**
* Image
* @description The image to adjust
*/
image?: components["schemas"]["ImageField"];
/**
* Hue
* @description The degrees by which to rotate the hue, 0-360
* @default 0
*/
hue?: number;
};
/**
* ImageInverseLerpInvocation
* @description Inverse linear interpolation of all pixels of an image
@ -1984,6 +2018,40 @@ export type components = {
*/
max?: number;
};
/**
* ImageLuminosityAdjustmentInvocation
* @description Adjusts the Luminosity (Value) of an image.
*/
ImageLuminosityAdjustmentInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default img_luminosity_adjust
* @enum {string}
*/
type?: "img_luminosity_adjust";
/**
* Image
* @description The image to adjust
*/
image?: components["schemas"]["ImageField"];
/**
* Luminosity
* @description The factor by which to adjust the luminosity (value)
* @default 1
*/
luminosity?: number;
};
/**
* ImageMetadata
* @description An image's generation metadata
@ -2239,6 +2307,40 @@ export type components = {
*/
resample_mode?: "nearest" | "box" | "bilinear" | "hamming" | "bicubic" | "lanczos";
};
/**
* ImageSaturationAdjustmentInvocation
* @description Adjusts the Saturation of an image.
*/
ImageSaturationAdjustmentInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default img_saturation_adjust
* @enum {string}
*/
type?: "img_saturation_adjust";
/**
* Image
* @description The image to adjust
*/
image?: components["schemas"]["ImageField"];
/**
* Saturation
* @description The factor by which to adjust the saturation
* @default 1
*/
saturation?: number;
};
/**
* ImageScaleInvocation
* @description Scales an image by a factor
@ -4912,6 +5014,82 @@ export type components = {
*/
denoising_end?: number;
};
/**
* SDXLLoraLoaderInvocation
* @description Apply selected lora to unet and text_encoder.
*/
SDXLLoraLoaderInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default sdxl_lora_loader
* @enum {string}
*/
type?: "sdxl_lora_loader";
/**
* Lora
* @description Lora model name
*/
lora?: components["schemas"]["LoRAModelField"];
/**
* Weight
* @description With what weight to apply lora
* @default 0.75
*/
weight?: number;
/**
* Unet
* @description UNet model for applying lora
*/
unet?: components["schemas"]["UNetField"];
/**
* Clip
* @description Clip model for applying lora
*/
clip?: components["schemas"]["ClipField"];
/**
* Clip2
* @description Clip2 model for applying lora
*/
clip2?: components["schemas"]["ClipField"];
};
/**
* SDXLLoraLoaderOutput
* @description Model loader output
*/
SDXLLoraLoaderOutput: {
/**
* Type
* @default sdxl_lora_loader_output
* @enum {string}
*/
type?: "sdxl_lora_loader_output";
/**
* Unet
* @description UNet submodel
*/
unet?: components["schemas"]["UNetField"];
/**
* Clip
* @description Tokenizer and text_encoder submodels
*/
clip?: components["schemas"]["ClipField"];
/**
* Clip2
* @description Tokenizer2 and text_encoder2 submodels
*/
clip2?: components["schemas"]["ClipField"];
};
/**
* SDXLModelLoaderInvocation
* @description Loads an sdxl base model, outputting its submodels.
@ -5961,6 +6139,24 @@ export type components = {
*/
image?: components["schemas"]["ImageField"];
};
/**
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionOnnxModelFormat
* @description An enumeration.
@ -5973,24 +6169,6 @@ export type components = {
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/**
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;
@ -6101,7 +6279,7 @@ export type operations = {
};
requestBody: {
content: {
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
};
};
responses: {
@ -6138,7 +6316,7 @@ export type operations = {
};
requestBody: {
content: {
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
"application/json": components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["MetadataAccumulatorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRawPromptInvocation"] | components["schemas"]["SDXLRefinerRawPromptInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageLuminosityAdjustmentInvocation"] | components["schemas"]["ImageSaturationAdjustmentInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SDXLTextToLatentsInvocation"] | components["schemas"]["SDXLLatentsToLatentsInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["ONNXSD1ModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["ParamStringInvocation"] | components["schemas"]["ParamPromptInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
};
};
responses: {

View File

@ -60,6 +60,9 @@ type InvokedSessionThunkConfig = {
const isErrorWithStatus = (error: unknown): error is { status: number } =>
isObject(error) && 'status' in error;
const isErrorWithDetail = (error: unknown): error is { detail: string } =>
isObject(error) && 'detail' in error;
/**
* `SessionsService.invokeSession()` thunk
*/
@ -85,7 +88,15 @@ export const sessionInvoked = createAsyncThunk<
error: (error as any).body.detail,
});
}
return rejectWithValue({ arg, status: response.status, error });
if (isErrorWithDetail(error) && response.status === 403) {
return rejectWithValue({
arg,
status: response.status,
error: error.detail
});
}
if (error)
return rejectWithValue({ arg, status: response.status, error });
}
});

View File

@ -166,6 +166,9 @@ export type OnnxModelLoaderInvocation = TypeReq<
export type LoraLoaderInvocation = TypeReq<
components['schemas']['LoraLoaderInvocation']
>;
export type SDXLLoraLoaderInvocation = TypeReq<
components['schemas']['SDXLLoraLoaderInvocation']
>;
export type MetadataAccumulatorInvocation = TypeReq<
components['schemas']['MetadataAccumulatorInvocation']
>;

View File

@ -1 +1 @@
__version__ = "3.0.2a1"
__version__ = "3.0.2post1"

View File

@ -77,7 +77,7 @@ dependencies = [
"realesrgan",
"requests~=2.28.2",
"rich~=13.3",
"safetensors~=0.3.0",
"safetensors==0.3.1",
"scikit-image~=0.21.0",
"send2trash",
"test-tube~=0.7.5",
@ -100,7 +100,7 @@ dependencies = [
"dev" = [
"pudb",
]
"test" = ["pytest>6.0.0", "pytest-cov", "black"]
"test" = ["pytest>6.0.0", "pytest-cov", "pytest-datadir", "black"]
"xformers" = [
"xformers~=0.0.19; sys_platform!='darwin'",
"triton; sys_platform=='linux'",
@ -118,7 +118,7 @@ dependencies = [
[project.scripts]
# legacy entrypoints; provided for backwards compatibility
"configure_invokeai.py" = "invokeai.frontend.install:invokeai_configure"
"configure_invokeai.py" = "invokeai.frontend.install.invokeai_configure:invokeai_configure"
"textual_inversion.py" = "invokeai.frontend.training:invokeai_textual_inversion"
# shortcut commands to start cli and web
@ -130,15 +130,16 @@ dependencies = [
"invokeai-web" = "invokeai.app.api_app:invoke_api"
# full commands
"invokeai-configure" = "invokeai.frontend.install:invokeai_configure"
"invokeai-configure" = "invokeai.frontend.install.invokeai_configure:invokeai_configure"
"invokeai-merge" = "invokeai.frontend.merge:invokeai_merge_diffusers"
"invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion"
"invokeai-model-install" = "invokeai.frontend.install:invokeai_model_install"
"invokeai-model-install" = "invokeai.frontend.install.model_install:main"
"invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main"
"invokeai-update" = "invokeai.frontend.install:invokeai_update"
"invokeai-update" = "invokeai.frontend.install.invokeai_update:main"
"invokeai-metadata" = "invokeai.frontend.CLI.sd_metadata:print_metadata"
"invokeai-node-cli" = "invokeai.app.cli_app:invoke_cli"
"invokeai-node-web" = "invokeai.app.api_app:invoke_api"
"invokeai-import-images" = "invokeai.frontend.install.import_images:main"
[project.urls]
"Homepage" = "https://invoke-ai.github.io/InvokeAI/"

View File

@ -0,0 +1,34 @@
#!/usr/bin/env python
"""
Read a checkpoint/safetensors file and write out a template .json file containing
its metadata for use in fast model probing.
"""
import sys
import argparse
import json
from pathlib import Path
from invokeai.backend.model_management.models.base import read_checkpoint_meta
parser = argparse.ArgumentParser(description="Create a .json template from checkpoint/safetensors model")
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
parser.add_argument("--template", "--out", type=Path, help="Path to the output .json file")
opt = parser.parse_args()
ckpt = read_checkpoint_meta(opt.checkpoint)
while "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
tmpl = {}
for key, tensor in ckpt.items():
tmpl[key] = list(tensor.shape)
try:
with open(opt.template, "w") as f:
json.dump(tmpl, f)
print(f"Template written out as {opt.template}")
except Exception as e:
print(f"An exception occurred while writing template: {str(e)}")

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python
"""
Read a checkpoint/safetensors file and compare it to a template .json.
Returns True if their metadata match.
"""
import sys
import argparse
import json
from pathlib import Path
from invokeai.backend.model_management.models.base import read_checkpoint_meta
parser = argparse.ArgumentParser(description="Compare a checkpoint/safetensors file to a JSON metadata template.")
parser.add_argument("--checkpoint", "--in", type=Path, help="Path to the input checkpoint/safetensors file")
parser.add_argument("--template", "--out", type=Path, help="Path to the template .json file to match against")
opt = parser.parse_args()
ckpt = read_checkpoint_meta(opt.checkpoint)
while "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
checkpoint_metadata = {}
for key, tensor in ckpt.items():
checkpoint_metadata[key] = list(tensor.shape)
with open(opt.template, "r") as f:
template = json.load(f)
if checkpoint_metadata == template:
print("True")
sys.exit(0)
else:
print("False")
sys.exit(-1)

View File

@ -0,0 +1,47 @@
from pathlib import Path
import pytest
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelType
BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main)
VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
VAE_NULL_OVERRIDE_MODEL_NAME = ("SDXL with empty VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
@pytest.fixture
def model_manager(datadir) -> ModelManager:
InvokeAIAppConfig.get_config(root=datadir)
return ModelManager(datadir / "configs" / "relative_sub.models.yaml")
def test_get_model_names(model_manager: ModelManager):
names = model_manager.model_names()
assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME]
def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2])
top_model_path, is_override = model_manager._get_model_path(model_config)
expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0"
assert top_model_path == expected_model_path
assert not is_override
def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(
VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2]
)
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix"
assert vae_model_path == expected_vae_path
assert is_override
def test_get_model_path_for_null_overridden_vae(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(
VAE_NULL_OVERRIDE_MODEL_NAME[1], VAE_NULL_OVERRIDE_MODEL_NAME[0], VAE_NULL_OVERRIDE_MODEL_NAME[2]
)
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
assert not is_override

View File

@ -0,0 +1,22 @@
__metadata__:
version: 3.0.0
sdxl/main/SDXL base:
path: sdxl/main/SDXL base 1_0
description: SDXL base v1.0
variant: normal
format: diffusers
sdxl/main/SDXL with VAE:
path: sdxl/main/SDXL base 1_0
description: SDXL with customized VAE
vae: sdxl/vae/sdxl-vae-fp16-fix/
variant: normal
format: diffusers
sdxl/main/SDXL with empty VAE:
path: sdxl/main/SDXL base 1_0
description: SDXL with customized VAE
vae: ''
variant: normal
format: diffusers