Compare commits

..

157 Commits

Author SHA1 Message Date
b3384bbc09 fix(installer): use extra index url when updating
If we don't include this, on updating, we will always get the CPU torch/torchvision/xformers.
2023-11-15 20:34:42 +11:00
1ce8615745 fix: pin torch and torchvision exactly 2023-11-15 17:16:54 +11:00
e7019b6e98 Update invokeai_version.py 2023-11-15 15:19:57 +11:00
e96a404530 Updated build to test cu121 updates 2023-11-14 14:01:05 +11:00
a047bad391 Revert torch to use cu121 (#5091)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [X] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [x] Yes
- [ ] No


## Description


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-11-14 13:47:51 +11:00
909afc266e Update 010_INSTALL_AUTOMATED.md 2023-11-13 20:28:00 -05:00
4039dd148d Update 030_INSTALL_CUDA_AND_ROCM.md 2023-11-13 20:28:00 -05:00
ea0f8b8791 Update 020_INSTALL_MANUAL.md 2023-11-13 20:28:00 -05:00
f412582d60 Update README.md to cu121 2023-11-13 20:28:00 -05:00
c5672adb6b Update 070_INSTALL_XFORMERS.md 2023-11-13 20:28:00 -05:00
0e5c3a641a Revert torch to use cu121 2023-11-13 20:28:00 -05:00
9015e72e1e Update README.md to include M3 (#5092)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [ ] Optimization
- [x] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [x] No, because:

      
## Have you updated all relevant documentation?
- [x] Yes
- [ ] No


## Description


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [x] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-11-14 12:24:57 +11:00
6b05d27c7a Update 040_INSTALL_DOCKER.md 2023-11-14 12:22:46 +11:00
19d0673085 Update 010_INSTALL_AUTOMATED.md 2023-11-14 12:22:08 +11:00
048b4fe7e8 Update README.md to include M3 2023-11-13 19:11:31 -06:00
e8b83fecff fix(backend): apply clip skip after lora
This handles LoRAs that attempt to modify layers skipped by CLIP Skip.
2023-11-14 11:30:15 +11:00
8883ecb2bf Model Manager Refactor Phase 1 - SQL-based config storage (#5039)
## What type of PR is this? (check all applicable)

- [X] Refactor


## Have you discussed this change with the InvokeAI team?
- [X] Extensively
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [X] Yes
- [ ] No


## Description

As discussed with @psychedelicious and @RyanJDick, this is the first
phase of the model manager refactor. In this phase, I've added support
for storing model configuration information the `invokeai.db` SQL3
database. All the code is separate from the original model manager, so
for the time being the frontend is still using the original YAML-based
configuration, so the web app still works.

To keep things clean, I've added a new FastAPI route called
`model_records` which can add, update, retrieve and delete model
records.

The architecture is described in the first section of
`docs/contributing/MODEL_MANAGER.md`.

## QA Instructions, Screenshots, Recordings

There is a pytest for the model sql storage backend in
`tests/backend/model_manager_2/test_model_storage_sql.py`.

To populate `invokeai.db` with models from your current `models.yaml`,
do the following:

1. Stop the running server
2. Back up `invokeai.db`
3. Run `pip install -e .` to install the command used in the next step.
4. Run `invokeai-migrate-models-to-db`

This will iterate through `models.yaml` and create equivalent database
entries in the `model_config` table of `invokeai.db`. Only the models
named in the yaml file will be migrated, so anything that is autoloaded
will be ignored.

Note that in order to get the `model_records` router to be recognized by
the swagger API, I had to rebuild the frontend. Not sure why this was
necessary and would appreciate a pointer on a less radical way to do
this.

## Added/updated tests?

- [X] Yes
- [ ] No
2023-11-13 18:59:25 -05:00
2f97f1d6d5 Merge branch 'main' into refactor/model-manager-2 2023-11-13 18:21:16 -05:00
73d6cc824b Update Pytorch to ~2.1.0 in the installer script (#5089)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [X] Bug Fix
- [X] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [X] No, because it's required

      
## Have you updated all relevant documentation?
- [ ] Yes
- [X] No, not necessary


## Description

We use Pytorch ~2.1.0 as a dependency for InvokeAI, but the installer
still installs 2.0.1 first until Invoke AIs dependencies kick in which
causes it to get deleted anyway and replaced with 2.1.0. This is
unnecessary and probably not wanted.

Fixed the dependencies for the installation script to install Pytorch
~2.1.0 to begin with.

P.s. Is there any reason why "torchmetrics==0.11.4" is pinned? What is
the reason for that? Does that change with Pytorch 2.1? It seems to work
since we use it already. It would be nice to know the reason.

Greetings

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-11-13 18:20:36 -05:00
acc0a29dca fixed ruff formatting issues 2023-11-13 18:15:17 -05:00
38c1436f02 resolve conflicts; blackify 2023-11-13 18:12:45 -05:00
efbdb75568 implement psychedelicious recommendations as of 13 November 2023-11-13 17:05:01 -05:00
8929495aeb fix(test): remove unused assignment to value 2023-11-14 08:08:23 +11:00
428f0b265f feat(api): add log stmt to update_model_record route 2023-11-14 08:06:35 +11:00
7daee41ad2 fix(api): remove unused ModelsListValidator 2023-11-14 08:01:44 +11:00
7cdd7b6ad7 feat(api): simplifiy list_model_records handler 2023-11-14 08:00:21 +11:00
bc64cde6f9 chore: ruff lint 2023-11-14 07:57:07 +11:00
4465f97cdf Merge branch 'main' into refactor/model-manager-2 2023-11-14 07:51:57 +11:00
fface2cda7 Update torch to ~2.1.0 in the installer 2023-11-13 17:30:51 +01:00
7fcb8959fb chore(ui): cleanup (#5084)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [x] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission

## Description

Bit of a cleanup. 

[chore(ui): delete unused
files](5eaea9dd64)

[feat(ui): add eslint rule
react/jsx-no-bind](3a0ec635c9)

This rule enforces no arrow functions in component props. In practice,
it means all functions passed as component props must be wrapped in
`useCallback()`.

This is a performance optimization to prevent unnecessary rerenders.

The rule is added and all violations have been fixed, whew!

[chore(ui): move useCopyImageToClipboard to
common/hooks/](f2d26a3a3c)

[chore(ui): move MM components & store to
features/](bb52861896)

Somehow they had ended up in `features/ui/tabs` which isn't right

## QA Instructions, Screenshots, Recordings

UI should still work.

It builds successfully, and I tested things out - looks good to me.
2023-11-13 13:22:41 +05:30
dcf0dc4274 Merge branch 'main' into chore/ui/cleanup 2023-11-13 16:33:08 +11:00
bb52861896 chore(ui): move MM components & store to features/
Somehow they had ended up in `features/ui/tabs` which isn't right
2023-11-13 16:32:03 +11:00
f2d26a3a3c chore(ui): move useCopyImageToClipboard to common/hooks/ 2023-11-13 16:23:46 +11:00
04d8f2dfea fix(backend): fix controlnet zip len
Do not use `strict=True` when scaling controlnet conditioning.

When using `guess_mode` (e.g. `more_control` or `more_prompt`), `down_block_res_samples` and `scales` are zipped.

These two objects are of different lengths, so using zip's strict mode raises an error.

In testing, `len(scales) === len(down_block_res_samples) + 1`.

It appears this behaviour is intentional, as the final "extra" item in `scales` is used immediately afterwards.
2023-11-13 15:45:03 +11:00
355d4cf4e2 Update Accelerate to 0.24.X (#5075)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [X] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [X] No, because: This is just housekeeping

      
## Have you updated all relevant documentation?
- [ ] Yes
- [X] No, not needed


## Description

Update Accelerate to the most recent version. No breaking changes.
Tested for 1 week in productive use now.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-11-13 14:20:05 +11:00
a3a828779a Merge branch 'main' into update-accelerate 2023-11-13 14:10:53 +11:00
8c71ff37ae Update config.py
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-11-12 19:03:39 -05:00
ddb65e6034 Merge branch 'main' into chore/ui/cleanup 2023-11-13 10:53:04 +11:00
8366cd2a00 feat: use ruff for lint & format (#5070)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [x] Feature
- [x] Bug Fix
- [x] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission

## Description

This PR introduces [`ruff`](https://github.com/astral-sh/ruff) as the
only linter and formatter needed for the project. It is really fast.
Like, alarmingly fast.

It is a drop-in replacement for flake8, isort, black, and much more.
I've configured it similarly to our existing config.

Note: we had enabled a number of flake8 plugins but didn't have the
packages themselves installed, so they did nothing. Ruff used the
existing config, and found a good number of changes needed to adhere to
those flake8 plugins. I've resolved all violations.

### Code changes

- many
[flake8-comprehensions](https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4)
violations, almost all auto-fixed
- a good handful of
[flake8-bugbear](https://docs.astral.sh/ruff/rules/#flake8-bugbear-b)
violations
- handful of
[pycodestyle](https://docs.astral.sh/ruff/rules/#pycodestyle-e-w)
violations
- some formatting

### Developer Experience

[Ruff integrates with most
editors](https://docs.astral.sh/ruff/integrations/):
- Official VSCode extension
- `ruff-lsp` python package allows it to integrate with any LSP-capable
editor (vim, emacs, etc)
- Can be configured as an external tool in PyCharm

### Github Actions

I've updated the `style-checks` action to use ruff, and deleted the
`pyflakes` action.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Closes #5066 

## QA Instructions, Screenshots, Recordings

Have a poke around, and run the app. There were some logic changes but
it was all pretty straightforward.

~~Not sure how to best test the changed github action.~~ Looks like it
just used the action from this PR, that's kinda unexpected but OK.

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-11-13 10:41:43 +11:00
ab1ec3720a Merge branch 'main' into feat/ruff 2023-11-13 10:32:23 +11:00
3a0ec635c9 feat(ui): add eslint rule react/jsx-no-bind
This rule enforces no arrow functions in component props. In practice, it means all functions passed as component props must be wrapped in `useCallback()`.

This is a performance optimization to prevent unnecessary rerenders.

The rule is added and all violations have been fixed, whew!
2023-11-13 10:01:14 +11:00
8afe517204 add note about discriminated union and Body() issue; blackified 2023-11-12 16:50:05 -05:00
5eaea9dd64 chore(ui): delete unused files 2023-11-13 08:43:27 +11:00
71e298b722 Feat (ui): Add VAE Model to Recall Parameters (#5073)
* adding VAE recall when using all parameters

* adding VAE to the RecallParameters tab in ImageMetadataActions

* checking for nil vae and casting to null if undefined

* adding default VAE to recall actions list if VAE is nullish

* fix(ui): use `lodash-es` for tree-shakeable imports

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-11-12 21:19:12 +00:00
89a039460d feat(ui): add number inputs for canvas brush color picker (#5067)
* drop-down for the color picker

* fixed the bug in alpha value

* designing done

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-11-12 21:07:26 +00:00
a342e64772 Merge branch 'main' into feat/ruff 2023-11-13 07:54:06 +11:00
ef8dcf5fae blackify 2023-11-12 14:20:32 -05:00
90a038c685 translationBot(ui): update translation (Italian)
Currently translated at 97.7% (1200 of 1228 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2023-11-12 20:24:04 +11:00
024a156114 isort 2023-11-11 13:58:36 -05:00
7ea2a135f1 remove dangling import 2023-11-11 12:24:58 -05:00
af2264b6eb implement workaround for FastAPI and discriminated unions in Body parameter 2023-11-11 12:22:38 -05:00
41bf9ec4a3 Update Accelerate to 0.24.X 2023-11-11 09:46:23 +01:00
520ccdb0a9 Merge branch 'main' into feat/ruff 2023-11-11 15:07:35 +11:00
2b36565e9e awkward workaround for double-Annotated in model_record route 2023-11-10 21:32:44 -05:00
f2c3b7c317 Merge branch 'refactor/model-manager-2' of github.com:invoke-ai/InvokeAI into refactor/model-manager-2 2023-11-10 19:47:01 -05:00
67751a01ab remove unused import 2023-11-10 19:25:05 -05:00
cb8cdefd59 Merge branch 'main' into refactor/model-manager-2 2023-11-10 19:24:19 -05:00
f1c846ba5c blackify 2023-11-10 19:14:29 -05:00
3a6ba236f5 replace _class_map in ModelConfigFactory with a nested discriminated union 2023-11-10 19:14:15 -05:00
1c7ea57492 feat (ui, generation): High Resolution Fix- added automatic resolution toggle and replaced latent upscale with two improved methods (#4905)
* working

* added selector for method

* refactoring graph

* added ersgan method

* fixing yarn build

* add tooltips

* a conjuction

* rephrase

* removed manual sliders, set HRF to calculate dimensions automatically to match 512^2 pixels

* working

* working

* working

* fixed tooltip

* add hrf to use all parameters

* adding hrf method to parameters

* working on parameter recall

* working on parameter recall

* cleaning

* fix(ui): fix unnecessary casts in addHrfToGraph

* chore(ui): use camelCase in addHrfToGraph

* fix(ui): do not add HRF metadata unless HRF is added to graph

* fix(ui): remove unused imports in addHrfToGraph

* feat(ui): do not hide HRF params when disabled, only disable them

* fix(ui): remove unused vars in addHrfToGraph

* feat(ui): default HRF str to 0.35, method ESRGAN

* fix(ui): use isValidBoolean to check hrfEnabled param

* fix(nodes): update CoreMetadataInvocation fields for HRF

* feat(ui): set hrf strength default to 0.45

* fix(ui): set default hrf strength in configSlice

* feat(ui): use translations for HRF features

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-11-11 00:11:46 +00:00
6494e8e551 chore: ruff format 2023-11-11 10:55:40 +11:00
513fceac82 chore: ruff check - fix pycodestyle 2023-11-11 10:55:33 +11:00
99a8ebe3a0 chore: ruff check - fix flake8-bugbear 2023-11-11 10:55:28 +11:00
3a136420d5 chore: ruff check - fix flake8-comprensions 2023-11-11 10:55:23 +11:00
bd56e9bc81 remove cruft code from router 2023-11-10 18:49:25 -05:00
43f2398e14 feat: use ruff's github output format for action 2023-11-11 10:42:27 +11:00
d0cf98d7f6 feat: add ruff-lsp to support most editors 2023-11-11 10:42:27 +11:00
8111dd6cc5 feat: remove pyflakes gh action
ruff supersedes it
2023-11-11 10:42:27 +11:00
99e4b87fae feat: use ruff in GH style-checks action 2023-11-11 10:42:27 +11:00
884ec0b5df feat: replace isort, flake8 & black with ruff 2023-11-11 10:42:27 +11:00
9ccfa34e04 Update installer.py to use cu118 instead of 121 (#5069)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [X] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [ ] No


## Description


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-11-11 10:40:47 +11:00
d5aa74623d Merge branch 'main' into Millu-patch-1 2023-11-11 10:39:06 +11:00
d63a614b8b Update Transformers to 4.35 and fix pad_to_multiple_of (#4817)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [X] Bug Fix
- [X] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [X] Yes, with @blessedcoolant 
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [ ] No


## Description

This PR updates Transformers to the most recent version and fixes the
value `pad_to_multiple_of` for `text_encoder.resize_token_embeddings`
which was introduced with
https://github.com/huggingface/transformers/pull/25088 in Transformers
4.32.0.

According to the [Nvidia
Documentation](https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc),
`Performance is better when equivalent matrix dimensions M, N, and K are
aligned to multiples of 8 bytes (or 64 bytes on A100) for FP16`
This fixes the following error that was popping up before every
invocation starting with Transformers 4.32.0
`You are resizing the embedding layer without providing a
pad_to_multiple_of parameter. This means that the new embedding
dimension will be None. This might induce some performance reduction as
Tensor Cores will not be available. For more details about this, or help
on choosing the correct value for resizing, refer to this guide:
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc`

This is my first "real" fix PR, so I hope this is fine. Please inform me
if there is anything wrong with this. I am glad to help.

Have a nice day and thank you!


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue:
https://github.com/huggingface/transformers/issues/26303
- Related Discord discussion:
https://discord.com/channels/1020123559063990373/1154152783579197571
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-11-11 10:38:33 +11:00
cbc905a4d6 Update installer.py to use cu118 instead of 121 2023-11-11 10:36:07 +11:00
b55fc2935e resolve conflicts with commits done on github 2023-11-10 18:26:48 -05:00
0544917161 multiple small fixes suggested in reviews from psychedelicious and ryan 2023-11-10 18:25:37 -05:00
1161dfe055 Update invokeai/app/api/routers/model_records.py
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2023-11-10 18:24:55 -05:00
433f347d7e Update invokeai/app/api/routers/model_records.py
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2023-11-10 18:22:54 -05:00
33a412a24f Update invokeai/backend/model_manager/config.py
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-11-10 18:21:38 -05:00
9316534d97 Update invokeai/app/services/model_records/model_records_sql.py
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-11-10 17:58:15 -05:00
fdaa661245 revert frontend dist files to main 2023-11-10 17:57:18 -05:00
f1c195afb7 Merge branch 'main' into refactor/model-manager-2 2023-11-10 17:54:28 -05:00
6001d3d71d Change pad_to_multiple_of to be 8 for all cases. Add comment about it's temporary status 2023-11-10 17:51:59 -05:00
b9f607be56 Update to 4.35.X 2023-11-10 17:51:59 -05:00
8831d1ee41 Update Documentation 2023-11-10 17:51:59 -05:00
a0be83e370 Update Transformers to 4.34 and fix pad_to_multiple_of 2023-11-10 17:51:59 -05:00
8702a63197 add support for downloading and installing LCM lora diffusers models 2023-11-10 17:51:30 -05:00
d7f0a7919f chore(ui): manually update vite to fix security issue in hoisted dep
`postcss` is a hoisted dependency of `vite`.
2023-11-10 06:58:22 -08:00
356b5a41a9 wip: Add LCMScheduler 2023-11-10 06:54:36 -08:00
e56a6d85a9 Update diffusers to ~=0.23 (#5063)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [X] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [X] Yes
- [ ] No


## Description


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-11-10 12:44:28 +11:00
e22a091d76 Update diffusers to ~=0.23 2023-11-10 11:50:50 +11:00
141d02939a Upstream diffusers PR was merged, this no longer seems necessary (#5060)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [x] No


## Description


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-11-10 11:47:21 +11:00
5cb372e9d0 Merge branch 'main' into remove-deprecated-sdxl-t2i-hack 2023-11-10 11:33:32 +11:00
f95fe68753 chore(ui): manually bump deps with security issues 2023-11-10 09:50:00 +11:00
6d33893844 chore(ui): update all deps 2023-11-10 09:50:00 +11:00
fc53112d8e chore(ui): remove unused deps 2023-11-10 09:50:00 +11:00
41f7aa6ab4 Remove unused import: 2023-11-09 15:06:01 -05:00
9bec755198 Upstream diffusers PR was merged, this no longer seems necessary 2023-11-09 15:02:24 -05:00
2570497d83 fix(installer): fix import of ValidationError
It was being imported from a deprecated module
2023-11-10 06:11:15 +11:00
5d735a714d translationBot(ui): update translation (Chinese (Simplified))
Currently translated at 100.0% (1219 of 1219 strings)

Co-authored-by: Surisen <zhonghx0804@outlook.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/zh_Hans/
Translation: InvokeAI/Web UI
2023-11-09 10:54:56 -08:00
6aa87f973e fix(nodes): create app/shared/ module to prevent circular imports
We have a number of shared classes, objects, and functions that are used in multiple places. This causes circular import issues.

This commit creates a new `app/shared/` module to hold these shared classes, objects, and functions.

Initially, only `FreeUConfig` and `FieldDescriptions` are moved here. This resolves a circular import issue with custom nodes.

Other shared classes, objects, and functions will be moved here in future commits.
2023-11-09 16:41:55 +11:00
f793fdf3d4 fix(socketio): leave room on unsubscribe
https://discord.com/channels/1020123559063990373/1049495067846524939/1171976251704086559
2023-11-09 12:12:32 +11:00
3b363d0258 fix flake8 lint check failures 2023-11-08 16:52:46 -05:00
36e0faea6b blackify 2023-11-08 16:47:03 -05:00
927f8a66e6 Merge branch 'main' into refactor/model-manager-2 2023-11-08 16:46:08 -05:00
eebc0e7315 Merge branch 'refactor/model-manager-2' of github.com:invoke-ai/InvokeAI into refactor/model-manager-2 2023-11-08 16:45:29 -05:00
6b173cc66f multiple small stylistic changes requested by reviewers 2023-11-08 16:45:26 -05:00
b4732a7308 Update invokeai/app/services/model_records/model_records_base.py
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2023-11-08 13:50:40 -05:00
344a56327a Update invokeai/app/services/model_records/model_records_base.py
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
2023-11-08 13:50:01 -05:00
2e404b7cca Fix updater option list numbering
Fix updater option list numbering in invokeai_update.py so that they don't go 1, 2, 2, 3.  The options themselves work fine.
2023-11-07 19:11:25 -08:00
a760bdae9f (fix) update freeU config to be compatible with custom nodes (#5050)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [X] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [X] Yes: @psychedelicious told me to do this :) 
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [ ] No


## Description


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-11-06 21:43:18 -08:00
4cfd55936c run black formatting 2023-11-07 16:06:18 +11:00
5c3a27aac6 fixed sorts 2023-11-07 16:03:06 +11:00
d573a23090 Moved FreeU Config Import 2023-11-07 15:48:53 +11:00
351abd2ca2 Merge branch 'invoke-ai:main' into main 2023-11-07 15:43:04 +11:00
9733cd4199 Update xformers to 0.0.22post7 2023-11-06 17:17:03 -08:00
9976bc6908 Update installer.py to cu121 2023-11-06 17:17:03 -08:00
c68db6e40f Update xformers to ~0.0.22 2023-11-06 17:17:03 -08:00
3a50798a52 Update xformers to 0.0.22post7 2023-11-07 12:00:39 +11:00
a98426d2c6 Update installer.py to cu121 2023-11-07 11:57:02 +11:00
504f426f0a Update xformers to ~0.0.22 2023-11-07 11:53:39 +11:00
840cbc1d39 xformers==0.0.20 (#4881)
I'm not sure if it's correct way of handling things, but correcting this
string to '==0.0.20' fixes xformers install for me - and maybe for
others it will too. Sorry for absolutely incorrect PR.

Please see [this
thread](https://github.com/facebookresearch/xformers/issues/740), this
is the issue I had (trying to install InvokeAI with
Automatic/Manual/StableMatrix way).

With ~=0.0.19 (0.0.22):
```
(InvokeAI) pip install torch torchvision xformers~=0.0.19
Collecting torch
  Obtaining dependency information for torch from edce54779f/torch-2.1.0-cp311-cp311-win_amd64.whl.metadata
  Using cached torch-2.1.0-cp311-cp311-win_amd64.whl.metadata (25 kB)
Collecting torchvision
  Obtaining dependency information for torchvision from ab6f42af83/torchvision-0.16.0-cp311-cp311-win_amd64.whl.metadata
  Using cached torchvision-0.16.0-cp311-cp311-win_amd64.whl.metadata (6.6 kB)
Collecting xformers
  Using cached xformers-0.0.22.post3.tar.gz (3.9 MB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... error
  error: subprocess-exited-with-error

  × Getting requirements to build wheel did not run successfully.
  │ exit code: 1
  ╰─> [20 lines of output]
      Traceback (most recent call last):
        File "C:\Users\Drun\invokeai\.venv\Lib\site-packages\pip\_vendor\pyproject_hooks\_in_process\_in_process.py", line 353, in <module>
          main()
        File "C:\Users\Drun\invokeai\.venv\Lib\site-packages\pip\_vendor\pyproject_hooks\_in_process\_in_process.py", line 335, in main
          json_out['return_val'] = hook(**hook_input['kwargs'])
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "C:\Users\Drun\invokeai\.venv\Lib\site-packages\pip\_vendor\pyproject_hooks\_in_process\_in_process.py", line 118, in get_requires_for_build_wheel
          return hook(config_settings)
                 ^^^^^^^^^^^^^^^^^^^^^
        File "C:\Users\Drun\AppData\Local\Temp\pip-build-env-rmhvraqj\overlay\Lib\site-packages\setuptools\build_meta.py", line 355, in get_requires_for_build_wheel
          return self._get_build_requires(config_settings, requirements=['wheel'])
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "C:\Users\Drun\AppData\Local\Temp\pip-build-env-rmhvraqj\overlay\Lib\site-packages\setuptools\build_meta.py", line 325, in _get_build_requires
          self.run_setup()
        File "C:\Users\Drun\AppData\Local\Temp\pip-build-env-rmhvraqj\overlay\Lib\site-packages\setuptools\build_meta.py", line 507, in run_setup
          super(_BuildMetaLegacyBackend, self).run_setup(setup_script=setup_script)
        File "C:\Users\Drun\AppData\Local\Temp\pip-build-env-rmhvraqj\overlay\Lib\site-packages\setuptools\build_meta.py", line 341, in run_setup
          exec(code, locals())
        File "<string>", line 23, in <module>
      ModuleNotFoundError: No module named 'torch'
```

With 0.0.20:

```
(InvokeAI) pip install torch torchvision xformers==0.0.20
Collecting torch
  Obtaining dependency information for torch from edce54779f/torch-2.1.0-cp311-cp311-win_amd64.whl.metadata
  Using cached torch-2.1.0-cp311-cp311-win_amd64.whl.metadata (25 kB)
Collecting torchvision
  Obtaining dependency information for torchvision from ab6f42af83/torchvision-0.16.0-cp311-cp311-win_amd64.whl.metadata
  Using cached torchvision-0.16.0-cp311-cp311-win_amd64.whl.metadata (6.6 kB)
Collecting xformers==0.0.20
  Obtaining dependency information for xformers==0.0.20 from d4a42f582a/xformers-0.0.20-cp311-cp311-win_amd64.whl.metadata
  Using cached xformers-0.0.20-cp311-cp311-win_amd64.whl.metadata (1.1 kB)
Collecting numpy (from xformers==0.0.20)
  Obtaining dependency information for numpy from 3f826c6d15/numpy-1.26.0-cp311-cp311-win_amd64.whl.metadata
  Using cached numpy-1.26.0-cp311-cp311-win_amd64.whl.metadata (61 kB)
Collecting pyre-extensions==0.0.29 (from xformers==0.0.20)
  Using cached pyre_extensions-0.0.29-py3-none-any.whl (12 kB)
Collecting torch
  Using cached torch-2.0.1-cp311-cp311-win_amd64.whl (172.3 MB)
Collecting filelock (from torch)
  Obtaining dependency information for filelock from 97afbafd9d/filelock-3.12.4-py3-none-any.whl.metadata
  Using cached filelock-3.12.4-py3-none-any.whl.metadata (2.8 kB)
Requirement already satisfied: typing-extensions in c:\users\drun\invokeai\.venv\lib\site-packages (from torch) (4.8.0)
Requirement already satisfied: sympy in c:\users\drun\invokeai\.venv\lib\site-packages (from torch) (1.12)
Collecting networkx (from torch)
  Using cached networkx-3.1-py3-none-any.whl (2.1 MB)
Collecting jinja2 (from torch)
  Using cached Jinja2-3.1.2-py3-none-any.whl (133 kB)
Collecting typing-inspect (from pyre-extensions==0.0.29->xformers==0.0.20)
  Obtaining dependency information for typing-inspect from 107a22063b/typing_inspect-0.9.0-py3-none-any.whl.metadata
  Using cached typing_inspect-0.9.0-py3-none-any.whl.metadata (1.5 kB)
Collecting requests (from torchvision)
  Obtaining dependency information for requests from 0e2d847013/requests-2.31.0-py3-none-any.whl.metadata
  Using cached requests-2.31.0-py3-none-any.whl.metadata (4.6 kB)
INFO: pip is looking at multiple versions of torchvision to determine which version is compatible with other requirements. This could take a while.
Collecting torchvision
  Using cached torchvision-0.15.2-cp311-cp311-win_amd64.whl (1.2 MB)
Collecting pillow!=8.3.*,>=5.3.0 (from torchvision)
  Obtaining dependency information for pillow!=8.3.*,>=5.3.0 from debe992677/Pillow-10.0.1-cp311-cp311-win_amd64.whl.metadata
  Using cached Pillow-10.0.1-cp311-cp311-win_amd64.whl.metadata (9.6 kB)
Collecting MarkupSafe>=2.0 (from jinja2->torch)
  Obtaining dependency information for MarkupSafe>=2.0 from 08b85bc194/MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl.metadata
  Using cached MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl.metadata (3.1 kB)
Collecting charset-normalizer<4,>=2 (from requests->torchvision)
  Obtaining dependency information for charset-normalizer<4,>=2 from 50028bbb26/charset_normalizer-3.3.0-cp311-cp311-win_amd64.whl.metadata
  Using cached charset_normalizer-3.3.0-cp311-cp311-win_amd64.whl.metadata (33 kB)
Collecting idna<4,>=2.5 (from requests->torchvision)
  Using cached idna-3.4-py3-none-any.whl (61 kB)
Collecting urllib3<3,>=1.21.1 (from requests->torchvision)
  Obtaining dependency information for urllib3<3,>=1.21.1 from 9957270221/urllib3-2.0.6-py3-none-any.whl.metadata
  Using cached urllib3-2.0.6-py3-none-any.whl.metadata (6.6 kB)
Collecting certifi>=2017.4.17 (from requests->torchvision)
  Obtaining dependency information for certifi>=2017.4.17 from 2234eab223/certifi-2023.7.22-py3-none-any.whl.metadata
  Using cached certifi-2023.7.22-py3-none-any.whl.metadata (2.2 kB)
Requirement already satisfied: mpmath>=0.19 in c:\users\drun\invokeai\.venv\lib\site-packages (from sympy->torch) (1.3.0)
Collecting mypy-extensions>=0.3.0 (from typing-inspect->pyre-extensions==0.0.29->xformers==0.0.20)
  Using cached mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)
Using cached xformers-0.0.20-cp311-cp311-win_amd64.whl (97.6 MB)
Using cached Pillow-10.0.1-cp311-cp311-win_amd64.whl (2.5 MB)
Using cached filelock-3.12.4-py3-none-any.whl (11 kB)
Using cached numpy-1.26.0-cp311-cp311-win_amd64.whl (15.8 MB)
Using cached requests-2.31.0-py3-none-any.whl (62 kB)
Using cached certifi-2023.7.22-py3-none-any.whl (158 kB)
Using cached charset_normalizer-3.3.0-cp311-cp311-win_amd64.whl (97 kB)
Using cached MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl (17 kB)
Using cached urllib3-2.0.6-py3-none-any.whl (123 kB)
Using cached typing_inspect-0.9.0-py3-none-any.whl (8.8 kB)
Installing collected packages: urllib3, pillow, numpy, networkx, mypy-extensions, MarkupSafe, idna, filelock, charset-normalizer, certifi, typing-inspect, requests, jinja2, torch, pyre-extensions, xformers, torchvision
Successfully installed MarkupSafe-2.1.3 certifi-2023.7.22 charset-normalizer-3.3.0 filelock-3.12.4 idna-3.4 jinja2-3.1.2 mypy-extensions-1.0.0 networkx-3.1 numpy-1.26.0 pillow-10.0.1 pyre-extensions-0.0.29 requests-2.31.0 torch-2.0.1 torchvision-0.15.2 typing-inspect-0.9.0 urllib3-2.0.6 xformers-0.0.20
```

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [x] No, because: I'm no-brainer. It fixed issue for me, so I did PR.
Who knows?

## Technical details:
Windows 11, Standalone clean and freshly-installed Python 3.11
2023-11-06 16:07:03 -08:00
014d6187ab Update pyproject.toml 2023-11-07 10:22:20 +11:00
9fb15fae87 Update pyproject.toml 2023-11-07 10:20:16 +11:00
a07336a020 Merge branch 'main' into patch-1 2023-11-07 10:17:46 +11:00
0718cc2392 Update xformers to 0.0.21 2023-11-07 10:16:44 +11:00
ce22c0fbaa sync pydantic and sql field names; merge routes 2023-11-06 18:08:57 -05:00
935e4632c2 feat(nodes): add freeu support (#4846)
### What type of PR is this? (check all applicable)

- [ ] Refactor
- [x] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [x] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [x] No


## Description

**Note: FreeU is not in the current release of diffusers. Looks like it
will be in release 0.22. This PR needs to wait until that is released.**

[feat(nodes): add freeu
support](15b33ad501)

Add support for FreeU. See:
- https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu
- https://github.com/ChenyangSi/FreeU

Implementation:
- `ModelPatcher.apply_freeu()` handles the enabling freeu (which is very
simple with diffusers).
- `FreeUConfig` model added to hold the hyperparameters.
- `freeu_config` added as optional sub-field on `UNetField`.
- `FreeUInvocation` added, works like LoRA - chain it to add the FreeU
config to the UNet
- No support for model-dependent presets, this will be a future workflow
editor enhancement

Closes https://github.com/invoke-ai/InvokeAI/issues/4845

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Closes #4845 

## QA Instructions, Screenshots, Recordings

You'll need to install diffusers from their github repo before testing
this:
`pip install git+https://github.com/huggingface/diffusers`

1. Create a graph like this:

![image](https://github.com/invoke-ai/InvokeAI/assets/4822129/af17719b-b001-4534-8c4e-883484fd7465)
2. Get a free lunch!

No FreeU:

![image](https://github.com/invoke-ai/InvokeAI/assets/4822129/279d1a69-1577-4c31-ab82-ebf67f65920d)
With FreeU:

![image](https://github.com/invoke-ai/InvokeAI/assets/4822129/c332c778-0b87-4215-8a36-d4822e06f4de)

No FreeU:

![image](https://github.com/invoke-ai/InvokeAI/assets/4822129/ebec097b-ad54-4295-b734-33656738a2cf)
With FreeU:

![image](https://github.com/invoke-ai/InvokeAI/assets/4822129/3423140d-c9ce-4697-9993-d2bb0d0f5634)

No FreeU:

![image](https://github.com/invoke-ai/InvokeAI/assets/4822129/7cb0e39d-aa87-4a48-a3af-b9f47a866814)
With FreeU:

![image](https://github.com/invoke-ai/InvokeAI/assets/4822129/9113d2fe-5bd3-474f-8f33-82cdeb7abf82)
2023-11-06 13:58:32 -08:00
a83d8810c4 Merge branch 'main' into feat/nodes/freeu 2023-11-06 13:47:56 -08:00
76b3f8956b Fix ROCm support in Docker container 2023-11-06 13:47:08 -08:00
ff8a8a1963 Merge branch 'main' into feat/nodes/freeu 2023-11-06 09:04:54 -08:00
cb6d0c8851 Re-add feat/mix cnet t2iadapter (#4929)
Reverts invoke-ai/InvokeAI#4923, which was a revert on the premature
merge.

slide to the left. revert, revert.
2023-11-06 22:29:13 +05:30
67f2616d5a Merge branch 'main' into revert-4923-revert-4914-feat/mix-cnet-t2iadapter 2023-11-06 07:34:51 -08:00
f8f1740668 Set Defaults to 1 2023-11-06 07:11:16 -08:00
e66d0f7372 Merge branch 'main' into feat/nodes/freeu 2023-11-06 05:39:58 -08:00
546aaedbe4 Update pyproject.toml 2023-11-06 05:29:17 -08:00
55f8865524 Merge branch 'main' into refactor/model-manager-2 2023-11-05 21:45:26 -05:00
2d051559d1 fix flake8 complaints 2023-11-05 21:45:08 -05:00
7f650d00de translationBot(ui): update translation (Italian)
Currently translated at 97.7% (1191 of 1219 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2023-11-05 11:12:33 -08:00
db9cef0092 re-run isort 2023-11-04 23:50:07 -04:00
72c34aea75 added add_model_record and get_model_record to router api 2023-11-04 23:42:44 -04:00
edeea5237b add sql-based model config store and api 2023-11-04 23:03:26 -04:00
4e6b579526 translationBot(ui): update translation (Italian)
Currently translated at 97.6% (1190 of 1219 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2023-11-05 12:09:20 +11:00
6334c4adf5 translationBot(ui): update translation (German)
Currently translated at 53.8% (657 of 1219 strings)

Co-authored-by: Alexander Eichhorn <pfannkuchensack@einfach-doof.de>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2023-11-05 12:09:20 +11:00
66b2366efc Remove LowRA from Initial Models (#5016)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [X] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [X] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [X] No


## Description
Removing LowRA from the initial models as it's been deleted from
CivitAI.

## Related Tickets & Documents

https://discord.com/channels/1020123559063990373/1168415065205112872


- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-11-04 17:14:24 -04:00
e147379aa7 Merge branch 'main' into main 2023-11-04 17:05:01 -04:00
5a821384d3 fix model-not-found error 2023-11-04 08:24:01 -07:00
584b513038 Remove LowRA from Initial Models 2023-11-01 08:55:06 +11:00
55ad4feb5c Revert "Revert "feat(ui): remove special handling for t2i vs controlnet""
This reverts commit bdf4c4944c.
2023-10-17 11:59:19 -04:00
b7555ddae8 Revert "Revert "chore: lint""
This reverts commit 38e7eb8878.
2023-10-17 11:59:19 -04:00
8afc47018b Revert "Revert "Cleaning up (removing diagnostic prints)""
This reverts commit 6e697b7b6f.
2023-10-17 11:59:19 -04:00
a97ec88e06 Revert "Revert "Changes to _apply_standard_conditioning_sequentially() and _apply_cross_attention_controlled_conditioning() to reflect changes to T2I-Adapter implementation to allow usage of T2I-Adapter and ControlNet at the same time.""
This reverts commit c04fb451ee.
2023-10-17 11:59:19 -04:00
282d36b640 Revert "Revert "Fixing some var and arg names.""
This reverts commit 58a0709c1e.
2023-10-17 11:59:19 -04:00
14e25bf277 Merge branch 'main' into feat/nodes/freeu 2023-10-17 16:42:59 +11:00
001bba1719 Merge branch 'main' into feat/nodes/freeu 2023-10-17 15:58:00 +11:00
9db152bf75 xformers==0.0.20
I'm not sure if it's correct way of handling things, but correcting this string to '==0.0.20' fixes xformers install for me - and maybe for others too. 

Please see this thread, this is the issue I had (trying to install InvokeAI):
https://github.com/facebookresearch/xformers/issues/740
2023-10-14 14:59:55 +04:00
15b33ad501 feat(nodes): add freeu support
Add support for FreeU. See:
- https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu
- https://github.com/ChenyangSi/FreeU

Implementation:
- `ModelPatcher.apply_freeu()` handles the enabling freeu (which is very simple with diffusers).
- `FreeUConfig` model added to hold the hyperparameters.
- `freeu_config` added as optional sub-field on `UNetField`.
- `FreeUInvocation` added, works like LoRA - chain it to add the FreeU config to the UNet
- No support for model-dependent presets, this will be a future workflow editor enhancement

Closes #4845
2023-10-11 13:49:28 +11:00
330 changed files with 8275 additions and 6305 deletions

View File

@ -1,20 +0,0 @@
on:
pull_request:
push:
branches:
- main
- development
- 'release-candidate-*'
jobs:
pyflakes:
name: runner / pyflakes
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: pyflakes
uses: reviewdog/action-pyflakes@v1
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
reporter: github-pr-review

View File

@ -18,8 +18,7 @@ jobs:
- name: Install dependencies with pip
run: |
pip install black flake8 Flake8-pyproject isort
pip install ruff
- run: isort --check-only .
- run: black --check .
- run: flake8
- run: ruff check --output-format=github .
- run: ruff format --check .

View File

@ -161,7 +161,7 @@ the command `npm install -g yarn` if needed)
_For Windows/Linux with an NVIDIA GPU:_
```terminal
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
```
_For Linux with an AMD GPU:_
@ -175,7 +175,7 @@ the command `npm install -g yarn` if needed)
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
```
_For Macintoshes, either Intel or M1/M2:_
_For Macintoshes, either Intel or M1/M2/M3:_
```sh
pip install InvokeAI --use-pep517

View File

@ -11,5 +11,5 @@ INVOKEAI_ROOT=
# HUGGING_FACE_HUB_TOKEN=
## optional variables specific to the docker setup.
# GPU_DRIVER=cuda
# CONTAINER_UID=1000
# GPU_DRIVER=cuda # or rocm
# CONTAINER_UID=1000

View File

@ -18,8 +18,8 @@ ENV INVOKEAI_SRC=/opt/invokeai
ENV VIRTUAL_ENV=/opt/venv/invokeai
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
ARG TORCH_VERSION=2.0.1
ARG TORCHVISION_VERSION=0.15.2
ARG TORCH_VERSION=2.1.0
ARG TORCHVISION_VERSION=0.16
ARG GPU_DRIVER=cuda
ARG TARGETPLATFORM="linux/amd64"
# unused but available
@ -35,7 +35,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm5.4.2"; \
extra_index_url_arg="--index-url https://download.pytorch.org/whl/rocm5.6"; \
else \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu121"; \
fi &&\

View File

@ -15,6 +15,10 @@ services:
- driver: nvidia
count: 1
capabilities: [gpu]
# For AMD support, comment out the deploy section above and uncomment the devices section below:
#devices:
# - /dev/kfd:/dev/kfd
# - /dev/dri:/dev/dri
build:
context: ..
dockerfile: docker/Dockerfile

View File

@ -7,5 +7,5 @@ set -e
SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}")
cd "$SCRIPTDIR" || exit 1
docker compose up --build -d
docker compose up -d
docker compose logs -f

File diff suppressed because it is too large Load Diff

View File

@ -198,6 +198,7 @@ The list of schedulers has been completely revamped and brought up to date:
| **dpmpp_2m** | DPMSolverMultistepScheduler | original noise scnedule |
| **dpmpp_2m_k** | DPMSolverMultistepScheduler | using karras noise schedule |
| **unipc** | UniPCMultistepScheduler | CPU only |
| **lcm** | LCMScheduler | |
Please see [3.0.0 Release Notes](https://github.com/invoke-ai/InvokeAI/releases/tag/v3.0.0) for further details.

View File

@ -179,7 +179,7 @@ experimental versions later.
you will have the choice of CUDA (NVidia cards), ROCm (AMD cards),
or CPU (no graphics acceleration). On Windows, you'll have the
choice of CUDA vs CPU, and on Macs you'll be offered CPU only. When
you select CPU on M1 or M2 Macintoshes, you will get MPS-based
you select CPU on M1/M2/M3 Macintoshes, you will get MPS-based
graphics acceleration without installing additional drivers. If you
are unsure what GPU you are using, you can ask the installer to
guess.
@ -471,7 +471,7 @@ Then type the following commands:
=== "NVIDIA System"
```bash
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu118
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu121
pip install xformers
```

View File

@ -148,7 +148,7 @@ manager, please follow these steps:
=== "CUDA (NVidia)"
```bash
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
```
=== "ROCm (AMD)"
@ -327,7 +327,7 @@ installation protocol (important!)
=== "CUDA (NVidia)"
```bash
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
```
=== "ROCm (AMD)"
@ -375,7 +375,7 @@ you can do so using this unsupported recipe:
mkdir ~/invokeai
conda create -n invokeai python=3.10
conda activate invokeai
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
invokeai-configure --root ~/invokeai
invokeai --root ~/invokeai --web
```

View File

@ -85,7 +85,7 @@ You can find which version you should download from [this link](https://docs.nvi
When installing torch and torchvision manually with `pip`, remember to provide
the argument `--extra-index-url
https://download.pytorch.org/whl/cu118` as described in the [Manual
https://download.pytorch.org/whl/cu121` as described in the [Manual
Installation Guide](020_INSTALL_MANUAL.md).
## :simple-amd: ROCm

View File

@ -30,7 +30,7 @@ methodology for details on why running applications in such a stateless fashion
The container is configured for CUDA by default, but can be built to support AMD GPUs
by setting the `GPU_DRIVER=rocm` environment variable at Docker image build time.
Developers on Apple silicon (M1/M2): You
Developers on Apple silicon (M1/M2/M3): You
[can't access your GPU cores from Docker containers](https://github.com/pytorch/pytorch/issues/81224)
and performance is reduced compared with running it directly on macOS but for
development purposes it's fine. Once you're done with development tasks on your

View File

@ -28,7 +28,7 @@ command line, then just be sure to activate it's virtual environment.
Then run the following three commands:
```sh
pip install xformers~=0.0.19
pip install xformers~=0.0.22
pip install triton # WON'T WORK ON WINDOWS
python -m xformers.info output
```
@ -42,7 +42,7 @@ If all goes well, you'll see a report like the
following:
```sh
xFormers 0.0.20
xFormers 0.0.22
memory_efficient_attention.cutlassF: available
memory_efficient_attention.cutlassB: available
memory_efficient_attention.flshattF: available
@ -59,14 +59,14 @@ swiglu.gemm_fused_operand_sum: available
swiglu.fused.p.cpp: available
is_triton_available: True
is_functorch_available: False
pytorch.version: 2.0.1+cu118
pytorch.version: 2.1.0+cu121
pytorch.cuda: available
gpu.compute_capability: 8.9
gpu.name: NVIDIA GeForce RTX 4070
build.info: available
build.cuda_version: 1108
build.python_version: 3.10.11
build.torch_version: 2.0.1+cu118
build.torch_version: 2.1.0+cu121
build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6
build.env.XFORMERS_BUILD_TYPE: Release
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None
@ -92,33 +92,22 @@ installed from source. These instructions were written for a system
running Ubuntu 22.04, but other Linux distributions should be able to
adapt this recipe.
#### 1. Install CUDA Toolkit 11.8
#### 1. Install CUDA Toolkit 12.1
You will need the CUDA developer's toolkit in order to compile and
install xFormers. **Do not try to install Ubuntu's nvidia-cuda-toolkit
package.** It is out of date and will cause conflicts among the NVIDIA
driver and binaries. Instead install the CUDA Toolkit package provided
by NVIDIA itself. Go to [CUDA Toolkit 11.8
Downloads](https://developer.nvidia.com/cuda-11-8-0-download-archive)
by NVIDIA itself. Go to [CUDA Toolkit 12.1
Downloads](https://developer.nvidia.com/cuda-12-1-0-download-archive)
and use the target selection wizard to choose your platform and Linux
distribution. Select an installer type of "runfile (local)" at the
last step.
This will provide you with a recipe for downloading and running a
install shell script that will install the toolkit and drivers. For
example, the install script recipe for Ubuntu 22.04 running on a
x86_64 system is:
install shell script that will install the toolkit and drivers.
```
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
sudo sh cuda_11.8.0_520.61.05_linux.run
```
Rather than cut-and-paste this example, We recommend that you walk
through the toolkit wizard in order to get the most up to date
installer for your system.
#### 2. Confirm/Install pyTorch 2.01 with CUDA 11.8 support
#### 2. Confirm/Install pyTorch 2.1.0 with CUDA 12.1 support
If you are using InvokeAI 3.0.2 or higher, these will already be
installed. If not, you can check whether you have the needed libraries
@ -133,7 +122,7 @@ Then run the command:
python -c 'exec("import torch\nprint(torch.__version__)")'
```
If it prints __1.13.1+cu118__ you're good. If not, you can install the
If it prints __2.1.0+cu121__ you're good. If not, you can install the
most up to date libraries with this command:
```sh

View File

@ -244,7 +244,7 @@ class InvokeAiInstance:
"numpy~=1.24.0", # choose versions that won't be uninstalled during phase 2
"urllib3~=1.26.0",
"requests~=2.28.0",
"torch~=2.0.0",
"torch~=2.1.0",
"torchmetrics==0.11.4",
"torchvision>=0.14.1",
"--force-reinstall",
@ -460,10 +460,10 @@ def get_torch_source() -> (Union[str, None], str):
url = "https://download.pytorch.org/whl/cpu"
if device == "cuda":
url = "https://download.pytorch.org/whl/cu118"
url = "https://download.pytorch.org/whl/cu121"
optional_modules = "[xformers,onnx-cuda]"
if device == "cuda_and_dml":
url = "https://download.pytorch.org/whl/cu118"
url = "https://download.pytorch.org/whl/cu121"
optional_modules = "[xformers,onnx-directml]"
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@ -137,7 +137,7 @@ def dest_path(dest=None) -> Path:
path_completer = PathCompleter(
only_directories=True,
expanduser=True,
get_paths=lambda: [browse_start],
get_paths=lambda: [browse_start], # noqa: B023
# get_paths=lambda: [".."].extend(list(browse_start.iterdir()))
)
@ -149,7 +149,7 @@ def dest_path(dest=None) -> Path:
completer=path_completer,
default=str(browse_start) + os.sep,
vi_mode=True,
complete_while_typing=True
complete_while_typing=True,
# Test that this is not needed on Windows
# complete_style=CompleteStyle.READLINE_LIKE,
)

View File

@ -24,6 +24,7 @@ from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
@ -85,6 +86,7 @@ class ApiDependencies:
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()
@ -111,6 +113,7 @@ class ApiDependencies:
latents=latents,
logger=logger,
model_manager=model_manager,
model_records=model_record_service,
names=names,
performance_statistics=performance_statistics,
processor=processor,

View File

@ -28,7 +28,7 @@ class FastAPIEventService(EventServiceBase):
self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None:
self.__queue.put(dict(event_name=event_name, payload=payload))
self.__queue.put({"event_name": event_name, "payload": payload})
async def __dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""

View File

@ -0,0 +1,164 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records."""
from hashlib import sha1
from random import randbytes
from typing import List, Optional
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
UnknownModelException,
)
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelType,
)
from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
class ModelsList(BaseModel):
"""Return list of configs."""
models: list[AnyModelConfig]
model_config = ConfigDict(use_enum_values=True)
@model_records_router.get(
"/",
operation_id="list_model_records",
)
async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
found_models: list[AnyModelConfig] = []
if base_models:
for base_model in base_models:
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
else:
found_models.extend(record_store.search_by_attr(model_type=model_type))
return ModelsList(models=found_models)
@model_records_router.get(
"/i/{key}",
operation_id="get_model_record",
responses={
200: {"description": "Success"},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
},
)
async def get_model_record(
key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig:
"""Get a model record"""
record_store = ApiDependencies.invoker.services.model_records
try:
return record_store.get_model(key)
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.patch(
"/i/{key}",
operation_id="update_model_record",
responses={
200: {"description": "The model was updated successfully"},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
409: {"description": "There is already a model corresponding to the new name"},
},
status_code=200,
response_model=AnyModelConfig,
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
try:
model_response = record_store.update_model(key, config=info)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return model_response
@model_records_router.delete(
"/i/{key}",
operation_id="del_model_record",
responses={
204: {"description": "Model deleted successfully"},
404: {"description": "Model not found"},
},
status_code=204,
)
async def del_model_record(
key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response:
"""Delete Model"""
logger = ApiDependencies.invoker.services.logger
try:
record_store = ApiDependencies.invoker.services.model_records
record_store.del_model(key)
logger.info(f"Deleted model: {key}")
return Response(status_code=204)
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.post(
"/i/",
operation_id="add_model_record",
responses={
201: {"description": "The model added successfully"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"},
},
status_code=201,
)
async def add_model_record(
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
) -> AnyModelConfig:
"""
Add a model using the configuration information appropriate for its type.
"""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
if config.key == "<NOKEY>":
config.key = sha1(randbytes(100)).hexdigest()
logger.info(f"Created model {config.key} for {config.name}")
try:
record_store.add_model(config.key, config)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
# now fetch it out
return record_store.get_model(config.key)

View File

@ -1,6 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
import pathlib
from typing import Annotated, List, Literal, Optional, Union
@ -55,7 +54,7 @@ async def list_models(
) -> ModelsList:
"""Gets a list of models"""
if base_models and len(base_models) > 0:
models_raw = list()
models_raw = []
for base_model in base_models:
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
else:

View File

@ -34,4 +34,4 @@ class SocketIO:
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
if "queue_id" in data:
await self.__sio.enter_room(sid, data["queue_id"])
await self.__sio.leave_room(sid, data["queue_id"])

View File

@ -43,6 +43,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
board_images,
boards,
images,
model_records,
models,
session_queue,
sessions,
@ -106,6 +107,7 @@ app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(model_records.model_records_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
@ -130,7 +132,7 @@ def custom_openapi() -> dict[str, Any]:
# Add all outputs
all_invocations = BaseInvocation.get_invocations()
output_types = set()
output_type_titles = dict()
output_type_titles = {}
for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)
@ -171,12 +173,12 @@ def custom_openapi() -> dict[str, Any]:
# print(f"Config with name {name} already defined")
continue
openapi_schema["components"]["schemas"][name] = dict(
title=name,
description="An enumeration.",
type="string",
enum=list(v.value for v in model_config_format_enum),
)
openapi_schema["components"]["schemas"][name] = {
"title": name,
"description": "An enumeration.",
"type": "string",
"enum": [v.value for v in model_config_format_enum],
}
app.openapi_schema = openapi_schema
return app.openapi_schema

View File

@ -25,4 +25,4 @@ spec.loader.exec_module(module)
# add core nodes to __all__
python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py"))
__all__ = list(f.stem for f in python_files) # type: ignore
__all__ = [f.stem for f in python_files] # type: ignore

View File

@ -16,6 +16,7 @@ from pydantic.fields import FieldInfo, _Unset
from pydantic_core import PydanticUndefined
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.util.misc import uuid_string
if TYPE_CHECKING:
@ -30,70 +31,6 @@ class InvalidFieldError(TypeError):
pass
class FieldDescriptions:
denoising_start = "When to start denoising, expressed a percentage of total steps"
denoising_end = "When to stop denoising, expressed a percentage of total steps"
cfg_scale = "Classifier-Free Guidance scale"
scheduler = "Scheduler to use during inference"
positive_cond = "Positive conditioning tensor"
negative_cond = "Negative conditioning tensor"
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
unet = "UNet (scheduler, LoRAs)"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
raw_prompt = "Raw prompt text (no parsing)"
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
skipped_layers = "Number of layers to skip in text encoder"
seed = "Seed for random number generation"
steps = "Number of steps to run"
width = "Width of output (px)"
height = "Height of output (px)"
control = "ControlNet(s) to apply"
ip_adapter = "IP-Adapter to apply"
t2i_adapter = "T2I-Adapter(s) to apply"
denoised_latents = "Denoised latents tensor"
latents = "Latents tensor"
strength = "Strength of denoising (proportional to steps)"
metadata = "Optional metadata to be saved with the image"
metadata_collection = "Collection of Metadata"
metadata_item_polymorphic = "A single metadata item or collection of metadata items"
metadata_item_label = "Label for this metadata item"
metadata_item_value = "The value for this metadata item (may be any type)"
workflow = "Optional workflow to be saved with the image"
interp_mode = "Interpolation mode"
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
fp32 = "Whether or not to use full float32 precision"
precision = "Precision to use"
tiled = "Processing using overlapping tiles (reduce memory consumption)"
detect_res = "Pixel resolution for detection"
image_res = "Pixel resolution for output image"
safe_mode = "Whether or not to use safe mode"
scribble_mode = "Whether or not to use scribble mode"
scale_factor = "The factor by which to scale"
blend_alpha = (
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
)
num_1 = "The first number"
num_2 = "The second number"
mask = "The mask to use for the operation"
board = "The board to save the image to"
image = "The image to process"
tile_size = "Tile size"
inclusive_low = "The inclusive low value"
exclusive_high = "The exclusive high value"
decimal_places = "The number of decimal places to round to"
class Input(str, Enum):
"""
The type of input a field accepts.
@ -299,35 +236,35 @@ def InputField(
Ignored for non-collection fields.
"""
json_schema_extra_: dict[str, Any] = dict(
input=input,
ui_type=ui_type,
ui_component=ui_component,
ui_hidden=ui_hidden,
ui_order=ui_order,
item_default=item_default,
ui_choice_labels=ui_choice_labels,
_field_kind="input",
)
json_schema_extra_: dict[str, Any] = {
"input": input,
"ui_type": ui_type,
"ui_component": ui_component,
"ui_hidden": ui_hidden,
"ui_order": ui_order,
"item_default": item_default,
"ui_choice_labels": ui_choice_labels,
"_field_kind": "input",
}
field_args = dict(
default=default,
default_factory=default_factory,
title=title,
description=description,
pattern=pattern,
strict=strict,
gt=gt,
ge=ge,
lt=lt,
le=le,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
min_length=min_length,
max_length=max_length,
)
field_args = {
"default": default,
"default_factory": default_factory,
"title": title,
"description": description,
"pattern": pattern,
"strict": strict,
"gt": gt,
"ge": ge,
"lt": lt,
"le": le,
"multiple_of": multiple_of,
"allow_inf_nan": allow_inf_nan,
"max_digits": max_digits,
"decimal_places": decimal_places,
"min_length": min_length,
"max_length": max_length,
}
"""
Invocation definitions have their fields typed correctly for their `invoke()` functions.
@ -362,24 +299,24 @@ def InputField(
# because we are manually making fields optional, we need to store the original required bool for reference later
if default is PydanticUndefined and default_factory is PydanticUndefined:
json_schema_extra_.update(dict(orig_required=True))
json_schema_extra_.update({"orig_required": True})
else:
json_schema_extra_.update(dict(orig_required=False))
json_schema_extra_.update({"orig_required": False})
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
default_ = None if default is PydanticUndefined else default
provided_args.update(dict(default=default_))
provided_args.update({"default": default_})
if default is not PydanticUndefined:
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
json_schema_extra_.update(dict(default=default))
json_schema_extra_.update(dict(orig_default=default))
json_schema_extra_.update({"default": default})
json_schema_extra_.update({"orig_default": default})
elif default is not PydanticUndefined and default_factory is PydanticUndefined:
default_ = default
provided_args.update(dict(default=default_))
json_schema_extra_.update(dict(orig_default=default_))
provided_args.update({"default": default_})
json_schema_extra_.update({"orig_default": default_})
elif default_factory is not PydanticUndefined:
provided_args.update(dict(default_factory=default_factory))
provided_args.update({"default_factory": default_factory})
# TODO: cannot serialize default_factory...
# json_schema_extra_.update(dict(orig_default_factory=default_factory))
@ -446,12 +383,12 @@ def OutputField(
decimal_places=decimal_places,
min_length=min_length,
max_length=max_length,
json_schema_extra=dict(
ui_type=ui_type,
ui_hidden=ui_hidden,
ui_order=ui_order,
_field_kind="output",
),
json_schema_extra={
"ui_type": ui_type,
"ui_hidden": ui_hidden,
"ui_order": ui_order,
"_field_kind": "output",
},
)
@ -523,14 +460,14 @@ class BaseInvocationOutput(BaseModel):
@classmethod
def get_output_types(cls) -> Iterable[str]:
return map(lambda i: get_type(i), BaseInvocationOutput.get_outputs())
return (get_type(i) for i in BaseInvocationOutput.get_outputs())
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
# Because we use a pydantic Literal field with default value for the invocation type,
# it will be typed as optional in the OpenAPI schema. Make it required manually.
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = list()
schema["required"] = []
schema["required"].extend(["type"])
model_config = ConfigDict(
@ -590,16 +527,11 @@ class BaseInvocation(ABC, BaseModel):
@classmethod
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
# Get the type strings out of the literals and into a dictionary
return dict(
map(
lambda i: (get_type(i), i),
BaseInvocation.get_invocations(),
)
)
return {get_type(i): i for i in BaseInvocation.get_invocations()}
@classmethod
def get_invocation_types(cls) -> Iterable[str]:
return map(lambda i: get_type(i), BaseInvocation.get_invocations())
return (get_type(i) for i in BaseInvocation.get_invocations())
@classmethod
def get_output_type(cls) -> BaseInvocationOutput:
@ -618,7 +550,7 @@ class BaseInvocation(ABC, BaseModel):
if uiconfig and hasattr(uiconfig, "version"):
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = list()
schema["required"] = []
schema["required"].extend(["type", "id"])
@abstractmethod
@ -672,15 +604,15 @@ class BaseInvocation(ABC, BaseModel):
id: str = Field(
default_factory=uuid_string,
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
json_schema_extra=dict(_field_kind="internal"),
json_schema_extra={"_field_kind": "internal"},
)
is_intermediate: bool = Field(
default=False,
description="Whether or not this is an intermediate invocation.",
json_schema_extra=dict(ui_type=UIType.IsIntermediate, _field_kind="internal"),
json_schema_extra={"ui_type": UIType.IsIntermediate, "_field_kind": "internal"},
)
use_cache: bool = Field(
default=True, description="Whether or not to use the cache", json_schema_extra=dict(_field_kind="internal")
default=True, description="Whether or not to use the cache", json_schema_extra={"_field_kind": "internal"}
)
UIConfig: ClassVar[Type[UIConfigBase]]
@ -714,7 +646,7 @@ class _Model(BaseModel):
# Get all pydantic model attrs, methods, etc
RESERVED_PYDANTIC_FIELD_NAMES = set(map(lambda m: m[0], inspect.getmembers(_Model())))
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
@ -729,9 +661,7 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
field_kind = (
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
field.json_schema_extra.get("_field_kind", None)
if field.json_schema_extra
else None
field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None
)
# must have a field_kind
@ -792,7 +722,7 @@ def invocation(
# Add OpenAPI schema extras
uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
cls.UIConfig = type(uiconf_name, (UIConfigBase,), {})
if title is not None:
cls.UIConfig.title = title
if tags is not None:
@ -819,7 +749,7 @@ def invocation(
invocation_type_annotation = Literal[invocation_type] # type: ignore
invocation_type_field = Field(
title="type", default=invocation_type, json_schema_extra=dict(_field_kind="internal")
title="type", default=invocation_type, json_schema_extra={"_field_kind": "internal"}
)
docstring = cls.__doc__
@ -865,7 +795,7 @@ def invocation_output(
# Add the output type to the model.
output_type_annotation = Literal[output_type] # type: ignore
output_type_field = Field(title="type", default=output_type, json_schema_extra=dict(_field_kind="internal"))
output_type_field = Field(title="type", default=output_type, json_schema_extra={"_field_kind": "internal"})
docstring = cls.__doc__
cls = create_model(
@ -897,7 +827,7 @@ WorkflowFieldValidator = TypeAdapter(WorkflowField)
class WithWorkflow(BaseModel):
workflow: Optional[WorkflowField] = Field(
default=None, description=FieldDescriptions.workflow, json_schema_extra=dict(_field_kind="internal")
default=None, description=FieldDescriptions.workflow, json_schema_extra={"_field_kind": "internal"}
)
@ -915,5 +845,5 @@ MetadataFieldValidator = TypeAdapter(MetadataField)
class WithMetadata(BaseModel):
metadata: Optional[MetadataField] = Field(
default=None, description=FieldDescriptions.metadata, json_schema_extra=dict(_field_kind="internal")
default=None, description=FieldDescriptions.metadata, json_schema_extra={"_field_kind": "internal"}
)

View File

@ -7,6 +7,7 @@ from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ExtraConditioningInfo,
@ -19,7 +20,6 @@ from ...backend.util.devices import torch_dtype
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
@ -112,10 +112,11 @@ class CompelInvocation(BaseInvocation):
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
):
compel = Compel(
tokenizer=tokenizer,
@ -234,10 +235,11 @@ class SDXLPromptInvocationBase:
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
):
compel = Compel(
tokenizer=tokenizer,

View File

@ -28,12 +28,12 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.shared.fields import FieldDescriptions
from ...backend.model_management import BaseModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,

View File

@ -131,7 +131,7 @@ def prepare_faces_list(
deduped_faces: list[FaceResultData] = []
if len(face_result_list) == 0:
return list()
return []
for candidate in face_result_list:
should_add = True
@ -210,7 +210,7 @@ def generate_face_box_mask(
# Check if any face is detected.
if results.multi_face_landmarks: # type: ignore # this are via protobuf and not typed
# Search for the face_id in the detected faces.
for face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
for _face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
# Get the bounding box of the face mesh.
x_coordinates = [landmark.x for landmark in face_landmarks.landmark]
y_coordinates = [landmark.y for landmark in face_landmarks.landmark]

View File

@ -9,19 +9,11 @@ from PIL import Image, ImageChops, ImageFilter, ImageOps
from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
from .baseinvocation import (
BaseInvocation,
FieldDescriptions,
Input,
InputField,
InvocationContext,
WithMetadata,
WithWorkflow,
invocation,
)
from .baseinvocation import BaseInvocation, Input, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")

View File

@ -7,7 +7,6 @@ from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
@ -17,6 +16,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id

View File

@ -10,7 +10,7 @@ import torch
import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.adapter import FullAdapterXL, T2IAdapter
from diffusers.models.adapter import T2IAdapter
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
@ -34,6 +34,7 @@ from invokeai.app.invocations.primitives import (
)
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
@ -57,7 +58,6 @@ from ...backend.util.devices import choose_precision, choose_torch_device
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
@ -77,7 +77,7 @@ if choose_torch_device() == torch.device("mps"):
DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
@invocation_output("scheduler_output")
@ -562,10 +562,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_model: T2IAdapter
with t2i_adapter_model_info as t2i_adapter_model:
total_downscale_factor = t2i_adapter_model.total_downscale_factor
if isinstance(t2i_adapter_model.adapter, FullAdapterXL):
# HACK(ryand): Work around a bug in FullAdapterXL. This is being addressed upstream in diffusers by
# this PR: https://github.com/huggingface/diffusers/pull/5134.
total_downscale_factor = total_downscale_factor // 2
# Resize the T2I-Adapter input image.
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
@ -710,6 +706,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
with (
ExitStack() as exit_stack,
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
set_seamless(unet_info.context.model, self.unet.seamless_axes),
unet_info as unet,
# Apply the LoRA after unet has been moved to its target device for faster patching.
@ -1107,7 +1105,7 @@ class BlendLatentsInvocation(BaseInvocation):
latents_b = context.services.latents.get(self.latents_b.latents_name)
if latents_a.shape != latents_b.shape:
raise "Latents to blend must be the same size."
raise Exception("Latents to blend must be the same size.")
# TODO:
device = choose_torch_device()

View File

@ -6,8 +6,9 @@ import numpy as np
from pydantic import ValidationInfo, field_validator
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
from invokeai.app.shared.fields import FieldDescriptions
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0")
@ -144,17 +145,17 @@ INTEGER_OPERATIONS = Literal[
]
INTEGER_OPERATIONS_LABELS = dict(
ADD="Add A+B",
SUB="Subtract A-B",
MUL="Multiply A*B",
DIV="Divide A/B",
EXP="Exponentiate A^B",
MOD="Modulus A%B",
ABS="Absolute Value of A",
MIN="Minimum(A,B)",
MAX="Maximum(A,B)",
)
INTEGER_OPERATIONS_LABELS = {
"ADD": "Add A+B",
"SUB": "Subtract A-B",
"MUL": "Multiply A*B",
"DIV": "Divide A/B",
"EXP": "Exponentiate A^B",
"MOD": "Modulus A%B",
"ABS": "Absolute Value of A",
"MIN": "Minimum(A,B)",
"MAX": "Maximum(A,B)",
}
@invocation(
@ -182,8 +183,8 @@ class IntegerMathInvocation(BaseInvocation):
operation: INTEGER_OPERATIONS = InputField(
default="ADD", description="The operation to perform", ui_choice_labels=INTEGER_OPERATIONS_LABELS
)
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
a: int = InputField(default=1, description=FieldDescriptions.num_1)
b: int = InputField(default=1, description=FieldDescriptions.num_2)
@field_validator("b")
def no_unrepresentable_results(cls, v: int, info: ValidationInfo):
@ -230,17 +231,17 @@ FLOAT_OPERATIONS = Literal[
]
FLOAT_OPERATIONS_LABELS = dict(
ADD="Add A+B",
SUB="Subtract A-B",
MUL="Multiply A*B",
DIV="Divide A/B",
EXP="Exponentiate A^B",
ABS="Absolute Value of A",
SQRT="Square Root of A",
MIN="Minimum(A,B)",
MAX="Maximum(A,B)",
)
FLOAT_OPERATIONS_LABELS = {
"ADD": "Add A+B",
"SUB": "Subtract A-B",
"MUL": "Multiply A*B",
"DIV": "Divide A/B",
"EXP": "Exponentiate A^B",
"ABS": "Absolute Value of A",
"SQRT": "Square Root of A",
"MIN": "Minimum(A,B)",
"MAX": "Maximum(A,B)",
}
@invocation(
@ -256,8 +257,8 @@ class FloatMathInvocation(BaseInvocation):
operation: FLOAT_OPERATIONS = InputField(
default="ADD", description="The operation to perform", ui_choice_labels=FLOAT_OPERATIONS_LABELS
)
a: float = InputField(default=0, description=FieldDescriptions.num_1)
b: float = InputField(default=0, description=FieldDescriptions.num_2)
a: float = InputField(default=1, description=FieldDescriptions.num_1)
b: float = InputField(default=1, description=FieldDescriptions.num_2)
@field_validator("b")
def no_unrepresentable_results(cls, v: float, info: ValidationInfo):
@ -265,7 +266,7 @@ class FloatMathInvocation(BaseInvocation):
raise ValueError("Cannot divide by zero")
elif info.data["operation"] == "EXP" and info.data["a"] == 0 and v < 0:
raise ValueError("Cannot raise zero to a negative power")
elif info.data["operation"] == "EXP" and type(info.data["a"] ** v) is complex:
elif info.data["operation"] == "EXP" and isinstance(info.data["a"] ** v, complex):
raise ValueError("Root operation resulted in a complex number")
return v

View File

@ -5,7 +5,6 @@ from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
InputField,
InvocationContext,
MetadataField,
@ -19,6 +18,7 @@ from invokeai.app.invocations.ip_adapter import IPAdapterModelField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.shared.fields import FieldDescriptions
from ...version import __version__
@ -160,13 +160,14 @@ class CoreMetadataInvocation(BaseInvocation):
)
# High resolution fix metadata.
hrf_width: Optional[int] = InputField(
hrf_enabled: Optional[float] = InputField(
default=None,
description="The high resolution fix height and width multipler.",
description="Whether or not high resolution fix was enabled.",
)
hrf_height: Optional[int] = InputField(
# TODO: should this be stricter or do we just let the UI handle it?
hrf_method: Optional[str] = InputField(
default=None,
description="The high resolution fix height and width multipler.",
description="The high resolution fix upscale method.",
)
hrf_strength: Optional[float] = InputField(
default=None,

View File

@ -3,11 +3,13 @@ from typing import List, Optional
from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.shared.models import FreeUConfig
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
@ -36,6 +38,7 @@ class UNetField(BaseModel):
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
class ClipField(BaseModel):
@ -51,15 +54,34 @@ class VaeField(BaseModel):
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@invocation_output("model_loader_output")
class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
@invocation_output("unet_output")
class UNetOutput(BaseInvocationOutput):
"""Base class for invocations that output a UNet field"""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
@invocation_output("vae_output")
class VAEOutput(BaseInvocationOutput):
"""Base class for invocations that output a VAE field"""
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation_output("clip_output")
class CLIPOutput(BaseInvocationOutput):
"""Base class for invocations that output a CLIP field"""
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
@invocation_output("model_loader_output")
class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
"""Model loader output"""
pass
class MainModelField(BaseModel):
"""Main model field"""
@ -366,13 +388,6 @@ class VAEModelField(BaseModel):
model_config = ConfigDict(protected_namespaces=())
@invocation_output("vae_loader_output")
class VaeLoaderOutput(BaseInvocationOutput):
"""VAE output"""
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
@ -384,7 +399,7 @@ class VaeLoaderInvocation(BaseInvocation):
title="VAE",
)
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
def invoke(self, context: InvocationContext) -> VAEOutput:
base_model = self.vae_model.base_model
model_name = self.vae_model.model_name
model_type = ModelType.Vae
@ -395,7 +410,7 @@ class VaeLoaderInvocation(BaseInvocation):
model_type=model_type,
):
raise Exception(f"Unkown vae name: {model_name}!")
return VaeLoaderOutput(
return VAEOutput(
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
@ -457,3 +472,24 @@ class SeamlessModeInvocation(BaseInvocation):
vae.seamless_axes = seamless_axes_list
return SeamlessModeOutput(unet=unet, vae=vae)
@invocation("freeu", title="FreeU", tags=["freeu"], category="unet", version="1.0.0")
class FreeUInvocation(BaseInvocation):
"""
Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2):
SD1.5: 1.2/1.4/0.9/0.2,
SD2: 1.1/1.2/0.9/0.2,
SDXL: 1.1/1.2/0.6/0.4,
"""
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet")
b1: float = InputField(default=1.2, ge=-1, le=3, description=FieldDescriptions.freeu_b1)
b2: float = InputField(default=1.4, ge=-1, le=3, description=FieldDescriptions.freeu_b2)
s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1)
s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2)
def invoke(self, context: InvocationContext) -> UNetOutput:
self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2)
return UNetOutput(unet=self.unet)

View File

@ -5,13 +5,13 @@ import torch
from pydantic import field_validator
from invokeai.app.invocations.latent import LatentsField
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from ...backend.util.devices import choose_torch_device, torch_dtype
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
InputField,
InvocationContext,
OutputField,

View File

@ -14,6 +14,7 @@ from tqdm import tqdm
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend import BaseModelType, ModelType, SubModelType
@ -23,7 +24,6 @@ from ...backend.util import choose_torch_device
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
@ -54,7 +54,7 @@ ORT_TO_NP_TYPE = {
"tensor(double)": np.float64,
}
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
PRECISION_VALUES = Literal[tuple(ORT_TO_NP_TYPE.keys())]
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
@ -252,7 +252,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
scheduler.set_timesteps(self.steps)
latents = latents * np.float64(scheduler.init_noise_sigma)
extra_step_kwargs = dict()
extra_step_kwargs = {}
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
eta=0.0,

View File

@ -100,7 +100,7 @@ EASING_FUNCTIONS_MAP = {
"BounceInOut": BounceEaseInOut,
}
EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
# actually I think for now could just use CollectionOutput (which is list[Any]
@ -161,7 +161,7 @@ class StepParamEasingInvocation(BaseInvocation):
easing_class = EASING_FUNCTIONS_MAP[self.easing]
if log_diagnostics:
context.services.logger.debug("easing class: " + str(easing_class))
easing_list = list()
easing_list = []
if self.mirror: # "expected" mirroring
# if number of steps is even, squeeze duration down to (number_of_steps)/2
# and create reverse copy of list to append
@ -178,7 +178,7 @@ class StepParamEasingInvocation(BaseInvocation):
end=self.end_value,
duration=base_easing_duration - 1,
)
base_easing_vals = list()
base_easing_vals = []
for step_index in range(base_easing_duration):
easing_val = easing_function.ease(step_index)
base_easing_vals.append(easing_val)

View File

@ -5,10 +5,11 @@ from typing import Optional, Tuple
import torch
from pydantic import BaseModel, Field
from invokeai.app.shared.fields import FieldDescriptions
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,

View File

@ -1,8 +1,9 @@
from invokeai.app.shared.fields import FieldDescriptions
from ...backend.model_management import ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,

View File

@ -5,7 +5,6 @@ from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
@ -16,6 +15,7 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.model_management.models.base import BaseModelType

View File

@ -139,7 +139,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
images = [deserialize_image_record(dict(r)) for r in result]
self._cursor.execute(
"""--sql
@ -167,7 +167,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
image_names = list(map(lambda r: r[0], result))
image_names = [r[0] for r in result]
return image_names
except sqlite3.Error as e:
self._conn.rollback()

View File

@ -199,7 +199,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
boards = [deserialize_board_record(dict(r)) for r in result]
# Get the total number of boards
self._cursor.execute(
@ -236,7 +236,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
boards = [deserialize_board_record(dict(r)) for r in result]
return boards

View File

@ -55,7 +55,7 @@ class InvokeAISettings(BaseSettings):
"""
cls = self.__class__
type = get_args(get_type_hints(cls)["type"])[0]
field_dict = dict({type: dict()})
field_dict = {type: {}}
for name, field in self.model_fields.items():
if name in cls._excluded_from_yaml():
continue
@ -64,7 +64,7 @@ class InvokeAISettings(BaseSettings):
)
value = getattr(self, name)
if category not in field_dict[type]:
field_dict[type][category] = dict()
field_dict[type][category] = {}
# keep paths as strings to make it easier to read
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
conf = OmegaConf.create(field_dict)
@ -89,7 +89,7 @@ class InvokeAISettings(BaseSettings):
# create an upcase version of the environment in
# order to achieve case-insensitive environment
# variables (the way Windows does)
upcase_environ = dict()
upcase_environ = {}
for key, value in os.environ.items():
upcase_environ[key.upper()] = value

View File

@ -188,18 +188,18 @@ DEFAULT_MAX_VRAM = 0.5
class Categories(object):
WebServer = dict(category="Web Server")
Features = dict(category="Features")
Paths = dict(category="Paths")
Logging = dict(category="Logging")
Development = dict(category="Development")
Other = dict(category="Other")
ModelCache = dict(category="Model Cache")
Device = dict(category="Device")
Generation = dict(category="Generation")
Queue = dict(category="Queue")
Nodes = dict(category="Nodes")
MemoryPerformance = dict(category="Memory/Performance")
WebServer = {"category": "Web Server"}
Features = {"category": "Features"}
Paths = {"category": "Paths"}
Logging = {"category": "Logging"}
Development = {"category": "Development"}
Other = {"category": "Other"}
ModelCache = {"category": "Model Cache"}
Device = {"category": "Device"}
Generation = {"category": "Generation"}
Queue = {"category": "Queue"}
Nodes = {"category": "Nodes"}
MemoryPerformance = {"category": "Memory/Performance"}
class InvokeAIAppConfig(InvokeAISettings):
@ -482,7 +482,7 @@ def _find_root() -> Path:
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ["INVOKEAI_ROOT"])
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
elif any((venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]):
root = (venv.parent).resolve()
else:
root = Path("~/invokeai").expanduser().resolve()

View File

@ -27,7 +27,7 @@ class EventServiceBase:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.queue_event,
payload=dict(event=event_name, data=payload),
payload={"event": event_name, "data": payload},
)
# Define events here for every event in the system.
@ -48,18 +48,18 @@ class EventServiceBase:
"""Emitted when there is generation progress"""
self.__emit_queue_event(
event_name="generator_progress",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node_id=node.get("id"),
source_node_id=source_node_id,
progress_image=progress_image.model_dump() if progress_image is not None else None,
step=step,
order=order,
total_steps=total_steps,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node_id": node.get("id"),
"source_node_id": source_node_id,
"progress_image": progress_image.model_dump() if progress_image is not None else None,
"step": step,
"order": order,
"total_steps": total_steps,
},
)
def emit_invocation_complete(
@ -75,15 +75,15 @@ class EventServiceBase:
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
event_name="invocation_complete",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
result=result,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"result": result,
},
)
def emit_invocation_error(
@ -100,16 +100,16 @@ class EventServiceBase:
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
event_name="invocation_error",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
error_type=error_type,
error=error,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"error_type": error_type,
"error": error,
},
)
def emit_invocation_started(
@ -124,14 +124,14 @@ class EventServiceBase:
"""Emitted when an invocation has started"""
self.__emit_queue_event(
event_name="invocation_started",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
},
)
def emit_graph_execution_complete(
@ -140,12 +140,12 @@ class EventServiceBase:
"""Emitted when a session has completed all invocations"""
self.__emit_queue_event(
event_name="graph_execution_state_complete",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
def emit_model_load_started(
@ -162,16 +162,16 @@ class EventServiceBase:
"""Emitted when a model is requested"""
self.__emit_queue_event(
event_name="model_load_started",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_name": model_name,
"base_model": base_model,
"model_type": model_type,
"submodel": submodel,
},
)
def emit_model_load_completed(
@ -189,19 +189,19 @@ class EventServiceBase:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event(
event_name="model_load_completed",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
hash=model_info.hash,
location=str(model_info.location),
precision=str(model_info.precision),
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_name": model_name,
"base_model": base_model,
"model_type": model_type,
"submodel": submodel,
"hash": model_info.hash,
"location": str(model_info.location),
"precision": str(model_info.precision),
},
)
def emit_session_retrieval_error(
@ -216,14 +216,14 @@ class EventServiceBase:
"""Emitted when session retrieval fails"""
self.__emit_queue_event(
event_name="session_retrieval_error",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
error_type=error_type,
error=error,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"error_type": error_type,
"error": error,
},
)
def emit_invocation_retrieval_error(
@ -239,15 +239,15 @@ class EventServiceBase:
"""Emitted when invocation retrieval fails"""
self.__emit_queue_event(
event_name="invocation_retrieval_error",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node_id=node_id,
error_type=error_type,
error=error,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node_id": node_id,
"error_type": error_type,
"error": error,
},
)
def emit_session_canceled(
@ -260,12 +260,12 @@ class EventServiceBase:
"""Emitted when a session is canceled"""
self.__emit_queue_event(
event_name="session_canceled",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
),
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
def emit_queue_item_status_changed(
@ -277,39 +277,39 @@ class EventServiceBase:
"""Emitted when a queue item's status changes"""
self.__emit_queue_event(
event_name="queue_item_status_changed",
payload=dict(
queue_id=queue_status.queue_id,
queue_item=dict(
queue_id=session_queue_item.queue_id,
item_id=session_queue_item.item_id,
status=session_queue_item.status,
batch_id=session_queue_item.batch_id,
session_id=session_queue_item.session_id,
error=session_queue_item.error,
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
),
batch_status=batch_status.model_dump(),
queue_status=queue_status.model_dump(),
),
payload={
"queue_id": queue_status.queue_id,
"queue_item": {
"queue_id": session_queue_item.queue_id,
"item_id": session_queue_item.item_id,
"status": session_queue_item.status,
"batch_id": session_queue_item.batch_id,
"session_id": session_queue_item.session_id,
"error": session_queue_item.error,
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
},
"batch_status": batch_status.model_dump(),
"queue_status": queue_status.model_dump(),
},
)
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
"""Emitted when a batch is enqueued"""
self.__emit_queue_event(
event_name="batch_enqueued",
payload=dict(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
enqueued=enqueue_result.enqueued,
),
payload={
"queue_id": enqueue_result.queue_id,
"batch_id": enqueue_result.batch.batch_id,
"enqueued": enqueue_result.enqueued,
},
)
def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when the queue is cleared"""
self.__emit_queue_event(
event_name="queue_cleared",
payload=dict(queue_id=queue_id),
payload={"queue_id": queue_id},
)

View File

@ -25,7 +25,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
__invoker: Invoker
def __init__(self, output_folder: Union[str, Path]):
self.__cache = dict()
self.__cache = {}
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config

View File

@ -90,25 +90,23 @@ class ImageRecordDeleteException(Exception):
IMAGE_DTO_COLS = ", ".join(
list(
map(
lambda c: "images." + c,
[
"image_name",
"image_origin",
"image_category",
"width",
"height",
"session_id",
"node_id",
"is_intermediate",
"created_at",
"updated_at",
"deleted_at",
"starred",
],
)
)
[
"images." + c
for c in [
"image_name",
"image_origin",
"image_category",
"width",
"height",
"session_id",
"node_id",
"is_intermediate",
"created_at",
"updated_at",
"deleted_at",
"starred",
]
]
)

View File

@ -263,7 +263,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = list(map(lambda c: c.value, set(categories)))
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
@ -307,7 +307,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Build the list of images, deserializing each row
self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
images = [deserialize_image_record(dict(r)) for r in result]
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
@ -386,7 +386,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
image_names = list(map(lambda r: r[0], result))
image_names = [r[0] for r in result]
self._cursor.execute(
"""--sql
DELETE FROM images

View File

@ -21,8 +21,8 @@ class ImageServiceABC(ABC):
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
self._on_changed_callbacks = []
self._on_deleted_callbacks = []
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
"""Register a callback for when an image is changed"""

View File

@ -217,18 +217,16 @@ class ImageService(ImageServiceABC):
board_id,
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
image_record=r,
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
),
results.items,
image_dtos = [
image_record_to_dto(
image_record=r,
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
)
)
for r in results.items
]
return OffsetPaginatedResults[ImageDTO](
items=image_dtos,

View File

@ -1,5 +1,5 @@
from abc import ABC
class InvocationProcessorABC(ABC):
class InvocationProcessorABC(ABC): # noqa: B024
pass

View File

@ -26,7 +26,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
self.__invoker_thread = Thread(
name="invoker_processor",
target=self.__process,
kwargs=dict(stop_event=self.__stop_event),
kwargs={"stop_event": self.__stop_event},
)
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
self.__invoker_thread.start()

View File

@ -14,7 +14,7 @@ class MemoryInvocationQueue(InvocationQueueABC):
def __init__(self):
self.__queue = Queue()
self.__cancellations = dict()
self.__cancellations = {}
def get(self) -> InvocationQueueItem:
item = self.__queue.get()

View File

@ -22,6 +22,7 @@ if TYPE_CHECKING:
from .item_storage.item_storage_base import ItemStorageABC
from .latents_storage.latents_storage_base import LatentsStorageBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase
@ -49,6 +50,7 @@ class InvocationServices:
latents: "LatentsStorageBase"
logger: "Logger"
model_manager: "ModelManagerServiceBase"
model_records: "ModelRecordServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
@ -76,6 +78,7 @@ class InvocationServices:
latents: "LatentsStorageBase",
logger: "Logger",
model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
@ -101,6 +104,7 @@ class InvocationServices:
self.latents = latents
self.logger = logger
self.model_manager = model_manager
self.model_records = model_records
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue

View File

@ -122,7 +122,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
def log_stats(self):
completed = set()
errored = set()
for graph_id, node_log in self._stats.items():
for graph_id, _node_log in self._stats.items():
try:
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
except Exception:
@ -142,7 +142,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
cache_stats = self._cache_stats[graph_id]
hwm = cache_stats.high_watermark / GIG
tot = cache_stats.cache_size / GIG
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG
loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GIG
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")

View File

@ -15,8 +15,8 @@ class ItemStorageABC(ABC, Generic[T]):
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
self._on_changed_callbacks = []
self._on_deleted_callbacks = []
"""Base item storage class"""

View File

@ -112,7 +112,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
)
result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result))
items = [self._parse_item(r[0]) for r in result]
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0]
@ -132,7 +132,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
)
result = self._cursor.fetchall()
items = list(map(lambda r: self._parse_item(r[0]), result))
items = [self._parse_item(r[0]) for r in result]
self._cursor.execute(
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",

View File

@ -13,8 +13,8 @@ class LatentsStorageBase(ABC):
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
self._on_changed_callbacks = []
self._on_deleted_callbacks = []
@abstractmethod
def get(self, name: str) -> torch.Tensor:

View File

@ -19,7 +19,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
super().__init__()
self.__underlying_storage = underlying_storage
self.__cache = dict()
self.__cache = {}
self.__cache_ids = Queue()
self.__max_cache_size = max_cache_size

View File

@ -0,0 +1,8 @@
"""Init file for model record services."""
from .model_records_base import ( # noqa F401
DuplicateModelException,
InvalidModelException,
ModelRecordServiceBase,
UnknownModelException,
)
from .model_records_sql import ModelRecordServiceSQL # noqa F401

View File

@ -0,0 +1,169 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Abstract base class for storing and retrieving model configuration records.
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Union
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
# should match the InvokeAI version when this is first released.
CONFIG_FILE_VERSION = "3.2.0"
class DuplicateModelException(Exception):
"""Raised on an attempt to add a model with the same key twice."""
class InvalidModelException(Exception):
"""Raised when an invalid model is detected."""
class UnknownModelException(Exception):
"""Raised on an attempt to fetch or delete a model with a nonexistent key."""
class ConfigFileVersionMismatchException(Exception):
"""Raised on an attempt to open a config with an incompatible version."""
class ModelRecordServiceBase(ABC):
"""Abstract base class for storage and retrieval of model configs."""
@property
@abstractmethod
def version(self) -> str:
"""Return the config file/database schema version."""
pass
@abstractmethod
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Add a model to the database.
:param key: Unique key for the model
:param config: Model configuration record, either a dict with the
required fields or a ModelConfigBase instance.
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
pass
@abstractmethod
def del_model(self, key: str) -> None:
"""
Delete a model.
:param key: Unique key for the model to be deleted
Can raise an UnknownModelException
"""
pass
@abstractmethod
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
pass
@abstractmethod
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the configuration for the indicated model.
:param key: Key of model config to be fetched.
Exceptions: UnknownModelException
"""
pass
@abstractmethod
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
:param key: Unique key for the model to be deleted
"""
pass
@abstractmethod
def search_by_path(
self,
path: Union[str, Path],
) -> List[AnyModelConfig]:
"""Return the model(s) having the indicated path."""
pass
@abstractmethod
def search_by_hash(
self,
hash: str,
) -> List[AnyModelConfig]:
"""Return the model(s) having the indicated original hash."""
pass
@abstractmethod
def search_by_attr(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
If none of the optional filters are passed, will return all
models in the database.
"""
pass
def all_models(self) -> List[AnyModelConfig]:
"""Return all the model configs in the database."""
return self.search_by_attr()
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> AnyModelConfig:
"""
Return information about a single model using its name, base type and model type.
If there are more than one model that match, raises a DuplicateModelException.
If no model matches, raises an UnknownModelException
"""
model_configs = self.search_by_attr(model_name=model_name, base_model=base_model, model_type=model_type)
if len(model_configs) > 1:
raise DuplicateModelException(
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
)
if len(model_configs) == 0:
raise UnknownModelException(
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
)
return model_configs[0]
def rename_model(
self,
key: str,
new_name: str,
) -> AnyModelConfig:
"""
Rename the indicated model. Just a special case of update_model().
In some implementations, renaming the model may involve changing where
it is stored on the filesystem. So this is broken out.
:param key: Model key
:param new_name: New name for model
"""
config = self.get_model(key)
config.name = new_name
return self.update_model(key, config)

View File

@ -0,0 +1,397 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Implementation of the ModelRecordServiceBase API
Typical usage:
from invokeai.backend.model_manager import ModelConfigStoreSQL
store = ModelConfigStoreSQL(sqlite_db)
config = dict(
path='/tmp/pokemon.bin',
name='old name',
base_model='sd-1',
type='embedding',
format='embedding_file',
)
# adding - the key becomes the model's "key" field
store.add_model('key1', config)
# updating
config.name='new name'
store.update_model('key1', config)
# checking for existence
if store.exists('key1'):
print("yes")
# fetching config
new_config = store.get_model('key1')
print(new_config.name, new_config.base)
assert new_config.key == 'key1'
# deleting
store.del_model('key1')
# searching
configs = store.search_by_path(path='/tmp/pokemon.bin')
configs = store.search_by_hash('750a499f35e43b7e1b4d15c207aa2f01')
configs = store.search_by_attr(base_model='sd-2', model_type='main')
"""
import json
import sqlite3
from pathlib import Path
from typing import List, Optional, Union
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelConfigBase,
ModelConfigFactory,
ModelType,
)
from ..shared.sqlite import SqliteDatabase
from .model_records_base import (
CONFIG_FILE_VERSION,
DuplicateModelException,
ModelRecordServiceBase,
UnknownModelException,
)
class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database."""
_db: SqliteDatabase
_cursor: sqlite3.Cursor
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
with self._db.lock:
# Enable foreign keys
self._db.conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._db.conn.commit()
assert (
str(self.version) == CONFIG_FILE_VERSION
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
def _create_tables(self) -> None:
"""Create sqlite3 tables."""
# model_config table breaks out the fields that are common to all config objects
# and puts class-specific ones in a serialized json object
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT NOT NULL,
type TEXT NOT NULL,
name TEXT NOT NULL,
path TEXT NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
# metadata table
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_manager_metadata (
metadata_key TEXT NOT NULL PRIMARY KEY,
metadata_value TEXT NOT NULL
);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
AFTER UPDATE
ON model_config FOR EACH ROW
BEGIN
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
# Add indexes for searchable fields
for stmt in [
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
]:
self._cursor.execute(stmt)
# Add our version to the metadata table
self._cursor.execute(
"""--sql
INSERT OR IGNORE into model_manager_metadata (
metadata_key,
metadata_value
)
VALUES (?,?);
""",
("version", CONFIG_FILE_VERSION),
)
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> AnyModelConfig:
"""
Add a model to the database.
:param key: Unique key for the model
:param config: Model configuration record, either a dict with the
required fields or a ModelConfigBase instance.
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_config (
id,
base,
type,
name,
path,
original_hash,
config
)
VALUES (?,?,?,?,?,?,?);
""",
(
key,
record.base,
record.type,
record.name,
record.path,
record.original_hash,
json_serialized,
),
)
self._db.conn.commit()
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "model_config.path" in str(e):
msg = f"A model with path '{record.path}' is already installed"
elif "model_config.name" in str(e):
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
else:
msg = f"A model with key '{key}' is already installed"
raise DuplicateModelException(msg) from e
else:
raise e
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(key)
@property
def version(self) -> str:
"""Return the version of the database schema."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata_value FROM model_manager_metadata
WHERE metadata_key=?;
""",
("version",),
)
rows = self._cursor.fetchone()
if not rows:
raise KeyError("Models database does not have metadata key 'version'")
return rows[0]
def del_model(self, key: str) -> None:
"""
Delete a model.
:param key: Unique key for the model to be deleted
Can raise an UnknownModelException
"""
with self._db.lock:
try:
self._cursor.execute(
"""--sql
DELETE FROM model_config
WHERE id=?;
""",
(key,),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
def update_model(self, key: str, config: ModelConfigBase) -> AnyModelConfig:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_config
SET base=?,
type=?,
name=?,
path=?,
config=?
WHERE id=?;
""",
(record.base, record.type, record.name, record.path, json_serialized, key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(key)
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the ModelConfigBase instance for the indicated model.
:param key: Key of model config to be fetched.
Exceptions: UnknownModelException
"""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE id=?;
""",
(key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]))
return model
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
:param key: Unique key for the model to be deleted
"""
count = 0
with self._db.lock:
self._cursor.execute(
"""--sql
select count(*) FROM model_config
WHERE id=?;
""",
(key,),
)
count = self._cursor.fetchone()[0]
return count > 0
def search_by_attr(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
If none of the optional filters are passed, will return all
models in the database.
"""
results = []
where_clause = []
bindings = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base=?")
bindings.append(base_model)
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._db.lock:
self._cursor.execute(
f"""--sql
select config FROM model_config
{where};
""",
tuple(bindings),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results
def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]:
"""Return models with the indicated path."""
results = []
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE model_path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[ModelConfigBase]:
"""Return models with the indicated original_hash."""
results = []
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE original_hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results

View File

@ -33,9 +33,11 @@ class DefaultSessionProcessor(SessionProcessorBase):
self.__thread = Thread(
name="session_processor",
target=self.__process,
kwargs=dict(
stop_event=self.__stop_event, poll_now_event=self.__poll_now_event, resume_event=self.__resume_event
),
kwargs={
"stop_event": self.__stop_event,
"poll_now_event": self.__poll_now_event,
"resume_event": self.__resume_event,
},
)
self.__thread.start()

View File

@ -129,12 +129,12 @@ class Batch(BaseModel):
return v
model_config = ConfigDict(
json_schema_extra=dict(
required=[
json_schema_extra={
"required": [
"graph",
"runs",
]
)
}
)
@ -191,8 +191,8 @@ class SessionQueueItemWithoutGraph(BaseModel):
return SessionQueueItemDTO(**queue_item_dict)
model_config = ConfigDict(
json_schema_extra=dict(
required=[
json_schema_extra={
"required": [
"item_id",
"status",
"batch_id",
@ -203,7 +203,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
"created_at",
"updated_at",
]
)
}
)
@ -222,8 +222,8 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
return SessionQueueItem(**queue_item_dict)
model_config = ConfigDict(
json_schema_extra=dict(
required=[
json_schema_extra={
"required": [
"item_id",
"status",
"batch_id",
@ -235,7 +235,7 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
"created_at",
"updated_at",
]
)
}
)
@ -355,7 +355,7 @@ def create_session_nfv_tuples(
for item in batch_datum.items
]
node_field_values_to_zip.append(node_field_values)
data.append(list(zip(*node_field_values_to_zip))) # type: ignore [arg-type]
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
# create generator to yield session,nfv tuples
count = 0
@ -383,7 +383,7 @@ def calc_session_count(batch: Batch) -> int:
for batch_datum in batch_datum_list:
batch_data_items = range(len(batch_datum.items))
to_zip.append(batch_data_items)
data.append(list(zip(*to_zip)))
data.append(list(zip(*to_zip, strict=True)))
data_product = list(product(*data))
return len(data_product) * batch.runs

View File

@ -78,7 +78,7 @@ def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[Li
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
graphs: list[LibraryGraph] = list()
graphs: list[LibraryGraph] = []
text_to_image = graph_library.get(default_text_to_image_graph_id)

View File

@ -352,7 +352,7 @@ class Graph(BaseModel):
# Validate that all node ids are unique
node_ids = [n.id for n in self.nodes.values()]
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
duplicate_node_ids = {node_id for node_id in node_ids if node_ids.count(node_id) >= 2}
if duplicate_node_ids:
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
@ -616,7 +616,7 @@ class Graph(BaseModel):
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", Union[str, None], Edge]]:
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
edges = list()
edges = []
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
@ -658,7 +658,7 @@ class Graph(BaseModel):
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", Union[str, None], Edge]]:
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
edges = list()
edges = []
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
@ -680,8 +680,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = list([e.source for e in self._get_input_edges(node_path, "collection")])
outputs = list([e.destination for e in self._get_output_edges(node_path, "item")])
inputs = [e.source for e in self._get_input_edges(node_path, "collection")]
outputs = [e.destination for e in self._get_output_edges(node_path, "item")]
if new_input is not None:
inputs.append(new_input)
@ -694,7 +694,7 @@ class Graph(BaseModel):
# Get input and output fields (the fields linked to the iterator's input/output)
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Input type must be a list
if get_origin(input_field) != list:
@ -713,8 +713,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = list([e.source for e in self._get_input_edges(node_path, "item")])
outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")])
inputs = [e.source for e in self._get_input_edges(node_path, "item")]
outputs = [e.destination for e in self._get_output_edges(node_path, "collection")]
if new_input is not None:
inputs.append(new_input)
@ -722,18 +722,16 @@ class Graph(BaseModel):
outputs.append(new_output)
# Get input and output fields (the fields linked to the iterator's input/output)
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
input_fields = [get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Validate that all inputs are derived from or match a single type
input_field_types = set(
[
t
for input_field in input_fields
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
if t != NoneType
]
) # Get unique types
input_field_types = {
t
for input_field in input_fields
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
if t != NoneType
} # Get unique types
type_tree = nx.DiGraph()
type_tree.add_nodes_from(input_field_types)
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
@ -761,15 +759,15 @@ class Graph(BaseModel):
"""Returns a NetworkX DiGraph representing the layout of this graph"""
# TODO: Cache this?
g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.keys()])
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
g.add_nodes_from(list(self.nodes.keys()))
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
return g
def nx_graph_with_data(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.items()])
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
g.add_nodes_from(list(self.nodes.items()))
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
return g
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
@ -791,7 +789,7 @@ class Graph(BaseModel):
# TODO: figure out if iteration nodes need to be expanded
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
return g
@ -843,8 +841,8 @@ class GraphExecutionState(BaseModel):
return v
model_config = ConfigDict(
json_schema_extra=dict(
required=[
json_schema_extra={
"required": [
"id",
"graph",
"execution_graph",
@ -855,7 +853,7 @@ class GraphExecutionState(BaseModel):
"prepared_source_mapping",
"source_prepared_mapping",
]
)
}
)
def next(self) -> Optional[BaseInvocation]:
@ -895,7 +893,7 @@ class GraphExecutionState(BaseModel):
source_node = self.prepared_source_mapping[node_id]
prepared_nodes = self.source_prepared_mapping[source_node]
if all([n in self.executed for n in prepared_nodes]):
if all(n in self.executed for n in prepared_nodes):
self.executed.add(source_node)
self.executed_history.append(source_node)
@ -930,7 +928,7 @@ class GraphExecutionState(BaseModel):
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
self_iteration_count = len(input_collection)
new_nodes: list[str] = list()
new_nodes: list[str] = []
if self_iteration_count == 0:
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
return new_nodes
@ -940,7 +938,7 @@ class GraphExecutionState(BaseModel):
# Create new edges for this iteration
# For collect nodes, this may contain multiple inputs to the same field
new_edges: list[Edge] = list()
new_edges: list[Edge] = []
for edge in input_edges:
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
new_edge = Edge(
@ -1034,7 +1032,7 @@ class GraphExecutionState(BaseModel):
# Create execution nodes
next_node = self.graph.get_node(next_node_id)
new_node_ids = list()
new_node_ids = []
if isinstance(next_node, CollectInvocation):
# Collapse all iterator input mappings and create a single execution node for the collect invocation
all_iteration_mappings = list(
@ -1055,7 +1053,10 @@ class GraphExecutionState(BaseModel):
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
# TODO: Handle a node mapping to none
eg = self.execution_graph.nx_graph_flat()
prepared_parent_mappings = [[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore
prepared_parent_mappings = [
[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents]
for it in iterator_node_prepared_combinations
] # type: ignore
# Create execution node for each iteration
for iteration_mappings in prepared_parent_mappings:
@ -1121,7 +1122,7 @@ class GraphExecutionState(BaseModel):
for edge in input_edges
if edge.destination.field == "item"
]
setattr(node, "collection", output_collection)
node.collection = output_collection
else:
for edge in input_edges:
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
@ -1201,7 +1202,7 @@ class LibraryGraph(BaseModel):
@field_validator("exposed_inputs", "exposed_outputs")
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
if len(v) != len(set(i.alias for i in v)):
if len(v) != len({i.alias for i in v}):
raise ValueError("Duplicate exposed alias")
return v

View File

@ -0,0 +1,5 @@
"""
This module contains various classes, functions and models which are shared across the app, particularly by invocations.
Lifting these classes, functions and models into this shared module helps to reduce circular imports.
"""

View File

@ -0,0 +1,66 @@
class FieldDescriptions:
denoising_start = "When to start denoising, expressed a percentage of total steps"
denoising_end = "When to stop denoising, expressed a percentage of total steps"
cfg_scale = "Classifier-Free Guidance scale"
scheduler = "Scheduler to use during inference"
positive_cond = "Positive conditioning tensor"
negative_cond = "Negative conditioning tensor"
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
unet = "UNet (scheduler, LoRAs)"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
raw_prompt = "Raw prompt text (no parsing)"
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
skipped_layers = "Number of layers to skip in text encoder"
seed = "Seed for random number generation"
steps = "Number of steps to run"
width = "Width of output (px)"
height = "Height of output (px)"
control = "ControlNet(s) to apply"
ip_adapter = "IP-Adapter to apply"
t2i_adapter = "T2I-Adapter(s) to apply"
denoised_latents = "Denoised latents tensor"
latents = "Latents tensor"
strength = "Strength of denoising (proportional to steps)"
metadata = "Optional metadata to be saved with the image"
metadata_collection = "Collection of Metadata"
metadata_item_polymorphic = "A single metadata item or collection of metadata items"
metadata_item_label = "Label for this metadata item"
metadata_item_value = "The value for this metadata item (may be any type)"
workflow = "Optional workflow to be saved with the image"
interp_mode = "Interpolation mode"
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
fp32 = "Whether or not to use full float32 precision"
precision = "Precision to use"
tiled = "Processing using overlapping tiles (reduce memory consumption)"
detect_res = "Pixel resolution for detection"
image_res = "Pixel resolution for output image"
safe_mode = "Whether or not to use safe mode"
scribble_mode = "Whether or not to use scribble mode"
scale_factor = "The factor by which to scale"
blend_alpha = (
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
)
num_1 = "The first number"
num_2 = "The second number"
mask = "The mask to use for the operation"
board = "The board to save the image to"
image = "The image to process"
tile_size = "Tile size"
inclusive_low = "The inclusive low value"
exclusive_high = "The exclusive high value"
decimal_places = "The number of decimal places to round to"
freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."

View File

@ -0,0 +1,16 @@
from pydantic import BaseModel, Field
from invokeai.app.shared.fields import FieldDescriptions
class FreeUConfig(BaseModel):
"""
Configuration for the FreeU hyperparameters.
- https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu
- https://github.com/ChenyangSi/FreeU
"""
s1: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_s1)
s2: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_s2)
b1: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_b1)
b2: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_b2)

View File

@ -59,7 +59,7 @@ def thin_one_time(x, kernels):
def lvmin_thin(x, prunings=True):
y = x
for i in range(32):
for _i in range(32):
y, is_done = thin_one_time(y, lvmin_kernels)
if is_done:
break

View File

@ -21,11 +21,11 @@ def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
# sanity check make sure the graph is at least reasonably shaped
if (
type(graph) is not dict
not isinstance(graph, dict)
or "nodes" not in graph
or type(graph["nodes"]) is not dict
or not isinstance(graph["nodes"], dict)
or "edges" not in graph
or type(graph["edges"]) is not list
or not isinstance(graph["edges"], list)
):
# something has gone terribly awry, return an empty dict
return None

View File

@ -88,7 +88,7 @@ class PromptFormatter:
t2i = self.t2i
opt = self.opt
switches = list()
switches = []
switches.append(f'"{opt.prompt}"')
switches.append(f"-s{opt.steps or t2i.steps}")
switches.append(f"-W{opt.width or t2i.width}")

View File

@ -88,7 +88,7 @@ class Txt2Mask(object):
provided image and returns a SegmentedGrayscale object in which the brighter
pixels indicate where the object is inferred to be.
"""
if type(image) is str:
if isinstance(image, str):
image = Image.open(image).convert("RGB")
image = ImageOps.exif_transpose(image)

View File

@ -40,7 +40,7 @@ class InitImageResizer:
(rw, rh) = (int(scale * im.width), int(scale * im.height))
# round everything to multiples of 64
width, height, rw, rh = map(lambda x: x - x % 64, (width, height, rw, rh))
width, height, rw, rh = (x - x % 64 for x in (width, height, rw, rh))
# no resize necessary, but return a copy
if im.width == width and im.height == height:

View File

@ -32,7 +32,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
from huggingface_hub import HfFolder
from huggingface_hub import login as hf_hub_login
from omegaconf import OmegaConf
from pydantic.error_wrappers import ValidationError
from pydantic import ValidationError
from tqdm import tqdm
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@ -197,7 +197,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
def download_conversion_models():
target_dir = config.models_path / "core/convert"
kwargs = dict() # for future use
kwargs = {} # for future use
try:
logger.info("Downloading core tokenizers and text encoders")
@ -252,26 +252,26 @@ def download_conversion_models():
def download_realesrgan():
logger.info("Installing ESRGAN Upscaling models...")
URLs = [
dict(
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
description="RealESRGAN_x4plus.pth",
),
dict(
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
description="RealESRGAN_x4plus_anime_6B.pth",
),
dict(
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
dest="core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
description="ESRGAN_SRx4_DF2KOST_official.pth",
),
dict(
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
description="RealESRGAN_x2plus.pth",
),
{
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
"description": "RealESRGAN_x4plus.pth",
},
{
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
"description": "RealESRGAN_x4plus_anime_6B.pth",
},
{
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"dest": "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"description": "ESRGAN_SRx4_DF2KOST_official.pth",
},
{
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
"dest": "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
"description": "RealESRGAN_x2plus.pth",
},
]
for model in URLs:
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
@ -680,7 +680,7 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
if program_opts.default_only
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
if program_opts.yes_to_all
else list(),
else [],
)

View File

@ -38,6 +38,7 @@ SAMPLER_CHOICES = [
"k_heun",
"k_lms",
"plms",
"lcm",
]
PRECISION_CHOICES = [

View File

@ -123,8 +123,6 @@ class MigrateTo3(object):
logger.error(str(e))
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
for f in files:
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
# let them be copied as part of a tree copy operation
@ -143,8 +141,6 @@ class MigrateTo3(object):
logger.error(str(e))
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def migrate_support_models(self):
"""
@ -182,10 +178,10 @@ class MigrateTo3(object):
"""
dest_directory = self.dest_models
kwargs = dict(
cache_dir=self.root_directory / "models/hub",
kwargs = {
"cache_dir": self.root_directory / "models/hub",
# local_files_only = True
)
}
try:
logger.info("Migrating core tokenizers and text encoders")
target_dir = dest_directory / "core" / "convert"
@ -316,11 +312,11 @@ class MigrateTo3(object):
dest_dir = self.dest_models
cache = self.root_directory / "models/hub"
kwargs = dict(
cache_dir=cache,
safety_checker=None,
kwargs = {
"cache_dir": cache,
"safety_checker": None,
# local_files_only = True,
)
}
owner, repo_name = repo_id.split("/")
model_name = model_name or repo_name

View File

@ -120,7 +120,7 @@ class ModelInstall(object):
be treated uniformly. It also sorts the models alphabetically
by their name, to improve the display somewhat.
"""
model_dict = dict()
model_dict = {}
# first populate with the entries in INITIAL_MODELS.yaml
for key, value in self.datasets.items():
@ -134,7 +134,7 @@ class ModelInstall(object):
model_dict[key] = model_info
# supplement with entries in models.yaml
installed_models = [x for x in self.mgr.list_models()]
installed_models = list(self.mgr.list_models())
for md in installed_models:
base = md["base_model"]
@ -176,7 +176,7 @@ class ModelInstall(object):
# logic here a little reversed to maintain backward compatibility
def starter_models(self, all_models: bool = False) -> Set[str]:
models = set()
for key, value in self.datasets.items():
for key, _value in self.datasets.items():
name, base, model_type = ModelManager.parse_key(key)
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
models.add(key)
@ -184,7 +184,7 @@ class ModelInstall(object):
def recommended_models(self) -> Set[str]:
starters = self.starter_models(all_models=True)
return set([x for x in starters if self.datasets[x].get("recommended", False)])
return {x for x in starters if self.datasets[x].get("recommended", False)}
def default_model(self) -> str:
starters = self.starter_models()
@ -234,7 +234,7 @@ class ModelInstall(object):
"""
if not models_installed:
models_installed = dict()
models_installed = {}
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
@ -252,10 +252,14 @@ class ModelInstall(object):
# folders style or similar
elif path.is_dir() and any(
[
(path / x).exists()
for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"}
]
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"pytorch_lora_weights.safetensors",
}
):
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
@ -357,7 +361,7 @@ class ModelInstall(object):
for suffix in ["safetensors", "bin"]:
if f"{prefix}pytorch_lora_weights.{suffix}" in files:
location = self._download_hf_model(
repo_id, ["pytorch_lora_weights.bin"], staging, subfolder=subfolder
repo_id, [f"pytorch_lora_weights.{suffix}"], staging, subfolder=subfolder
) # LoRA
break
elif (
@ -427,17 +431,17 @@ class ModelInstall(object):
rel_path = self.relative_to_root(path, self.config.models_path)
attributes = dict(
path=str(rel_path),
description=str(description),
model_format=info.format,
)
attributes = {
"path": str(rel_path),
"description": str(description),
"model_format": info.format,
}
legacy_conf = None
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
attributes.update(
dict(
variant=info.variant_type,
)
{
"variant": info.variant_type,
}
)
if info.format == "checkpoint":
try:
@ -468,7 +472,7 @@ class ModelInstall(object):
)
if legacy_conf:
attributes.update(dict(config=str(legacy_conf)))
attributes.update({"config": str(legacy_conf)})
return attributes
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
@ -513,7 +517,7 @@ class ModelInstall(object):
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
_, name = repo_id.split("/")
location = staging / name
paths = list()
paths = []
for filename in files:
filePath = Path(filename)
p = hf_download_with_resume(

View File

@ -130,7 +130,9 @@ class IPAttnProcessor2_0(torch.nn.Module):
assert ip_adapter_image_prompt_embeds is not None
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
for ipa_embed, ipa_weights, scale in zip(
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The token_len dimensions should match.

View File

@ -56,7 +56,7 @@ class PerceiverAttention(nn.Module):
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
b, L, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
@ -72,7 +72,7 @@ class PerceiverAttention(nn.Module):
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
out = out.permute(0, 2, 1, 3).reshape(b, L, -1)
return self.to_out(out)

View File

@ -269,7 +269,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
for _i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
@ -1223,7 +1223,7 @@ def download_from_original_stable_diffusion_ckpt(
# scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)
@ -1664,7 +1664,7 @@ def download_controlnet_from_original_ckpt(
# scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)

View File

@ -12,6 +12,8 @@ from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig
from .models.lora import LoRAModel
"""
@ -102,7 +104,7 @@ class ModelPatcher:
loras: List[Tuple[LoRAModel, float]],
prefix: str,
):
original_weights = dict()
original_weights = {}
try:
with torch.no_grad():
for lora, lora_weight in loras:
@ -164,6 +166,15 @@ class ModelPatcher:
init_tokens_count = None
new_tokens_added = None
# TODO: This is required since Transformers 4.32 see
# https://github.com/huggingface/transformers/pull/25088
# More information by NVIDIA:
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
# This value might need to be changed in the future and take the GPUs model into account as there seem
# to be ideal values for different GPUS. This value is temporary!
# For references to the current discussion please see https://github.com/invoke-ai/InvokeAI/pull/4817
pad_to_multiple_of = 8
try:
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
@ -173,7 +184,7 @@ class ModelPatcher:
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings
def _get_trigger(ti_name, index):
trigger = ti_name
@ -188,7 +199,7 @@ class ModelPatcher:
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
# modify text_encoder
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
model_embeddings = text_encoder.get_input_embeddings()
for ti_name, ti in ti_list:
@ -220,7 +231,7 @@ class ModelPatcher:
finally:
if init_tokens_count and new_tokens_added:
text_encoder.resize_token_embeddings(init_tokens_count)
text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of)
@classmethod
@contextmanager
@ -231,7 +242,7 @@ class ModelPatcher:
):
skipped_layers = []
try:
for i in range(clip_skip):
for _i in range(clip_skip):
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
yield
@ -240,6 +251,25 @@ class ModelPatcher:
while len(skipped_layers) > 0:
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
@classmethod
@contextmanager
def apply_freeu(
cls,
unet: UNet2DConditionModel,
freeu_config: Optional[FreeUConfig] = None,
):
did_apply_freeu = False
try:
if freeu_config is not None:
unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2)
did_apply_freeu = True
yield
finally:
if did_apply_freeu:
unet.disable_freeu()
class TextualInversionModel:
embedding: torch.Tensor # [n, 768]|[n, 1280]
@ -294,7 +324,7 @@ class TextualInversionManager(BaseTextualInversionManager):
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = dict()
self.pad_tokens = {}
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
@ -355,10 +385,10 @@ class ONNXModelPatcher:
if not isinstance(model, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
orig_weights = dict()
orig_weights = {}
try:
blended_loras = dict()
blended_loras = {}
for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items():
@ -374,7 +404,7 @@ class ONNXModelPatcher:
else:
blended_loras[layer_key] = layer_weight
node_names = dict()
node_names = {}
for node in model.nodes.values():
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name

View File

@ -66,11 +66,13 @@ class CacheStats(object):
class ModelLocker(object):
"Forward declaration"
pass
class ModelCache(object):
"Forward declaration"
pass
@ -132,7 +134,7 @@ class ModelCache(object):
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
"""
self.model_infos: Dict[str, ModelBase] = dict()
self.model_infos: Dict[str, ModelBase] = {}
# allow lazy offloading only when vram cache enabled
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self.precision: torch.dtype = precision
@ -147,8 +149,8 @@ class ModelCache(object):
# used for stats collection
self.stats = None
self._cached_models = dict()
self._cache_stack = list()
self._cached_models = {}
self._cache_stack = []
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:

View File

@ -26,5 +26,5 @@ def skip_torch_weight_init():
yield None
finally:
for torch_module, saved_function in zip(torch_modules, saved_functions):
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
torch_module.reset_parameters = saved_function

View File

@ -363,7 +363,7 @@ class ModelManager(object):
else:
return
self.models = dict()
self.models = {}
for model_key, model_config in config.items():
if model_key.startswith("_"):
continue
@ -374,7 +374,7 @@ class ModelManager(object):
self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary
self.cache_keys = dict()
self.cache_keys = {}
# add controlnet, lora and textual_inversion models from disk
self.scan_models_directory()
@ -655,7 +655,7 @@ class ModelManager(object):
"""
# TODO: redo
for model_dict in self.list_models():
for model_name, model_info in model_dict.items():
for _model_name, model_info in model_dict.items():
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
print(line)
@ -902,7 +902,7 @@ class ModelManager(object):
"""
Write current configuration out to the indicated file.
"""
data_to_save = dict()
data_to_save = {}
data_to_save["__metadata__"] = self.config_meta.model_dump()
for model_key, model_config in self.models.items():
@ -1034,7 +1034,7 @@ class ModelManager(object):
self.ignore = ignore
def on_search_started(self):
self.new_models_found = dict()
self.new_models_found = {}
def on_model_found(self, model: Path):
if model not in self.ignore:
@ -1106,7 +1106,7 @@ class ModelManager(object):
# avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall
successfully_installed = dict()
successfully_installed = {}
installer = ModelInstall(
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self

View File

@ -92,7 +92,7 @@ class ModelMerger(object):
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_paths = list()
model_paths = []
config = self.manager.app_config
base_model = BaseModelType(base_model)
vae = None
@ -124,13 +124,13 @@ class ModelMerger(object):
dump_path = (dump_path / merged_model_name).as_posix()
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
attributes = dict(
path=dump_path,
description=f"Merge of models {', '.join(model_names)}",
model_format="diffusers",
variant=ModelVariantType.Normal.value,
vae=vae,
)
attributes = {
"path": dump_path,
"description": f"Merge of models {', '.join(model_names)}",
"model_format": "diffusers",
"variant": ModelVariantType.Normal.value,
"vae": vae,
}
return self.manager.add_model(
merged_model_name,
base_model=base_model,

View File

@ -183,12 +183,13 @@ class ModelProbe(object):
if model:
class_name = model.__class__.__name__
else:
for suffix in ["bin", "safetensors"]:
if (folder_path / f"learned_embeds.{suffix}").exists():
return ModelType.TextualInversion
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
return ModelType.Lora
if (folder_path / "unet/model.onnx").exists():
return ModelType.ONNX
if (folder_path / "learned_embeds.bin").exists():
return ModelType.TextualInversion
if (folder_path / "pytorch_lora_weights.bin").exists():
return ModelType.Lora
if (folder_path / "image_encoder.txt").exists():
return ModelType.IPAdapter
@ -236,7 +237,7 @@ class ModelProbe(object):
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise "The model {model_name} is potentially infected by malware. Aborting import."
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
# ##################################################3

View File

@ -59,7 +59,7 @@ class ModelSearch(ABC):
for root, dirs, files in os.walk(path, followlinks=True):
if str(Path(root).name).startswith("."):
self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
if any(Path(root).is_relative_to(x) for x in self._pruned_paths):
continue
self._items_scanned += len(dirs) + len(files)
@ -69,16 +69,14 @@ class ModelSearch(ABC):
self._scanned_dirs.add(path)
continue
if any(
[
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
}
]
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
}
):
try:
self.on_model_found(path)

View File

@ -97,8 +97,8 @@ MODEL_CLASSES = {
# },
}
MODEL_CONFIGS = list()
OPENAPI_MODEL_CONFIGS = list()
MODEL_CONFIGS = []
OPENAPI_MODEL_CONFIGS = []
class OpenAPIModelInfoBase(BaseModel):
@ -109,7 +109,7 @@ class OpenAPIModelInfoBase(BaseModel):
model_config = ConfigDict(protected_namespaces=())
for base_model, models in MODEL_CLASSES.items():
for _base_model, models in MODEL_CLASSES.items():
for model_type, model_class in models.items():
model_configs = set(model_class._get_configs().values())
model_configs.discard(None)
@ -133,7 +133,7 @@ for base_model, models in MODEL_CLASSES.items():
def get_model_config_enums():
enums = list()
enums = []
for model_config in MODEL_CONFIGS:
if hasattr(inspect, "get_annotations"):

View File

@ -153,7 +153,7 @@ class ModelBase(metaclass=ABCMeta):
else:
res_type = sys.modules["diffusers"]
res_type = getattr(res_type, "pipelines")
res_type = res_type.pipelines
for subtype in subtypes:
res_type = getattr(res_type, subtype)
@ -164,7 +164,7 @@ class ModelBase(metaclass=ABCMeta):
with suppress(Exception):
return cls.__configs
configs = dict()
configs = {}
for name in dir(cls):
if name.startswith("__"):
continue
@ -246,8 +246,8 @@ class DiffusersModel(ModelBase):
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = dict()
self.child_sizes: Dict[str, int] = dict()
self.child_types: Dict[str, Type] = {}
self.child_sizes: Dict[str, int] = {}
try:
config_data = DiffusionPipeline.load_config(self.model_path)
@ -326,8 +326,8 @@ def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, vari
all_files = os.listdir(model_path)
all_files = [f for f in all_files if os.path.isfile(os.path.join(model_path, f))]
fp16_files = set([f for f in all_files if ".fp16." in f or ".fp16-" in f])
bit8_files = set([f for f in all_files if ".8bit." in f or ".8bit-" in f])
fp16_files = {f for f in all_files if ".fp16." in f or ".fp16-" in f}
bit8_files = {f for f in all_files if ".8bit." in f or ".8bit-" in f}
other_files = set(all_files) - fp16_files - bit8_files
if variant is None:
@ -413,7 +413,7 @@ def _calc_onnx_model_by_data(model) -> int:
def _fast_safetensors_reader(path: str):
checkpoint = dict()
checkpoint = {}
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), "little")
@ -483,7 +483,7 @@ class IAIOnnxRuntimeModel:
class _tensor_access:
def __init__(self, model):
self.model = model
self.indexes = dict()
self.indexes = {}
for idx, obj in enumerate(self.model.proto.graph.initializer):
self.indexes[obj.name] = idx
@ -524,7 +524,7 @@ class IAIOnnxRuntimeModel:
class _access_helper:
def __init__(self, raw_proto):
self.indexes = dict()
self.indexes = {}
self.raw_proto = raw_proto
for idx, obj in enumerate(raw_proto):
self.indexes[obj.name] = idx
@ -549,7 +549,7 @@ class IAIOnnxRuntimeModel:
return self.indexes.keys()
def values(self):
return [obj for obj in self.raw_proto]
return list(self.raw_proto)
def __init__(self, model_path: str, provider: Optional[str]):
self.path = model_path

View File

@ -104,7 +104,7 @@ class ControlNetModel(ModelBase):
return ControlNetModelFormat.Diffusers
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]]):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "pth"]):
return ControlNetModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {path}")

View File

@ -68,11 +68,12 @@ class LoRAModel(ModelBase):
raise ModelNotFoundException()
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "pytorch_lora_weights.bin")):
return LoRAModelFormat.Diffusers
for ext in ["safetensors", "bin"]:
if os.path.exists(os.path.join(path, f"pytorch_lora_weights.{ext}")):
return LoRAModelFormat.Diffusers
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return LoRAModelFormat.LyCORIS
raise InvalidModelException(f"Not a valid model: {path}")
@ -86,8 +87,10 @@ class LoRAModel(ModelBase):
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
# TODO: add diffusers lora when it stabilizes a bit
raise NotImplementedError("Diffusers lora not supported")
for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder
path = Path(model_path, f"pytorch_lora_weights.{ext}")
if path.exists():
return path
else:
return model_path
@ -459,7 +462,7 @@ class LoRAModelRaw: # (torch.nn.Module):
dtype: Optional[torch.dtype] = None,
):
# TODO: try revert if exception?
for key, layer in self.layers.items():
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
@ -496,7 +499,7 @@ class LoRAModelRaw: # (torch.nn.Module):
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict = dict()
new_state_dict = {}
for full_key, value in state_dict.items():
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
@ -542,7 +545,7 @@ class LoRAModelRaw: # (torch.nn.Module):
model = cls(
name=file_path.stem, # TODO:
layers=dict(),
layers={},
)
if file_path.suffix == ".safetensors":
@ -590,12 +593,12 @@ class LoRAModelRaw: # (torch.nn.Module):
@staticmethod
def _group_state(state_dict: dict):
state_dict_groupped = dict()
state_dict_groupped = {}
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = dict()
state_dict_groupped[stem] = {}
state_dict_groupped[stem][leaf] = value
return state_dict_groupped

View File

@ -110,7 +110,7 @@ class StableDiffusion1Model(DiffusersModel):
return StableDiffusion1ModelFormat.Diffusers
if os.path.isfile(model_path):
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return StableDiffusion1ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}")
@ -221,7 +221,7 @@ class StableDiffusion2Model(DiffusersModel):
return StableDiffusion2ModelFormat.Diffusers
if os.path.isfile(model_path):
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
if any(model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return StableDiffusion2ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}")

View File

@ -71,7 +71,7 @@ class TextualInversionModel(ModelBase):
return None # diffusers-ti
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]):
return None
raise InvalidModelException(f"Not a valid model: {path}")

View File

@ -89,7 +89,7 @@ class VaeModel(ModelBase):
return VaeModelFormat.Diffusers
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
return VaeModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {path}")

View File

@ -0,0 +1,323 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Configuration definitions for image generation models.
Typical usage:
from invokeai.backend.model_manager import ModelConfigFactory
raw = dict(path='models/sd-1/main/foo.ckpt',
name='foo',
base='sd-1',
type='main',
config='configs/stable-diffusion/v1-inference.yaml',
variant='normal',
format='checkpoint'
)
config = ModelConfigFactory.make_config(raw)
print(config.name)
Validation errors will raise an InvalidModelConfigException error.
"""
from enum import Enum
from typing import Literal, Optional, Type, Union
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""
class BaseModelType(str, Enum):
"""Base model type."""
Any = "any"
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
# Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
"""Model type."""
ONNX = "onnx"
Main = "main"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter"
class SubModelType(str, Enum):
"""Submodel type."""
UNet = "unet"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Vae = "vae"
VaeDecoder = "vae_decoder"
VaeEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
class ModelVariantType(str, Enum):
"""Variant type."""
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class ModelFormat(str, Enum):
"""Storage format of model."""
Diffusers = "diffusers"
Checkpoint = "checkpoint"
Lycoris = "lycoris"
Onnx = "onnx"
Olive = "olive"
EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder"
InvokeAI = "invokeai"
class SchedulerPredictionType(str, Enum):
"""Scheduler prediction type."""
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelConfigBase(BaseModel):
"""Base class for model configuration information."""
path: str
name: str
base: BaseModelType
type: ModelType
format: ModelFormat
key: str = Field(description="unique key for model", default="<NOKEY>")
original_hash: Optional[str] = Field(
description="original fasthash of model contents", default=None
) # this is assigned at install time and will not change
current_hash: Optional[str] = Field(
description="current fasthash of model contents", default=None
) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(default=None)
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
model_config = ConfigDict(
use_enum_values=False,
validate_assignment=True,
)
def update(self, attributes: dict):
"""Update the object with fields in dict."""
for key, value in attributes.items():
setattr(self, key, value) # may raise a validation error
class _CheckpointConfig(ModelConfigBase):
"""Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config: str = Field(description="path to the checkpoint model config file")
class _DiffusersConfig(ModelConfigBase):
"""Model config for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class LoRAConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
type: Literal[ModelType.Lora] = ModelType.Lora
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
class VaeCheckpointConfig(ModelConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class VaeDiffusersConfig(ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetDiffusersConfig(_DiffusersConfig):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetCheckpointConfig(_CheckpointConfig):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class TextualInversionConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
class _MainConfig(ModelConfigBase):
"""Model config for main models."""
vae: Optional[str] = Field(default=None)
variant: ModelVariantType = ModelVariantType.Normal
ztsnr_training: bool = False
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
# Note that we do not need prediction_type or upcast_attention here
# because they are provided in the checkpoint's own config file.
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
"""Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD1Config(_MainConfig):
"""Model config for ONNX format models based on sd-1."""
type: Literal[ModelType.ONNX] = ModelType.ONNX
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
base: Literal[BaseModelType.StableDiffusion1] = BaseModelType.StableDiffusion1
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD2Config(_MainConfig):
"""Model config for ONNX format models based on sd-2."""
type: Literal[ModelType.ONNX] = ModelType.ONNX
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
# No yaml config file for ONNX, so these are part of config
base: Literal[BaseModelType.StableDiffusion2] = BaseModelType.StableDiffusion2
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
upcast_attention: bool = True
class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models."""
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
format: Literal[ModelFormat.InvokeAI]
class CLIPVisionDiffusersConfig(ModelConfigBase):
"""Model config for ClipVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers]
class T2IConfig(ModelConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers]
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
_ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
Field(discriminator="format"),
]
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
AnyModelConfig = Union[
_MainModelConfig,
_ONNXConfig,
_VaeConfig,
_ControlNetConfig,
LoRAConfig,
TextualInversionConfig,
IPAdapterConfig,
CLIPVisionDiffusersConfig,
T2IConfig,
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
# This is a known issue. Please see:
# https://github.com/tiangolo/fastapi/discussions/9761 and
# https://github.com/tiangolo/fastapi/discussions/9287
# AnyModelConfig = Annotated[
# Union[
# _MainModelConfig,
# _ONNXConfig,
# _VaeConfig,
# _ControlNetConfig,
# LoRAConfig,
# TextualInversionConfig,
# IPAdapterConfig,
# CLIPVisionDiffusersConfig,
# T2IConfig,
# ],
# Field(discriminator="type"),
# ]
class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects."""
@classmethod
def make_config(
cls,
model_data: Union[dict, AnyModelConfig],
key: Optional[str] = None,
dest_class: Optional[Type] = None,
) -> AnyModelConfig:
"""
Return the appropriate config object from raw dict values.
:param model_data: A raw dict corresponding the obect fields to be
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
object, which will be passed through unchanged.
:param dest_class: The config class to be returned. If not provided, will
be selected automatically.
"""
if isinstance(model_data, ModelConfigBase):
model = model_data
elif dest_class:
model = dest_class.validate_python(model_data)
else:
model = AnyModelConfigValidator.validate_python(model_data)
if key:
model.key = key
return model

View File

@ -0,0 +1,66 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Fast hashing of diffusers and checkpoint-style models.
Usage:
from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f'
"""
import hashlib
import os
from pathlib import Path
from typing import Dict, Union
from imohash import hashfile
class FastModelHash(object):
"""FastModelHash obect provides one public class method, hash()."""
@classmethod
def hash(cls, model_location: Union[str, Path]) -> str:
"""
Return hexdigest string for model located at model_location.
:param model_location: Path to the model
"""
model_location = Path(model_location)
if model_location.is_file():
return cls._hash_file(model_location)
elif model_location.is_dir():
return cls._hash_dir(model_location)
else:
raise OSError(f"Not a valid file or directory: {model_location}")
@classmethod
def _hash_file(cls, model_location: Union[str, Path]) -> str:
"""
Fasthash a single file and return its hexdigest.
:param model_location: Path to the model file
"""
# we return md5 hash of the filehash to make it shorter
# cryptographic security not needed here
return hashlib.md5(hashfile(model_location)).hexdigest()
@classmethod
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
components: Dict[str, str] = {}
for root, _dirs, files in os.walk(model_location):
for file in files:
# only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted.
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue
path = (Path(root) / file).as_posix()
fast_hash = cls._hash_file(path)
components.update({path: fast_hash})
# hash all the model hashes together, using alphabetic file order
md5 = hashlib.md5()
for _path, fast_hash in sorted(components.items()):
md5.update(fast_hash.encode("utf-8"))
return md5.hexdigest()

View File

@ -0,0 +1,93 @@
# Copyright (c) 2023 Lincoln D. Stein
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
from hashlib import sha1
from omegaconf import DictConfig, OmegaConf
from pydantic import TypeAdapter
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import (
DuplicateModelException,
ModelRecordServiceSQL,
)
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.util.logging import InvokeAILogger
ModelsValidator = TypeAdapter(AnyModelConfig)
class MigrateModelYamlToDb:
"""
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.2.0)
The class has one externally useful method, migrate(), which scans the
currently models.yaml file and imports all its entries into invokeai.db.
Use this way:
from invokeai.backend.model_manager/migrate_to_db import MigrateModelYamlToDb
MigrateModelYamlToDb().migrate()
"""
config: InvokeAIAppConfig
logger: InvokeAILogger
def __init__(self):
self.config = InvokeAIAppConfig.get_config()
self.config.parse_args()
self.logger = InvokeAILogger.get_logger()
def get_db(self) -> ModelRecordServiceSQL:
"""Fetch the sqlite3 database for this installation."""
db = SqliteDatabase(self.config, self.logger)
return ModelRecordServiceSQL(db)
def get_yaml(self) -> DictConfig:
"""Fetch the models.yaml DictConfig for this installation."""
yaml_path = self.config.model_conf_path
return OmegaConf.load(yaml_path)
def migrate(self):
"""Do the migration from models.yaml to invokeai.db."""
db = self.get_db()
yaml = self.get_yaml()
for model_key, stanza in yaml.items():
if model_key == "__metadata__":
assert (
stanza["version"] == "3.0.0"
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
continue
base_type, model_type, model_name = str(model_key).split("/")
hash = FastModelHash.hash(self.config.models_path / stanza.path)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type)
stanza["name"] = model_name
stanza["original_hash"] = hash
stanza["current_hash"] = hash
new_config = ModelsValidator.validate_python(stanza)
self.logger.info(f"Adding model {model_name} with key {model_key}")
try:
db.add_model(new_key, new_config)
except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database")
def main():
MigrateModelYamlToDb().migrate()
if __name__ == "__main__":
main()

View File

@ -193,6 +193,7 @@ class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
after generation completes. Optional.
"""
attention_map_saver: Optional[AttentionMapSaver]
@ -546,11 +547,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# Handle ControlNet(s) and T2I-Adapter(s)
down_block_additional_residuals = None
mid_block_additional_residual = None
if control_data is not None and t2i_adapter_data is not None:
# TODO(ryand): This is a limitation of the UNet2DConditionModel API, not a fundamental incompatibility
# between ControlNets and T2I-Adapters. We will try to fix this upstream in diffusers.
raise Exception("ControlNet(s) and T2I-Adapter(s) cannot be used simultaneously (yet).")
elif control_data is not None:
down_intrablock_additional_residuals = None
# if control_data is not None and t2i_adapter_data is not None:
# TODO(ryand): This is a limitation of the UNet2DConditionModel API, not a fundamental incompatibility
# between ControlNets and T2I-Adapters. We will try to fix this upstream in diffusers.
# raise Exception("ControlNet(s) and T2I-Adapter(s) cannot be used simultaneously (yet).")
# elif control_data is not None:
if control_data is not None:
down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step(
control_data=control_data,
sample=latent_model_input,
@ -559,7 +562,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=total_step_count,
conditioning_data=conditioning_data,
)
elif t2i_adapter_data is not None:
# elif t2i_adapter_data is not None:
if t2i_adapter_data is not None:
accum_adapter_state = None
for single_t2i_adapter_data in t2i_adapter_data:
# Determine the T2I-Adapter weights for the current denoising step.
@ -584,7 +588,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
for idx, value in enumerate(single_t2i_adapter_data.adapter_state):
accum_adapter_state[idx] += value * t2i_adapter_weight
down_block_additional_residuals = accum_adapter_state
# down_block_additional_residuals = accum_adapter_state
down_intrablock_additional_residuals = accum_adapter_state
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
sample=latent_model_input,
@ -593,8 +598,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=total_step_count,
conditioning_data=conditioning_data,
# extra:
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
)
guidance_scale = conditioning_data.guidance_scale

View File

@ -54,13 +54,13 @@ class Context:
self.clear_requests(cleanup=True)
def register_cross_attention_modules(self, model):
for name, module in get_cross_attention_modules(model, CrossAttentionType.SELF):
for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF):
if name in self.self_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
raise AssertionError(f"name {name} cannot appear more than once")
self.self_cross_attention_module_identifiers.append(name)
for name, module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
if name in self.tokens_cross_attention_module_identifiers:
assert False, f"name {name} cannot appear more than once"
raise AssertionError(f"name {name} cannot appear more than once")
self.tokens_cross_attention_module_identifiers.append(name)
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
@ -170,7 +170,7 @@ class Context:
self.saved_cross_attention_maps = {}
def offload_saved_attention_slices_to_cpu(self):
for key, map_dict in self.saved_cross_attention_maps.items():
for _key, map_dict in self.saved_cross_attention_maps.items():
for offset, slice in map_dict["slices"].items():
map_dict[offset] = slice.to("cpu")
@ -433,7 +433,7 @@ def inject_attention_function(unet, context: Context):
module.identifier = identifier
try:
module.set_attention_slice_wrangler(attention_slice_wrangler)
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier))
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
except AttributeError as e:
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
@ -445,7 +445,7 @@ def remove_attention_function(unet):
cross_attention_modules = get_cross_attention_modules(
unet, CrossAttentionType.TOKENS
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
for identifier, module in cross_attention_modules:
for _identifier, module in cross_attention_modules:
try:
# clear wrangler callback
module.set_attention_slice_wrangler(None)

View File

@ -56,7 +56,7 @@ class AttentionMapSaver:
merged = None
for key, maps in self.collated_maps.items():
for _key, maps in self.collated_maps.items():
# maps has shape [(H*W), N] for N tokens
# but we want [N, H, W]
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))

Some files were not shown because too many files have changed in this diff Show More