Compare commits

..

152 Commits

Author SHA1 Message Date
3c50448ccf Merge branch 'main' into dev/pytorch2 2023-04-06 21:47:46 -04:00
76bcd4d44f Fix typo (#3133)
'hotdot' to 'hotdog'; the world's least important PR :)
2023-04-07 12:38:05 +12:00
50f5e1bc83 Fix typo
'hotdot' to 'hotdog'; the world's least important PR :)
2023-04-06 16:47:57 -07:00
85b020f76c [nodes] Add latent nodes, storage, and fix iteration bugs (#3091)
* Add latents nodes.
* Fix iteration expansion.
* Add collection generator nodes, math nodes.
* Add noise node.
* Add some graph debug commands to the CLI.
* Fix negative id linking in CLI.
* Fix a CLI bug with multiple links per node.
2023-04-06 04:06:05 +00:00
a7833cc9a9 [api] Add models router and list model API. 2023-04-05 23:59:07 -04:00
919294e977 fix build-container.yml (#3117)
Add permission go write packages to GITHUB_TOKEN
2023-04-06 00:25:00 +02:00
7640acfb1f update build-container.yml
- add packages write permission
2023-04-05 15:44:26 +02:00
aed9ecef2a feat(nodes): add thumbnail generation to DiskImageStorage 2023-04-05 08:22:23 +10:00
18cddd7972 Right link on pytorch installer for linux rocm (#3084)
Right link on pytorch installer for linux rocm
2023-04-04 17:40:42 -04:00
e6b25f4ae3 Merge branch 'main' into patch-1 2023-04-04 17:40:12 -04:00
d1c0050e65 fix(nodes): fix typo in list_sessions handler (#3109)
The typo accidentally did not affect functionality; when `query==""`, it
`search()`ed but found everything due to empty query, then paginated
results, so it worked the same as `list()`.

Still fix it
2023-04-03 21:24:48 -04:00
ecdfa136a0 fix(nodes): fix typo in list_sessions handler 2023-04-04 00:34:32 +10:00
5cd513ee63 [deps] bump compel version to fix crash on invalid (auto111) syntax (#3107)
currently if users input eg `happy (camper:0.3)` it gets parsed
incorrectly, which causes crashes if it's in the negative prompt. bump
to compel 1.0.5 fixes the parser to avoid this (note the weight is
parsed as plain text, it's not converted to proper invoke syntax)
2023-04-04 02:30:17 +12:00
ab45086546 Merge branch 'main' into deps_bump_compel 2023-04-04 02:05:40 +12:00
77ba7359f4 fix(nodes): commit changes to db 2023-04-03 19:09:49 +10:00
8cbe2e14d9 bump compel version to fix on invalid (auto111) syntax 2023-04-03 10:37:01 +02:00
ee86eedf01 Right link on pytorch installer for linux rocm
Right link on pytorch installer for linux rocm
2023-03-31 17:22:00 -03:00
c4e6511a59 Add support for yet another TI embedding format (main version) (#3050)
- This PR adds support for embedding files that contain a single key
"emb_params". The only example I know of this format is the
"EasyNegative" embedding on HuggingFace, but there are certainly others.

- This PR also adds support for loading embedding files that have been
saved in safetensors format.

- It also cleans up the code so that the logic of probing for and
selecting the right format parser is clear.

- This is the same as #3045, which is on the 2.3 branch.
2023-03-31 03:57:57 -04:00
44843be4c8 Merge branch 'main' into enhance/support-another-embedding-format-main 2023-03-30 23:16:52 -04:00
054e963bef add basic autocomplete functionality to node cli (#3035)
- Commands, invocations and their parameters will now autocomplete using
introspection.
- Two types of parameter *arguments* will also autocomplete:
  - --sampler_name  will autocomplete the scheduler name
  - --model will autocomplete the model name
- There don't seem to be commands for reading/writing image files yet,
so path autocompletion is not implemented
2023-03-30 08:25:36 -04:00
afb66a7884 Merge branch 'main' into feat/node-cli-autocompleter 2023-03-30 07:51:51 -04:00
b9df9e26f2 Merge branch 'main' into enhance/support-another-embedding-format-main 2023-03-30 07:51:23 -04:00
25ae36ceb5 I18n build mode (#3051)
Add build mode option to bundle english translation with UI
2023-03-29 22:26:45 -04:00
3ae8daedaa Merge branch 'main' into i18n-build-mode 2023-03-29 22:26:17 -04:00
e11c1d66ab handle multiple tokens and embeddings in single file 2023-03-29 22:05:06 -04:00
b913e1e11e improve importation and conversion of legacy checkpoint files (#3053)
A long-standing issue with importing legacy checkpoints (both ckpt and
safetensors) is that the user has to identify the correct config file,
either by providing its path or by selecting which type of model the
checkpoint is (e.g. "v1 inpainting"). In addition, some users wish to
provide custom VAEs for use with the model. Currently this is done in
the WebUI by importing the model, editing it, and then typing in the
path to the VAE.

## Model configuration file selection

To improve the user experience, the model manager's `heuristic_import()`
method has been enhanced as follows:

1. When initially called, the caller can pass a config file path, in
which case it will be used.

2. If no config file provided, the method looks for a .yaml file in the
same directory as the model which bears the same basename. e.g.
```
   my-new-model.safetensors
   my-new-model.yaml
```
The yaml file is then used as the configuration file for importation and
conversion.

3. If no such file is found, then the method opens up the checkpoint and
probes it to determine whether it is V1, V1-inpaint or V2. If it is a V1
format, then the appropriate v1-inference.yaml config file is used.
Unfortunately there are two V2 variants that cannot be distinguished by
introspection.

4. If the probe algorithm is unable to determine the model type, then
its last-ditch effort is to execute an optional callback function that
can be provided by the caller. This callback, named
`config_file_callback` receives the path to the legacy checkpoint and
returns the path to the config file to use. The CLI uses to put up a
multiple choice prompt to the user. The WebUI **could** use this to
prompt the user to choose from a radio-button selection.

5. If the config file cannot be determined, then the import is
abandoned.

## Custom VAE Selection

The user can attach a custom VAE to the imported and converted model by
copying the desired VAE into the same directory as the file to be
imported, and giving it the same basename. E.g.:

```
    my-new-model.safetensors
    my-new-model.vae.pt
```

For this to work, the VAE must end with ".vae.pt", ".vae.ckpt", or
".vae.safetensors". The indicated VAE will be converted into diffusers
format and stored with the converted models file, so the ".pt" file can
be deleted after conversion.

No facility is currently provided to swap a diffusers VAE at import
time, but this can be done after the fact using the WebUI and CLI's
model editing functions.

Note that this is the same fix that was applied to the 2.3 branch in
#3043 . This applies to `main`.
2023-03-29 17:22:15 -04:00
3c4b6d5735 Merge branch 'main' into enhance/heuristic-import-improvements 2023-03-29 16:54:43 -04:00
e6123eac19 Merge branch 'main' into i18n-build-mode 2023-03-29 05:33:14 -07:00
30ca25897e Fix bugs in online ckpt conversion of 2.0 models (#3057)
## Enable the on-the-fly conversion of models based on SD 2.0/2.1 into
diffusers

This commit fixes bugs related to the on-the-fly conversion and loading
of legacy checkpoint models built on SD-2.0 base.

- When legacy checkpoints built on SD-2.0 models were converted
on-the-fly using --ckpt_convert, generation would crash with a precision
incompatibility error. This problem has been found and fixed.
2023-03-28 23:34:53 -04:00
abaee6b9ed Merge branch 'main' into feat/node-cli-autocompleter 2023-03-28 23:32:10 -04:00
4d7c9e1ab7 Merge branch 'main' into bugfix/convert-2.0-models 2023-03-28 23:01:36 -04:00
cc5687f26c [nodes] downgrade fastapi+uvicorn to fix openapi schema 2023-03-28 22:53:20 -04:00
cdb3616dca Merge branch 'main' into enhance/support-another-embedding-format-main 2023-03-28 21:03:06 -04:00
78e76f26f9 Merge branch 'main' into i18n-build-mode 2023-03-28 11:04:32 -04:00
9a7580dedd fix bugs in online ckpt conversion of 2.0 models
This commit fixes bugs related to the on-the-fly conversion and loading of
legacy checkpoint models built on SD-2.0 base.

- When legacy checkpoints built on SD-2.0 models were converted
  on-the-fly using --ckpt_convert, generation would crash with a
  precision incompatibility error.
2023-03-28 00:17:20 -04:00
dc2da8cff4 Doc: updating ROCm version in documentation (#3041)
The Pytorch ROCm version in the documentation in outdated (`rocm5.2`)
which leads to errors during the installation of InvokeAI.

This PR updates the documentation with the latest Pytorch ROCm `5.4.2`
version.
2023-03-27 22:37:43 -04:00
019a9f0329 address change requests in PR
1. Prompt has changed to "invoke> ".
2. Function to initialize the autocompleter has been renamed "set_autocompleter()"
2023-03-27 12:20:24 -04:00
fe5d9ad171 improve importation and conversion of legacy checkpoint files
A long-standing issue with importing legacy checkpoints (both ckpt and
safetensors) is that the user has to identify the correct config file,
either by providing its path or by selecting which type of model the
checkpoint is (e.g. "v1 inpainting"). In addition, some users wish to
provide custom VAEs for use with the model. Currently this is done in
the WebUI by importing the model, editing it, and then typing in the
path to the VAE.

To improve the user experience, the model manager's
`heuristic_import()` method has been enhanced as follows:

1. When initially called, the caller can pass a config file path, in
which case it will be used.

2. If no config file provided, the method looks for a .yaml file in the
same directory as the model which bears the same basename. e.g.
```
   my-new-model.safetensors
   my-new-model.yaml
```
   The yaml file is then used as the configuration file for
   importation and conversion.

3. If no such file is found, then the method opens up the checkpoint
   and probes it to determine whether it is V1, V1-inpaint or V2.
   If it is a V1 format, then the appropriate v1-inference.yaml config
   file is used. Unfortunately there are two V2 variants that cannot be
   distinguished by introspection.

4. If the probe algorithm is unable to determine the model type, then its
   last-ditch effort is to execute an optional callback function that can
   be provided by the caller. This callback, named `config_file_callback`
   receives the path to the legacy checkpoint and returns the path to the
   config file to use. The CLI uses to put up a multiple choice prompt to
   the user. The WebUI **could** use this to prompt the user to choose
   from a radio-button selection.

5. If the config file cannot be determined, then the import is abandoned.

The user can attach a custom VAE to the imported and converted model
by copying the desired VAE into the same directory as the file to be
imported, and giving it the same basename. E.g.:

```
    my-new-model.safetensors
    my-new-model.vae.pt
```

For this to work, the VAE must end with ".vae.pt", ".vae.ckpt", or
".vae.safetensors". The indicated VAE will be converted into diffusers
format and stored with the converted models file, so the ".pt" file
can be deleted after conversion.

No facility is currently provided to swap a diffusers VAE at import
time, but this can be done after the fact using the WebUI and CLI's
model editing functions.
2023-03-27 11:27:45 -04:00
dbc0093b31 Merge remote-tracking branch 'origin' into i18n-build-mode 2023-03-27 10:57:41 -04:00
92e512b8b6 add package mode option for i18next 2023-03-27 10:49:52 -04:00
abe4dc8ac1 Add support for yet another textual inversion embedding format
- This PR adds support for embedding files that contain a single key
  "emb_params". The only example I know of this format is the
  "EasyNegative" embedding on HuggingFace, but there are certainly
  others.

- This PR also adds support for loading embedding files that have been
  saved in safetensors format.

- It also cleans up the code so that the logic of probing for and
  selecting the right format parser is clear.
2023-03-27 09:39:03 -04:00
dc14701d20 Merge branch 'main' into feat/node-cli-autocompleter 2023-03-26 23:46:10 -04:00
737e0f3085 doc: fixing error in rocm version 2023-03-26 12:40:20 +02:00
81b7ea4362 doc: updating ROCm version for pip install 2023-03-26 12:32:12 +02:00
09dfde0ba1 fix(ui): fix viewer tooltip localisation strings (#3037)
fixes #2923
2023-03-26 20:35:52 +13:00
3ba7e966b5 Merge branch 'main' into fix/ui/viewer-localisation 2023-03-26 20:35:12 +13:00
a1cd4834d1 nodes: add cancelation, updated progress callback, typing fixes (#3036)
keeping `main` up to date with my api nodes branch:
- bd7e515290: [nodes] Add cancelation to
the API @Kyle0654
- 5fe38f7: fix(backend): simple typing fixes
  - just picking some low-hanging fruit to improve IDE hinting
- c34ac91: fix(nodes): fix cancel; fix callback for img2img, inpaint
- makes nodes cancel immediate, use fix progress images on nodes, fix
callbacks for img2img/inpaint
- 4221cf7: fix(nodes): fix schema generation for output classes
- did this previously for some other class; needed to not have node
outputs be optional
2023-03-26 20:34:27 +13:00
a724038dc6 fix(ui): fix viewer tooltip localisation strings
fixes #2923
2023-03-26 17:43:00 +11:00
4221cf7731 fix(nodes): fix schema generation for output classes
All output classes need to have their properties flagged as `required` for the schema generation to work as needed.
2023-03-26 17:20:10 +11:00
c34ac91ff0 fix(nodes): fix cancel; fix callback for img2img, inpaint 2023-03-26 17:07:40 +11:00
5fe38f7c88 fix(backend): simple typing fixes 2023-03-26 17:07:03 +11:00
bd7e515290 [nodes] Add cancelation to the API 2023-03-26 15:47:32 +11:00
076fac07eb feat[web]: use the predicted denoised image for previews (#2915)
Some schedulers report not only the noisy latents at the current
timestep, but also their estimate so far of what the de-noised latents
will be.

It makes for a more legible preview than the noisy latents do.

I think this is a huge improvement, but there are a few considerations:
- Need to not spook @JPPhoto by changing how previews look.
- Some schedulers (most notably **DPM Solver++**) don't provide this
data, and it falls back to the current behavior there. That's not
terrible, but seeing such a big difference in how _previews_ look from
one scheduler to the next might mislead people into thinking there's a
bigger difference in their overall effectiveness than there really is.

My fear of configuration-option-overwhelm leaves me inclined to _not_
add a configuration option for this, but we could.
2023-03-26 00:29:00 -04:00
9348161600 add basic autocomplete functionality to node cli
- Commands, invocations and their parameters will now autocomplete
  using introspection.
- Two types of parameter *arguments* will also autocomplete:
  - --sampler_name  will autocomplete the scheduler name
  - --model will autocomplete the model name
- There don't seem to be commands for reading/writing image files yet, so
  path autocompletion is not implemented
2023-03-26 00:24:27 -04:00
dac3c158a5 Merge branch 'main' into feat/preview_predicted_x0
- resolve conflicts with generate.py invocation
- remove unused symbols that pyflakes complains about
- add **untested** code for passing intermediate latent image to the
  step callback in the format expected.
2023-03-25 16:07:18 -04:00
17d8bbf330 ask for escalated privileges in push workflows 2023-03-25 15:22:25 -04:00
9344687a56 installer: fix indentation in invoke.sh template (tabs -> spaces) 2023-03-25 13:57:09 -04:00
cf534d735c duplicate of PR #3016, but based on main 2023-03-25 13:57:09 -04:00
501924bc60 do not reexport PipelineIntermediateState 2023-03-25 13:57:09 -04:00
d117251747 make step_callback work again in generate() call
This PR fixes #2951 and restores the step_callback argument in the
refactored generate() method. Note that this issue states that
"something is still wrong because steps and step are zero." However,
I think this is confusion over the call signature of the callback, which
since the diffusers merge has been `callback(state:PipelineIntermediateState)`

This is the test script that I used to determine that `step` is being passed
correctly:

```

from pathlib import Path
from invokeai.backend import ModelManager, PipelineIntermediateState
from invokeai.backend.globals import global_config_dir
from invokeai.backend.generator import Txt2Img

def my_callback(state:PipelineIntermediateState, total_steps:int):
    print(f'callback(step={state.step}/{total_steps})')

def main():
    manager = ModelManager(Path(global_config_dir()) / "models.yaml")
    model = manager.get_model('stable-diffusion-1.5')
    print ('=== TXT2IMG TEST ===')
    steps=30
    output = next(Txt2Img(model).generate(prompt='banana sushi',
                                          iterations=None,
                                          steps=steps,
                                          step_callback=lambda x: my_callback(x,steps)
                                          )
                  )
    print(f'image={output.image}, seed={output.seed}, steps={output.params.steps}')

if __name__=='__main__':
    main()
```
2023-03-25 13:57:09 -04:00
6ea61a8486 fix issue with embeddings being loaded twice (#3029)
This bug was causing a bunch of annoying warnings about not overwriting
previously loaded tokens.

- as noted by JPPhoto
2023-03-26 04:45:20 +13:00
e4d903af20 Merge branch 'main' into bugfix/load-embeddings-once 2023-03-26 04:19:43 +13:00
2d9797da35 (fix)[docs] Fixed snippet/code formatting (#2918)
It was pasted as plain text, now it's a code fence.
2023-03-25 10:49:13 -04:00
07ea806553 Merge branch 'main' into patch-1 2023-03-25 10:48:25 -04:00
5ac0316c62 fix issue with embeddings being loaded twice
- as noted by JPPhoto
2023-03-25 10:45:03 -04:00
9536ba22af Convert custom VAEs during legacy checkpoint loading (#3010)
- When a legacy checkpoint model is loaded via --convert_ckpt and its
models.yaml stanza refers to a custom VAE path (using the 'vae:' key),
the custom VAE will be converted and used within the diffusers model.
Otherwise the VAE contained within the legacy model will be used.
    
- Note that the checkpoint import functions in the CLI or Web UIs
continue to default to the standard stabilityai/sd-vae-ft-mse VAE. This
can be fixed after the fact by editing VAE key using either the CLI or
Web UI.
   
- Fixes issue #2917
2023-03-25 00:37:12 -04:00
5503749085 Merge branch 'main' into feat/use-custom-vaes 2023-03-25 17:09:38 +13:00
9bfe2fa371 add github API token to mkdocs workflow (#3023)
The mkdocs-workflow has been failing over the past week due to
permission denied errors. I *think* this is the result of not passing
the GitHub API token to the workflow, and this is a speculative fix for
the issue.
2023-03-24 17:59:53 -04:00
d8ce6e4426 Merge branch 'bugfix/mkdocs-workflow' of github.com:invoke-ai/InvokeAI into bugfix/mkdocs-workflow 2023-03-24 17:58:16 -04:00
43d2d6d98c add blessedcoolant as backup to mauwii codeowner 2023-03-24 17:58:03 -04:00
64c233efd4 Merge branch 'main' into bugfix/mkdocs-workflow 2023-03-24 17:47:14 -04:00
2245a4e117 doc(readme): fix incorrect install command (#3024)
Hi, I am trying to install InvokeAI on my linux machine, the command in
README.md cannot install correct dependency
2023-03-24 17:46:58 -04:00
9ceec40b76 Merge branch 'main' into feat/use-custom-vaes 2023-03-24 17:45:02 -04:00
0f13b90059 doc(readme): fix incorrect install command 2023-03-24 23:21:51 +08:00
d91fc16ae4 add github API token to mkdocs workflow 2023-03-24 09:17:30 -04:00
bc01a96f9d re-implement model scanning when loading legacy checkpoint files (#3012)
- This PR turns on pickle scanning before a legacy checkpoint file is
loaded from disk within the checkpoint_to_diffusers module.

- Also miscellaneous diagnostic message cleanup.

- See also #3011 for a similar patch to the 2.3 branch.
2023-03-24 08:57:07 -04:00
85b2822f5e Merge branch 'main' into security/scan-ckpt-files-main 2023-03-24 08:39:59 -04:00
c33d8694bb build: do not run python tests on ui build (#2987)
`invokeai/frontend/web/dist/**` should not be triggering the full test
suite.
2023-03-25 00:54:40 +13:00
685bd027f0 Merge branch 'main' into build/no-test-on-ui-build 2023-03-25 00:51:26 +13:00
f592d620d5 ui: translations update from weblate (#3021)
Translations update from [Hosted Weblate](https://hosted.weblate.org)
for [InvokeAI/Web
UI](https://hosted.weblate.org/projects/invokeai/web-ui/).



Current translation status:

![Weblate translation
status](https://hosted.weblate.org/widgets/invokeai/-/web-ui/horizontal-auto.svg)
2023-03-24 19:25:17 +11:00
Tom
2b127b73ac translationBot(ui): update translation (French)
Currently translated at 82.7% (417 of 504 strings)

Co-authored-by: Tom <tom.fouthier@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/fr/
Translation: InvokeAI/Web UI
2023-03-24 04:49:27 +01:00
8855902cfe translationBot(ui): update translation (Spanish)
Currently translated at 100.0% (504 of 504 strings)

translationBot(ui): update translation (Spanish)

Currently translated at 100.0% (501 of 501 strings)

Co-authored-by: gallegonovato <fran-carro@hotmail.es>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/es/
Translation: InvokeAI/Web UI
2023-03-24 04:49:27 +01:00
9d8ddc6a08 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2023-03-24 04:49:27 +01:00
4ca5189e73 translationBot(ui): update translation (Italian)
Currently translated at 100.0% (504 of 504 strings)

translationBot(ui): update translation (Italian)

Currently translated at 100.0% (501 of 501 strings)

translationBot(ui): update translation (Italian)

Currently translated at 100.0% (500 of 500 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2023-03-24 04:49:27 +01:00
873597cb84 Allow loading all types of dreambooth models - Fix issue #2932 (#2933)
Allows to load models with EMA using `model_ema.diffusion_model.xxxx` or
`model_ema.xxxx` weights.

Fixes #2932
2023-03-23 23:40:04 -04:00
44d742f232 Merge branch 'main' into security/scan-ckpt-files-main 2023-03-23 23:33:49 -04:00
5dec5b6f51 Merge branch 'main' into dev/pytorch2 2023-03-23 23:31:21 -04:00
6e7dbf99f3 Merge branch 'main' into bugfix/dreambooth_ema 2023-03-23 23:24:15 -04:00
1ba1076888 Tidy up Tests and Provide Documentation (#2869)
Bit of basic housekeeping and documentation to explain to people how to
get local development environment running (including the tests).
2023-03-23 23:23:20 -04:00
cafa108f69 Merge branch 'main' into tests 2023-03-23 23:22:27 -04:00
deeff36e16 Merge branch 'main' into security/scan-ckpt-files-main 2023-03-23 23:20:52 -04:00
d770b14358 [deps] upgrade compel for better .swap defaults and a bugfix (#3014) 2023-03-23 19:01:12 -04:00
20414ba4ad Merge branch 'main' into deps_upgrade_compel 2023-03-23 18:38:46 -04:00
92721a1d45 do not reexport PipelineIntermediateState 2023-03-24 09:32:47 +11:00
f329fddab9 make step_callback work again in generate() call
This PR fixes #2951 and restores the step_callback argument in the
refactored generate() method. Note that this issue states that
"something is still wrong because steps and step are zero." However,
I think this is confusion over the call signature of the callback, which
since the diffusers merge has been `callback(state:PipelineIntermediateState)`

This is the test script that I used to determine that `step` is being passed
correctly:

```

from pathlib import Path
from invokeai.backend import ModelManager, PipelineIntermediateState
from invokeai.backend.globals import global_config_dir
from invokeai.backend.generator import Txt2Img

def my_callback(state:PipelineIntermediateState, total_steps:int):
    print(f'callback(step={state.step}/{total_steps})')

def main():
    manager = ModelManager(Path(global_config_dir()) / "models.yaml")
    model = manager.get_model('stable-diffusion-1.5')
    print ('=== TXT2IMG TEST ===')
    steps=30
    output = next(Txt2Img(model).generate(prompt='banana sushi',
                                          iterations=None,
                                          steps=steps,
                                          step_callback=lambda x: my_callback(x,steps)
                                          )
                  )
    print(f'image={output.image}, seed={output.seed}, steps={output.params.steps}')

if __name__=='__main__':
    main()
```
2023-03-24 09:32:47 +11:00
f2efde27f6 load embeddings after a ckpt legacy model is converted to diffusers (#3013)
This PR corrects a bug in which embeddings were not being applied when a
non-diffusers model was loaded.

- Fixes #2954
- Also improves diagnostic reporting during embedding loading.
2023-03-23 18:10:19 -04:00
02c58f22be upgrade compel for better .swap defaults and a bugfix 2023-03-23 22:34:54 +01:00
f751dcd245 load embeddings after a ckpt legacy model is converted to diffusers
- Fixes #2954
- Also improves diagnostic reporting during embedding loading.
2023-03-23 15:21:58 -04:00
a97107bd90 handle VAEs that do not have a "state_dict" key 2023-03-23 15:11:29 -04:00
b2ce45a417 re-implement model scanning when loading legacy checkpoint files
- This PR turns on pickle scanning before a legacy checkpoint file
  is loaded from disk within the checkpoint_to_diffusers module.

- Also miscellaneous diagnostic message cleanup.
2023-03-23 15:03:30 -04:00
4e0b5d85ba convert custom VAEs into diffusers
- When a legacy checkpoint model is loaded via --convert_ckpt and its
  models.yaml stanza refers to a custom VAE path (using the 'vae:'
  key), the custom VAE will be converted and used within the diffusers
  model. Otherwise the VAE contained within the legacy model will be
  used.

- Note that the heuristic_import() method, which imports arbitrary
  legacy files on disk and URLs, will continue to default to the
  the standard stabilityai/sd-vae-ft-mse VAE. This can be fixed after
  the fact by editing the models.yaml stanza using the Web or CLI
  UIs.

- Fixes issue #2917
2023-03-23 13:14:19 -04:00
a958ae5e29 Merge branch 'main' into feat/use-custom-vaes 2023-03-23 10:32:56 -04:00
4d50fbf8dc Merge branch 'main' into build/no-test-on-ui-build 2023-03-23 01:08:24 +13:00
485f6e5954 Export more for header (#2996)
* export more items needed for dynamic header
* remove build mode that is no longer needed
2023-03-23 01:07:16 +13:00
1f6ce838ba Merge branch 'main' into export-more-for-header 2023-03-22 07:49:15 -04:00
0dc5773849 [nodes] Update fastapi packages to latest (except FastAPI, which has an annotation bug in the newest version) (#3004) 2023-03-22 19:12:45 +13:00
bc347f749c [nodes] Update fastapi packages to latest (except FastAPI, which has an annotation bug in the newest version) 2023-03-21 19:45:17 -07:00
1b215059e7 Merge branch 'main' into export-more-for-header 2023-03-21 16:29:53 -04:00
db079a2733 remove unneeded build:package code 2023-03-21 10:29:27 -04:00
26f71d3536 change back 2023-03-21 10:28:29 -04:00
eb7ae2588c unused var 2023-03-21 10:21:58 -04:00
278c14ba2e try jsx.element 2023-03-21 10:18:38 -04:00
74e83dda54 update type 2023-03-21 10:10:48 -04:00
28c1fca477 Merge branch 'main' into build/no-test-on-ui-build 2023-03-20 02:21:40 +13:00
1f0324102a chore(ui): build 2023-03-19 23:16:29 +11:00
a782ad092d feat(ui): localise iaialertdialog defaults 2023-03-19 23:16:29 +11:00
eae4eb419a fix(ui): popovers trigger on click (accessibility) 2023-03-19 23:16:29 +11:00
fb7f38f46e fix(ui): make alertdialogs centered 2023-03-19 23:16:29 +11:00
93d0cae455 fix(ui): fix alertdialogs closing immediately 2023-03-19 23:16:29 +11:00
35f6b5d562 fix(ui): make invoketabs not lazy 2023-03-19 23:16:29 +11:00
2aefa06ef1 fix(ui): Clean up manual add forms 2023-03-19 23:16:29 +11:00
5906888477 feat(ui): add current image loading fallback 2023-03-19 23:16:29 +11:00
f22c7d0da6 feat(ui): add more w/h options 2023-03-19 23:16:29 +11:00
93b38707b2 feat(ui): tidy up model manager styling
fixes #2970
2023-03-19 23:16:29 +11:00
6ecf53078f fix(ui): Misalignment of model search entries 2023-03-19 23:16:29 +11:00
9c93b7cb59 build: do not run python tests on ui build
`invokeai/frontend/web/dist/**` should not be triggering the full test suite.
2023-03-19 23:01:30 +11:00
7789e8319c Fix some text and a link (#2910)
- Fix link to `070_INSTALL_XFORMERS.md`.
- Fix some spelling.
2023-03-19 05:55:18 +13:00
7d7a28beb3 Merge branch 'main' into main-text-fixup-PR 2023-03-18 09:54:41 -07:00
e158ad8534 deps: upgrade to PyTorch 2.0 (replaces xformers) 2023-03-15 15:45:48 -07:00
c2922d5991 add settingsmodal 2023-03-15 16:12:51 -04:00
85888030c3 more things needed for header 2023-03-15 14:38:22 -04:00
7cf59c1e60 Merge branch 'main' into main-text-fixup-PR 2023-03-16 04:43:22 +13:00
3ca654d256 speculative fix for alternative vaes 2023-03-13 23:27:29 -04:00
0996bd5acf Merge branch 'main' into patch-1 2023-03-14 03:18:58 +13:00
288cee9611 Merge remote-tracking branch 'origin/main' into feat/preview_predicted_x0
# Conflicts:
#	invokeai/app/invocations/generate.py
2023-03-12 20:56:02 -07:00
5c5106c14a Add keys when non EMA 2023-03-12 16:22:22 -05:00
c367b21c71 Fix issue #2932 2023-03-12 15:40:33 -05:00
06aa5a8120 Merge branch 'main' into feat/preview_predicted_x0 2023-03-11 14:50:30 -06:00
f45483e519 Merge branch 'main' into feat/preview_predicted_x0 2023-03-10 22:25:26 -06:00
4156bfd810 Fixed snippet/code formatting
It was pasted as plain text, now it's a code fence.
2023-03-11 02:08:59 +01:00
63f59201f8 Merge branch 'main' into feat/preview_predicted_x0 2023-03-10 12:34:07 -06:00
4a00f1cc74 Merge branch 'main' into feat/preview_predicted_x0 2023-03-10 09:20:01 -06:00
fe6858f2d9 feat: use the predicted denoised image for previews
Some schedulers report not only the noisy latents at the current timestep,
but also their estimate so far of what the de-noised latents will be.

It makes for a more legible preview than the noisy latents do.
2023-03-09 20:28:06 -08:00
26cd1728ac Fix some text and a link 2023-03-09 20:03:11 -06:00
1cd4cdd0e5 Merge branch 'main' into tests 2023-03-08 12:19:04 -05:00
63725d7534 add .pytest.ini to .gitignore 2023-03-07 09:08:27 +00:00
00f30ea457 Merge branch 'main' into tests 2023-03-07 09:03:18 +00:00
5eff035f55 Merge branch 'main' into tests 2023-03-06 08:37:07 -05:00
486c445afb fix typos and replace frontend REAMDE content 2023-03-05 21:05:09 +00:00
4547c48013 add docs for local development including tests 2023-03-05 19:59:06 +00:00
c247f430f7 combine pytest.ini with pyproject.toml 2023-03-05 17:00:08 +00:00
3d6a358042 remove .coveragerc from source contrl 2023-03-05 16:59:12 +00:00
109 changed files with 2514 additions and 1110 deletions

View File

@ -1,6 +0,0 @@
[run]
omit='.env/*'
source='.'
[report]
show_missing = true

8
.github/CODEOWNERS vendored
View File

@ -1,16 +1,16 @@
# continuous integration
/.github/workflows/ @mauwii @lstein
/.github/workflows/ @mauwii @lstein @blessedcoolant
# documentation
/docs/ @lstein @mauwii @tildebyte
/mkdocs.yml @lstein @mauwii
/docs/ @lstein @mauwii @tildebyte @blessedcoolant
/mkdocs.yml @lstein @mauwii @blessedcoolant
# nodes
/invokeai/app/ @Kyle0654 @blessedcoolant
# installation and configuration
/pyproject.toml @mauwii @lstein @blessedcoolant
/docker/ @mauwii @lstein
/docker/ @mauwii @lstein @blessedcoolant
/scripts/ @ebr @lstein
/installer/ @lstein @ebr
/invokeai/assets @lstein @ebr

View File

@ -16,6 +16,10 @@ on:
- 'v*.*.*'
workflow_dispatch:
permissions:
contents: write
packages: write
jobs:
docker:
if: github.event.pull_request.draft == false

View File

@ -5,6 +5,9 @@ on:
- 'main'
- 'development'
permissions:
contents: write
jobs:
mkdocs-material:
if: github.event.pull_request.draft == false

View File

@ -6,7 +6,6 @@ on:
- '!pyproject.toml'
- '!invokeai/**'
- 'invokeai/frontend/web/**'
- '!invokeai/frontend/web/dist/**'
merge_group:
workflow_dispatch:

View File

@ -7,13 +7,11 @@ on:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'invokeai/frontend/web/dist/**'
pull_request:
paths:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'invokeai/frontend/web/dist/**'
types:
- 'ready_for_review'
- 'opened'

2
.gitignore vendored
View File

@ -63,6 +63,7 @@ pip-delete-this-directory.txt
htmlcov/
.tox/
.nox/
.coveragerc
.coverage
.coverage.*
.cache
@ -73,6 +74,7 @@ cov.xml
*.py,cover
.hypothesis/
.pytest_cache/
.pytest.ini
cover/
junit/

View File

@ -1,5 +0,0 @@
[pytest]
DJANGO_SETTINGS_MODULE = webtas.settings
; python_files = tests.py test_*.py *_tests.py
addopts = --cov=. --cov-config=.coveragerc --cov-report xml:cov.xml

View File

@ -139,13 +139,13 @@ not supported.
_For Windows/Linux with an NVIDIA GPU:_
```terminal
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
```
_For Linux with an AMD GPU:_
```sh
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
```
_For Macintoshes, either Intel or M1/M2:_

4
coverage/.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore

Binary file not shown.

After

Width:  |  Height:  |  Size: 470 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 457 KiB

View File

@ -0,0 +1,83 @@
# Local Development
If you are looking to contribute you will need to have a local development
environment. See the
[Developer Install](../installation/020_INSTALL_MANUAL.md#developer-install) for
full details.
Broadly this involves cloning the repository, installing the pre-reqs, and
InvokeAI (in editable form). Assuming this is working, choose your area of
focus.
## Documentation
We use [mkdocs](https://www.mkdocs.org) for our documentation with the
[material theme](https://squidfunk.github.io/mkdocs-material/). Documentation is
written in markdown files under the `./docs` folder and then built into a static
website for hosting with GitHub Pages at
[invoke-ai.github.io/InvokeAI](https://invoke-ai.github.io/InvokeAI).
To contribute to the documentation you'll need to install the dependencies. Note
the use of `"`.
```zsh
pip install ".[docs]"
```
Now, to run the documentation locally with hot-reloading for changes made.
```zsh
mkdocs serve
```
You'll then be prompted to connect to `http://127.0.0.1:8080` in order to
access.
## Backend
The backend is contained within the `./invokeai/backend` folder structure. To
get started however please install the development dependencies.
From the root of the repository run the following command. Note the use of `"`.
```zsh
pip install ".[test]"
```
This in an optional group of packages which is defined within the
`pyproject.toml` and will be required for testing the changes you make the the
code.
### Running Tests
We use [pytest](https://docs.pytest.org/en/7.2.x/) for our test suite. Tests can
be found under the `./tests` folder and can be run with a single `pytest`
command. Optionally, to review test coverage you can append `--cov`.
```zsh
pytest --cov
```
Test outcomes and coverage will be reported in the terminal. In addition a more
detailed report is created in both XML and HTML format in the `./coverage`
folder. The HTML one in particular can help identify missing statements
requiring tests to ensure coverage. This can be run by opening
`./coverage/html/index.html`.
For example.
```zsh
pytest --cov; open ./coverage/html/index.html
```
??? info "HTML coverage report output"
![html-overview](../assets/contributing/html-overview.png)
![html-detail](../assets/contributing/html-detail.png)
## Front End
<!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.-->
--8<-- "invokeai/frontend/web/README.md"

View File

@ -168,11 +168,15 @@ used by Stable Diffusion 1.4 and 1.5.
After installation, your `models.yaml` should contain an entry that looks like
this one:
inpainting-1.5: weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
description: SD inpainting v1.5 config:
configs/stable-diffusion/v1-inpainting-inference.yaml vae:
models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt width: 512
height: 512
```yml
inpainting-1.5:
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
description: SD inpainting v1.5
config: configs/stable-diffusion/v1-inpainting-inference.yaml
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
width: 512
height: 512
```
As shown in the example, you may include a VAE fine-tuning weights file as well.
This is strongly recommended.

View File

@ -268,7 +268,7 @@ model is so good at inpainting, a good substitute is to use the `clipseg` text
masking option:
```bash
invoke> a fluffy cat eating a hotdot
invoke> a fluffy cat eating a hotdog
Outputs:
[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog
invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat

View File

@ -17,7 +17,7 @@ notebooks.
You will need a GPU to perform training in a reasonable length of
time, and at least 12 GB of VRAM. We recommend using the [`xformers`
library](../installation/070_INSTALL_XFORMERS) to accelerate the
library](../installation/070_INSTALL_XFORMERS.md) to accelerate the
training process further. During training, about ~8 GB is temporarily
needed in order to store intermediate models, checkpoints and logs.

View File

@ -417,7 +417,7 @@ Then type the following commands:
=== "AMD System"
```bash
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.2
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
```
### Corrupted configuration file

View File

@ -154,7 +154,7 @@ manager, please follow these steps:
=== "ROCm (AMD)"
```bash
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
```
=== "CPU (Intel Macs & non-GPU systems)"
@ -315,7 +315,7 @@ installation protocol (important!)
=== "ROCm (AMD)"
```bash
pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2
pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
```
=== "CPU (Intel Macs & non-GPU systems)"

View File

@ -110,7 +110,7 @@ recipes are available
When installing torch and torchvision manually with `pip`, remember to provide
the argument `--extra-index-url
https://download.pytorch.org/whl/rocm5.2` as described in the [Manual
https://download.pytorch.org/whl/rocm5.4.2` as described in the [Manual
Installation Guide](020_INSTALL_MANUAL.md).
This will be done automatically for you if you use the installer

View File

@ -24,7 +24,7 @@ You need to have opencv installed so that pypatchmatch can be built:
brew install opencv
```
The next time you start `invoke`, after sucesfully installing opencv, pypatchmatch will be built.
The next time you start `invoke`, after successfully installing opencv, pypatchmatch will be built.
## Linux
@ -56,7 +56,7 @@ Prior to installing PyPatchMatch, you need to take the following steps:
5. Confirm that pypatchmatch is installed. At the command-line prompt enter
`python`, and then at the `>>>` line type
`from patchmatch import patch_match`: It should look like the follwing:
`from patchmatch import patch_match`: It should look like the following:
```py
Python 3.9.5 (default, Nov 23 2021, 15:27:38)
@ -108,4 +108,4 @@ Prior to installing PyPatchMatch, you need to take the following steps:
[**Next, Follow Steps 4-6 from the Debian Section above**](#linux)
If you see no errors, then you're ready to go!
If you see no errors you're ready to go!

View File

@ -456,13 +456,12 @@ def get_torch_source() -> (Union[str, None],str):
optional_modules = None
if OS == "Linux":
if device == "rocm":
url = "https://download.pytorch.org/whl/rocm5.2"
url = "https://download.pytorch.org/whl/rocm5.4.2"
elif device == "cpu":
url = "https://download.pytorch.org/whl/cpu"
if device == 'cuda':
url = 'https://download.pytorch.org/whl/cu117'
optional_modules = '[xformers]'
url = 'https://download.pytorch.org/whl/cu118'
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@ -24,9 +24,9 @@ if [ "$(uname -s)" == "Darwin" ]; then
export PYTORCH_ENABLE_MPS_FALLBACK=1
fi
while true
do
if [ "$0" != "bash" ]; then
while true
do
echo "Do you want to generate images using the"
echo "1. command-line interface"
echo "2. browser-based UI"
@ -67,29 +67,29 @@ if [ "$0" != "bash" ]; then
;;
7)
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
;;
8)
echo "Developer Console:"
;;
8)
echo "Developer Console:"
file_name=$(basename "${BASH_SOURCE[0]}")
bash --init-file "$file_name"
;;
9)
echo "Update:"
echo "Update:"
invokeai-update
;;
10)
invokeai --help
;;
[qQ])
[qQ])
exit 0
;;
*)
echo "Invalid selection"
exit;;
esac
done
else # in developer console
python --version
echo "Press ^D to exit"
export PS1="(InvokeAI) \u@\h \w> "
fi
done

View File

@ -3,6 +3,8 @@
import os
from argparse import Namespace
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ...backend import Globals
from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices
@ -54,7 +56,9 @@ class ApiDependencies:
os.path.join(os.path.dirname(__file__), "../../../../outputs")
)
images = DiskImageStorage(output_folder)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
images = DiskImageStorage(f'{output_folder}/images')
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db")
@ -62,6 +66,7 @@ class ApiDependencies:
services = InvocationServices(
model_manager=get_model_manager(config),
events=events,
latents=latents,
images=images,
queue=MemoryInvocationQueue(),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](

View File

@ -23,6 +23,16 @@ async def get_image(
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
return FileResponse(filename)
@images_router.get("/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail")
async def get_thumbnail(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
):
"""Gets a thumbnail"""
# TODO: This is not really secure at all. At least make sure only output results are served
filename = ApiDependencies.invoker.services.images.get_path(image_type, 'thumbnails/' + image_name)
return FileResponse(filename)
@images_router.post(
"/uploads/",

View File

@ -0,0 +1,279 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Annotated, Any, List, Literal, Optional, Union
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
class VaeRepo(BaseModel):
repo_id: str = Field(description="The repo ID to use for this VAE")
path: Optional[str] = Field(description="The path to the VAE")
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model")
class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt'
config: str = Field(description="The path to the model config")
weights: str = Field(description="The path to the model weights")
vae: str = Field(description="The path to the model VAE")
width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['diffusers'] = 'diffusers'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class ModelsList(BaseModel):
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
@models_router.get(
"/",
operation_id="list_models",
responses={200: {"model": ModelsList }},
)
async def list_models() -> ModelsList:
"""Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models()
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
# @socketio.on("requestSystemConfig")
# def handle_request_capabilities():
# print(">> System config requested")
# config = self.get_system_config()
# config["model_list"] = self.generate.model_manager.list_models()
# config["infill_methods"] = infill_methods()
# socketio.emit("systemConfig", config)
# @socketio.on("searchForModels")
# def handle_search_models(search_folder: str):
# try:
# if not search_folder:
# socketio.emit(
# "foundModels",
# {"search_folder": None, "found_models": None},
# )
# else:
# (
# search_folder,
# found_models,
# ) = self.generate.model_manager.search_models(search_folder)
# socketio.emit(
# "foundModels",
# {"search_folder": search_folder, "found_models": found_models},
# )
# except Exception as e:
# self.handle_exceptions(e)
# print("\n")
# @socketio.on("addNewModel")
# def handle_add_model(new_model_config: dict):
# try:
# model_name = new_model_config["name"]
# del new_model_config["name"]
# model_attributes = new_model_config
# if len(model_attributes["vae"]) == 0:
# del model_attributes["vae"]
# update = False
# current_model_list = self.generate.model_manager.list_models()
# if model_name in current_model_list:
# update = True
# print(f">> Adding New Model: {model_name}")
# self.generate.model_manager.add_model(
# model_name=model_name,
# model_attributes=model_attributes,
# clobber=True,
# )
# self.generate.model_manager.commit(opt.conf)
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "newModelAdded",
# {
# "new_model_name": model_name,
# "model_list": new_model_list,
# "update": update,
# },
# )
# print(f">> New Model Added: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("deleteModel")
# def handle_delete_model(model_name: str):
# try:
# print(f">> Deleting Model: {model_name}")
# self.generate.model_manager.del_model(model_name)
# self.generate.model_manager.commit(opt.conf)
# updated_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelDeleted",
# {
# "deleted_model_name": model_name,
# "model_list": updated_model_list,
# },
# )
# print(f">> Model Deleted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("requestModelChange")
# def handle_set_model(model_name: str):
# try:
# print(f">> Model change requested: {model_name}")
# model = self.generate.set_model(model_name)
# model_list = self.generate.model_manager.list_models()
# if model is None:
# socketio.emit(
# "modelChangeFailed",
# {"model_name": model_name, "model_list": model_list},
# )
# else:
# socketio.emit(
# "modelChanged",
# {"model_name": model_name, "model_list": model_list},
# )
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict):
# try:
# if model_info := self.generate.model_manager.model_info(
# model_name=model_to_convert["model_name"]
# ):
# if "weights" in model_info:
# ckpt_path = Path(model_info["weights"])
# original_config_file = Path(model_info["config"])
# model_name = model_to_convert["model_name"]
# model_description = model_info["description"]
# else:
# self.socketio.emit(
# "error", {"message": "Model is not a valid checkpoint file"}
# )
# else:
# self.socketio.emit(
# "error", {"message": "Could not retrieve model info."}
# )
# if not ckpt_path.is_absolute():
# ckpt_path = Path(Globals.root, ckpt_path)
# if original_config_file and not original_config_file.is_absolute():
# original_config_file = Path(Globals.root, original_config_file)
# diffusers_path = Path(
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
# )
# if model_to_convert["save_location"] == "root":
# diffusers_path = Path(
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
# )
# if (
# model_to_convert["save_location"] == "custom"
# and model_to_convert["custom_location"] is not None
# ):
# diffusers_path = Path(
# model_to_convert["custom_location"], f"{model_name}_diffusers"
# )
# if diffusers_path.exists():
# shutil.rmtree(diffusers_path)
# self.generate.model_manager.convert_and_import(
# ckpt_path,
# diffusers_path,
# model_name=model_name,
# model_description=model_description,
# vae=None,
# original_config_file=original_config_file,
# commit_to_conf=opt.conf,
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelConverted",
# {
# "new_model_name": model_name,
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Model Converted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("mergeDiffusersModels")
# def merge_diffusers_models(model_merge_info: dict):
# try:
# models_to_merge = model_merge_info["models_to_merge"]
# model_ids_or_paths = [
# self.generate.model_manager.model_name_or_path(x)
# for x in models_to_merge
# ]
# merged_pipe = merge_diffusion_models(
# model_ids_or_paths,
# model_merge_info["alpha"],
# model_merge_info["interp"],
# model_merge_info["force"],
# )
# dump_path = global_models_dir() / "merged_models"
# if model_merge_info["model_merge_save_path"] is not None:
# dump_path = Path(model_merge_info["model_merge_save_path"])
# os.makedirs(dump_path, exist_ok=True)
# dump_path = dump_path / model_merge_info["merged_model_name"]
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
# merged_model_config = dict(
# model_name=model_merge_info["merged_model_name"],
# description=f'Merge of models {", ".join(models_to_merge)}',
# commit_to_conf=opt.conf,
# )
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
# "vae", None
# ):
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
# merged_model_config.update(vae=vae)
# self.generate.model_manager.import_diffuser_model(
# dump_path, **merged_model_config
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelsMerged",
# {
# "merged_models": models_to_merge,
# "merged_model_name": model_merge_info["merged_model_name"],
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
# except Exception as e:
# self.handle_exceptions(e)

View File

@ -51,7 +51,7 @@ async def list_sessions(
query: str = Query(default="", description="The query string to search for"),
) -> PaginatedResults[GraphExecutionState]:
"""Gets a list of sessions, optionally searching"""
if filter == "":
if query == "":
result = ApiDependencies.invoker.services.graph_execution_manager.list(
page, per_page
)

View File

@ -14,7 +14,7 @@ from pydantic.schema import schema
from ..backend import Args
from .api.dependencies import ApiDependencies
from .api.routers import images, sessions
from .api.routers import images, sessions, models
from .api.sockets import SocketIO
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
@ -76,6 +76,8 @@ app.include_router(sessions.session_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?

View File

@ -4,7 +4,8 @@ from abc import ABC, abstractmethod
import argparse
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
from pydantic import BaseModel, Field
import networkx as nx
import matplotlib.pyplot as plt
from ..invocations.image import ImageField
from ..services.graph import GraphExecutionState
from ..services.invoker import Invoker
@ -46,7 +47,7 @@ def add_parsers(
f"--{name}",
dest=name,
type=field_type,
default=field.default,
default=field.default if field.default_factory is None else field.default_factory(),
choices=allowed_values,
help=field.field_info.description,
)
@ -55,7 +56,7 @@ def add_parsers(
f"--{name}",
dest=name,
type=field.type_,
default=field.default,
default=field.default if field.default_factory is None else field.default_factory(),
help=field.field_info.description,
)
@ -200,3 +201,39 @@ class SetDefaultCommand(BaseCommand):
del context.defaults[self.field]
else:
context.defaults[self.field] = self.value
class DrawGraphCommand(BaseCommand):
"""Debugs a graph"""
type: Literal['draw_graph'] = 'draw_graph'
def run(self, context: CliContext) -> None:
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
nxgraph = session.graph.nx_graph_flat()
# Draw the networkx graph
plt.figure(figsize=(20, 20))
pos = nx.spectral_layout(nxgraph)
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
nx.draw_networkx_edges(nxgraph, pos, width=2)
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
plt.axis("off")
plt.show()
class DrawExecutionGraphCommand(BaseCommand):
"""Debugs an execution graph"""
type: Literal['draw_xgraph'] = 'draw_xgraph'
def run(self, context: CliContext) -> None:
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
nxgraph = session.execution_graph.nx_graph_flat()
# Draw the networkx graph
plt.figure(figsize=(20, 20))
pos = nx.spectral_layout(nxgraph)
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
nx.draw_networkx_edges(nxgraph, pos, width=2)
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
plt.axis("off")
plt.show()

View File

@ -0,0 +1,167 @@
"""
Readline helper functions for cli_app.py
You may import the global singleton `completer` to get access to the
completer object.
"""
import atexit
import readline
import shlex
from pathlib import Path
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
from ...backend import ModelManager, Globals
from ..invocations.baseinvocation import BaseInvocation
from .commands import BaseCommand
# singleton object, class variable
completer = None
class Completer(object):
def __init__(self, model_manager: ModelManager):
self.commands = self.get_commands()
self.matches = None
self.linebuffer = None
self.manager = model_manager
return
def complete(self, text, state):
"""
Complete commands and switches fromm the node CLI command line.
Switches are determined in a context-specific manner.
"""
buffer = readline.get_line_buffer()
if state == 0:
options = None
try:
current_command, current_switch = self.get_current_command(buffer)
options = self.get_command_options(current_command, current_switch)
except IndexError:
pass
options = options or list(self.parse_commands().keys())
if not text: # first time
self.matches = options
else:
self.matches = [s for s in options if s and s.startswith(text)]
try:
match = self.matches[state]
except IndexError:
match = None
return match
@classmethod
def get_commands(self)->List[object]:
"""
Return a list of all the client commands and invocations.
"""
return BaseCommand.get_commands() + BaseInvocation.get_invocations()
def get_current_command(self, buffer: str)->tuple[str, str]:
"""
Parse the readline buffer to find the most recent command and its switch.
"""
if len(buffer)==0:
return None, None
tokens = shlex.split(buffer)
command = None
switch = None
for t in tokens:
if t[0].isalpha():
if switch is None:
command = t
else:
switch = t
# don't try to autocomplete switches that are already complete
if switch and buffer.endswith(' '):
switch=None
return command or '', switch or ''
def parse_commands(self)->Dict[str, List[str]]:
"""
Return a dict in which the keys are the command name
and the values are the parameters the command takes.
"""
result = dict()
for command in self.commands:
hints = get_type_hints(command)
name = get_args(hints['type'])[0]
result.update({name:hints})
return result
def get_command_options(self, command: str, switch: str)->List[str]:
"""
Return all the parameters that can be passed to the command as
command-line switches. Returns None if the command is unrecognized.
"""
parsed_commands = self.parse_commands()
if command not in parsed_commands:
return None
# handle switches in the format "-foo=bar"
argument = None
if switch and '=' in switch:
switch, argument = switch.split('=')
parameter = switch.strip('-')
if parameter in parsed_commands[command]:
if argument is None:
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
else:
return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])]
else:
return [f"--{x}" for x in parsed_commands[command].keys()]
def get_parameter_options(self, parameter: str, typehint)->List[str]:
"""
Given a parameter type (such as Literal), offers autocompletions.
"""
if get_origin(typehint) == Literal:
return get_args(typehint)
if parameter == 'model':
return self.manager.model_names()
def _pre_input_hook(self):
if self.linebuffer:
readline.insert_text(self.linebuffer)
readline.redisplay()
self.linebuffer = None
def set_autocompleter(model_manager: ModelManager) -> Completer:
global completer
if completer:
return completer
completer = Completer(model_manager)
readline.set_completer(completer.complete)
# pyreadline3 does not have a set_auto_history() method
try:
readline.set_auto_history(True)
except:
pass
readline.set_pre_input_hook(completer._pre_input_hook)
readline.set_completer_delims(" ")
readline.parse_and_bind("tab: complete")
readline.parse_and_bind("set print-completions-horizontally off")
readline.parse_and_bind("set page-completions on")
readline.parse_and_bind("set skip-completed-text on")
readline.parse_and_bind("set show-all-if-ambiguous on")
histfile = Path(Globals.root, ".invoke_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
except OSError: # file likely corrupted
newname = f"{histfile}.old"
print(
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
)
histfile.replace(Path(newname))
atexit.register(readline.write_history_file, histfile)

View File

@ -2,6 +2,7 @@
import argparse
import os
import re
import shlex
import time
from typing import (
@ -12,14 +13,17 @@ from typing import (
from pydantic import BaseModel
from pydantic.fields import Field
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
from .cli.completer import set_autocompleter
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices
from .services.graph import Edge, EdgeConnection, GraphExecutionState
from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
@ -43,7 +47,7 @@ def add_invocation_args(command_parser):
"-l",
action="append",
nargs=3,
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)",
help="A link in the format 'source_node source_field dest_field'. source_node can be relative to history (e.g. -1)",
)
command_parser.add_argument(
@ -93,6 +97,9 @@ def generate_matching_edges(
invalid_fields = set(["type", "id"])
matching_fields = matching_fields.difference(invalid_fields)
# Validate types
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])]
edges = [
Edge(
source=EdgeConnection(node_id=a.id, field=field),
@ -130,6 +137,12 @@ def invoke_cli():
config.parse_args()
model_manager = get_model_manager(config)
# This initializes the autocompleter and returns it.
# Currently nothing is done with the returned Completer
# object, but the object can be used to change autocompletion
# behavior on the fly, if desired.
completer = set_autocompleter(model_manager)
events = EventServiceBase()
output_folder = os.path.abspath(
@ -142,7 +155,8 @@ def invoke_cli():
services = InvocationServices(
model_manager=model_manager,
events=events,
images=DiskImageStorage(output_folder),
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=DiskImageStorage(f'{output_folder}/images'),
queue=MemoryInvocationQueue(),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
@ -155,6 +169,8 @@ def invoke_cli():
session: GraphExecutionState = invoker.create_execution_state()
parser = get_command_parser()
re_negid = re.compile('^-[0-9]+$')
# Uncomment to print out previous sessions at startup
# print(services.session_manager.list())
@ -162,8 +178,8 @@ def invoke_cli():
while True:
try:
cmd_input = input("> ")
except KeyboardInterrupt:
cmd_input = input("invoke> ")
except (KeyboardInterrupt, EOFError):
# Ctrl-c exits
break
@ -220,7 +236,11 @@ def invoke_cli():
# Parse provided links
if "link_node" in args and args["link_node"]:
for link in args["link_node"]:
link_node = context.session.graph.get_node(link)
node_id = link
if re_negid.match(node_id):
node_id = str(current_id + int(node_id))
link_node = context.session.graph.get_node(node_id)
matching_edges = generate_matching_edges(
link_node, command.command
)
@ -230,10 +250,15 @@ def invoke_cli():
if "link" in args and args["link"]:
for link in args["link"]:
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]]
node_id = link[0]
if re_negid.match(node_id):
node_id = str(current_id + int(node_id))
edges.append(
Edge(
source=EdgeConnection(node_id=link[1], field=link[0]),
source=EdgeConnection(node_id=node_id, field=link[1]),
destination=EdgeConnection(
node_id=command.command.id, field=link[2]
)

View File

@ -0,0 +1,50 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
import cv2 as cv
import numpy as np
import numpy.random
from PIL import Image, ImageOps
from pydantic import Field
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput
from .image import ImageField, ImageOutput
class IntCollectionOutput(BaseInvocationOutput):
"""A collection of integers"""
type: Literal["int_collection"] = "int_collection"
# Outputs
collection: list[int] = Field(default=[], description="The int collection")
class RangeInvocation(BaseInvocation):
"""Creates a range"""
type: Literal["range"] = "range"
# Inputs
start: int = Field(default=0, description="The start of the range")
stop: int = Field(default=10, description="The stop of the range")
step: int = Field(default=1, description="The step of the range")
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""
type: Literal["random_range"] = "random_range"
# Inputs
low: int = Field(default=0, description="The inclusive low value")
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
size: int = Field(default=1, description="The number of values to generate")
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(numpy.random.randint(self.low, self.high, size=self.size)))

View File

@ -1,22 +1,19 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Any, Literal, Optional, Union
from functools import partial
from typing import Literal, Optional, Union
import numpy as np
from torch import Tensor
from PIL import Image
from pydantic import Field
from skimage.exposure.histogram_matching import match_histograms
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
from ..util.util import diffusers_step_callback_adapter, CanceledException
SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers())
@ -45,32 +42,26 @@ class TextToImageInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, sample: Tensor, step: int
) -> None:
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
self.steps,
)
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)
# def step_callback(state: PipelineIntermediateState):
# if (context.services.queue.is_canceled(context.graph_execution_state_id)):
# raise CanceledException
# self.dispatch_progress(context, state.latents, state.step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
@ -79,7 +70,7 @@ class TextToImageInvocation(BaseInvocation):
model= context.services.model_manager.get_model()
outputs = Txt2Img(model).generate(
prompt=self.prompt,
step_callback=step_callback,
step_callback=partial(self.dispatch_progress, context),
**self.dict(
exclude={"prompt"}
), # Shorthand for passing all of the parameters above manually
@ -116,6 +107,22 @@ class ImageToImageInvocation(TextToImageInvocation):
description="Whether or not the result should be fit to the aspect ratio of the input image",
)
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
None
@ -126,24 +133,23 @@ class ImageToImageInvocation(TextToImageInvocation):
)
mask = None
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
model = context.services.model_manager.get_model()
generator_output = next(
Img2Img(model).generate(
outputs = Img2Img(model).generate(
prompt=self.prompt,
init_image=image,
init_mask=mask,
step_callback=step_callback,
step_callback=partial(self.dispatch_progress, context),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image
@ -173,6 +179,22 @@ class InpaintInvocation(ImageToImageInvocation):
description="The amount by which to replace masked areas with latent noise",
)
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
None
@ -187,24 +209,23 @@ class InpaintInvocation(ImageToImageInvocation):
else context.services.images.get(self.mask.image_type, self.mask.image_name)
)
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
manager = context.services.model_manager.get_model()
generator_output = next(
Inpaint(model).generate(
model = context.services.model_manager.get_model()
outputs = Inpaint(model).generate(
prompt=self.prompt,
init_image=image,
mask_image=mask,
step_callback=step_callback,
init_img=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image

View File

@ -28,12 +28,28 @@ class ImageOutput(BaseInvocationOutput):
image: ImageField = Field(default=None, description="The output image")
#fmt: on
class Config:
schema_extra = {
'required': [
'type',
'image',
]
}
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
#fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
#fomt: on
#fmt: on
class Config:
schema_extra = {
'required': [
'type',
'mask',
]
}
# TODO: this isn't really necessary anymore
class LoadImageInvocation(BaseInvocation):

View File

@ -0,0 +1,321 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional
from pydantic import BaseModel, Field
from torch import Tensor
import torch
from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import CUDA_DEVICE, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
import numpy as np
from accelerate.utils import set_seed
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
from diffusers import DiffusionPipeline
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents"""
#fmt: off
type: Literal["latent_output"] = "latent_output"
latents: LatentsField = Field(default=None, description="The output latents")
#fmt: on
class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output"""
#fmt: off
type: Literal["noise_output"] = "noise_output"
noise: LatentsField = Field(default=None, description="The output noise")
#fmt: on
# TODO: this seems like a hack
scheduler_map = dict(
ddim=diffusers.DDIMScheduler,
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_euler=diffusers.EulerDiscreteScheduler,
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
k_heun=diffusers.HeunDiscreteScheduler,
k_lms=diffusers.LMSDiscreteScheduler,
plms=diffusers.PNDMScheduler,
)
SAMPLER_NAME_VALUES = Literal[
tuple(list(scheduler_map.keys()))
]
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
scheduler = scheduler_class.from_config(model.scheduler.config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
return scheduler
def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8):
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(latent_channels, 4)
use_device = "cpu" if (use_mps_noise or device.type == "mps") else device
generator = torch.Generator(device=use_device).manual_seed(seed)
x = torch.randn(
[
1,
input_channels,
height // downsampling_factor,
width // downsampling_factor,
],
dtype=torch_dtype(device),
device=use_device,
generator=generator,
).to(device)
# if self.perlin > 0.0:
# perlin_noise = self.get_perlin_noise(
# width // self.downsampling_factor, height // self.downsampling_factor
# )
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
return x
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", )
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(CUDA_DEVICE)
noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, noise)
return NoiseOutput(
noise=LatentsField(latents_name=name)
)
# Text to image
class TextToLatentsInvocation(BaseInvocation):
"""Generates latents from a prompt."""
type: Literal["t2l"] = "t2l"
# Inputs
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
# fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, sample: Tensor, step: int
) -> None:
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
self.steps,
)
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = model_manager.get_model(self.model)
model_name = model_info['model_name']
model_hash = model_info['hash']
model: StableDiffusionGeneratorPipeline = model_info['model']
model.scheduler = get_scheduler(
model=model,
scheduler_name=self.sampler_name
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
return model
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
conditioning_data = ConditioningData(
uc,
c,
self.cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=0.0,#threshold,
warmup=0.2,#warmup,
h_symmetry_time_pct=None,#h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta)
return conditioning_data
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=model.device, dtype=latent.dtype
)
timesteps, _ = model.get_img2img_timesteps(
self.steps,
self.strength,
device=model.device,
)
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
# Latent to image
class LatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""
type: Literal["l2i"] = "l2i"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
model: str = Field(default="", description="The model to use")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
# TODO: this only really needs the vae
model_info = context.services.model_manager.get_model(self.model)
model: StableDiffusionGeneratorPipeline = model_info['model']
with torch.inference_mode():
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@ -0,0 +1,68 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Literal, Optional
import numpy
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
class IntOutput(BaseInvocationOutput):
"""An integer output"""
#fmt: off
type: Literal["int_output"] = "int_output"
a: int = Field(default=None, description="The output integer")
#fmt: on
class AddInvocation(BaseInvocation):
"""Adds two numbers"""
#fmt: off
type: Literal["add"] = "add"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a + self.b)
class SubtractInvocation(BaseInvocation):
"""Subtracts two numbers"""
#fmt: off
type: Literal["sub"] = "sub"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a - self.b)
class MultiplyInvocation(BaseInvocation):
"""Multiplies two numbers"""
#fmt: off
type: Literal["mul"] = "mul"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a * self.b)
class DivideInvocation(BaseInvocation):
"""Divides two numbers"""
#fmt: off
type: Literal["div"] = "div"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=int(self.a / self.b))

View File

@ -12,3 +12,11 @@ class PromptOutput(BaseInvocationOutput):
prompt: str = Field(default=None, description="The output prompt")
#fmt: on
class Config:
schema_extra = {
'required': [
'type',
'prompt',
]
}

View File

@ -127,6 +127,13 @@ class NodeAlreadyExecutedError(Exception):
class GraphInvocationOutput(BaseInvocationOutput):
type: Literal["graph_output"] = "graph_output"
class Config:
schema_extra = {
'required': [
'type',
'image',
]
}
# TODO: Fill this out and move to invocations
class GraphInvocation(BaseInvocation):
@ -147,6 +154,13 @@ class IterateInvocationOutput(BaseInvocationOutput):
item: Any = Field(description="The item being iterated over")
class Config:
schema_extra = {
'required': [
'type',
'item',
]
}
# TODO: Fill this out and move to invocations
class IterateInvocation(BaseInvocation):
@ -169,6 +183,13 @@ class CollectInvocationOutput(BaseInvocationOutput):
collection: list[Any] = Field(description="The collection of input items")
class Config:
schema_extra = {
'required': [
'type',
'collection',
]
}
class CollectInvocation(BaseInvocation):
"""Collects values into a collection"""
@ -1048,9 +1069,8 @@ class GraphExecutionState(BaseModel):
n
for n in prepared_nodes
if all(
pit
nx.has_path(execution_graph, pit[0], n)
for pit in parent_iterators
if nx.has_path(execution_graph, pit[0], n)
)
),
None,

View File

@ -9,6 +9,7 @@ from queue import Queue
from typing import Dict
from PIL.Image import Image
from invokeai.app.util.save_thumbnail import save_thumbnail
from invokeai.backend.image_util import PngWriter
@ -66,6 +67,9 @@ class DiskImageStorage(ImageStorageBase):
Path(os.path.join(output_folder, image_type)).mkdir(
parents=True, exist_ok=True
)
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
def get(self, image_type: ImageType, image_name: str) -> Image:
image_path = self.get_path(image_type, image_name)
@ -87,7 +91,11 @@ class DiskImageStorage(ImageStorageBase):
self.__pngWriter.save_image_and_prompt_to_png(
image, "", image_subpath, None
) # TODO: just pass full path to png writer
save_thumbnail(
image=image,
filename=image_name,
path=os.path.join(self.__output_folder, image_type, "thumbnails"),
)
image_path = self.get_path(image_type, image_name)
self.__set_cache(image_path, image)

View File

@ -2,6 +2,7 @@
from invokeai.backend import ModelManager
from .events import EventServiceBase
from .latent_storage import LatentsStorageBase
from .image_storage import ImageStorageBase
from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC
@ -11,6 +12,7 @@ class InvocationServices:
"""Services that can be used by invocations"""
events: EventServiceBase
latents: LatentsStorageBase
images: ImageStorageBase
queue: InvocationQueueABC
model_manager: ModelManager
@ -24,6 +26,7 @@ class InvocationServices:
self,
model_manager: ModelManager,
events: EventServiceBase,
latents: LatentsStorageBase,
images: ImageStorageBase,
queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
@ -32,6 +35,7 @@ class InvocationServices:
):
self.model_manager = model_manager
self.events = events
self.latents = latents
self.images = images
self.queue = queue
self.graph_execution_manager = graph_execution_manager

View File

@ -33,7 +33,6 @@ class Invoker:
self.services.graph_execution_manager.set(graph_execution_state)
# Queue the invocation
print(f"queueing item {invocation.id}")
self.services.queue.put(
InvocationQueueItem(
# session_id = session.id,

View File

@ -0,0 +1,93 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict
import torch
class LatentsStorageBase(ABC):
"""Responsible for storing and retrieving latents."""
@abstractmethod
def get(self, name: str) -> torch.Tensor:
pass
@abstractmethod
def set(self, name: str, data: torch.Tensor) -> None:
pass
@abstractmethod
def delete(self, name: str) -> None:
pass
class ForwardCacheLatentsStorage(LatentsStorageBase):
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
__cache: Dict[str, torch.Tensor]
__cache_ids: Queue
__max_cache_size: int
__underlying_storage: LatentsStorageBase
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
self.__underlying_storage = underlying_storage
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = max_cache_size
def get(self, name: str) -> torch.Tensor:
cache_item = self.__get_cache(name)
if cache_item is not None:
return cache_item
latent = self.__underlying_storage.get(name)
self.__set_cache(name, latent)
return latent
def set(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.set(name, data)
self.__set_cache(name, data)
def delete(self, name: str) -> None:
self.__underlying_storage.delete(name)
if name in self.__cache:
del self.__cache[name]
def __get_cache(self, name: str) -> torch.Tensor|None:
return None if name not in self.__cache else self.__cache[name]
def __set_cache(self, name: str, data: torch.Tensor):
if not name in self.__cache:
self.__cache[name] = data
self.__cache_ids.put(name)
if self.__cache_ids.qsize() > self.__max_cache_size:
self.__cache.pop(self.__cache_ids.get())
class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching"""
__output_folder: str
def __init__(self, output_folder: str):
self.__output_folder = output_folder
Path(output_folder).mkdir(parents=True, exist_ok=True)
def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name)
return torch.load(latent_path)
def set(self, name: str, data: torch.Tensor) -> None:
latent_path = self.get_path(name)
torch.save(data, latent_path)
def delete(self, name: str) -> None:
latent_path = self.get_path(name)
os.remove(latent_path)
def get_path(self, name: str) -> str:
return os.path.join(self.__output_folder, name)

View File

@ -4,7 +4,7 @@ from threading import Event, Thread
from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker
from ..util.util import CanceledException
class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread
@ -82,6 +82,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
except KeyboardInterrupt:
pass
except CanceledException:
pass
except Exception as e:
error = traceback.format_exc()

View File

@ -59,6 +59,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(item.json(),),
)
self._conn.commit()
finally:
self._lock.release()
self._on_changed(item)
@ -84,6 +85,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._cursor.execute(
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
self._conn.commit()
finally:
self._lock.release()
self._on_deleted(id)

View File

@ -0,0 +1,25 @@
import os
from PIL import Image
def save_thumbnail(
image: Image.Image,
filename: str,
path: str,
size: int = 256,
) -> str:
"""
Saves a thumbnail of an image, returning its path.
"""
base_filename = os.path.splitext(filename)[0]
thumbnail_path = os.path.join(path, base_filename + ".webp")
if os.path.exists(thumbnail_path):
return thumbnail_path
image_copy = image.copy()
image_copy.thumbnail(size=(size, size))
image_copy.save(thumbnail_path, "WEBP")
return thumbnail_path

42
invokeai/app/util/util.py Normal file
View File

@ -0,0 +1,42 @@
import torch
from PIL import Image
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
class CanceledException(Exception):
pass
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ):
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
steps,
)
def diffusers_step_callback_adapter(*cb_args, **kwargs):
"""
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
This adapter grabs the needed data and passes it along to the callback function.
"""
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs)
else:
return fast_latents_step_callback(*cb_args, **kwargs)

View File

@ -21,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed
from diffusers import DiffusionPipeline
from tqdm import trange
from typing import List, Iterator, Type
from typing import Callable, List, Iterator, Optional, Type
from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler
@ -35,23 +35,23 @@ downsampling = 8
@dataclass
class InvokeAIGeneratorBasicParams:
seed: int=None
seed: Optional[int]=None
width: int=512
height: int=512
cfg_scale: int=7.5
cfg_scale: float=7.5
steps: int=20
ddim_eta: float=0.0
scheduler: int='ddim'
scheduler: str='ddim'
precision: str='float16'
perlin: float=0.0
threshold: int=0.0
threshold: float=0.0
seamless: bool=False
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
h_symmetry_time_pct: float=None
v_symmetry_time_pct: float=None
h_symmetry_time_pct: Optional[float]=None
v_symmetry_time_pct: Optional[float]=None
variation_amount: float = 0.0
with_variations: list=field(default_factory=list)
safety_checker: SafetyChecker=None
safety_checker: Optional[SafetyChecker]=None
@dataclass
class InvokeAIGeneratorOutput:
@ -61,10 +61,10 @@ class InvokeAIGeneratorOutput:
and the model hash, as well as all the generate() parameters that went into
generating the image (in .params, also available as attributes)
'''
image: Image
image: Image.Image
seed: int
model_hash: str
attention_maps_images: List[Image]
attention_maps_images: List[Image.Image]
params: Namespace
# we are interposing a wrapper around the original Generator classes so that
@ -92,8 +92,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def generate(self,
prompt: str='',
callback: callable=None,
step_callback: callable=None,
callback: Optional[Callable]=None,
step_callback: Optional[Callable]=None,
iterations: int=1,
**keyword_args,
)->Iterator[InvokeAIGeneratorOutput]:
@ -154,6 +154,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
for i in iteration_count:
results = generator.generate(prompt,
conditioning=(uc, c, extra_conditioning_info),
step_callback=step_callback,
sampler=scheduler,
**generator_args,
)
@ -205,10 +206,10 @@ class Txt2Img(InvokeAIGenerator):
# ------------------------------------
class Img2Img(InvokeAIGenerator):
def generate(self,
init_image: Image | torch.FloatTensor,
init_image: Image.Image | torch.FloatTensor,
strength: float=0.75,
**keyword_args
)->List[InvokeAIGeneratorOutput]:
)->Iterator[InvokeAIGeneratorOutput]:
return super().generate(init_image=init_image,
strength=strength,
**keyword_args
@ -222,7 +223,7 @@ class Img2Img(InvokeAIGenerator):
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
class Inpaint(Img2Img):
def generate(self,
mask_image: Image | torch.FloatTensor,
mask_image: Image.Image | torch.FloatTensor,
# Seam settings - when 0, doesn't fill seam
seam_size: int = 0,
seam_blur: int = 0,
@ -235,7 +236,7 @@ class Inpaint(Img2Img):
inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
**keyword_args
)->List[InvokeAIGeneratorOutput]:
)->Iterator[InvokeAIGeneratorOutput]:
return super().generate(
mask_image=mask_image,
seam_size=seam_size,
@ -262,7 +263,7 @@ class Embiggen(Txt2Img):
embiggen: list=None,
embiggen_tiles: list = None,
strength: float=0.75,
**kwargs)->List[InvokeAIGeneratorOutput]:
**kwargs)->Iterator[InvokeAIGeneratorOutput]:
return super().generate(embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
strength=strength,

View File

@ -372,22 +372,32 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100:
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
if extract_ema:
print(" | Extracting EMA weights (usually better for inference)")
print(" | Extracting EMA weights (usually better for inference)")
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key
)
flat_ema_key_alt = "model_ema." + "".join(key.split(".")[2:])
if flat_ema_key in checkpoint:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key
)
elif flat_ema_key_alt in checkpoint:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key_alt
)
else:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
key
)
else:
print(
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
)
for key in keys:
if key.startswith(unet_key):
if key.startswith("model.diffusion_model") and key in checkpoint:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
@ -1026,6 +1036,15 @@ def convert_open_clip_checkpoint(checkpoint):
return text_model
def replace_checkpoint_vae(checkpoint, vae_path:str):
if vae_path.endswith(".safetensors"):
vae_ckpt = load_file(vae_path)
else:
vae_ckpt = torch.load(vae_path, map_location="cpu")
state_dict = vae_ckpt['state_dict'] if "state_dict" in vae_ckpt else vae_ckpt
for vae_key in state_dict:
new_key = f'first_stage_model.{vae_key}'
checkpoint[new_key] = state_dict[vae_key]
def load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path: str,
@ -1038,8 +1057,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
extract_ema: bool = True,
upcast_attn: bool = False,
vae: AutoencoderKL = None,
vae_path: str = None,
precision: torch.dtype = torch.float32,
return_generator_pipeline: bool = False,
scan_needed:bool=True,
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
"""
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
@ -1067,6 +1088,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
running stable diffusion 2.1.
:param vae: A diffusers VAE to load into the pipeline.
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
"""
with warnings.catch_warnings():
@ -1074,12 +1097,13 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
checkpoint = (
torch.load(checkpoint_path)
if Path(checkpoint_path).suffix == ".ckpt"
else load_file(checkpoint_path)
)
if Path(checkpoint_path).suffix == '.ckpt':
if scan_needed:
ModelManager.scan_model(checkpoint_path,checkpoint_path)
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = load_file(checkpoint_path)
cache_dir = global_cache_dir("hub")
pipeline_class = (
StableDiffusionGeneratorPipeline
@ -1091,7 +1115,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
print(" | global_step key not found in model")
print(" | global_step key not found in model")
global_step = None
# sometimes there is a state_dict key and sometimes not
@ -1202,9 +1226,19 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model, or use the one passed
if not vae:
print(" | Using checkpoint model's original VAE")
# If a replacement VAE path was specified, we'll incorporate that into
# the checkpoint model and then convert it
if vae_path:
print(f" | Converting VAE {vae_path}")
replace_checkpoint_vae(checkpoint,vae_path)
# otherwise we use the original VAE, provided that
# an externally loaded diffusers VAE was not passed
elif not vae:
print(" | Using checkpoint model's original VAE")
if vae:
print(" | Using replacement diffusers VAE")
else: # convert the original or replacement VAE
vae_config = create_vae_diffusers_config(
original_config, image_size=image_size
)
@ -1214,8 +1248,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
else:
print(" | Using external VAE specified in config")
# Convert the text model.
model_type = pipeline_type
@ -1232,10 +1264,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
cache_dir=cache_dir,
)
pipe = pipeline_class(
vae=vae,
text_encoder=text_model,
vae=vae.to(precision),
text_encoder=text_model.to(precision),
tokenizer=tokenizer,
unet=unet,
unet=unet.to(precision),
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,

View File

@ -18,7 +18,7 @@ import warnings
from enum import Enum
from pathlib import Path
from shutil import move, rmtree
from typing import Any, Optional, Union
from typing import Any, Optional, Union, Callable
import safetensors
import safetensors.torch
@ -34,7 +34,7 @@ from picklescan.scanner import scan_file_path
from invokeai.backend.globals import Globals, global_cache_dir
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume
from ..util import CUDA_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum):
V1 = 1
@ -45,9 +45,6 @@ class SDLegacyType(Enum):
UNKNOWN = 99
DEFAULT_MAX_MODELS = 2
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
"vae-ft-mse-840000-ema-pruned": "stabilityai/sd-vae-ft-mse",
}
class ModelManager(object):
'''
@ -285,13 +282,13 @@ class ModelManager(object):
self.stack.remove(model_name)
if delete_files:
if weights:
print(f"** deleting file {weights}")
print(f"** Deleting file {weights}")
Path(weights).unlink(missing_ok=True)
elif path:
print(f"** deleting directory {path}")
print(f"** Deleting directory {path}")
rmtree(path, ignore_errors=True)
elif repo_id:
print(f"** deleting the cached model directory for {repo_id}")
print(f"** Deleting the cached model directory for {repo_id}")
self._delete_model_from_cache(repo_id)
def add_model(
@ -362,6 +359,7 @@ class ModelManager(object):
raise NotImplementedError(
f"Unknown model format {model_name}: {model_format}"
)
self._add_embeddings_to_model(model)
# usage statistics
toc = time.time()
@ -381,9 +379,9 @@ class ModelManager(object):
print(f">> Loading diffusers model from {name_or_path}")
if using_fp16:
print(" | Using faster float16 precision")
print(" | Using faster float16 precision")
else:
print(" | Using more accurate float32 precision")
print(" | Using more accurate float32 precision")
# TODO: scan weights maybe?
pipeline_args: dict[str, Any] = dict(
@ -434,10 +432,8 @@ class ModelManager(object):
# square images???
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
height = width
print(f" | Default image dimensions = {width} x {height}")
self._add_embeddings_to_model(pipeline)
print(f" | Default image dimensions = {width} x {height}")
return pipeline, width, height, model_hash
def _load_ckpt_model(self, model_name, mconfig):
@ -457,15 +453,21 @@ class ModelManager(object):
from . import load_pipeline_from_original_stable_diffusion_ckpt
self.offload_model(self.current_model)
if vae_config := self._choose_diffusers_vae(model_name):
vae = self._load_vae(vae_config)
try:
if self.list_models()[self.current_model]['status'] == 'active':
self.offload_model(self.current_model)
except Exception as e:
pass
vae_path = None
if vae:
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
if self._has_cuda():
torch.cuda.empty_cache()
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=weights,
original_config_file=config,
vae=vae,
vae_path=vae_path,
return_generator_pipeline=True,
precision=torch.float16 if self.precision == "float16" else torch.float32,
)
@ -473,7 +475,6 @@ class ModelManager(object):
pipeline.enable_offload_submodels(self.device)
else:
pipeline.to(self.device)
return (
pipeline,
width,
@ -512,18 +513,20 @@ class ModelManager(object):
print(f">> Offloading {model_name} to CPU")
model = self.models[model_name]["model"]
model.offload_all()
self.current_model = None
gc.collect()
if self._has_cuda():
torch.cuda.empty_cache()
@classmethod
def scan_model(self, model_name, checkpoint):
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
print(f">> Scanning Model: {model_name}")
print(f" | Scanning Model: {model_name}")
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
if scan_result.infected_files == 1:
@ -546,7 +549,7 @@ class ModelManager(object):
print("### Exiting InvokeAI")
sys.exit()
else:
print(">> Model scanned ok")
print(" | Model scanned ok")
def import_diffuser_model(
self,
@ -627,14 +630,13 @@ class ModelManager(object):
def heuristic_import(
self,
path_url_or_repo: str,
convert: bool = True,
model_name: str = None,
description: str = None,
model_config_file: Path = None,
commit_to_conf: Path = None,
config_file_callback: Callable[[Path], Path] = None,
) -> str:
"""
Accept a string which could be:
"""Accept a string which could be:
- a HF diffusers repo_id
- a URL pointing to a legacy .ckpt or .safetensors file
- a local path pointing to a legacy .ckpt or .safetensors file
@ -648,16 +650,20 @@ class ModelManager(object):
The model_name and/or description can be provided. If not, they will
be generated automatically.
If convert is true, legacy models will be converted to diffusers
before importing.
If commit_to_conf is provided, the newly loaded model will be written
to the `models.yaml` file at the indicated path. Otherwise, the changes
will only remain in memory.
The (potentially derived) name of the model is returned on success, or None
on failure. When multiple models are added from a directory, only the last
imported one is returned.
The routine will do its best to figure out the config file
needed to convert legacy checkpoint file, but if it can't it
will call the config_file_callback routine, if provided. The
callback accepts a single argument, the Path to the checkpoint
file, and returns a Path to the config file to use.
The (potentially derived) name of the model is returned on
success, or None on failure. When multiple models are added
from a directory, only the last imported one is returned.
"""
model_path: Path = None
thing = path_url_or_repo # to save typing
@ -665,7 +671,7 @@ class ModelManager(object):
print(f">> Probing {thing} for import")
if thing.startswith(("http:", "https:", "ftp:")):
print(f" | {thing} appears to be a URL")
print(f" | {thing} appears to be a URL")
model_path = self._resolve_path(
thing, "models/ldm/stable-diffusion-v1"
) # _resolve_path does a download if needed
@ -673,15 +679,15 @@ class ModelManager(object):
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
print(
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
)
return
else:
print(f" | {thing} appears to be a checkpoint file on disk")
print(f" | {thing} appears to be a checkpoint file on disk")
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
print(f" | {thing} appears to be a diffusers file on disk")
print(f" | {thing} appears to be a diffusers file on disk")
model_name = self.import_diffuser_model(
thing,
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
@ -692,25 +698,25 @@ class ModelManager(object):
elif Path(thing).is_dir():
if (Path(thing) / "model_index.json").exists():
print(f" | {thing} appears to be a diffusers model.")
print(f" | {thing} appears to be a diffusers model.")
model_name = self.import_diffuser_model(
thing, commit_to_conf=commit_to_conf
)
else:
print(
f" |{thing} appears to be a directory. Will scan for models to import"
f" |{thing} appears to be a directory. Will scan for models to import"
)
for m in list(Path(thing).rglob("*.ckpt")) + list(
Path(thing).rglob("*.safetensors")
):
if model_name := self.heuristic_import(
str(m), convert, commit_to_conf=commit_to_conf
str(m), commit_to_conf=commit_to_conf
):
print(f" >> {model_name} successfully imported")
return model_name
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
model_name = self.import_diffuser_model(
thing, commit_to_conf=commit_to_conf
)
@ -727,55 +733,75 @@ class ModelManager(object):
return
if model_path.stem in self.config: # already imported
print(" | Already imported. Skipping")
print(" | Already imported. Skipping")
return model_path.stem
# another round of heuristics to guess the correct config file.
checkpoint = (
torch.load(model_path)
if model_path.suffix == ".ckpt"
else safetensors.torch.load_file(model_path)
)
checkpoint = None
if model_path.suffix in [".ckpt",".pt"]:
self.scan_model(model_path,model_path)
checkpoint = torch.load(model_path)
else:
checkpoint = safetensors.torch.load_file(model_path)
# additional probing needed if no config file provided
if model_config_file is None:
model_type = self.probe_model_type(checkpoint)
if model_type == SDLegacyType.V1:
print(" | SD-v1 model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
)
elif model_type == SDLegacyType.V1_INPAINT:
print(" | SD-v1 inpainting model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
)
elif model_type == SDLegacyType.V2_v:
print(
" | SD-v2-v model detected; model will be converted to diffusers format"
)
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
)
convert = True
elif model_type == SDLegacyType.V2_e:
print(
" | SD-v2-e model detected; model will be converted to diffusers format"
)
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
)
convert = True
elif model_type == SDLegacyType.V2:
print(
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
)
return
# look for a like-named .yaml file in same directory
if model_path.with_suffix(".yaml").exists():
model_config_file = model_path.with_suffix(".yaml")
print(f" | Using config file {model_config_file.name}")
else:
print(
f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
)
return
model_type = self.probe_model_type(checkpoint)
if model_type == SDLegacyType.V1:
print(" | SD-v1 model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
)
elif model_type == SDLegacyType.V1_INPAINT:
print(" | SD-v1 inpainting model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
)
elif model_type == SDLegacyType.V2_v:
print(
" | SD-v2-v model detected"
)
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
)
elif model_type == SDLegacyType.V2_e:
print(
" | SD-v2-e model detected"
)
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
)
elif model_type == SDLegacyType.V2:
print(
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
)
return
else:
print(
f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
)
return
if not model_config_file and config_file_callback:
model_config_file = config_file_callback(model_path)
# despite our best efforts, we could not find a model config file, so give up
if not model_config_file:
return
# look for a custom vae, a like-named file ending with .vae in the same directory
vae_path = None
for suffix in ["pt", "ckpt", "safetensors"]:
if (model_path.with_suffix(f".vae.{suffix}")).exists():
vae_path = model_path.with_suffix(f".vae.{suffix}")
print(f" | Using VAE file {vae_path.name}")
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
diffuser_path = Path(
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
@ -783,23 +809,27 @@ class ModelManager(object):
model_name = self.convert_and_import(
model_path,
diffusers_path=diffuser_path,
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
vae=vae,
vae_path=str(vae_path),
model_name=model_name,
model_description=description,
original_config_file=model_config_file,
commit_to_conf=commit_to_conf,
scan_needed=False,
)
return model_name
def convert_and_import(
self,
ckpt_path: Path,
diffusers_path: Path,
model_name=None,
model_description=None,
vae=None,
original_config_file: Path = None,
commit_to_conf: Path = None,
self,
ckpt_path: Path,
diffusers_path: Path,
model_name=None,
model_description=None,
vae:dict=None,
vae_path:Path=None,
original_config_file: Path = None,
commit_to_conf: Path = None,
scan_needed: bool=True,
) -> str:
"""
Convert a legacy ckpt weights file to diffuser model and import
@ -822,23 +852,28 @@ class ModelManager(object):
return
model_name = model_name or diffusers_path.name
model_description = model_description or f"Optimized version of {model_name}"
print(f">> Optimizing {model_name} (30-60s)")
model_description = model_description or f"Converted version of {model_name}"
print(f" | Converting {model_name} to diffusers (30-60s)")
try:
# By passing the specified VAE to the conversion function, the autoencoder
# will be built into the model rather than tacked on afterward via the config file
vae_model = self._load_vae(vae) if vae else None
vae_model=None
if vae:
vae_model=self._load_vae(vae)
vae_path=None
convert_ckpt_to_diffusers(
ckpt_path,
diffusers_path,
extract_ema=True,
original_config_file=original_config_file,
vae=vae_model,
vae_path=vae_path,
scan_needed=scan_needed,
)
print(
f" | Success. Optimized model is now located at {str(diffusers_path)}"
f" | Success. Converted model is now located at {str(diffusers_path)}"
)
print(f" | Writing new config file entry for {model_name}")
print(f" | Writing new config file entry for {model_name}")
new_config = dict(
path=str(diffusers_path),
description=model_description,
@ -849,7 +884,7 @@ class ModelManager(object):
self.add_model(model_name, new_config, True)
if commit_to_conf:
self.commit(commit_to_conf)
print(">> Conversion succeeded")
print(" | Conversion succeeded")
except Exception as e:
print(f"** Conversion failed: {str(e)}")
print(
@ -879,36 +914,6 @@ class ModelManager(object):
return search_folder, found_models
def _choose_diffusers_vae(
self, model_name: str, vae: str = None
) -> Union[dict, str]:
# In the event that the original entry is using a custom ckpt VAE, we try to
# map that VAE onto a diffuser VAE using a hard-coded dictionary.
# I would prefer to do this differently: We load the ckpt model into memory, swap the
# VAE in memory, and then pass that to convert_ckpt_to_diffuser() so that the swapped
# VAE is built into the model. However, when I tried this I got obscure key errors.
if vae:
return vae
if model_name in self.config and (
vae_ckpt_path := self.model_info(model_name).get("vae", None)
):
vae_basename = Path(vae_ckpt_path).stem
diffusers_vae = None
if diffusers_vae := VAE_TO_REPO_ID.get(vae_basename, None):
print(
f">> {vae_basename} VAE corresponds to known {diffusers_vae} diffusers version"
)
vae = {"repo_id": diffusers_vae}
else:
print(
f'** Custom VAE "{vae_basename}" found, but corresponding diffusers model unknown'
)
print(
'** Using "stabilityai/sd-vae-ft-mse"; If this isn\'t right, please edit the model config'
)
vae = {"repo_id": "stabilityai/sd-vae-ft-mse"}
return vae
def _make_cache_room(self) -> None:
num_loaded_models = len(self.models)
if num_loaded_models >= self.max_loaded_models:
@ -1105,7 +1110,7 @@ class ModelManager(object):
with open(hashpath) as f:
hash = f.read()
return hash
print(" | Calculating sha256 hash of model files")
print(" | Calculating sha256 hash of model files")
tic = time.time()
sha = hashlib.sha256()
count = 0
@ -1117,7 +1122,7 @@ class ModelManager(object):
sha.update(chunk)
hash = sha.hexdigest()
toc = time.time()
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
with open(hashpath, "w") as f:
f.write(hash)
return hash
@ -1162,12 +1167,12 @@ class ModelManager(object):
local_files_only=not Globals.internet_available,
)
print(f" | Loading diffusers VAE from {name_or_path}")
print(f" | Loading diffusers VAE from {name_or_path}")
if using_fp16:
vae_args.update(torch_dtype=torch.float16)
fp_args_list = [{"revision": "fp16"}, {}]
else:
print(" | Using more accurate float32 precision")
print(" | Using more accurate float32 precision")
fp_args_list = [{}]
vae = None
@ -1208,7 +1213,7 @@ class ModelManager(object):
hashes_to_delete.add(revision.commit_hash)
strategy = cache_info.delete_revisions(*hashes_to_delete)
print(
f"** deletion of this model is expected to free {strategy.expected_freed_size_str}"
f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
)
strategy.execute()

View File

@ -6,7 +6,6 @@ The interface is through the Concepts() object.
"""
import os
import re
import traceback
from typing import Callable
from urllib import error as ul_error
from urllib import request
@ -15,7 +14,6 @@ from huggingface_hub import (
HfApi,
HfFolder,
ModelFilter,
ModelSearchArguments,
hf_hub_url,
)
@ -84,7 +82,7 @@ class HuggingFaceConceptsLibrary(object):
"""
if not concept_name in self.list_concepts():
print(
f"This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
)
return None
return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
@ -236,7 +234,7 @@ class HuggingFaceConceptsLibrary(object):
except ul_error.HTTPError as e:
if e.code == 404:
print(
f"This concept is not known to the Hugging Face library. Generation will continue without the concept."
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
)
else:
print(
@ -246,7 +244,7 @@ class HuggingFaceConceptsLibrary(object):
return False
except ul_error.URLError as e:
print(
f"ERROR: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
f"ERROR while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
)
os.rmdir(dest)
return False

View File

@ -531,7 +531,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id: str = None,
additional_guidance: List[Callable] = None,
):
self._adjust_memory_efficient_attention(latents)
# FIXME: do we still use any slicing now that PyTorch 2.0 has scaled dot-product attention on all platforms?
# self._adjust_memory_efficient_attention(latents)
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
if additional_guidance is None:

View File

@ -1,16 +1,26 @@
import os
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union
from typing import Optional, Union, List
import safetensors.torch
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from picklescan.scanner import scan_file_path
from transformers import CLIPTextModel, CLIPTokenizer
from .concepts_lib import HuggingFaceConceptsLibrary
@dataclass
class EmbeddingInfo:
name: str
embedding: torch.Tensor
num_vectors_per_token: int
token_dim: int
trained_steps: int = None
trained_model_name: str = None
trained_model_checksum: str = None
@dataclass
class TextualInversion:
@ -72,66 +82,46 @@ class TextualInversionManager(BaseTextualInversionManager):
if str(ckpt_path).endswith(".DS_Store"):
return
try:
scan_result = scan_file_path(str(ckpt_path))
if scan_result.infected_files == 1:
embedding_list = self._parse_embedding(str(ckpt_path))
for embedding_info in embedding_list:
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
print(
f"\n### Security Issues Found in Model: {scan_result.issues_count}"
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
)
print("### For your safety, InvokeAI will not load this embed.")
return
except Exception:
print(
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
)
return
continue
embedding_info = self._parse_embedding(str(ckpt_path))
if embedding_info is None:
# We've already put out an error message about the bad embedding in _parse_embedding, so just return.
return
elif (
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
!= embedding_info["token_dim"]
):
print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
)
return
# Resolve the situation in which an earlier embedding has claimed the same
# trigger string. We replace the trigger with '<source_file>', as we used to.
trigger_str = embedding_info["name"]
sourcefile = (
f"{ckpt_path.parent.name}/{ckpt_path.name}"
if ckpt_path.name == "learned_embeds.bin"
else ckpt_path.name
)
if trigger_str in self.trigger_to_sourcefile:
replacement_trigger_str = (
f"<{ckpt_path.parent.name}>"
# Resolve the situation in which an earlier embedding has claimed the same
# trigger string. We replace the trigger with '<source_file>', as we used to.
trigger_str = embedding_info.name
sourcefile = (
f"{ckpt_path.parent.name}/{ckpt_path.name}"
if ckpt_path.name == "learned_embeds.bin"
else f"<{ckpt_path.stem}>"
else ckpt_path.name
)
print(
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
)
trigger_str = replacement_trigger_str
try:
self._add_textual_inversion(
trigger_str,
embedding_info["embedding"],
defer_injecting_tokens=defer_injecting_tokens,
)
# remember which source file claims this trigger
self.trigger_to_sourcefile[trigger_str] = sourcefile
if trigger_str in self.trigger_to_sourcefile:
replacement_trigger_str = (
f"<{ckpt_path.parent.name}>"
if ckpt_path.name == "learned_embeds.bin"
else f"<{ckpt_path.stem}>"
)
print(
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
)
trigger_str = replacement_trigger_str
except ValueError as e:
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
print(f" | The error was {str(e)}")
try:
self._add_textual_inversion(
trigger_str,
embedding_info.embedding,
defer_injecting_tokens=defer_injecting_tokens,
)
# remember which source file claims this trigger
self.trigger_to_sourcefile[trigger_str] = sourcefile
except ValueError as e:
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
print(f" | The error was {str(e)}")
def _add_textual_inversion(
self, trigger_str, embedding, defer_injecting_tokens=False
@ -309,111 +299,130 @@ class TextualInversionManager(BaseTextualInversionManager):
return token_id
def _parse_embedding(self, embedding_file: str):
file_type = embedding_file.split(".")[-1]
if file_type == "pt":
return self._parse_embedding_pt(embedding_file)
elif file_type == "bin":
return self._parse_embedding_bin(embedding_file)
def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]:
suffix = Path(embedding_file).suffix
try:
if suffix in [".pt",".ckpt",".bin"]:
scan_result = scan_file_path(embedding_file)
if scan_result.infected_files > 0:
print(
f" ** Security Issues Found in Model: {scan_result.issues_count}"
)
print(" ** For your safety, InvokeAI will not load this embed.")
return list()
ckpt = torch.load(embedding_file,map_location="cpu")
else:
ckpt = safetensors.torch.load_file(embedding_file)
except Exception as e:
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
return list()
# try to figure out what kind of embedding file it is and parse accordingly
keys = list(ckpt.keys())
if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
elif all(x in keys for x in ['string_to_token','string_to_param']):
return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
elif 'emb_params' in keys:
return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
else:
print(f"** Notice: unrecognized embedding file format: {embedding_file}")
return None
return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
def _parse_embedding_pt(self, embedding_file):
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
embedding_info = {}
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
basename = Path(file_path).stem
print(f' | Loading v1 embedding file: {basename}')
# Check if valid embedding file
if "string_to_token" and "string_to_param" in embedding_ckpt:
# Catch variants that do not have the expected keys or values.
try:
embedding_info["name"] = embedding_ckpt["name"] or os.path.basename(
os.path.splitext(embedding_file)[0]
)
embeddings = list()
token_counter = -1
for token,embedding in embedding_ckpt["string_to_param"].items():
if token_counter < 0:
trigger = embedding_ckpt["name"]
elif token_counter == 0:
trigger = f'<basename>'
else:
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
token_counter += 1
embedding_info = EmbeddingInfo(
name = trigger,
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
trained_steps = embedding_ckpt["step"],
trained_model_name = embedding_ckpt["sd_checkpoint_name"],
trained_model_checksum = embedding_ckpt["sd_checkpoint"]
)
embeddings.append(embedding_info)
return embeddings
# Check num of embeddings and warn user only the first will be used
embedding_info["num_of_embeddings"] = len(
embedding_ckpt["string_to_token"]
)
if embedding_info["num_of_embeddings"] > 1:
print(">> More than 1 embedding found. Will use the first one")
embedding = list(embedding_ckpt["string_to_param"].values())[0]
except (AttributeError, KeyError):
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
embedding_info["embedding"] = embedding
embedding_info["num_vectors_per_token"] = embedding.size()[0]
embedding_info["token_dim"] = embedding.size()[1]
try:
embedding_info["trained_steps"] = embedding_ckpt["step"]
embedding_info["trained_model_name"] = embedding_ckpt[
"sd_checkpoint_name"
]
embedding_info["trained_model_checksum"] = embedding_ckpt[
"sd_checkpoint"
]
except AttributeError:
print(">> No Training Details Found. Passing ...")
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
# They are actually .bin files
elif len(embedding_ckpt.keys()) == 1:
embedding_info = self._parse_embedding_bin(embedding_file)
else:
print(">> Invalid embedding format")
embedding_info = None
return embedding_info
def _parse_embedding_bin(self, embedding_file):
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
embedding_info = {}
if list(embedding_ckpt.keys()) == 0:
print(">> Invalid concepts file")
embedding_info = None
else:
for token in list(embedding_ckpt.keys()):
embedding_info["name"] = (
token
or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
)
embedding_info["embedding"] = embedding_ckpt[token]
embedding_info[
"num_vectors_per_token"
] = 1 # All Concepts seem to default to 1
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
return embedding_info
def _handle_broken_pt_variants(
self, embedding_ckpt: dict, embedding_file: str
) -> dict:
def _parse_embedding_v2 (
self, embedding_ckpt: dict, file_path: str
) -> List[EmbeddingInfo]:
"""
This handles the broken .pt file variants. We only know of one at present.
This handles embedding .pt file variant #2.
"""
embedding_info = {}
basename = Path(file_path).stem
print(f' | Loading v2 embedding file: {basename}')
embeddings = list()
if isinstance(
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
):
for token in list(embedding_ckpt["string_to_token"].keys()):
embedding_info["name"] = (
token
if token != "*"
else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
token_counter = 0
for token,embedding in embedding_ckpt["string_to_param"].items():
trigger = token if token != '*' \
else f'<{basename}>' if token_counter == 0 \
else f'<{basename}-{int(token_counter:=token_counter+1)}>'
embedding_info = EmbeddingInfo(
name = trigger,
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
)
embedding_info["embedding"] = embedding_ckpt[
"string_to_param"
].state_dict()[token]
embedding_info["num_vectors_per_token"] = embedding_info[
"embedding"
].shape[0]
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
embeddings.append(embedding_info)
else:
print(">> Invalid embedding format")
embedding_info = None
print(f" ** {basename}: Unrecognized embedding format")
return embedding_info
return embeddings
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
"""
Parse 'version 3' of the .pt textual inversion embedding files.
"""
basename = Path(file_path).stem
print(f' | Loading v3 embedding file: {basename}')
embedding = embedding_ckpt['emb_params']
embedding_info = EmbeddingInfo(
name = f'<{basename}>',
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
)
return [embedding_info]
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]:
"""
Parse 'version 4' of the textual inversion embedding files. This one
is usually associated with .bin files trained by HuggingFace diffusers.
"""
basename = Path(filepath).stem
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
print(f' | Loading v4 embedding file: {short_path}')
embeddings = list()
if list(embedding_ckpt.keys()) == 0:
print(f" ** Invalid embeddings file: {short_path}")
else:
for token,embedding in embedding_ckpt.items():
embedding_info = EmbeddingInfo(
name = token or f"<{basename}>",
embedding = embedding,
num_vectors_per_token = 1, # All Concepts seem to default to 1
token_dim = embedding.size()[0],
)
embeddings.append(embedding_info)
return embeddings

View File

@ -1022,7 +1022,7 @@ class InvokeAIWebServer:
"RGB"
)
def image_progress(sample, step):
def image_progress(intermediate_state: PipelineIntermediateState):
if self.canceled.is_set():
raise CanceledException
@ -1030,6 +1030,14 @@ class InvokeAIWebServer:
nonlocal generation_parameters
nonlocal progress
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
generation_messages = {
"txt2img": "common.statusGeneratingTextToImage",
"img2img": "common.statusGeneratingImageToImage",
@ -1302,16 +1310,9 @@ class InvokeAIWebServer:
progress.set_current_iteration(progress.current_iteration + 1)
def diffusers_step_callback_adapter(*cb_args, **kwargs):
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
return image_progress(progress_state.latents, progress_state.step)
else:
return image_progress(*cb_args, **kwargs)
self.generate.prompt2image(
**generation_parameters,
step_callback=diffusers_step_callback_adapter,
step_callback=image_progress,
image_callback=image_done,
)

View File

@ -626,7 +626,7 @@ def set_default_output_dir(opt: Args, completer: Completer):
completer.set_default_dir(opt.outdir)
def import_model(model_path: str, gen, opt, completer, convert=False):
def import_model(model_path: str, gen, opt, completer):
"""
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
(3) a huggingface repository id; or (4) a local directory containing a
@ -657,7 +657,6 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
model_path,
model_name=model_name,
description=model_desc,
convert=convert,
)
if not imported_name:
@ -666,7 +665,6 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
model_path,
model_name=model_name,
description=model_desc,
convert=convert,
model_config_file=config_file,
)
if not imported_name:
@ -757,7 +755,6 @@ def _get_model_name_and_desc(
)
return model_name, model_description
def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
model_name_or_path = model_name_or_path.replace("\\", "/") # windows
manager = gen.model_manager
@ -772,16 +769,10 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
original_config_file = Path(model_info["config"])
model_name = model_name_or_path
model_description = model_info["description"]
vae = model_info["vae"]
vae_path = model_info.get("vae")
else:
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
return
if vae_repo := invokeai.backend.model_management.model_manager.VAE_TO_REPO_ID.get(
Path(vae).stem
):
vae_repo = dict(repo_id=vae_repo)
else:
vae_repo = None
model_name = manager.convert_and_import(
ckpt_path,
diffusers_path=Path(
@ -790,11 +781,11 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
model_name=model_name,
model_description=model_description,
original_config_file=original_config_file,
vae=vae_repo,
vae_path=vae_path,
)
else:
try:
import_model(model_name_or_path, gen, opt, completer, convert=True)
import_model(model_name_or_path, gen, opt, completer)
except KeyboardInterrupt:
return

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,4 @@
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-d64f4654.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-0eed8e9f.js";var Or=`
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-f7f41e1f.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-eaf47ae3.js";var Or=`
:root {
--chakra-vh: 100vh;
}

View File

@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-d64f4654.js"></script>
<script type="module" crossorigin src="./assets/index-f7f41e1f.js"></script>
<link rel="stylesheet" href="./assets/index-5483945c.css">
</head>

View File

@ -64,6 +64,8 @@
"trainingDesc2": "InvokeAI already supports training custom embeddings using Textual Inversion using the main script.",
"upload": "Upload",
"close": "Close",
"cancel": "Cancel",
"accept": "Accept",
"load": "Load",
"back": "Back",
"statusConnected": "Connected",
@ -333,6 +335,7 @@
"addNewModel": "Add New Model",
"addCheckpointModel": "Add Checkpoint / Safetensor Model",
"addDiffuserModel": "Add Diffusers",
"scanForModels": "Scan For Models",
"addManually": "Add Manually",
"manual": "Manual",
"name": "Name",

View File

@ -1,4 +1,6 @@
import React, { PropsWithChildren } from 'react';
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
export {};
@ -50,9 +52,27 @@ declare module '@invoke-ai/invoke-ai-ui' {
declare class InvokeAiLogoComponent extends React.Component<InvokeAILogoComponentProps> {
public constructor(props: InvokeAILogoComponentProps);
}
declare class IAIPopover extends React.Component<IAIPopoverProps> {
public constructor(props: IAIPopoverProps);
}
declare class IAIIconButton extends React.Component<IAIIconButtonProps> {
public constructor(props: IAIIconButtonProps);
}
declare class SettingsModal extends React.Component<SettingsModalProps> {
public constructor(props: SettingsModalProps);
}
}
declare function Invoke(props: PropsWithChildren): JSX.Element;
export { ThemeChanger, InvokeAiLogoComponent };
export {
ThemeChanger,
InvokeAiLogoComponent,
IAIPopover,
IAIIconButton,
SettingsModal,
};
export = Invoke;

View File

@ -6,7 +6,6 @@
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build",
"build:package": "vite build --mode=package",
"preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",

View File

@ -8,7 +8,6 @@
"darkTheme": "داكن",
"lightTheme": "فاتح",
"greenTheme": "أخضر",
"text2img": "نص إلى صورة",
"img2img": "صورة إلى صورة",
"unifiedCanvas": "لوحة موحدة",
"nodes": "عقد",

View File

@ -7,7 +7,6 @@
"darkTheme": "Dunkel",
"lightTheme": "Hell",
"greenTheme": "Grün",
"text2img": "Text zu Bild",
"img2img": "Bild zu Bild",
"nodes": "Knoten",
"langGerman": "Deutsch",

View File

@ -64,6 +64,8 @@
"trainingDesc2": "InvokeAI already supports training custom embeddings using Textual Inversion using the main script.",
"upload": "Upload",
"close": "Close",
"cancel": "Cancel",
"accept": "Accept",
"load": "Load",
"back": "Back",
"statusConnected": "Connected",
@ -333,6 +335,7 @@
"addNewModel": "Add New Model",
"addCheckpointModel": "Add Checkpoint / Safetensor Model",
"addDiffuserModel": "Add Diffusers",
"scanForModels": "Scan For Models",
"addManually": "Add Manually",
"manual": "Manual",
"name": "Name",

View File

@ -8,7 +8,6 @@
"darkTheme": "Oscuro",
"lightTheme": "Claro",
"greenTheme": "Verde",
"text2img": "Texto a Imagen",
"img2img": "Imagen a Imagen",
"unifiedCanvas": "Lienzo Unificado",
"nodes": "Nodos",
@ -70,7 +69,11 @@
"langHebrew": "Hebreo",
"pinOptionsPanel": "Pin del panel de opciones",
"loading": "Cargando",
"loadingInvokeAI": "Cargando invocar a la IA"
"loadingInvokeAI": "Cargando invocar a la IA",
"postprocessing": "Tratamiento posterior",
"txt2img": "De texto a imagen",
"accept": "Aceptar",
"cancel": "Cancelar"
},
"gallery": {
"generations": "Generaciones",
@ -404,7 +407,8 @@
"none": "ninguno",
"pickModelType": "Elige el tipo de modelo",
"v2_768": "v2 (768px)",
"addDifference": "Añadir una diferencia"
"addDifference": "Añadir una diferencia",
"scanForModels": "Buscar modelos"
},
"parameters": {
"images": "Imágenes",
@ -574,7 +578,7 @@
"autoSaveToGallery": "Guardar automáticamente en galería",
"saveBoxRegionOnly": "Guardar solo región dentro de la caja",
"limitStrokesToBox": "Limitar trazos a la caja",
"showCanvasDebugInfo": "Mostrar información de depuración de lienzo",
"showCanvasDebugInfo": "Mostrar la información adicional del lienzo",
"clearCanvasHistory": "Limpiar historial de lienzo",
"clearHistory": "Limpiar historial",
"clearCanvasHistoryMessage": "Limpiar el historial de lienzo también restablece completamente el lienzo unificado. Esto incluye todo el historial de deshacer/rehacer, las imágenes en el área de preparación y la capa base del lienzo.",

View File

@ -8,7 +8,6 @@
"darkTheme": "Sombre",
"lightTheme": "Clair",
"greenTheme": "Vert",
"text2img": "Texte en image",
"img2img": "Image en image",
"unifiedCanvas": "Canvas unifié",
"nodes": "Nœuds",
@ -47,7 +46,19 @@
"statusLoadingModel": "Chargement du modèle",
"statusModelChanged": "Modèle changé",
"discordLabel": "Discord",
"githubLabel": "Github"
"githubLabel": "Github",
"accept": "Accepter",
"statusMergingModels": "Mélange des modèles",
"loadingInvokeAI": "Chargement de Invoke AI",
"cancel": "Annuler",
"langEnglish": "Anglais",
"statusConvertingModel": "Conversion du modèle",
"statusModelConverted": "Modèle converti",
"loading": "Chargement",
"pinOptionsPanel": "Épingler la page d'options",
"statusMergedModels": "Modèles mélangés",
"txt2img": "Texte vers image",
"postprocessing": "Post-Traitement"
},
"gallery": {
"generations": "Générations",
@ -518,5 +529,15 @@
"betaDarkenOutside": "Assombrir à l'extérieur",
"betaLimitToBox": "Limiter à la boîte",
"betaPreserveMasked": "Conserver masqué"
},
"accessibility": {
"uploadImage": "Charger une image",
"reset": "Réinitialiser",
"nextImage": "Image suivante",
"previousImage": "Image précédente",
"useThisParameter": "Utiliser ce paramètre",
"zoomIn": "Zoom avant",
"zoomOut": "Zoom arrière",
"showOptionsPanel": "Montrer la page d'options"
}
}

View File

@ -125,7 +125,6 @@
"langSimplifiedChinese": "סינית",
"langUkranian": "אוקראינית",
"langSpanish": "ספרדית",
"text2img": "טקסט לתמונה",
"img2img": "תמונה לתמונה",
"unifiedCanvas": "קנבס מאוחד",
"nodes": "צמתים",

View File

@ -8,7 +8,6 @@
"darkTheme": "Scuro",
"lightTheme": "Chiaro",
"greenTheme": "Verde",
"text2img": "Testo a Immagine",
"img2img": "Immagine a Immagine",
"unifiedCanvas": "Tela unificata",
"nodes": "Nodi",
@ -70,7 +69,11 @@
"loading": "Caricamento in corso",
"oceanTheme": "Oceano",
"langHebrew": "Ebraico",
"loadingInvokeAI": "Caricamento Invoke AI"
"loadingInvokeAI": "Caricamento Invoke AI",
"postprocessing": "Post Elaborazione",
"txt2img": "Testo a Immagine",
"accept": "Accetta",
"cancel": "Annulla"
},
"gallery": {
"generations": "Generazioni",
@ -404,7 +407,8 @@
"v2_768": "v2 (768px)",
"none": "niente",
"addDifference": "Aggiungi differenza",
"pickModelType": "Scegli il tipo di modello"
"pickModelType": "Scegli il tipo di modello",
"scanForModels": "Cerca modelli"
},
"parameters": {
"images": "Immagini",
@ -574,7 +578,7 @@
"autoSaveToGallery": "Salvataggio automatico nella Galleria",
"saveBoxRegionOnly": "Salva solo l'area di selezione",
"limitStrokesToBox": "Limita i tratti all'area di selezione",
"showCanvasDebugInfo": "Mostra informazioni di debug della Tela",
"showCanvasDebugInfo": "Mostra ulteriori informazioni sulla Tela",
"clearCanvasHistory": "Cancella cronologia Tela",
"clearHistory": "Cancella la cronologia",
"clearCanvasHistoryMessage": "La cancellazione della cronologia della tela lascia intatta la tela corrente, ma cancella in modo irreversibile la cronologia degli annullamenti e dei ripristini.",
@ -612,7 +616,7 @@
"copyMetadataJson": "Copia i metadati JSON",
"exitViewer": "Esci dal visualizzatore",
"zoomIn": "Zoom avanti",
"zoomOut": "Zoom Indietro",
"zoomOut": "Zoom indietro",
"rotateCounterClockwise": "Ruotare in senso antiorario",
"rotateClockwise": "Ruotare in senso orario",
"flipHorizontally": "Capovolgi orizzontalmente",

View File

@ -11,7 +11,6 @@
"langArabic": "العربية",
"langEnglish": "English",
"langDutch": "Nederlands",
"text2img": "텍스트->이미지",
"unifiedCanvas": "통합 캔버스",
"langFrench": "Français",
"langGerman": "Deutsch",

View File

@ -8,7 +8,6 @@
"darkTheme": "Donker",
"lightTheme": "Licht",
"greenTheme": "Groen",
"text2img": "Tekst naar afbeelding",
"img2img": "Afbeelding naar afbeelding",
"unifiedCanvas": "Centraal canvas",
"nodes": "Knooppunten",

View File

@ -8,7 +8,6 @@
"darkTheme": "Ciemny",
"lightTheme": "Jasny",
"greenTheme": "Zielony",
"text2img": "Tekst na obraz",
"img2img": "Obraz na obraz",
"unifiedCanvas": "Tryb uniwersalny",
"nodes": "Węzły",

View File

@ -20,7 +20,6 @@
"langSpanish": "Espanhol",
"langRussian": "Русский",
"langUkranian": "Украї́нська",
"text2img": "Texto para Imagem",
"img2img": "Imagem para Imagem",
"unifiedCanvas": "Tela Unificada",
"nodes": "Nós",

View File

@ -8,7 +8,6 @@
"darkTheme": "Noite",
"lightTheme": "Dia",
"greenTheme": "Verde",
"text2img": "Texto Para Imagem",
"img2img": "Imagem Para Imagem",
"unifiedCanvas": "Tela Unificada",
"nodes": "Nódulos",

View File

@ -8,7 +8,6 @@
"darkTheme": "Темная",
"lightTheme": "Светлая",
"greenTheme": "Зеленая",
"text2img": "Изображение из текста (text2img)",
"img2img": "Изображение в изображение (img2img)",
"unifiedCanvas": "Универсальный холст",
"nodes": "Ноды",

View File

@ -8,7 +8,6 @@
"darkTheme": "Темна",
"lightTheme": "Світла",
"greenTheme": "Зелена",
"text2img": "Зображення із тексту (text2img)",
"img2img": "Зображення із зображення (img2img)",
"unifiedCanvas": "Універсальне полотно",
"nodes": "Вузли",

View File

@ -8,7 +8,6 @@
"darkTheme": "暗色",
"lightTheme": "亮色",
"greenTheme": "绿色",
"text2img": "文字到图像",
"img2img": "图像到图像",
"unifiedCanvas": "统一画布",
"nodes": "节点",

View File

@ -33,7 +33,6 @@
"langBrPortuguese": "巴西葡萄牙語",
"langRussian": "俄語",
"langSpanish": "西班牙語",
"text2img": "文字到圖像",
"unifiedCanvas": "統一畫布"
}
}

View File

@ -31,18 +31,14 @@ export const DIFFUSERS_SAMPLERS: Array<string> = [
];
// Valid image widths
export const WIDTHS: Array<number> = [
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472, 1536, 1600, 1664, 1728, 1792,
1856, 1920, 1984, 2048,
];
export const WIDTHS: Array<number> = Array.from(Array(65)).map(
(_x, i) => i * 64
);
// Valid image heights
export const HEIGHTS: Array<number> = [
64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960,
1024, 1088, 1152, 1216, 1280, 1344, 1408, 1472, 1536, 1600, 1664, 1728, 1792,
1856, 1920, 1984, 2048,
];
export const HEIGHTS: Array<number> = Array.from(Array(65)).map(
(_x, i) => i * 64
);
// Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [

View File

@ -9,6 +9,7 @@ import {
useDisclosure,
} from '@chakra-ui/react';
import { cloneElement, memo, ReactElement, ReactNode, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import IAIButton from './IAIButton';
type Props = {
@ -22,10 +23,12 @@ type Props = {
};
const IAIAlertDialog = forwardRef((props: Props, ref) => {
const { t } = useTranslation();
const {
acceptButtonText = 'Accept',
acceptButtonText = t('common.accept'),
acceptCallback,
cancelButtonText = 'Cancel',
cancelButtonText = t('common.cancel'),
cancelCallback,
children,
title,
@ -56,6 +59,7 @@ const IAIAlertDialog = forwardRef((props: Props, ref) => {
isOpen={isOpen}
leastDestructiveRef={cancelRef}
onClose={onClose}
isCentered
>
<AlertDialogOverlay>
<AlertDialogContent>

View File

@ -0,0 +1,8 @@
import { chakra } from '@chakra-ui/react';
/**
* Chakra-enabled <form />
*/
const IAIForm = chakra.form;
export default IAIForm;

View File

@ -0,0 +1,23 @@
import { Flex } from '@chakra-ui/react';
import { ReactElement } from 'react';
export function IAIFormItemWrapper({
children,
}: {
children: ReactElement | ReactElement[];
}) {
return (
<Flex
sx={{
flexDirection: 'column',
padding: 4,
rowGap: 4,
borderRadius: 'base',
width: 'full',
bg: 'base.900',
}}
>
{children}
</Flex>
);
}

View File

@ -8,7 +8,7 @@ import {
} from '@chakra-ui/react';
import { memo, ReactNode } from 'react';
type IAIPopoverProps = PopoverProps & {
export type IAIPopoverProps = PopoverProps & {
triggerComponent: ReactNode;
triggerContainerProps?: BoxProps;
children: ReactNode;

View File

@ -2,6 +2,15 @@ import Component from './component';
import InvokeAiLogoComponent from './features/system/components/InvokeAILogoComponent';
import ThemeChanger from './features/system/components/ThemeChanger';
import IAIPopover from './common/components/IAIPopover';
import IAIIconButton from './common/components/IAIIconButton';
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
export default Component;
export { InvokeAiLogoComponent, ThemeChanger };
export {
InvokeAiLogoComponent,
ThemeChanger,
IAIPopover,
IAIIconButton,
SettingsModal,
};

View File

@ -104,7 +104,6 @@ const IAICanvasMaskOptions = () => {
return (
<IAIPopover
trigger="hover"
triggerComponent={
<ButtonGroup>
<IAIIconButton

View File

@ -88,7 +88,7 @@ const IAICanvasSettingsButtonPopover = () => {
return (
<IAIPopover
trigger="hover"
isLazy={false}
triggerComponent={
<IAIIconButton
tooltip={t('unifiedCanvas.canvasSettings')}

View File

@ -219,7 +219,6 @@ const IAICanvasToolChooserOptions = () => {
onClick={handleSelectColorPickerTool}
/>
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
aria-label={t('unifiedCanvas.brushOptions')}

View File

@ -405,7 +405,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
>
<ButtonGroup isAttached={true}>
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
aria-label={`${t('parameters.sendTo')}...`}
@ -505,7 +504,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<ButtonGroup isAttached={true}>
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
icon={<FaGrinStars />}
@ -535,7 +533,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
</IAIPopover>
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
icon={<FaExpandArrowsAlt />}

View File

@ -0,0 +1,24 @@
import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react';
type CurrentImageFallbackProps = SpinnerProps;
const CurrentImageFallback = (props: CurrentImageFallbackProps) => {
const { size = 'xl', ...rest } = props;
return (
<Flex
sx={{
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
position: 'absolute',
color: 'base.400',
}}
>
<Spinner size={size} {...rest} />
</Flex>
);
};
export default CurrentImageFallback;

View File

@ -7,6 +7,7 @@ import { isEqual } from 'lodash';
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
import { gallerySelector } from '../store/gallerySelectors';
import CurrentImageFallback from './CurrentImageFallback';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons';
@ -48,6 +49,7 @@ export default function CurrentImagePreview() {
src={imageToDisplay.url}
width={imageToDisplay.width}
height={imageToDisplay.height}
fallback={!isIntermediate ? <CurrentImageFallback /> : undefined}
sx={{
objectFit: 'contain',
maxWidth: '100%',

View File

@ -34,7 +34,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton
icon={<BiZoomIn />}
aria-label={t('accessibility.zoomIn')}
tooltip="Zoom In"
tooltip={t('accessibility.zoomIn')}
onClick={() => zoomIn()}
fontSize={20}
/>
@ -42,7 +42,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton
icon={<BiZoomOut />}
aria-label={t('accessibility.zoomOut')}
tooltip="Zoom Out"
tooltip={t('accessibility.zoomOut')}
onClick={() => zoomOut()}
fontSize={20}
/>
@ -50,7 +50,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton
icon={<BiRotateLeft />}
aria-label={t('accessibility.rotateCounterClockwise')}
tooltip="Rotate Counter-Clockwise"
tooltip={t('accessibility.rotateCounterClockwise')}
onClick={rotateCounterClockwise}
fontSize={20}
/>
@ -58,7 +58,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton
icon={<BiRotateRight />}
aria-label={t('accessibility.rotateClockwise')}
tooltip="Rotate Clockwise"
tooltip={t('accessibility.rotateClockwise')}
onClick={rotateClockwise}
fontSize={20}
/>
@ -66,7 +66,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton
icon={<MdFlip />}
aria-label={t('accessibility.flipHorizontally')}
tooltip="Flip Horizontally"
tooltip={t('accessibility.flipHorizontally')}
onClick={flipHorizontally}
fontSize={20}
/>
@ -74,7 +74,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton
icon={<MdFlip style={{ transform: 'rotate(90deg)' }} />}
aria-label={t('accessibility.flipVertically')}
tooltip="Flip Vertically"
tooltip={t('accessibility.flipVertically')}
onClick={flipVertically}
fontSize={20}
/>
@ -82,7 +82,7 @@ const ReactPanZoomButtons = ({
<IAIIconButton
icon={<BiReset />}
aria-label={t('accessibility.reset')}
tooltip="Reset"
tooltip={t('accessibility.reset')}
onClick={() => {
resetTransform();
reset();

View File

@ -55,7 +55,6 @@ export default function LanguagePicker() {
return (
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
aria-label={t('common.languagePickerLabel')}

View File

@ -1,4 +1,5 @@
import {
Flex,
FormControl,
FormErrorMessage,
FormHelperText,
@ -25,10 +26,10 @@ import { useTranslation } from 'react-i18next';
import type { InvokeModelConfigProps } from 'app/invokeai';
import type { RootState } from 'app/store';
import IAIIconButton from 'common/components/IAIIconButton';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import type { FieldInputProps, FormikProps } from 'formik';
import { BiArrowBack } from 'react-icons/bi';
import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
const MIN_MODEL_SIZE = 64;
const MAX_MODEL_SIZE = 2048;
@ -72,243 +73,250 @@ export default function AddCheckpointModel() {
return (
<VStack gap={2} alignItems="flex-start">
<IAIIconButton
aria-label={t('common.back')}
tooltip={t('common.back')}
onClick={() => dispatch(setAddNewModelUIOption(null))}
width="max-content"
position="absolute"
zIndex={1}
size="sm"
insetInlineEnd={12}
top={3}
icon={<BiArrowBack />}
/>
<Flex columnGap={4}>
<IAICheckbox
isChecked={!addManually}
label={t('modelManager.scanForModels')}
onChange={() => setAddmanually(!addManually)}
/>
<IAICheckbox
label={t('modelManager.addManually')}
isChecked={addManually}
onChange={() => setAddmanually(!addManually)}
/>
</Flex>
<SearchModels />
<IAICheckbox
label={t('modelManager.addManually')}
isChecked={addManually}
onChange={() => setAddmanually(!addManually)}
/>
{addManually && (
{addManually ? (
<Formik
initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<form onSubmit={handleSubmit}>
<IAIForm onSubmit={handleSubmit} sx={{ w: 'full' }}>
<VStack rowGap={2}>
<Text fontSize={20} fontWeight="bold" alignSelf="start">
{t('modelManager.manual')}
</Text>
{/* Name */}
<FormControl
isInvalid={!!errors.name && touched.name}
isRequired
>
<FormLabel htmlFor="name" fontSize="sm">
{t('modelManager.name')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="name"
name="name"
type="text"
validate={baseValidation}
width="2xl"
/>
{!!errors.name && touched.name ? (
<FormErrorMessage>{errors.name}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.nameValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.name && touched.name}
isRequired
>
<FormLabel htmlFor="name" fontSize="sm">
{t('modelManager.name')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="name"
name="name"
type="text"
validate={baseValidation}
width="full"
/>
{!!errors.name && touched.name ? (
<FormErrorMessage>{errors.name}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.nameValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Description */}
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="2xl"
/>
{!!errors.description && touched.description ? (
<FormErrorMessage>{errors.description}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.descriptionValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
width="full"
/>
{!!errors.description && touched.description ? (
<FormErrorMessage>
{errors.description}
</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.descriptionValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Config */}
<FormControl
isInvalid={!!errors.config && touched.config}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.config')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="config"
name="config"
type="text"
width="2xl"
/>
{!!errors.config && touched.config ? (
<FormErrorMessage>{errors.config}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.configValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.config && touched.config}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.config')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="config"
name="config"
type="text"
width="full"
/>
{!!errors.config && touched.config ? (
<FormErrorMessage>{errors.config}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.configValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Weights */}
<FormControl
isInvalid={!!errors.weights && touched.weights}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="weights"
name="weights"
type="text"
width="2xl"
/>
{!!errors.weights && touched.weights ? (
<FormErrorMessage>{errors.weights}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.modelLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
<IAIFormItemWrapper>
<FormControl
isInvalid={!!errors.weights && touched.weights}
isRequired
>
<FormLabel htmlFor="config" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="weights"
name="weights"
type="text"
width="full"
/>
{!!errors.weights && touched.weights ? (
<FormErrorMessage>{errors.weights}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.modelLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* VAE */}
<FormControl isInvalid={!!errors.vae && touched.vae}>
<FormLabel htmlFor="vae" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae"
name="vae"
type="text"
width="2xl"
/>
{!!errors.vae && touched.vae ? (
<FormErrorMessage>{errors.vae}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.vae && touched.vae}>
<FormLabel htmlFor="vae" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae"
name="vae"
type="text"
width="full"
/>
{!!errors.vae && touched.vae ? (
<FormErrorMessage>{errors.vae}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<HStack width="100%">
{/* Width */}
<FormControl isInvalid={!!errors.width && touched.width}>
<FormLabel htmlFor="width" fontSize="sm">
{t('modelManager.width')}
</FormLabel>
<VStack alignItems="start">
<Field id="width" name="width">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="width"
name="width"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
width="90%"
value={form.values.width}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.width && touched.width}>
<FormLabel htmlFor="width" fontSize="sm">
{t('modelManager.width')}
</FormLabel>
<VStack alignItems="start">
<Field id="width" name="width">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="width"
name="width"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.width}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.width && touched.width ? (
<FormErrorMessage>{errors.width}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.widthValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
{!!errors.width && touched.width ? (
<FormErrorMessage>{errors.width}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.widthValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
{/* Height */}
<FormControl isInvalid={!!errors.height && touched.height}>
<FormLabel htmlFor="height" fontSize="sm">
{t('modelManager.height')}
</FormLabel>
<VStack alignItems="start">
<Field id="height" name="height">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="height"
name="height"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
width="90%"
step={64}
value={form.values.height}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
<IAIFormItemWrapper>
<FormControl isInvalid={!!errors.height && touched.height}>
<FormLabel htmlFor="height" fontSize="sm">
{t('modelManager.height')}
</FormLabel>
<VStack alignItems="start">
<Field id="height" name="height">
{({
field,
form,
}: {
field: FieldInputProps<number>;
form: FormikProps<InvokeModelConfigProps>;
}) => (
<IAINumberInput
id="height"
name="height"
min={MIN_MODEL_SIZE}
max={MAX_MODEL_SIZE}
step={64}
value={form.values.height}
onChange={(value) =>
form.setFieldValue(field.name, Number(value))
}
/>
)}
</Field>
{!!errors.height && touched.height ? (
<FormErrorMessage>{errors.height}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.heightValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
{!!errors.height && touched.height ? (
<FormErrorMessage>{errors.height}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.heightValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
</HStack>
<IAIButton
@ -319,9 +327,11 @@ export default function AddCheckpointModel() {
{t('modelManager.addModel')}
</IAIButton>
</VStack>
</form>
</IAIForm>
)}
</Formik>
) : (
<SearchModels />
)}
</VStack>
);

View File

@ -11,36 +11,14 @@ import { InvokeDiffusersModelConfigProps } from 'app/invokeai';
import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIInput from 'common/components/IAIInput';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import { BiArrowBack } from 'react-icons/bi';
import type { RootState } from 'app/store';
import type { ReactElement } from 'react';
function FormItemWrapper({
children,
}: {
children: ReactElement | ReactElement[];
}) {
return (
<Flex
sx={{
flexDirection: 'column',
padding: 4,
rowGap: 4,
borderRadius: 'base',
width: 'full',
bg: 'base.900',
}}
>
{children}
</Flex>
);
}
import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
export default function AddDiffusersModel() {
const dispatch = useAppDispatch();
@ -89,26 +67,14 @@ export default function AddDiffusersModel() {
return (
<Flex>
<IAIIconButton
aria-label={t('common.back')}
tooltip={t('common.back')}
onClick={() => dispatch(setAddNewModelUIOption(null))}
width="max-content"
position="absolute"
zIndex={1}
size="sm"
insetInlineEnd={12}
top={3}
icon={<BiArrowBack />}
/>
<Formik
initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<form onSubmit={handleSubmit}>
<IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2}>
<FormItemWrapper>
<IAIFormItemWrapper>
{/* Name */}
<FormControl
isInvalid={!!errors.name && touched.name}
@ -136,9 +102,9 @@ export default function AddDiffusersModel() {
)}
</VStack>
</FormControl>
</FormItemWrapper>
</IAIFormItemWrapper>
<FormItemWrapper>
<IAIFormItemWrapper>
{/* Description */}
<FormControl
isInvalid={!!errors.description && touched.description}
@ -165,9 +131,9 @@ export default function AddDiffusersModel() {
)}
</VStack>
</FormControl>
</FormItemWrapper>
</IAIFormItemWrapper>
<FormItemWrapper>
<IAIFormItemWrapper>
<Text fontWeight="bold" fontSize="sm">
{t('modelManager.formMessageDiffusersModelLocation')}
</Text>
@ -226,9 +192,9 @@ export default function AddDiffusersModel() {
)}
</VStack>
</FormControl>
</FormItemWrapper>
</IAIFormItemWrapper>
<FormItemWrapper>
<IAIFormItemWrapper>
{/* VAE Path */}
<Text fontWeight="bold">
{t('modelManager.formMessageDiffusersVAELocation')}
@ -290,13 +256,13 @@ export default function AddDiffusersModel() {
)}
</VStack>
</FormControl>
</FormItemWrapper>
</IAIFormItemWrapper>
<IAIButton type="submit" isLoading={isProcessing}>
{t('modelManager.addModel')}
</IAIButton>
</VStack>
</form>
</IAIForm>
)}
</Formik>
</Flex>

View File

@ -14,7 +14,7 @@ import {
import IAIButton from 'common/components/IAIButton';
import { FaPlus } from 'react-icons/fa';
import { FaArrowLeft, FaPlus } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { useTranslation } from 'react-i18next';
@ -23,6 +23,7 @@ import type { RootState } from 'app/store';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import AddCheckpointModel from './AddCheckpointModel';
import AddDiffusersModel from './AddDiffusersModel';
import IAIIconButton from 'common/components/IAIIconButton';
function AddModelBox({
text,
@ -83,8 +84,22 @@ export default function AddModel() {
closeOnOverlayClick={false}
>
<ModalOverlay />
<ModalContent margin="auto" paddingInlineEnd={4}>
<ModalHeader>{t('modelManager.addNewModel')}</ModalHeader>
<ModalContent margin="auto">
<ModalHeader>{t('modelManager.addNewModel')} </ModalHeader>
{addNewModelUIOption !== null && (
<IAIIconButton
aria-label={t('common.back')}
tooltip={t('common.back')}
onClick={() => dispatch(setAddNewModelUIOption(null))}
position="absolute"
variant="ghost"
zIndex={1}
size="sm"
insetInlineEnd={12}
top={2}
icon={<FaArrowLeft />}
/>
)}
<ModalCloseButton />
<ModalBody>
{addNewModelUIOption == null && (

View File

@ -28,6 +28,7 @@ import { isEqual, pickBy } from 'lodash';
import ModelConvert from './ModelConvert';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIForm from 'common/components/IAIForm';
const selector = createSelector(
[systemSelector],
@ -120,7 +121,7 @@ export default function CheckpointModelEdit() {
onSubmit={editModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<form onSubmit={handleSubmit}>
<IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2} alignItems="start">
{/* Description */}
<FormControl
@ -317,7 +318,7 @@ export default function CheckpointModelEdit() {
{t('modelManager.updateModel')}
</IAIButton>
</VStack>
</form>
</IAIForm>
)}
</Formik>
</Flex>

View File

@ -18,6 +18,7 @@ import type { RootState } from 'app/store';
import { isEqual, pickBy } from 'lodash';
import IAIFormHelperText from 'common/components/IAIForms/IAIFormHelperText';
import IAIFormErrorMessage from 'common/components/IAIForms/IAIFormErrorMessage';
import IAIForm from 'common/components/IAIForm';
const selector = createSelector(
[systemSelector],
@ -116,7 +117,7 @@ export default function DiffusersModelEdit() {
onSubmit={editModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<form onSubmit={handleSubmit}>
<IAIForm onSubmit={handleSubmit}>
<VStack rowGap={2} alignItems="start">
{/* Description */}
<FormControl
@ -259,7 +260,7 @@ export default function DiffusersModelEdit() {
{t('modelManager.updateModel')}
</IAIButton>
</VStack>
</form>
</IAIForm>
)}
</Formik>
</Flex>

View File

@ -12,14 +12,13 @@ import {
RadioGroup,
Spacer,
Text,
VStack,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import { useTranslation } from 'react-i18next';
import { FaPlus, FaSearch } from 'react-icons/fa';
import { FaSearch, FaTrash } from 'react-icons/fa';
import { addNewModel, searchForModels } from 'app/socketio/actions';
import {
@ -34,7 +33,7 @@ import IAIInput from 'common/components/IAIInput';
import { Field, Formik } from 'formik';
import { forEach, remove } from 'lodash';
import type { ChangeEvent, ReactNode } from 'react';
import { BiReset } from 'react-icons/bi';
import IAIForm from 'common/components/IAIForm';
const existingModelsSelector = createSelector([systemSelector], (system) => {
const { model_list } = system;
@ -71,34 +70,32 @@ function SearchModelEntry({
};
return (
<VStack>
<Flex
flexDirection="column"
gap={2}
backgroundColor={
modelsToAdd.includes(model.name) ? 'accent.650' : 'base.800'
}
paddingX={4}
paddingY={2}
borderRadius={4}
>
<Flex gap={4}>
<IAICheckbox
value={model.name}
label={<Text fontWeight={500}>{model.name}</Text>}
isChecked={modelsToAdd.includes(model.name)}
isDisabled={existingModels.includes(model.location)}
onChange={foundModelsChangeHandler}
></IAICheckbox>
{existingModels.includes(model.location) && (
<Badge colorScheme="accent">{t('modelManager.modelExists')}</Badge>
)}
</Flex>
<Text fontStyle="italic" variant="subtext">
{model.location}
</Text>
<Flex
flexDirection="column"
gap={2}
backgroundColor={
modelsToAdd.includes(model.name) ? 'accent.650' : 'base.800'
}
paddingX={4}
paddingY={2}
borderRadius={4}
>
<Flex gap={4} alignItems="center" justifyContent="space-between">
<IAICheckbox
value={model.name}
label={<Text fontWeight={500}>{model.name}</Text>}
isChecked={modelsToAdd.includes(model.name)}
isDisabled={existingModels.includes(model.location)}
onChange={foundModelsChangeHandler}
></IAICheckbox>
{existingModels.includes(model.location) && (
<Badge colorScheme="accent">{t('modelManager.modelExists')}</Badge>
)}
</Flex>
</VStack>
<Text fontStyle="italic" variant="subtext">
{model.location}
</Text>
</Flex>
);
}
@ -215,10 +212,10 @@ export default function SearchModels() {
}
return (
<>
<Flex flexDirection="column" rowGap={4}>
{newFoundModels}
{shouldShowExistingModelsInSearch && existingFoundModels}
</>
</Flex>
);
};
@ -245,26 +242,26 @@ export default function SearchModels() {
<Text
sx={{
fontWeight: 500,
fontSize: 'sm',
}}
variant="subtext"
>
{t('modelManager.checkpointFolder')}
</Text>
<Text sx={{ fontWeight: 500, fontSize: 'sm' }}>{searchFolder}</Text>
<Text sx={{ fontWeight: 500 }}>{searchFolder}</Text>
</Flex>
<Spacer />
<IAIIconButton
aria-label={t('modelManager.scanAgain')}
tooltip={t('modelManager.scanAgain')}
icon={<BiReset />}
icon={<FaSearch />}
fontSize={18}
disabled={isProcessing}
onClick={() => dispatch(searchForModels(searchFolder))}
/>
<IAIIconButton
aria-label={t('modelManager.clearCheckpointFolder')}
icon={<FaPlus style={{ transform: 'rotate(45deg)' }} />}
tooltip={t('modelManager.clearCheckpointFolder')}
icon={<FaTrash />}
onClick={resetSearchModelHandler}
/>
</Flex>
@ -276,9 +273,9 @@ export default function SearchModels() {
}}
>
{({ handleSubmit }) => (
<form onSubmit={handleSubmit}>
<HStack columnGap={2} alignItems="flex-end" width="100%">
<FormControl isRequired width="lg">
<IAIForm onSubmit={handleSubmit} width="100%">
<HStack columnGap={2} alignItems="flex-end">
<FormControl flexGrow={1}>
<Field
as={IAIInput}
id="checkpointFolder"
@ -294,12 +291,12 @@ export default function SearchModels() {
tooltip={t('modelManager.findModels')}
type="submit"
disabled={isProcessing}
paddingX={10}
px={8}
>
{t('modelManager.findModels')}
</IAIButton>
</HStack>
</form>
</IAIForm>
)}
</Formik>
)}
@ -410,7 +407,6 @@ export default function SearchModels() {
maxHeight={72}
overflowY="scroll"
borderRadius="sm"
paddingInlineEnd={4}
gap={2}
>
{foundModels.length > 0 ? (

View File

@ -50,7 +50,6 @@ export default function ThemeChanger() {
return (
<IAIPopover
trigger="hover"
triggerComponent={
<IAIIconButton
aria-label={t('common.themeLabel')}

View File

@ -166,20 +166,8 @@ export default function InvokeTabs() {
[]
);
/**
* isLazy means the tabs are mounted and unmounted when changing them. There is a tradeoff here,
* as mounting is expensive, but so is retaining all tabs in the DOM at all times.
*
* Removing isLazy messes with the outside click watcher, which is used by ResizableDrawer.
* Because you have multiple handlers listening for an outside click, any click anywhere triggers
* the watcher for the hidden drawers, closing the open drawer.
*
* TODO: Add logic to the `useOutsideClick` in ResizableDrawer to enable it only for the active
* tab's drawer.
*/
return (
<Tabs
isLazy
defaultIndex={activeTab}
index={activeTab}
onChange={(index: number) => {

Some files were not shown because too many files have changed in this diff Show More