Compare commits

...

108 Commits

Author SHA1 Message Date
7d46e8430b Run tests on this branch 2024-03-19 01:05:27 -04:00
b3f8f22998 more debug statements 2024-03-19 00:20:32 -04:00
6a75c5ba08 Merge branch 'allow-model-type-passthrough-on-probe' of github.com:invoke-ai/InvokeAI into allow-model-type-passthrough-on-probe 2024-03-19 00:01:21 -04:00
dcbb1ff894 add debugging statements to catch hang 2024-03-19 00:01:09 -04:00
5751455618 Remove redundant embedding file read 2024-03-18 23:58:33 -04:00
a2232c2e09 Disallow adding new installs to the queue during stop events 2024-03-18 15:37:21 -04:00
30da11998b Use wait_for_job instead of wait_for_installs 2024-03-18 12:23:46 -04:00
1f000306f3 Wrap in try except for InvalidModelConfigException 2024-03-18 12:10:06 -04:00
c778d74a42 Skip hashing in test_heuristic_import_with_type 2024-03-18 11:53:49 -04:00
97bcd54408 Increase timeout for test_heuristic_import_with_type, fix Url import 2024-03-18 11:50:32 -04:00
62d7e38030 Run ruff 2024-03-18 10:11:56 -04:00
9b6f3bded9 Fix test to run on windows vms 2024-03-18 10:10:45 -04:00
aec567179d Pull format type setting out of model_type if statement 2024-03-17 22:44:26 -04:00
28bfc1c935 Simplify logic for determining model type in probe 2024-03-17 22:44:26 -04:00
39f62ac63c Run Ruff 2024-03-17 22:44:26 -04:00
9d0952c2ef Add unit test 2024-03-17 22:44:26 -04:00
902e26507d Allow type field to be a string 2024-03-17 22:44:26 -04:00
b83427d7ce Allow users to specify model type and skip detection step of probe 2024-03-17 22:44:26 -04:00
7387b0bdc9 install missing clip_vision encoders if required by an ip adapter (#5982)
Co-authored-by: Lincoln Stein <lstein@gmail.com>
2024-03-18 02:19:53 +00:00
7ea9cac9a3 Add sdxl controlnet models (#5980)
* allow removal of models with legacy relative path addressing

* added five controlnet models for sdxl to INITIAL_MODELS

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
2024-03-18 01:15:15 +00:00
ea5bc94b9c Resolve when instantiating _cached_model_paths 2024-03-18 11:17:23 +11:00
a1743647b7 Stop registering and moving models which have symlinks in the models dir 2024-03-18 11:17:23 +11:00
a6d64f69e1 Remove core conversion models (#5981)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [ ] Bug Fix
- [X] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [X] Yes
- [ ] No


## Description

We've been using a forked copy of the diffusers safetensors->diffusers
model conversion code, which was hacked to read CLIP and the other
models needed for conversion from the local invokeai root models
directory. This was getting unsustainable as the code bases diverged,
and also required the installation and maintenance of the "core/convert"
directory.

This PR gets rid of the hacked conversion code and reverts to using the
native diffusers methods. Core convert models are no longer installed at
root configure time. Instead we rely on the HuggingFace hub system to
download the conversion models if and when they are needed. They are
relatively small and the initial delay seems minor.

Conversion of SD-1, SD-2 (both epsilon and v-prediction), SDXL, VAE and
ControlNet SD-1/2 models has been tested. ControlNet SDXL models are
still a WIP due to the need for some work on the prober.

The main implication of this change is that InvokeAI is no longer
internet-independent and will need an internet connection at least the
first time a safetensors file needs to be converted. However, there are
several other places where the "no internet" rule is already violated,
and I suggest that we abandon this principle.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #5964 

## QA Instructions, Screenshots, Recordings

1. Remove or move `$INVOKEAI_ROOT/models/.cache`
2. Move `$INVOKEAI/models/core/convert`
3. Try generating with an unconverted .safetensors model.

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan

Merge when approved.

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2024-03-18 11:11:15 +11:00
e74e78894f Merge branch 'main' into optimization/remove-core-conversion-models 2024-03-17 19:25:44 -04:00
71a1740740 Remove core safetensors->diffusers conversion models
- No longer install core conversion models. Use the HuggingFace cache to load
  them if and when needed.

- Call directly into the diffusers library to perform conversions with only shallow
   wrappers around them to massage arguments, etc.

- At root configuration time, do not create all the possible model subdirectories,
  but let them be created and populated at model install time.

- Remove checks for missing core conversion files, since they are no
  longer installed.
2024-03-17 19:13:18 -04:00
b79f2f337e Allow removal of models with legacy relative path addressing (#5979)
* allow removal of models with legacy relative path addressing

* fix ruff error

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
2024-03-17 18:07:35 +00:00
a0420d1442 fix ruff error 2024-03-17 14:01:04 -04:00
a17021ba0c allow removal of models with legacy relative path addressing 2024-03-17 09:58:16 -04:00
faa1ffb06f Update diffusers 0.26.3 -> 0.27.0 and other HF packages 2024-03-16 05:44:58 -07:00
8c04eec210 fix initial main model logic 2024-03-15 10:22:16 -04:00
330e1354b4 Run typegen, update version to 4.0.0rc2 2024-03-14 17:01:36 -04:00
21621eebf0 feat(ui): handle control adapter processed images
- Add helper functions to build metadata for control adapters, including the processed images
- Update parses to parse the new metadata
2024-03-14 12:34:03 -07:00
c24f2046e7 chore(ui): typegen 2024-03-14 12:34:03 -07:00
297408d67e feat(nodes): add control adapter processed images to metadata
In the client, a controlnet or t2i adapter has two images:
- The source control image: the image the user selected (required)
- The processed control image: the user's image after we've processed it (optional)

The processed image is optional because a user may provide a pre-processed image.

We only actually use one of these images when building the graph, and until this change, we only stored one of the in image metadata. This created a situation where only a processed image was stored in metadata - say, a canny edge map - and the user-selected image wasn't provided.

By adding the processed image to metadata, we can recall both the control image and optional processed image.

This commit is followed by a UI-facing changes to support the change.
2024-03-14 12:34:03 -07:00
0131e7d928 fix(ui): recall control adapter metadata fields 2024-03-14 12:34:03 -07:00
06ff105a1f fix(ui): reset loras/control adapters when using recall all or remix 2024-03-14 12:34:03 -07:00
bb8e6bbee6 docs: update CONFIGURATION.md
Update model hashing docs
2024-03-15 00:14:48 +11:00
328dc99f3a fix(ui): log model load events
- Fix types
- Fix logging in listener
2024-03-14 18:29:55 +05:30
ef55077e84 feat(events): add submodel_type to model load events
This was lost during MM2 migration
2024-03-14 18:29:55 +05:30
ba3d8af161 fix(events): dump event payloads to serializable format 2024-03-14 18:29:55 +05:30
b07b7af710 feat(ui): single getModelConfigs query (#5962)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [z] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission

## Description

Single query, with simple wrapper hooks (type-safe). Updated everywhere
in frontend.

## QA Instructions, Screenshots, Recordings

Things that use models should work. All of this code is strictly
typechecked, so we can be confident in this change.

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan

This PR can be merged when approved

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->
2024-03-14 18:20:38 +05:30
19d66d5ec7 feat(ui): single getModelConfigs query
Single query, with simple wrapper hooks (type-safe). Updated everywhere in frontend.
2024-03-14 23:37:40 +11:00
ed20255abf fix(nodes): depth anything processor (#5956) (#5961)
We were passing a PIL image when we needed to pass the np image.

Closes #5956

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Description

We were passing a PIL image when we needed to pass the np image.

Closes #5956

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #5956

## QA Instructions, Screenshots, Recordings

Depth anything processor should work.

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan

This PR can be merged when approved

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->
2024-03-14 14:57:40 +05:30
fed1f983db fix(nodes): depth anything processor (#5956)
We were passing a PIL image when we needed to pass the np image.

Closes #5956
2024-03-14 20:14:53 +11:00
a386544a1d chore: ruff 2024-03-14 17:32:02 +11:00
0851de9090 closes #5932 2024-03-14 17:32:02 +11:00
1bd8e33f8c Work around missing core conversion model issue
- This adds additional logic to the safetensors->diffusers conversion script
  to check for and install missing core conversion models at runtime.

- Fixes #5934
2024-03-14 16:52:01 +11:00
e3f29ed320 tests: update default settings tests 2024-03-14 16:03:37 +11:00
3fd824306c feat(mm): probe for main model default settings
Currently, this is just the width and height, derived from the model base.
2024-03-14 16:03:37 +11:00
2584a950aa feat(ui): add w/h to default model settings 2024-03-14 16:03:37 +11:00
1adaf63253 chore(ui): typegen 2024-03-14 16:03:37 +11:00
b9f1a4bd65 feat(nodes): add w/h defaults for models 2024-03-14 16:03:37 +11:00
731942dbed feat(nodes): add constraints & descriptions to default settings 2024-03-14 16:03:37 +11:00
4117cea5bf tidy(mm): remove misplaced comment 2024-03-14 15:54:42 +11:00
21617f3bc1 docs: update description for hashing_algorithm in config 2024-03-14 15:54:42 +11:00
9fcd67b5c0 feat(mm): add algorithm prefix to hashes
For example:
- md5:a0cd925fc063f98dbf029eee315060c3
- sha1:9e362940e5603fdc60566ea100a288ba2fe48b8c
- blake3:ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0
2024-03-14 15:54:42 +11:00
a4be935458 docs: update config docs 2024-03-14 15:54:42 +11:00
eb6e6548ed feat(mm): faster hashing for spinning disk HDDs
BLAKE3 has poor performance on spinning disks when parallelized. See https://github.com/BLAKE3-team/BLAKE3/issues/31

- Replace `skip_model_hash` setting with `hashing_algorithm`. Any algorithm we support is accepted.
- Add `random` algorithm: hashes a UUID with BLAKE3 to create a random "hash". Equivalent to the previous skip functionality.
- Add `blake3_single` algorithm: hashes on a single thread using BLAKE3, fixes the aforementioned performance issue
- Update model probe to accept the algorithm to hash with as an optional arg, defaulting to `blake3`
- Update all calls of the probe to use the app's configured hashing algorithm
- Update an external script that probes models
- Update tests
- Move ModelHash into its own module to avoid circuclar import issues
2024-03-14 15:54:42 +11:00
8287fcf097 feat: ✏️ rename "Workflow Editor" tab label to "Workflows" 2024-03-14 12:22:23 +11:00
dd475e28ed chore(ui): remove unused translation keys via script 2024-03-14 11:38:29 +11:00
24e741e2d1 feat(ui): add script to clean translations
This script removes unused translations from the `en.json` source translation file:
- Parse `en.json` to build a list of all keys, e.g. `controlnet.depthAnything`
- Check every frontend file for every key
- If the key is not found, it is removed from the translation file
- Exact matches (e.g. `controlnet.depthAnything`) and stem matches (e.g. `depthAnything`) are ignored
2024-03-14 11:38:29 +11:00
e0bf9ce5c6 tidy(ui): use normal quotes in translations 2024-03-14 11:38:29 +11:00
c66e8b395e fix(ui): remove unused input on depth anything processor node 2024-03-14 10:53:57 +11:00
4c417adc82 fix(ui): use revised metadata model types
We can also totally remove the fetch logic because we store the same model data in state now.
2024-03-14 10:53:57 +11:00
437a413ca3 chore(ui): typegen 2024-03-14 10:53:57 +11:00
4492bedd19 tidy(nodes): use ModelIdentifierField for model metadata
Until recently, this had a different shape than the ModelMetadataField. They are now the same, so we can re-use the ModelIdentifierField.
2024-03-14 10:53:57 +11:00
db12ce95a8 fix(ui): invalid collect node error w/ control adapters
The graph builders used awaited functions within `Array.prototype.forEach` loops. This doesn't do what you'd think. This caused graphs to be enqueued before they were fully constructed.

 Changed to `for..of` loops to fix this.
2024-03-14 10:53:57 +11:00
ee3a1a95ef fix(ui): control adapters require control images
There wasn't enough validation of control adapters during graph building. It would be possible for a graph to be built with empty collect node, causing an error. Addressed with an extra check.

This should never happen in practice, because the invoke button should be disabled if an invalid CA is active.
2024-03-14 10:53:57 +11:00
4bb5aba70e feat(ui): only fetch TIs on first load, add comment 2024-03-14 07:38:09 +11:00
cd55c23713 initiate TI model query when socket connects so user doesnt have to wait when opening prompt trigger phrases 2024-03-14 07:38:09 +11:00
1d2743af1b remove log 2024-03-14 07:25:48 +11:00
99d2099ccd add key for controladapter CustomSelect too 2024-03-14 07:25:48 +11:00
b64a693f16 try adding a key to force rerender when items load 2024-03-14 07:25:48 +11:00
9d523a3094 chore: cleanup DepthAnything code (#5945)
## What type of PR is this? (check all applicable)

- [x] Optimization

## Description

Was merged into next but never carried over to main. So cleaning up
again.
2024-03-13 20:46:54 +05:30
af660163ca chore: cleanup DepthAnything code 2024-03-13 20:35:52 +05:30
7e4b462fca docs: OVERVIEW.md typo 2024-03-13 22:43:20 +11:00
4468dd6948 docs: update OVERVIEW.md
Update pkg scripts.
2024-03-13 22:43:20 +11:00
4f39e248dd docs: update OVERVIEW.md
Fix links
2024-03-13 22:43:20 +11:00
44b3e5d43f docs: update INVOCATION_API.md
Add blurb about `WithMetadata` and `WithBoard` mixins.
2024-03-13 22:43:20 +11:00
8894a9e48a docs: update WORKFLOWS.md 2024-03-13 22:43:20 +11:00
c73f58e486 docs: move frontend docs to mkdocs 2024-03-13 22:43:20 +11:00
614fece147 chore(ui): prettier 2024-03-13 21:02:29 +11:00
8ef8082d65 feat(ui): style add model panel 2024-03-13 21:02:29 +11:00
d93d4afbb7 feat(ui): style HF scan tab 2024-03-13 21:02:29 +11:00
01207a2fa5 fix(mm): config.json to indicates diffusers model 2024-03-13 21:02:29 +11:00
d0800c4888 ui consistency, moved is_diffusers logic to backend, extended HuggingFaceMetadata, removed logic from service 2024-03-13 21:02:29 +11:00
2a300ecada updated add model copy, added search to hugging face results 2024-03-13 21:02:29 +11:00
90340a39c7 clean up python errors 2024-03-13 21:02:29 +11:00
ee77abb4fe updated simple install button to match other tabs 2024-03-13 21:02:29 +11:00
004bca5c42 updated endpoint types 2024-03-13 21:02:29 +11:00
5ad048a161 fixed error handling 2024-03-13 21:02:29 +11:00
6369ccd05e added placeholders, updated some copy 2024-03-13 21:02:29 +11:00
3a5314f1ca install model if diffusers or single file, cleaned up backend logic to not mess with existing model install 2024-03-13 21:02:29 +11:00
4c0896e436 removed log 2024-03-13 21:02:29 +11:00
f7cd3cf1f4 added hf models import tab and route for getting available hf models 2024-03-13 21:02:29 +11:00
efea1a8a7d ci: add always_run input to checks & tests, use this on release workflow
This bypasses the `changed-files` check, and forces the checks to run. The release workflow sets this flag to ensure that the checks and tests are always run for a release.
2024-03-13 15:32:42 +11:00
d0d695c020 disable trigger phrase form if empty 2024-03-12 21:08:15 -04:00
2a648da557 updated model manager to display when import item is cancelled 2024-03-13 09:18:05 +11:00
54f1a1f952 Update l2i invoke and seamless to support AutoencoderTiny, remove att… (#5936)
…ention processors if no mid_block is detected

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [x] No


## Description
L2i throws an assertion error when run with `madebyollin/taesdxl` due to
it requiring a different class in diffusers to load it. This is a small
PR to update seamless and l2i to accept AutoencoderTiny models and not
throw exceptions while processing them.

## QA Instructions, Screenshots, Recordings

<img width="445" alt="Screenshot 2024-03-12 at 12 04 29 PM"
src="https://github.com/invoke-ai/InvokeAI/assets/58442074/34a17e44-d911-4fef-8fc1-71f7b688688c">
Run an sdxl pipeline using a vae that requires AutoencoderTiny and
validate that the image successfully encodes and decodes.

## Merge Plan

This PR can be merged when approved
2024-03-12 21:52:32 +05:30
8d2a4db902 Found another instance of expecting a mid_block on the decoder in a vae 2024-03-12 12:11:38 -04:00
7b393656de Update l2i invoke and seamless to support AutoencoderTiny, remove attention processors if no mid_block is detected 2024-03-12 12:00:24 -04:00
43948e0758 feat(ui): add setting for always show image size badge 2024-03-12 18:52:23 +11:00
cc03fcbcb6 style(ui): tweak image dimension badge overlay styles 2024-03-12 18:52:23 +11:00
d1e445fa49 fix(ui): changed to theme tokens 2024-03-12 18:52:23 +11:00
adba8489f2 fix(ui): made changes to avoid overlapping 2024-03-12 18:52:23 +11:00
d919022ba5 fix(ui): fixed requested changes and made the badge display on hover 2024-03-12 18:52:23 +11:00
e076898798 fix(ui): logic to remove badge for small image size 2024-03-12 18:52:23 +11:00
9f19b766a4 feat(ui): Add image size badge to gallery images 2024-03-12 18:52:23 +11:00
112 changed files with 2565 additions and 3830 deletions

View File

@ -1,7 +1,7 @@
# Runs frontend code quality checks.
#
# Checks for changes to frontend files before running the checks.
# When manually triggered or when called from another workflow, always runs the checks.
# If always_run is true, always runs the checks.
name: 'frontend checks'
@ -16,7 +16,19 @@ on:
- 'synchronize'
merge_group:
workflow_dispatch:
inputs:
always_run:
description: 'Always run the checks'
required: true
type: boolean
default: true
workflow_call:
inputs:
always_run:
description: 'Always run the checks'
required: true
type: boolean
default: true
defaults:
run:
@ -30,7 +42,7 @@ jobs:
- uses: actions/checkout@v4
- name: check for changed frontend files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
if: ${{ inputs.always_run != true }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
@ -39,30 +51,30 @@ jobs:
- 'invokeai/frontend/web/**'
- name: install dependencies
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
uses: ./.github/actions/install-frontend-deps
- name: tsc
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
run: 'pnpm lint:tsc'
shell: bash
- name: dpdm
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
run: 'pnpm lint:dpdm'
shell: bash
- name: eslint
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
run: 'pnpm lint:eslint'
shell: bash
- name: prettier
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
run: 'pnpm lint:prettier'
shell: bash
- name: knip
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
run: 'pnpm lint:knip'
shell: bash

View File

@ -1,7 +1,7 @@
# Runs frontend tests.
#
# Checks for changes to frontend files before running the tests.
# When manually triggered or called from another workflow, always runs the tests.
# If always_run is true, always runs the tests.
name: 'frontend tests'
@ -16,7 +16,19 @@ on:
- 'synchronize'
merge_group:
workflow_dispatch:
inputs:
always_run:
description: 'Always run the tests'
required: true
type: boolean
default: true
workflow_call:
inputs:
always_run:
description: 'Always run the tests'
required: true
type: boolean
default: true
defaults:
run:
@ -30,7 +42,7 @@ jobs:
- uses: actions/checkout@v4
- name: check for changed frontend files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
if: ${{ inputs.always_run != true }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
@ -39,10 +51,10 @@ jobs:
- 'invokeai/frontend/web/**'
- name: install dependencies
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
uses: ./.github/actions/install-frontend-deps
- name: vitest
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
run: 'pnpm test:no-watch'
shell: bash

View File

@ -1,7 +1,7 @@
# Runs python code quality checks.
#
# Checks for changes to python files before running the checks.
# When manually triggered or called from another workflow, always runs the tests.
# If always_run is true, always runs the checks.
#
# TODO: Add mypy or pyright to the checks.
@ -18,7 +18,19 @@ on:
- 'synchronize'
merge_group:
workflow_dispatch:
inputs:
always_run:
description: 'Always run the checks'
required: true
type: boolean
default: true
workflow_call:
inputs:
always_run:
description: 'Always run the checks'
required: true
type: boolean
default: true
jobs:
python-checks:
@ -29,7 +41,7 @@ jobs:
uses: actions/checkout@v4
- name: check for changed python files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
if: ${{ inputs.always_run != true }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
@ -41,7 +53,7 @@ jobs:
- 'tests/**'
- name: setup python
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
uses: actions/setup-python@v5
with:
python-version: '3.10'
@ -49,16 +61,16 @@ jobs:
cache-dependency-path: pyproject.toml
- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: pip install ruff
shell: bash
- name: ruff check
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: ruff check --output-format=github .
shell: bash
- name: ruff format
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: ruff format --check .
shell: bash

View File

@ -1,7 +1,7 @@
# Runs python tests on a matrix of python versions and platforms.
#
# Checks for changes to python files before running the tests.
# When manually triggered or called from another workflow, always runs the tests.
# If always_run is true, always runs the tests.
name: 'python tests'
@ -9,6 +9,7 @@ on:
push:
branches:
- 'main'
- 'bug-install-job-running-multiple-times'
pull_request:
types:
- 'ready_for_review'
@ -16,7 +17,19 @@ on:
- 'synchronize'
merge_group:
workflow_dispatch:
inputs:
always_run:
description: 'Always run the tests'
required: true
type: boolean
default: true
workflow_call:
inputs:
always_run:
description: 'Always run the tests'
required: true
type: boolean
default: true
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
@ -63,7 +76,7 @@ jobs:
uses: actions/checkout@v4
- name: check for changed python files
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
if: ${{ inputs.always_run != true }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
@ -75,7 +88,7 @@ jobs:
- 'tests/**'
- name: setup python
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
@ -83,12 +96,12 @@ jobs:
cache-dependency-path: pyproject.toml
- name: install dependencies
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
pip3 install --editable=".[test]"
- name: run pytest
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: pytest

View File

@ -30,15 +30,23 @@ jobs:
frontend-checks:
uses: ./.github/workflows/frontend-checks.yml
with:
always_run: true
frontend-tests:
uses: ./.github/workflows/frontend-tests.yml
with:
always_run: true
python-checks:
uses: ./.github/workflows/python-checks.yml
with:
always_run: true
python-tests:
uses: ./.github/workflows/python-tests.yml
with:
always_run: true
build:
uses: ./.github/workflows/build-installer.yml

View File

@ -0,0 +1,133 @@
# Invoke UI
Invoke's UI is made possible by many contributors and open-source libraries. Thank you!
## Dev environment
### Setup
1. Install [node] and [pnpm].
1. Run `pnpm i` to install all packages.
#### Run in dev mode
1. From `invokeai/frontend/web/`, run `pnpm dev`.
1. From repo root, run `python scripts/invokeai-web.py`.
1. Point your browser to the dev server address, e.g. <http://localhost:5173/>
### Package scripts
- `dev`: run the frontend in dev mode, enabling hot reloading
- `build`: run all checks (madge, eslint, prettier, tsc) and then build the frontend
- `typegen`: generate types from the OpenAPI schema (see [Type generation])
- `lint:dpdm`: check circular dependencies
- `lint:eslint`: check code quality
- `lint:prettier`: check code formatting
- `lint:tsc`: check type issues
- `lint:knip`: check for unused exports or objects (failures here are just suggestions, not hard fails)
- `lint`: run all checks concurrently
- `fix`: run `eslint` and `prettier`, fixing fixable issues
### Type generation
We use [openapi-typescript] to generate types from the app's OpenAPI schema.
The generated types are committed to the repo in [schema.ts].
```sh
# from the repo root, start the server
python scripts/invokeai-web.py
# from invokeai/frontend/web/, run the script
pnpm typegen
```
### Localization
We use [i18next] for localization, but translation to languages other than English happens on our [Weblate] project.
Only the English source strings should be changed on this repo.
### VSCode
#### Example debugger config
```jsonc
{
"version": "0.2.0",
"configurations": [
{
"type": "chrome",
"request": "launch",
"name": "Invoke UI",
"url": "http://localhost:5173",
"webRoot": "${workspaceFolder}/invokeai/frontend/web"
}
]
}
```
#### Remote dev
We've noticed an intermittent timeout issue with the VSCode remote dev port forwarding.
We suggest disabling the editor's port forwarding feature and doing it manually via SSH:
```sh
ssh -L 9090:localhost:9090 -L 5173:localhost:5173 user@host
```
## Contributing Guidelines
Thanks for your interest in contributing to the Invoke Web UI!
Please follow these guidelines when contributing.
### Check in before investing your time
Please check in before you invest your time on anything besides a trivial fix, in case it conflicts with ongoing work or isn't aligned with the vision for the app.
If a feature request or issue doesn't already exist for the thing you want to work on, please create one.
Ping `@psychedelicious` on [discord] in the `#frontend-dev` channel or in the feature request / issue you want to work on - we're happy to chat.
### Code conventions
- This is a fairly complex app with a deep component tree. Please use memoization (`useCallback`, `useMemo`, `memo`) with enthusiasm.
- If you need to add some global, ephemeral state, please use [nanostores] if possible.
- Be careful with your redux selectors. If they need to be parameterized, consider creating them inside a `useMemo`.
- Feel free to use `lodash` (via `lodash-es`) to make the intent of your code clear.
- Please add comments describing the "why", not the "how" (unless it is really arcane).
### Commit format
Please use the [conventional commits] spec for the web UI, with a scope of "ui":
- `chore(ui): bump deps`
- `chore(ui): lint`
- `feat(ui): add some cool new feature`
- `fix(ui): fix some bug`
### Submitting a PR
- Ensure your branch is tidy. Use an interactive rebase to clean up the commit history and reword the commit messages if they are not descriptive.
- Run `pnpm lint`. Some issues are auto-fixable with `pnpm fix`.
- Fill out the PR form when creating the PR.
- It doesn't need to be super detailed, but a screenshot or video is nice if you changed something visually.
- If a section isn't relevant, delete it. There are no UI tests at this time.
## Other docs
- [Workflows - Design and Implementation]
- [State Management]
[node]: https://nodejs.org/en/download/
[pnpm]: https://github.com/pnpm/pnpm
[discord]: https://discord.gg/ZmtBAhwWhy
[i18next]: https://github.com/i18next/react-i18next
[Weblate]: https://hosted.weblate.org/engage/invokeai/
[openapi-typescript]: https://github.com/drwpow/openapi-typescript
[Type generation]: #type-generation
[schema.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/services/api/schema.ts
[conventional commits]: https://www.conventionalcommits.org/en/v1.0.0/
[Workflows - Design and Implementation]: ./WORKFLOWS.md
[State Management]: ./STATE_MGMT.md

View File

@ -1,40 +1,5 @@
# Workflows - Design and Implementation
<!-- @import "[TOC]" {cmd="toc" depthFrom=1 depthTo=6 orderedList=false} -->
<!-- code_chunk_output -->
- [Workflows - Design and Implementation](#workflows---design-and-implementation)
- [Design](#design)
- [Linear UI](#linear-ui)
- [Workflow Editor](#workflow-editor)
- [Workflows](#workflows)
- [Workflow -> reactflow state -> InvokeAI graph](#workflow---reactflow-state---invokeai-graph)
- [Nodes vs Invocations](#nodes-vs-invocations)
- [Workflow Linear View](#workflow-linear-view)
- [OpenAPI Schema](#openapi-schema)
- [Field Instances and Templates](#field-instances-and-templates)
- [Stateful vs Stateless Fields](#stateful-vs-stateless-fields)
- [Collection and Polymorphic Fields](#collection-and-polymorphic-fields)
- [Implementation](#implementation)
- [zod Schemas and Types](#zod-schemas-and-types)
- [OpenAPI Schema Parsing](#openapi-schema-parsing)
- [Parsing Field Types](#parsing-field-types)
- [Primitive Types](#primitive-types)
- [Complex Types](#complex-types)
- [Collection Types](#collection-types)
- [Collection or Scalar Types](#collection-or-scalar-types)
- [Optional Fields](#optional-fields)
- [Building Field Input Templates](#building-field-input-templates)
- [Building Field Output Templates](#building-field-output-templates)
- [Managing reactflow State](#managing-reactflow-state)
- [Building Nodes and Edges](#building-nodes-and-edges)
- [Building a Workflow](#building-a-workflow)
- [Loading a Workflow](#loading-a-workflow)
- [Workflow Migrations](#workflow-migrations)
<!-- /code_chunk_output -->
> This document describes, at a high level, the design and implementation of workflows in the InvokeAI frontend. There are a substantial number of implementation details not included, but which are hopefully clear from the code.
InvokeAI's backend uses graphs, composed of **nodes** and **edges**, to process data and generate images.
@ -152,13 +117,13 @@ Stateless fields do not store their value in the node, so their field instances
"Custom" fields will always be treated as stateless fields.
##### Collection and Polymorphic Fields
##### Collection and Scalar Fields
Field types have a name and two flags which may identify it as a **collection** or **polymorphic** field.
Field types have a name and two flags which may identify it as a **collection** or **collection or scalar** field.
If a field is annotated in python as a list, its field type is parsed and flagged as a collection type (e.g. `list[int]`).
If a field is annotated in python as a list, its field type is parsed and flagged as a **collection** type (e.g. `list[int]`).
If it is annotated as a union of a type and list, the type will be flagged as a polymorphic type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
If it is annotated as a union of a type and list, the type will be flagged as a **collection or scalar** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
## Implementation
@ -338,13 +303,13 @@ Migration logic is in [migrations.ts].
[reactflow]: https://github.com/xyflow/xyflow 'reactflow'
[reactflow-concepts]: https://reactflow.dev/learn/concepts/terms-and-definitions
[reactflow-events]: https://reactflow.dev/api-reference/react-flow#event-handlers
[buildWorkflow.ts]: ../src/features/nodes/util/workflow/buildWorkflow.ts
[nodesSlice.ts]: ../src/features/nodes/store/nodesSlice.ts
[buildLinearTextToImageGraph.ts]: ../src/features/nodes/util/graph/buildLinearTextToImageGraph.ts
[buildNodesGraph.ts]: ../src/features/nodes/util/graph/buildNodesGraph.ts
[buildInvocationNode.ts]: ../src/features/nodes/util/node/buildInvocationNode.ts
[validateWorkflow.ts]: ../src/features/nodes/util/workflow/validateWorkflow.ts
[migrations.ts]: ../src/features/nodes/util/workflow/migrations.ts
[parseSchema.ts]: ../src/features/nodes/util/schema/parseSchema.ts
[buildFieldInputTemplate.ts]: ../src/features/nodes/util/schema/buildFieldInputTemplate.ts
[buildFieldOutputTemplate.ts]: ../src/features/nodes/util/schema/buildFieldOutputTemplate.ts
[buildWorkflow.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts
[nodesSlice.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
[buildLinearTextToImageGraph.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts
[buildNodesGraph.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts
[buildInvocationNode.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts
[validateWorkflow.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts
[migrations.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts
[parseSchema.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts
[buildFieldInputTemplate.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts
[buildFieldOutputTemplate.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldOutputTemplate.ts

View File

@ -184,16 +184,26 @@ The provided token will be added as a `Bearer` token to the network requests to
### Model Hashing
Models are hashed during installation with the `BLAKE3` algorithm, providing a stable identifier for models across all platforms.
Models are hashed during installation, providing a stable identifier for models across all platforms. The default algorithm is `blake3`, with a multi-threaded implementation.
Model hashing is a one-time operation, but it may take a couple minutes to hash a large model collection. You may opt out of model hashing and instead have a random UUID assigned instead:
If your models are stored on a spinning hard drive, we suggest using `blake3_single`, the single-threaded implementation. The hashes are the same, but it's much faster on spinning disks.
```yaml
InvokeAI:
Model Install:
skip_model_hash: true
hashing_algorithm: blake3_single
```
Model hashing is a one-time operation, but it may take a couple minutes to hash a large model collection. You may opt out of model hashing entirely by setting the algorithm to `random`.
```yaml
InvokeAI:
Model Install:
hashing_algorithm: random
```
Most common algorithms are supported, like `md5`, `sha256`, and `sha512`. These are typically much, much slower than `blake3`.
### Paths
These options set the paths of various directories and files used by

View File

@ -22,6 +22,24 @@ class MyInvocation(BaseInvocation):
...
```
The full API is documented below.
## Invocation Mixins
Two important mixins are provided to facilitate working with metadata and gallery boards.
### `WithMetadata`
Inherit from this class (in addition to `BaseInvocation`) to add a `metadata` input to your node. When you do this, you can access the metadata dict from `self.metadata` in the `invoke()` function.
The dict will be populated via the node's input, and you can add any metadata you'd like to it. When you call `context.images.save()`, if the metadata dict has any data, it be automatically embedded in the image.
### `WithBoard`
Inherit from this class (in addition to `BaseInvocation`) to add a `board` input to your node. This renders as a drop-down to select a board. The user's selection will be accessible from `self.board` in the `invoke()` function.
When you call `context.images.save()`, if a board was selected, the image will added to that board as it is saved.
<!-- prettier-ignore-start -->
::: invokeai.app.services.shared.invocation_context.InvocationContext
options:

View File

@ -11,7 +11,7 @@ from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, ConfigDict, Field
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
@ -29,6 +29,8 @@ from invokeai.backend.model_manager.config import (
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies
@ -246,6 +248,40 @@ async def scan_for_models(
return scan_results
class HuggingFaceModels(BaseModel):
urls: List[AnyHttpUrl] | None = Field(description="URLs for all checkpoint format models in the metadata")
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model")
@model_manager_router.get(
"/hugging_face",
operation_id="get_hugging_face_models",
responses={
200: {"description": "Hugging Face repo scanned successfully"},
400: {"description": "Invalid hugging face repo"},
},
status_code=200,
response_model=HuggingFaceModels,
)
async def get_hugging_face_models(
hugging_face_repo: str = Query(description="Hugging face repo to search for models", default=None),
) -> HuggingFaceModels:
try:
metadata = HuggingFaceMetadataFetch().from_id(hugging_face_repo)
except UnknownMetadataException:
raise HTTPException(
status_code=400,
detail="No HuggingFace repository found",
)
assert isinstance(metadata, ModelMetadataWithFiles)
return HuggingFaceModels(
urls=metadata.ckpt_urls,
is_diffusers=metadata.is_diffusers,
)
@model_manager_router.patch(
"/i/{key}",
operation_id="update_model_record",

View File

@ -574,7 +574,7 @@ DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
title="Depth Anything Processor",
tags=["controlnet", "depth", "depth anything"],
category="controlnet",
version="1.0.0",
version="1.0.1",
)
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a depth map based on the Depth Anything algorithm"""
@ -583,13 +583,12 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
default="small", description="The size of the depth model to use"
)
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
offload: bool = InputField(default=False)
def run_processor(self, image: Image.Image):
depth_anything_detector = DepthAnythingDetector()
depth_anything_detector.load_model(model_size=self.model_size)
processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload)
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
return processed_image

View File

@ -15,7 +15,7 @@ from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig, ModelType
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, IPAdapterConfig, ModelType
class IPAdapterField(BaseModel):
@ -89,17 +89,32 @@ class IPAdapterInvocation(BaseInvocation):
assert isinstance(ip_adapter_info, IPAdapterConfig)
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
assert len(image_encoder_models) == 1
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=ModelIdentifierField.from_config(image_encoder_models[0]),
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
weight=self.weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
),
)
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
found = False
while not found:
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
found = len(image_encoder_models) > 0
if not found:
context.logger.warning(
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed."
)
context.logger.warning("Downloading and installing now. This may take a while.")
installer = context._services.model_manager.install
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
installer.wait_for_job(job, timeout=600) # wait up to 10 minutes - then raise a TimeoutException
assert len(image_encoder_models) == 1
return image_encoder_models[0]

View File

@ -837,14 +837,14 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL))
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module)
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
@ -1018,7 +1018,7 @@ class ImageToLatentsInvocation(BaseInvocation):
if upcast:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,

View File

@ -20,8 +20,8 @@ from invokeai.app.invocations.fields import (
OutputField,
UIType,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from ...version import __version__
@ -31,20 +31,10 @@ class MetadataItemField(BaseModel):
value: Any = Field(description=FieldDescriptions.metadata_item_value)
class ModelMetadataField(BaseModel):
"""Model Metadata Field"""
key: str
hash: str
name: str
base: BaseModelType
type: ModelType
class LoRAMetadataField(BaseModel):
"""LoRA Metadata Field"""
model: ModelMetadataField = Field(description=FieldDescriptions.lora_model)
model: ModelIdentifierField = Field(description=FieldDescriptions.lora_model)
weight: float = Field(description=FieldDescriptions.lora_weight)
@ -52,19 +42,16 @@ class IPAdapterMetadataField(BaseModel):
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: ModelMetadataField = Field(
description="The IP-Adapter model.",
)
weight: Union[float, list[float]] = Field(
description="The weight given to the IP-Adapter",
)
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
class T2IAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: ModelMetadataField = Field(description="The T2I-Adapter model to use.")
image: ImageField = Field(description="The control image.")
processed_image: Optional[ImageField] = Field(default=None, description="The control image, after processing.")
t2i_adapter_model: ModelIdentifierField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
@ -77,7 +64,8 @@ class T2IAdapterMetadataField(BaseModel):
class ControlNetMetadataField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ModelMetadataField = Field(description="The ControlNet model to use")
processed_image: Optional[ImageField] = Field(default=None, description="The control image, after processing.")
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
@ -178,7 +166,7 @@ class CoreMetadataInvocation(BaseInvocation):
default=None,
description="The number of skipped CLIP layers",
)
model: Optional[ModelMetadataField] = InputField(default=None, description="The main model used for inference")
model: Optional[ModelIdentifierField] = InputField(default=None, description="The main model used for inference")
controlnets: Optional[list[ControlNetMetadataField]] = InputField(
default=None, description="The ControlNets used for inference"
)
@ -197,7 +185,7 @@ class CoreMetadataInvocation(BaseInvocation):
default=None,
description="The name of the initial image",
)
vae: Optional[ModelMetadataField] = InputField(
vae: Optional[ModelIdentifierField] = InputField(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
@ -228,7 +216,7 @@ class CoreMetadataInvocation(BaseInvocation):
)
# SDXL Refiner
refiner_model: Optional[ModelMetadataField] = InputField(
refiner_model: Optional[ModelIdentifierField] = InputField(
default=None,
description="The SDXL Refiner model used",
)

View File

@ -179,6 +179,8 @@ from pydantic import BaseModel, Field, field_validator
from pydantic.config import JsonDict
from pydantic_settings import SettingsConfigDict
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from .config_base import InvokeAISettings
INIT_FILE = Path("invokeai.yaml")
@ -259,7 +261,7 @@ class InvokeAIAppConfig(InvokeAISettings):
profile_prefix: **Development**: An optional prefix for profile output files.
profiles_dir: **Development**: Path to profiles output directory.
version: **CLIArgs**: CLI arg - show InvokeAI version and exit.
skip_model_hash: **Model Install**: Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models.
hashing_algorithm: **Model Install**: Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.
remote_api_tokens: **Model Install**: List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.
ram: **Model Cache**: Maximum memory amount used by memory model cache for rapid switching (GB).
vram: **Model Cache**: Amount of VRAM reserved for model storage (GB)
@ -360,7 +362,7 @@ class InvokeAIAppConfig(InvokeAISettings):
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory.", json_schema_extra=Categories.Nodes)
# MODEL INSTALL
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models.", json_schema_extra=Categories.ModelInstall)
hashing_algorithm : HASHING_ALGORITHMS = Field(default="blake3", description="Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.", json_schema_extra=Categories.ModelInstall)
remote_api_tokens : Optional[list[URLRegexToken]] = Field(
default=None,
description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.",

View File

@ -12,6 +12,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
)
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager import AnyModelConfig
from invokeai.backend.model_manager.config import SubModelType
class EventServiceBase:
@ -80,7 +81,7 @@ class EventServiceBase:
"graph_execution_state_id": graph_execution_state_id,
"node_id": node_id,
"source_node_id": source_node_id,
"progress_image": progress_image.model_dump() if progress_image is not None else None,
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None,
"step": step,
"order": order,
"total_steps": total_steps,
@ -180,6 +181,7 @@ class EventServiceBase:
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is requested"""
self.__emit_queue_event(
@ -189,7 +191,8 @@ class EventServiceBase:
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(),
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
@ -200,6 +203,7 @@ class EventServiceBase:
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event(
@ -209,7 +213,8 @@ class EventServiceBase:
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(),
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
@ -254,8 +259,8 @@ class EventServiceBase:
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
},
"batch_status": batch_status.model_dump(),
"queue_status": queue_status.model_dump(),
"batch_status": batch_status.model_dump(mode="json"),
"queue_status": queue_status.model_dump(mode="json"),
},
)
@ -405,7 +410,7 @@ class EventServiceBase:
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
)
def emit_model_install_cancelled(self, source: str) -> None:
def emit_model_install_cancelled(self, source: str, id: int) -> None:
"""
Emit when an install job is cancelled.
@ -413,7 +418,7 @@ class EventServiceBase:
"""
self.__emit_model_event(
event_name="model_install_cancelled",
payload={"source": source},
payload={"source": source, "id": id},
)
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:

View File

@ -22,7 +22,6 @@ from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
@ -134,6 +133,14 @@ class ModelInstallService(ModelInstallServiceBase):
self._download_cache.clear()
self._running = False
def _put_in_queue(self, job: ModelInstallJob) -> None:
print(f'DEBUG: in _put_in_queue(job={job.id})')
if self._stop_event.is_set():
self.cancel_job(job)
else:
print(f'DEBUG: putting {job.id} into the install queue')
self._install_queue.put(job)
def register_path(
self,
model_path: Union[Path, str],
@ -154,10 +161,7 @@ class ModelInstallService(ModelInstallServiceBase):
model_path = Path(model_path)
config = config or {}
if self._app_config.skip_model_hash:
config["hash"] = uuid_string()
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm)
if preferred_name := config.get("name"):
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
@ -222,7 +226,7 @@ class ModelInstallService(ModelInstallServiceBase):
if isinstance(source, LocalModelSource):
install_job = self._import_local_model(source, config)
self._install_queue.put(install_job) # synchronously install
self._put_in_queue(install_job) # synchronously install
elif isinstance(source, HFModelSource):
install_job = self._import_from_hf(source, config)
elif isinstance(source, URLModelSource):
@ -332,7 +336,7 @@ class ModelInstallService(ModelInstallServiceBase):
yaml_path.rename(yaml_path.with_suffix(".yaml.bak"))
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback)
self._models_installed.clear()
@ -346,7 +350,7 @@ class ModelInstallService(ModelInstallServiceBase):
"""Unregister the model. Delete its files only if they are within our models directory."""
model = self.record_store.get_model(key)
models_dir = self.app_config.models_path
model_path = Path(model.path)
model_path = models_dir / Path(model.path) # handle legacy relative model paths
if model_path.is_relative_to(models_dir):
self.unconditionally_delete(key)
else:
@ -354,7 +358,7 @@ class ModelInstallService(ModelInstallServiceBase):
def unconditionally_delete(self, key: str) -> None: # noqa D102
model = self.record_store.get_model(key)
model_path = Path(model.path)
model_path = self.app_config.models_path / model.path
if model_path.is_dir():
rmtree(model_path)
else:
@ -407,10 +411,11 @@ class ModelInstallService(ModelInstallServiceBase):
done = True
continue
try:
print(f'DEBUG: _install_next_item() checking for a job to install')
job = self._install_queue.get(timeout=1)
except Empty:
continue
print(f'DEBUG: _install_next_item() got job {job.id}, status={job.status}')
assert job.local_path is not None
try:
if job.cancelled:
@ -436,6 +441,7 @@ class ModelInstallService(ModelInstallServiceBase):
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
print(f'DEBUG: _install_next_item() signaling completion for job={job.id}, status={job.status}')
self._signal_job_completed(job)
except InvalidModelConfigException as excp:
@ -496,6 +502,8 @@ class ModelInstallService(ModelInstallServiceBase):
for cur_base_model in BaseModelType:
for cur_model_type in ModelType:
models_dir = self._app_config.models_path / Path(cur_base_model.value, cur_model_type.value)
if not models_dir.exists():
continue
installed.update(self.scan_directory(models_dir))
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
@ -522,7 +530,7 @@ class ModelInstallService(ModelInstallServiceBase):
new_path = models_dir / model.base.value / model.type.value / old_path.name
if old_path == new_path:
if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
return model
self._logger.info(f"Moving {model.name} to {new_path}.")
@ -585,10 +593,7 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str:
config = config or {}
if self._app_config.skip_model_hash:
config["hash"] = uuid_string()
info = info or ModelProbe.probe(model_path, config)
info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm)
model_path = model_path.resolve()
@ -786,14 +791,16 @@ class ModelInstallService(ModelInstallServiceBase):
def _download_complete_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"{download_job.source}: model download complete")
print(f'DEBUG: _download_complete_callback(download_job={download_job.source}')
with self._lock:
install_job = self._download_cache[download_job.source]
self._download_cache.pop(download_job.source, None)
install_job = self._download_cache.pop(download_job.source, None)
print(f'DEBUG: download_job={download_job.source} / install_job={install_job}')
# are there any more active jobs left in this task?
if install_job.downloading and all(x.complete for x in install_job.download_parts):
if install_job and install_job.downloading and all(x.complete for x in install_job.download_parts):
print(f'DEBUG: setting job {install_job.id} to DOWNLOADS_DONE')
install_job.status = InstallStatus.DOWNLOADS_DONE
self._install_queue.put(install_job)
print(f'DEBUG: putting {install_job.id} into the install queue')
self._put_in_queue(install_job)
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
@ -835,7 +842,7 @@ class ModelInstallService(ModelInstallServiceBase):
if all(x.in_terminal_state for x in install_job.download_parts):
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
self._install_queue.put(install_job)
self._put_in_queue(install_job)
# ------------------------------------------------------------------------------------------------
# Internal methods that put events on the event bus
@ -892,7 +899,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
self._logger.info(f"{job.source}: model installation was cancelled")
if self._event_bus:
self._event_bus.emit_model_install_cancelled(str(job.source))
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
@staticmethod
def get_fetcher_from_url(url: str):

View File

@ -68,6 +68,7 @@ class ModelLoadService(ModelLoadServiceBase):
self._emit_load_event(
context_data=context_data,
model_config=model_config,
submodel_type=submodel_type,
)
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
@ -82,6 +83,7 @@ class ModelLoadService(ModelLoadServiceBase):
self._emit_load_event(
context_data=context_data,
model_config=model_config,
submodel_type=submodel_type,
loaded=True,
)
return loaded_model
@ -91,6 +93,7 @@ class ModelLoadService(ModelLoadServiceBase):
context_data: InvocationContextData,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
submodel_type: Optional[SubModelType] = None,
) -> None:
if not self._invoker:
return
@ -102,6 +105,7 @@ class ModelLoadService(ModelLoadServiceBase):
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)
else:
self._invoker.services.events.emit_model_load_completed(
@ -110,4 +114,5 @@ class ModelLoadService(ModelLoadServiceBase):
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)

View File

@ -13,9 +13,11 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.util import download_with_progress_bar
config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger(config=config)
DEPTH_ANYTHING_MODELS = {
"large": {
@ -54,8 +56,9 @@ class DepthAnythingDetector:
def __init__(self) -> None:
self.model = None
self.model_size: Union[Literal["large", "base", "small"], None] = None
self.device = choose_torch_device()
def load_model(self, model_size=Literal["large", "base", "small"]):
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"])
if not DEPTH_ANYTHING_MODEL_PATH.exists():
download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH)
@ -71,8 +74,6 @@ class DepthAnythingDetector:
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
case "large":
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
case _:
raise TypeError("Not a supported model")
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
self.model.eval()
@ -80,20 +81,20 @@ class DepthAnythingDetector:
self.model.to(choose_torch_device())
return self.model
def to(self, device):
self.model.to(device)
return self
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
if not self.model:
logger.warn("DepthAnything model was not loaded. Returning original image")
return image
def __call__(self, image, resolution=512, offload=False):
image = np.array(image, dtype=np.uint8)
image = image[:, :, ::-1] / 255.0
np_image = np.array(image, dtype=np.uint8)
np_image = np_image[:, :, ::-1] / 255.0
image_height, image_width = image.shape[:2]
image = transform({"image": image})["image"]
image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
image_height, image_width = np_image.shape[:2]
np_image = transform({"image": np_image})["image"]
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device())
with torch.no_grad():
depth = self.model(image)
depth = self.model(tensor_image)
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
@ -103,7 +104,4 @@ class DepthAnythingDetector:
new_height = int(image_height * (resolution / image_width))
depth_map = depth_map.resize((resolution, new_height))
if offload:
del self.model
return depth_map

View File

@ -11,17 +11,6 @@ def check_invokeai_root(config: InvokeAIAppConfig):
try:
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
assert config.models_path.exists(), f"{config.models_path} not found"
if not config.ignore_missing_core_models:
for model in [
"CLIP-ViT-bigG-14-laion2B-39B-b160k",
"bert-base-uncased",
"clip-vit-large-patch14",
"sd-vae-ft-mse",
"stable-diffusion-2-clip",
"stable-diffusion-safety-checker",
]:
path = config.models_path / f"core/convert/{model}"
assert path.exists(), f"{path} is missing"
except Exception as e:
print()
print(f"An exception has occurred: {str(e)}")
@ -32,10 +21,5 @@ def check_invokeai_root(config: InvokeAIAppConfig):
print(
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
)
print(
'** (To skip this check completely, add "--ignore_missing_core_models" to your CLI args. Not installing '
"these core models will prevent the loading of some or all .safetensors and .ckpt files. However, you can "
"always come back and install these core models in the future.)"
)
input("Press any key to continue...")
sys.exit(0)

View File

@ -25,20 +25,20 @@ import npyscreen
import psutil
import torch
import transformers
from diffusers import AutoencoderKL, ModelMixin
from diffusers import ModelMixin
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder
from huggingface_hub import login as hf_hub_login
from omegaconf import DictConfig, OmegaConf
from pydantic.error_wrappers import ValidationError
from tqdm import tqdm
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from transformers import AutoFeatureExtractor
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
from invokeai.backend.model_manager import BaseModelType, ModelType
from invokeai.backend.model_manager import ModelType
from invokeai.backend.util import choose_precision, choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm
@ -210,51 +210,15 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
print(traceback.format_exc(), file=sys.stderr)
def download_conversion_models():
def download_safety_checker():
target_dir = config.models_path / "core/convert"
kwargs = {} # for future use
try:
logger.info("Downloading core tokenizers and text encoders")
# bert
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
bert.save_pretrained(target_dir / "bert-base-uncased", safe_serialization=True)
# sd-1
repo_id = "openai/clip-vit-large-patch14"
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / "clip-vit-large-patch14")
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / "clip-vit-large-patch14")
# sd-2
repo_id = "stabilityai/stable-diffusion-2"
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "tokenizer", safe_serialization=True)
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "text_encoder", safe_serialization=True)
# sd-xl - tokenizer_2
repo_id = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
_, model_name = repo_id.split("/")
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
pipeline = CLIPTextConfig.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
# VAE
logger.info("Downloading stable diffusion VAE")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", **kwargs)
vae.save_pretrained(target_dir / "sd-vae-ft-mse", safe_serialization=True)
# safety checking
logger.info("Downloading safety checker")
repo_id = "CompVis/stable-diffusion-safety-checker"
pipeline = AutoFeatureExtractor.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
except KeyboardInterrupt:
@ -307,7 +271,7 @@ def download_lama():
def download_support_models() -> None:
download_realesrgan()
download_lama()
download_conversion_models()
download_safety_checker()
# -------------------------------------
@ -744,12 +708,7 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
dest = root / "models"
for model_base in BaseModelType:
for model_type in ModelType:
path = dest / model_base.value / model_type.value
path.mkdir(parents=True, exist_ok=True)
path = dest / "core"
path.mkdir(parents=True, exist_ok=True)
dest.mkdir(parents=True, exist_ok=True)
# -------------------------------------

View File

@ -1,12 +1,4 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Fast hashing of diffusers and checkpoint-style models.
Usage:
from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f'
"""
import hashlib
import os
@ -15,9 +7,9 @@ from typing import Callable, Literal, Optional, Union
from blake3 import blake3
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
from invokeai.app.util.misc import uuid_string
ALGORITHM = Literal[
HASHING_ALGORITHMS = Literal[
"md5",
"sha1",
"sha224",
@ -33,12 +25,15 @@ ALGORITHM = Literal[
"shake_128",
"shake_256",
"blake3",
"blake3_single",
"random",
]
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
class ModelHash:
"""
Creates a hash of a model using a specified algorithm.
Creates a hash of a model using a specified algorithm. The hash is prefixed by the algorithm used.
Args:
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
@ -53,20 +48,29 @@ class ModelHash:
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
that directory hashes are never weaker than the file hashes.
A convenience algorithm choice of "random" is also available, which returns a random string. This is not a hash.
Usage:
```py
# BLAKE3 hash
ModelHash().hash("path/to/some/model.safetensors")
ModelHash().hash("path/to/some/model.safetensors") # "blake3:ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"
# MD5
ModelHash("md5").hash("path/to/model/dir/")
ModelHash("md5").hash("path/to/model/dir/") # "md5:a0cd925fc063f98dbf029eee315060c3"
```
"""
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
def __init__(
self, algorithm: HASHING_ALGORITHMS = "blake3", file_filter: Optional[Callable[[str], bool]] = None
) -> None:
self.algorithm: HASHING_ALGORITHMS = algorithm
if algorithm == "blake3":
self._hash_file = self._blake3
elif algorithm == "blake3_single":
self._hash_file = self._blake3_single
elif algorithm in hashlib.algorithms_available:
self._hash_file = self._get_hashlib(algorithm)
elif algorithm == "random":
self._hash_file = self._random
else:
raise ValueError(f"Algorithm {algorithm} not available")
@ -87,10 +91,12 @@ class ModelHash:
"""
model_path = Path(model_path)
# blake3_single is a single-threaded version of blake3, prefix should still be "blake3:"
prefix = self._get_prefix(self.algorithm)
if model_path.is_file():
return self._hash_file(model_path)
return prefix + self._hash_file(model_path)
elif model_path.is_dir():
return self._hash_dir(model_path)
return prefix + self._hash_dir(model_path)
else:
raise OSError(f"Not a valid file or directory: {model_path}")
@ -114,6 +120,7 @@ class ModelHash:
composite_hasher = blake3()
for h in component_hashes:
composite_hasher.update(h.encode("utf-8"))
return composite_hasher.hexdigest()
@staticmethod
@ -137,7 +144,7 @@ class ModelHash:
@staticmethod
def _blake3(file_path: Path) -> str:
"""Hashes a file using BLAKE3
"""Hashes a file using BLAKE3, using parallelized and memory-mapped I/O to avoid reading the entire file into memory.
Args:
file_path: Path to the file to hash
@ -150,7 +157,21 @@ class ModelHash:
return file_hasher.hexdigest()
@staticmethod
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
def _blake3_single(file_path: Path) -> str:
"""Hashes a file using BLAKE3, without parallelism. Suitable for spinning hard drives.
Args:
file_path: Path to the file to hash
Returns:
Hexdigest of the hash of the file
"""
file_hasher = blake3()
file_hasher.update_mmap(file_path)
return file_hasher.hexdigest()
@staticmethod
def _get_hashlib(algorithm: HASHING_ALGORITHMS) -> Callable[[Path], str]:
"""Factory function that returns a function to hash a file with the given algorithm.
Args:
@ -172,6 +193,13 @@ class ModelHash:
return hashlib_hasher
@staticmethod
def _random(_file_path: Path) -> str:
"""Returns a random string. This is not a hash.
The string is a UUID, hashed with BLAKE3 to ensure that it is unique."""
return blake3(uuid_string().encode()).hexdigest()
@staticmethod
def _default_file_filter(file_path: str) -> bool:
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
@ -183,3 +211,9 @@ class ModelHash:
True if the file matches the given extensions, otherwise False
"""
return file_path.endswith(MODEL_FILE_EXTENSIONS)
@staticmethod
def _get_prefix(algorithm: HASHING_ALGORITHMS) -> str:
"""Return the prefix for the given algorithm, e.g. \"blake3:\" or \"md5:\"."""
# blake3_single is a single-threaded version of blake3, prefix should still be "blake3:"
return "blake3:" if algorithm == "blake3_single" else f"{algorithm}:"

View File

@ -131,13 +131,20 @@ class ModelSourceType(str, Enum):
HFRepoID = "hf_repo_id"
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
class MainModelDefaultSettings(BaseModel):
vae: str | None
vae_precision: str | None
scheduler: SCHEDULER_NAME_VALUES | None
steps: int | None
cfg_scale: float | None
cfg_rescale_multiplier: float | None
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
scheduler: SCHEDULER_NAME_VALUES | None = Field(default=None, description="Default scheduler for this model")
steps: int | None = Field(default=None, gt=0, description="Default number of steps for this model")
cfg_scale: float | None = Field(default=None, ge=1, description="Default CFG Scale for this model")
cfg_rescale_multiplier: float | None = Field(
default=None, ge=0, lt=1, description="Default CFG Rescale Multiplier for this model"
)
width: int | None = Field(default=None, multiple_of=8, ge=64, description="Default width for this model")
height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model")
class ControlAdapterDefaultSettings(BaseModel):

File diff suppressed because it is too large Load Diff

View File

@ -3,9 +3,6 @@
from pathlib import Path
import torch
from safetensors.torch import load_file as safetensors_load_file
from invokeai.backend.model_manager import (
AnyModelConfig,
BaseModelType,
@ -37,27 +34,25 @@ class ControlNetLoader(GenericDiffusersLoader):
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"ControlNet conversion not supported for model type: {config.base}")
else:
assert isinstance(config, CheckpointConfigBase)
config_file = config.config_path
assert isinstance(config, CheckpointConfigBase)
config_file = config.config_path
if model_path.suffix == ".safetensors":
checkpoint = safetensors_load_file(model_path, device="cpu")
else:
checkpoint = torch.load(model_path, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
convert_controlnet_to_diffusers(
model_path,
output_path,
original_config_file=self._app_config.root_path / config_file,
image_size=512,
scan_needed=True,
from_safetensors=model_path.suffix == ".safetensors",
image_size = (
512
if config.base == BaseModelType.StableDiffusion1
else 768
if config.base == BaseModelType.StableDiffusion2
else 1024
)
self._logger.info(f"Converting {model_path} to diffusers format")
with open(self._app_config.root_path / config_file, "r") as config_stream:
convert_controlnet_to_diffusers(
model_path,
output_path,
original_config_file=config_stream,
image_size=image_size,
precision=self._torch_dtype,
from_safetensors=model_path.suffix == ".safetensors",
)
return output_path

View File

@ -4,9 +4,6 @@
from pathlib import Path
from typing import Optional
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@ -14,7 +11,7 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelRepoVariant,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
@ -68,27 +65,31 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
assert isinstance(config, MainCheckpointConfig)
variant = config.variant
base = config.base
pipeline_class = (
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
)
config_file = config.config_path
prediction_type = config.prediction_type.value
upcast_attention = config.upcast_attention
image_size = (
1024
if base == BaseModelType.StableDiffusionXL
else 768
if config.prediction_type == SchedulerPredictionType.VPrediction and base == BaseModelType.StableDiffusion2
else 512
)
self._logger.info(f"Converting {model_path} to diffusers format")
convert_ckpt_to_diffusers(
model_path,
output_path,
model_type=self.model_base_to_model_type[base],
model_version=base,
model_variant=variant,
original_config_file=self._app_config.root_path / config_file,
extract_ema=True,
scan_needed=True,
pipeline_class=pipeline_class,
from_safetensors=model_path.suffix == ".safetensors",
precision=self._torch_dtype,
prediction_type=prediction_type,
image_size=image_size,
upcast_attention=upcast_attention,
load_safety_checker=False,
)
return output_path

View File

@ -57,12 +57,12 @@ class VAELoader(GenericDiffusersLoader):
ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
assert isinstance(ckpt_config, DictConfig)
self._logger.info(f"Converting {model_path} to diffusers format")
vae_model = convert_ldm_vae_to_diffusers(
checkpoint=checkpoint,
vae_config=ckpt_config,
image_size=512,
precision=self._torch_dtype,
)
vae_model.to(self._torch_dtype) # set precision appropriately
vae_model.save_pretrained(output_path, safe_serialization=True)
return output_path

View File

@ -90,8 +90,35 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
)
)
# diffusers models have a `model_index.json` or `config.json` file
is_diffusers = any(str(f.url).endswith(("model_index.json", "config.json")) for f in files)
# These URLs will be exposed to the user - I think these are the only file types we fully support
ckpt_urls = (
None
if is_diffusers
else [
f.url
for f in files
if str(f.url).endswith(
(
".safetensors",
".bin",
".pth",
".pt",
".ckpt",
)
)
]
)
return HuggingFaceMetadata(
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
id=model_info.id,
name=name,
files=files,
api_response=json.dumps(model_info.__dict__, default=str),
is_diffusers=is_diffusers,
ckpt_urls=ckpt_urls,
)
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:

View File

@ -84,6 +84,10 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
type: Literal["huggingface"] = "huggingface"
id: str = Field(description="The HF model id")
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model", default=False)
ckpt_urls: Optional[List[AnyHttpUrl]] = Field(
description="URLs for all checkpoint format models in the metadata", default=None
)
def download_urls(
self,

View File

@ -9,6 +9,7 @@ from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.util.util import SilenceWarnings
from .config import (
@ -16,6 +17,7 @@ from .config import (
BaseModelType,
ControlAdapterDefaultSettings,
InvalidModelConfigException,
MainModelDefaultSettings,
ModelConfigFactory,
ModelFormat,
ModelRepoVariant,
@ -24,7 +26,6 @@ from .config import (
ModelVariantType,
SchedulerPredictionType,
)
from .hash import ModelHash
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
CkptType = Dict[str, Any]
@ -113,9 +114,7 @@ class ModelProbe(object):
@classmethod
def probe(
cls,
model_path: Path,
fields: Optional[Dict[str, Any]] = None,
cls, model_path: Path, fields: Optional[Dict[str, Any]] = None, hash_algo: HASHING_ALGORITHMS = "blake3"
) -> AnyModelConfig:
"""
Probe the model at model_path and return its configuration record.
@ -133,11 +132,12 @@ class ModelProbe(object):
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None
model_type = None
if format_type is ModelFormat.Diffusers:
model_type = cls.get_model_type_from_folder(model_path)
else:
model_type = cls.get_model_type_from_checkpoint(model_path)
model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
if not model_type:
if format_type is ModelFormat.Diffusers:
model_type = cls.get_model_type_from_folder(model_path)
else:
model_type = cls.get_model_type_from_checkpoint(model_path)
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type)
@ -157,16 +157,18 @@ class ModelProbe(object):
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
fields["description"] = (
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
fields["default_settings"] = (
fields.get("default_settings") or probe.get_default_settings(fields["name"])
if isinstance(probe, ControlAdapterProbe)
else None
)
fields["default_settings"] = fields.get("default_settings")
if not fields["default_settings"]:
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter}:
fields["default_settings"] = get_default_settings_controlnet_t2i_adapter(fields["name"])
elif fields["type"] is ModelType.Main:
fields["default_settings"] = get_default_settings_main(fields["base"])
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
@ -318,7 +320,7 @@ class ModelProbe(object):
@classmethod
def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType:
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
cls._scan_model(model_path.name, model_path)
model = torch.load(model_path)
assert isinstance(model, dict)
@ -338,36 +340,41 @@ class ModelProbe(object):
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
class ControlAdapterProbe(ProbeBase):
"""Adds `get_default_settings` for ControlNet and T2IAdapter probes"""
# Probing utilities
MODEL_NAME_TO_PREPROCESSOR = {
"canny": "canny_image_processor",
"mlsd": "mlsd_image_processor",
"depth": "depth_anything_image_processor",
"bae": "normalbae_image_processor",
"normal": "normalbae_image_processor",
"sketch": "pidi_image_processor",
"scribble": "lineart_image_processor",
"lineart": "lineart_image_processor",
"lineart_anime": "lineart_anime_image_processor",
"softedge": "hed_image_processor",
"shuffle": "content_shuffle_image_processor",
"pose": "dw_openpose_image_processor",
"mediapipe": "mediapipe_face_processor",
"pidi": "pidi_image_processor",
"zoe": "zoe_depth_image_processor",
"color": "color_map_image_processor",
}
# TODO(psyche): It would be nice to get these from the invocations, but that creates circular dependencies.
# "canny": CannyImageProcessorInvocation.get_type()
MODEL_NAME_TO_PREPROCESSOR = {
"canny": "canny_image_processor",
"mlsd": "mlsd_image_processor",
"depth": "depth_anything_image_processor",
"bae": "normalbae_image_processor",
"normal": "normalbae_image_processor",
"sketch": "pidi_image_processor",
"scribble": "lineart_image_processor",
"lineart": "lineart_image_processor",
"lineart_anime": "lineart_anime_image_processor",
"softedge": "hed_image_processor",
"shuffle": "content_shuffle_image_processor",
"pose": "dw_openpose_image_processor",
"mediapipe": "mediapipe_face_processor",
"pidi": "pidi_image_processor",
"zoe": "zoe_depth_image_processor",
"color": "color_map_image_processor",
}
@classmethod
def get_default_settings(cls, model_name: str) -> Optional[ControlAdapterDefaultSettings]:
for k, v in cls.MODEL_NAME_TO_PREPROCESSOR.items():
if k in model_name:
return ControlAdapterDefaultSettings(preprocessor=v)
return None
def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
if k in model_name:
return ControlAdapterDefaultSettings(preprocessor=v)
return None
def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
return MainModelDefaultSettings(width=512, height=512)
elif model_base is BaseModelType.StableDiffusionXL:
return MainModelDefaultSettings(width=1024, height=1024)
# We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models.
return None
# ##################################################3
@ -493,7 +500,7 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type")
class ControlNetCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe):
class ControlNetCheckpointProbe(CheckpointProbeBase):
"""Class for probing controlnets."""
def get_base_type(self) -> BaseModelType:
@ -521,7 +528,7 @@ class CLIPVisionCheckpointProbe(CheckpointProbeBase):
raise NotImplementedError()
class T2IAdapterCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe):
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
@ -659,7 +666,7 @@ class ONNXFolderProbe(PipelineFolderProbe):
return ModelVariantType.Normal
class ControlNetFolderProbe(FolderProbeBase, ControlAdapterProbe):
class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
if not config_file.exists():
@ -733,7 +740,7 @@ class CLIPVisionFolderProbe(FolderProbeBase):
return BaseModelType.Any
class T2IAdapterFolderProbe(FolderProbeBase, ControlAdapterProbe):
class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
if not config_file.exists():

View File

@ -5,6 +5,7 @@ from typing import Callable, List, Union
import torch.nn as nn
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
@ -26,7 +27,7 @@ def _conv_forward_asymmetric(self, input, weight, bias):
@contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
try:

View File

@ -62,40 +62,72 @@ sd-1/main/trinart_stable_diffusion_v2:
recommended: False
sd-1/controlnet/qrcode_monster:
source: monster-labs/control_v1p_sd15_qrcode_monster
description: Controlnet model that generates scannable creative QR codes
subfolder: v2
sd-1/controlnet/canny:
description: Controlnet weights trained on sd-1.5 with canny conditioning.
source: lllyasviel/control_v11p_sd15_canny
recommended: True
sd-1/controlnet/inpaint:
source: lllyasviel/control_v11p_sd15_inpaint
description: Controlnet weights trained on sd-1.5 with canny conditioning, inpaint version
sd-1/controlnet/mlsd:
description: Controlnet weights trained on sd-1.5 with canny conditioning, MLSD version
source: lllyasviel/control_v11p_sd15_mlsd
sd-1/controlnet/depth:
description: Controlnet weights trained on sd-1.5 with depth conditioning
source: lllyasviel/control_v11f1p_sd15_depth
recommended: True
sd-1/controlnet/normal_bae:
description: Controlnet weights trained on sd-1.5 with normalbae image conditioning
source: lllyasviel/control_v11p_sd15_normalbae
sd-1/controlnet/seg:
description: Controlnet weights trained on sd-1.5 with seg image conditioning
source: lllyasviel/control_v11p_sd15_seg
sd-1/controlnet/lineart:
description: Controlnet weights trained on sd-1.5 with lineart image conditioning
source: lllyasviel/control_v11p_sd15_lineart
recommended: True
sd-1/controlnet/lineart_anime:
description: Controlnet weights trained on sd-1.5 with anime image conditioning
source: lllyasviel/control_v11p_sd15s2_lineart_anime
sd-1/controlnet/openpose:
description: Controlnet weights trained on sd-1.5 with openpose image conditioning
source: lllyasviel/control_v11p_sd15_openpose
recommended: True
sd-1/controlnet/scribble:
source: lllyasviel/control_v11p_sd15_scribble
description: Controlnet weights trained on sd-1.5 with scribble image conditioning
recommended: False
sd-1/controlnet/softedge:
source: lllyasviel/control_v11p_sd15_softedge
description: Controlnet weights trained on sd-1.5 with soft edge conditioning
sd-1/controlnet/shuffle:
source: lllyasviel/control_v11e_sd15_shuffle
description: Controlnet weights trained on sd-1.5 with shuffle image conditioning
sd-1/controlnet/tile:
source: lllyasviel/control_v11f1e_sd15_tile
description: Controlnet weights trained on sd-1.5 with tiled image conditioning
sd-1/controlnet/ip2p:
source: lllyasviel/control_v11e_sd15_ip2p
description: Controlnet weights trained on sd-1.5 with ip2p conditioning.
sdxl/controlnet/canny-sdxl:
description: Controlnet weights trained on sdxl-1.0 with canny conditioning.
source: diffusers/controlnet-canny-sdxl-1.0
recommended: True
sdxl/controlnet/depth-sdxl:
description: Controlnet weights trained on sdxl-1.0 with depth conditioning.
source: diffusers/controlnet-depth-sdxl-1.0
recommended: True
sdxl/controlnet/softedge-dexined-sdxl:
description: Controlnet weights trained on sdxl-1.0 with dexined soft edge preprocessing.
source: SargeZT/controlnet-sd-xl-1.0-softedge-dexined
sdxl/controlnet/depth-16bit-zoe-sdxl:
description: Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).
source: SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe
sdxl/controlnet/depth-zoe-sdxl:
description: Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).
source: diffusers/controlnet-zoe-depth-sdxl-1.0
sd-1/t2i_adapter/canny-sd15:
source: TencentARC/t2iadapter_canny_sd15v2
sd-1/t2i_adapter/sketch-sd15:

View File

@ -608,8 +608,9 @@ def main() -> None:
config.parse_args(invoke_args)
logger = InvokeAILogger().get_logger(config=config)
if not config.model_conf_path.exists():
if not config.models_path.exists():
logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.")
sys.argv = ["invokeai_configure", "--yes", "--skip-sd-weights"]
from invokeai.frontend.install.invokeai_configure import invokeai_configure
invokeai_configure()

View File

@ -1,150 +1,3 @@
# Invoke UI
<!-- @import "[TOC]" {cmd="toc" depthFrom=2 depthTo=3 orderedList=false} -->
<!-- code_chunk_output -->
- [Dev environment](#dev-environment)
- [Setup](#setup)
- [Package scripts](#package-scripts)
- [Type generation](#type-generation)
- [Localization](#localization)
- [VSCode](#vscode)
- [Contributing](#contributing)
- [Check in before investing your time](#check-in-before-investing-your-time)
- [Commit format](#commit-format)
- [Submitting a PR](#submitting-a-pr)
- [Other docs](#other-docs)
<!-- /code_chunk_output -->
Invoke's UI is made possible by many contributors and open-source libraries. Thank you!
## Dev environment
### Setup
1. Install [node] and [pnpm].
1. Run `pnpm i` to install all packages.
#### Run in dev mode
1. From `invokeai/frontend/web/`, run `pnpm dev`.
1. From repo root, run `python scripts/invokeai-web.py`.
1. Point your browser to the dev server address, e.g. <http://localhost:5173/>
### Package scripts
- `dev`: run the frontend in dev mode, enabling hot reloading
- `build`: run all checks (madge, eslint, prettier, tsc) and then build the frontend
- `typegen`: generate types from the OpenAPI schema (see [Type generation])
- `lint:madge`: check frontend for circular dependencies
- `lint:eslint`: check frontend for code quality
- `lint:prettier`: check frontend for code formatting
- `lint:tsc`: check frontend for type issues
- `lint`: run all checks concurrently
- `fix`: run `eslint` and `prettier`, fixing fixable issues
### Type generation
We use [openapi-typescript] to generate types from the app's OpenAPI schema.
The generated types are committed to the repo in [schema.ts].
```sh
# from the repo root, start the server
python scripts/invokeai-web.py
# from invokeai/frontend/web/, run the script
pnpm typegen
```
### Localization
We use [i18next] for localization, but translation to languages other than English happens on our [Weblate] project.
Only the English source strings should be changed on this repo.
### VSCode
#### Example debugger config
```jsonc
{
"version": "0.2.0",
"configurations": [
{
"type": "chrome",
"request": "launch",
"name": "Invoke UI",
"url": "http://localhost:5173",
"webRoot": "${workspaceFolder}/invokeai/frontend/web",
},
],
}
```
#### Remote dev
We've noticed an intermittent timeout issue with the VSCode remote dev port forwarding.
We suggest disabling the editor's port forwarding feature and doing it manually via SSH:
```sh
ssh -L 9090:localhost:9090 -L 5173:localhost:5173 user@host
```
## Contributing Guidelines
Thanks for your interest in contributing to the Invoke Web UI!
Please follow these guidelines when contributing.
### Check in before investing your time
Please check in before you invest your time on anything besides a trivial fix, in case it conflicts with ongoing work or isn't aligned with the vision for the app.
If a feature request or issue doesn't already exist for the thing you want to work on, please create one.
Ping `@psychedelicious` on [discord] in the `#frontend-dev` channel or in the feature request / issue you want to work on - we're happy chat.
### Code conventions
- This is a fairly complex app with a deep component tree. Please use memoization (`useCallback`, `useMemo`, `memo`) with enthusiasm.
- If you need to add some global, ephemeral state, please use [nanostores] if possible.
- Be careful with your redux selectors. If they need to be parameterized, consider creating them inside a `useMemo`.
- Feel free to use `lodash` (via `lodash-es`) to make the intent of your code clear.
- Please add comments describing the "why", not the "how" (unless it is really arcane).
### Commit format
Please use the [conventional commits] spec for the web UI, with a scope of "ui":
- `chore(ui): bump deps`
- `chore(ui): lint`
- `feat(ui): add some cool new feature`
- `fix(ui): fix some bug`
### Submitting a PR
- Ensure your branch is tidy. Use an interactive rebase to clean up the commit history and reword the commit messages if they are not descriptive.
- Run `pnpm lint`. Some issues are auto-fixable with `pnpm fix`.
- Fill out the PR form when creating the PR.
- It doesn't need to be super detailed, but a screenshot or video is nice if you changed something visually.
- If a section isn't relevant, delete it. There are no UI tests at this time.
## Other docs
- [Workflows - Design and Implementation]
- [State Management]
[node]: https://nodejs.org/en/download/
[pnpm]: https://github.com/pnpm/pnpm
[discord]: https://discord.gg/ZmtBAhwWhy
[i18next]: https://github.com/i18next/react-i18next
[Weblate]: https://hosted.weblate.org/engage/invokeai/
[openapi-typescript]: https://github.com/drwpow/openapi-typescript
[Type generation]: #type-generation
[schema.ts]: ../src/services/api/schema.ts
[conventional commits]: https://www.conventionalcommits.org/en/v1.0.0/
[Workflows - Design and Implementation]: ./docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md
[State Management]: ./docs/STATE_MGMT.md
<https://invoke-ai.github.io/InvokeAI/contributing/frontend/OVERVIEW/>

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,88 @@
# Cleans translations by removing unused keys
# Usage: python clean_translations.py
# Note: Must be run from invokeai/frontend/web/scripts directory
#
# After running the script, open `en.json` and check for empty objects (`{}`) and remove them manually.
import json
import os
import re
from typing import TypeAlias, Union
from tqdm import tqdm
RecursiveDict: TypeAlias = dict[str, Union["RecursiveDict", str]]
class TranslationCleaner:
file_cache: dict[str, str] = {}
def _get_keys(self, obj: RecursiveDict, current_path: str = "", keys: list[str] | None = None):
if keys is None:
keys = []
for key in obj:
new_path = f"{current_path}.{key}" if current_path else key
next_ = obj[key]
if isinstance(next_, dict):
self._get_keys(next_, new_path, keys)
elif "_" in key:
# This typically means its a pluralized key
continue
else:
keys.append(new_path)
return keys
def _search_codebase(self, key: str):
for root, _dirs, files in os.walk("../src"):
for file in files:
if file.endswith(".ts") or file.endswith(".tsx"):
full_path = os.path.join(root, file)
if full_path in self.file_cache:
content = self.file_cache[full_path]
else:
with open(full_path, "r") as f:
content = f.read()
self.file_cache[full_path] = content
# match the whole key, surrounding by quotes
if re.search(r"['\"`]" + re.escape(key) + r"['\"`]", self.file_cache[full_path]):
return True
# math the stem of the key, with quotes at the end
if re.search(re.escape(key.split(".")[-1]) + r"['\"`]", self.file_cache[full_path]):
return True
return False
def _remove_key(self, obj: RecursiveDict, key: str):
path = key.split(".")
last_key = path[-1]
for k in path[:-1]:
obj = obj[k]
del obj[last_key]
def clean(self, obj: RecursiveDict) -> RecursiveDict:
keys = self._get_keys(obj)
pbar = tqdm(keys, desc="Checking keys")
for key in pbar:
if not self._search_codebase(key):
self._remove_key(obj, key)
return obj
def main():
try:
with open("../public/locales/en.json", "r") as f:
data = json.load(f)
except FileNotFoundError as e:
raise FileNotFoundError(
"Unable to find en.json file - must be run from invokeai/frontend/web/scripts directory"
) from e
cleaner = TranslationCleaner()
cleaned_data = cleaner.clean(data)
with open("../public/locales/en.json", "w") as f:
json.dump(cleaned_data, f, indent=4)
if __name__ == "__main__":
main()

View File

@ -1,10 +1,10 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import type { JSONObject } from 'common/types';
import {
controlAdapterModelCleared,
selectAllControlNets,
selectAllIPAdapters,
selectAllT2IAdapters,
selectControlAdapterAll,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
@ -12,212 +12,161 @@ import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es';
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models';
import type { TypeGuardFor } from 'services/api/types';
import { forEach } from 'lodash-es';
import type { Logger } from 'roarr';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import { isNonRefinerMainModelConfig, isRefinerMainModelModelConfig, isVAEModelConfig } from 'services/api/types';
export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
startAppListening({
predicate: (action): action is TypeGuardFor<typeof modelsApi.endpoints.getMainModels.matchFulfilled> =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
predicate: modelsApi.endpoints.getModelConfigs.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models');
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`);
log.info({ models: action.payload.entities }, `Models loaded (${action.payload.ids.length})`);
const state = getState();
const currentModel = state.generation.model;
const models = mainModelsAdapterSelectors.selectAll(action.payload);
const models = modelConfigsAdapterSelectors.selectAll(action.payload);
if (models.length === 0) {
// No models loaded at all
dispatch(modelChanged(null));
return;
}
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
if (isCurrentModelAvailable) {
return;
}
const defaultModel = state.config.sd.defaultModel;
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
if (defaultModelInList) {
const result = zParameterModel.safeParse(defaultModelInList);
if (result.success) {
dispatch(modelChanged(defaultModelInList, currentModel));
const optimalDimension = getOptimalDimension(defaultModelInList);
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
return;
}
const { width, height } = calculateNewSize(
state.generation.aspectRatio.value,
optimalDimension * optimalDimension
);
dispatch(widthChanged(width));
dispatch(heightChanged(height));
return;
}
}
const result = zParameterModel.safeParse(models[0]);
if (!result.success) {
log.error({ error: result.error.format() }, 'Failed to parse main model');
return;
}
dispatch(modelChanged(result.data, currentModel));
},
});
startAppListening({
predicate: (action): action is TypeGuardFor<typeof modelsApi.endpoints.getMainModels.matchFulfilled> =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) && action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models');
log.info({ models: action.payload.entities }, `SDXL Refiner models loaded (${action.payload.ids.length})`);
const currentModel = getState().sdxl.refinerModel;
const models = mainModelsAdapterSelectors.selectAll(action.payload);
if (models.length === 0) {
// No models loaded at all
dispatch(refinerModelChanged(null));
return;
}
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
if (!isCurrentModelAvailable) {
dispatch(refinerModelChanged(null));
return;
}
},
});
startAppListening({
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// VAEs loaded, need to reset the VAE is it's no longer available
const log = logger('models');
log.info({ models: action.payload.entities }, `VAEs loaded (${action.payload.ids.length})`);
const currentVae = getState().generation.vae;
if (currentVae === null) {
// null is a valid VAE! it means "use the default with the main model"
return;
}
const isCurrentVAEAvailable = some(action.payload.entities, (m) => m?.key === currentVae?.key);
if (isCurrentVAEAvailable) {
return;
}
const firstModel = vaeModelsAdapterSelectors.selectAll(action.payload)[0];
if (!firstModel) {
// No custom VAEs loaded at all; use the default
dispatch(vaeSelected(null));
return;
}
const result = zParameterVAEModel.safeParse(firstModel);
if (!result.success) {
log.error({ error: result.error.format() }, 'Failed to parse VAE model');
return;
}
dispatch(vaeSelected(result.data));
},
});
startAppListening({
matcher: modelsApi.endpoints.getLoRAModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// LoRA models loaded - need to remove missing LoRAs from state
const log = logger('models');
log.info({ models: action.payload.entities }, `LoRAs loaded (${action.payload.ids.length})`);
const loras = getState().lora.loras;
forEach(loras, (lora, id) => {
const isLoRAAvailable = some(action.payload.entities, (m) => m?.key === lora?.model.key);
if (isLoRAAvailable) {
return;
}
dispatch(loraRemoved(id));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info({ models: action.payload.entities }, `ControlNet models loaded (${action.payload.ids.length})`);
selectAllControlNets(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getT2IAdapterModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info({ models: action.payload.entities }, `T2I Adapter models loaded (${action.payload.ids.length})`);
selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getIPAdapterModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info({ models: action.payload.entities }, `IP Adapter models loaded (${action.payload.ids.length})`);
selectAllIPAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getTextualInversionModels.matchFulfilled,
effect: async (action) => {
const log = logger('models');
log.info({ models: action.payload.entities }, `Embeddings loaded (${action.payload.ids.length})`);
handleMainModels(models, state, dispatch, log);
handleRefinerModels(models, state, dispatch, log);
handleVAEModels(models, state, dispatch, log);
handleLoRAModels(models, state, dispatch, log);
handleControlAdapterModels(models, state, dispatch, log);
},
});
};
type ModelHandler = (
models: AnyModelConfig[],
state: RootState,
dispatch: AppDispatch,
log: Logger<JSONObject>
) => undefined;
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
const currentModel = state.generation.model;
const mainModels = models.filter(isNonRefinerMainModelConfig);
if (mainModels.length === 0) {
// No models loaded at all
dispatch(modelChanged(null));
return;
}
const isCurrentMainModelAvailable = currentModel ? mainModels.some((m) => m.key === currentModel.key) : false;
if (isCurrentMainModelAvailable) {
return;
}
const defaultModel = state.config.sd.defaultModel;
const defaultModelInList = defaultModel ? mainModels.find((m) => m.key === defaultModel) : false;
if (defaultModelInList) {
const result = zParameterModel.safeParse(defaultModelInList);
if (result.success) {
dispatch(modelChanged(defaultModelInList, currentModel));
const optimalDimension = getOptimalDimension(defaultModelInList);
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
return;
}
const { width, height } = calculateNewSize(
state.generation.aspectRatio.value,
optimalDimension * optimalDimension
);
dispatch(widthChanged(width));
dispatch(heightChanged(height));
return;
}
}
const result = zParameterModel.safeParse(mainModels[0]);
if (!result.success) {
log.error({ error: result.error.format() }, 'Failed to parse main model');
return;
}
dispatch(modelChanged(result.data, currentModel));
};
const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
const currentRefinerModel = state.sdxl.refinerModel;
const refinerModels = models.filter(isRefinerMainModelModelConfig);
if (models.length === 0) {
// No models loaded at all
dispatch(refinerModelChanged(null));
return;
}
const isCurrentRefinerModelAvailable = currentRefinerModel
? refinerModels.some((m) => m.key === currentRefinerModel.key)
: false;
if (!isCurrentRefinerModelAvailable) {
dispatch(refinerModelChanged(null));
return;
}
};
const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
const currentVae = state.generation.vae;
if (currentVae === null) {
// null is a valid VAE! it means "use the default with the main model"
return;
}
const vaeModels = models.filter(isVAEModelConfig);
const isCurrentVAEAvailable = vaeModels.some((m) => m.key === currentVae.key);
if (isCurrentVAEAvailable) {
return;
}
const firstModel = vaeModels[0];
if (!firstModel) {
// No custom VAEs loaded at all; use the default
dispatch(vaeSelected(null));
return;
}
const result = zParameterVAEModel.safeParse(firstModel);
if (!result.success) {
log.error({ error: result.error.format() }, 'Failed to parse VAE model');
return;
}
dispatch(vaeSelected(result.data));
};
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
const loras = state.lora.loras;
forEach(loras, (lora, id) => {
const isLoRAAvailable = models.some((m) => m.key === lora.model.key);
if (isLoRAAvailable) {
return;
}
dispatch(loraRemoved(id));
});
};
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
const isModelAvailable = models.some((m) => m.key === ca.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
};

View File

@ -1,26 +1,29 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { setDefaultSettings } from 'features/parameters/store/actions';
import {
heightChanged,
setCfgRescaleMultiplier,
setCfgScale,
setScheduler,
setSteps,
vaePrecisionChanged,
vaeSelected,
widthChanged,
} from 'features/parameters/store/generationSlice';
import {
isParameterCFGRescaleMultiplier,
isParameterCFGScale,
isParameterHeight,
isParameterPrecision,
isParameterScheduler,
isParameterSteps,
isParameterWidth,
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { map } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types';
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
@ -35,14 +38,19 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
return;
}
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate());
const data = await request.unwrap();
request.unsubscribe();
const models = modelConfigsAdapterSelectors.selectAll(data);
const modelConfig = models.find((model) => model.key === currentModel.key);
if (!modelConfig) {
return;
}
if (isNonRefinerMainModelConfig(modelConfig) && modelConfig.default_settings) {
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } =
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler, width, height } =
modelConfig.default_settings;
if (vae) {
@ -51,11 +59,8 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
if (vae === 'default') {
dispatch(vaeSelected(null));
} else {
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
const vaeArray = map(data?.entities);
const validVae = vaeArray.find((model) => model.key === vae);
const result = zParameterVAEModel.safeParse(validVae);
const vaeModel = models.find((model) => model.key === vae);
const result = zParameterVAEModel.safeParse(vaeModel);
if (!result.success) {
return;
}
@ -93,6 +98,18 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
}
}
if (width) {
if (isParameterWidth(width)) {
dispatch(widthChanged(width));
}
}
if (height) {
if (isParameterHeight(height)) {
dispatch(heightChanged(height));
}
}
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
}
},

View File

@ -4,6 +4,7 @@ import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { isEqual } from 'lodash-es';
import { atom } from 'nanostores';
import { api } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
import { socketConnected } from 'services/events/actions';
@ -29,6 +30,11 @@ export const addSocketConnectedEventListener = (startAppListening: AppStartListe
// Bail on the recovery logic if this is the first connection - we don't need to recover anything
if ($isFirstConnection.get()) {
// Populate the model configs on first connection. This query cache has a 24hr timeout, so we can immediately
// unsubscribe.
const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate());
request.unsubscribe();
$isFirstConnection.set(false);
return;
}

View File

@ -2,6 +2,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { api } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
import {
socketModelInstallCancelled,
socketModelInstallCompleted,
socketModelInstallDownloading,
socketModelInstallError,
@ -63,4 +64,21 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
);
},
});
startAppListening({
actionCreator: socketModelInstallCancelled,
effect: (action, { dispatch }) => {
const { id } = action.payload.data;
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'cancelled';
}
return draft;
})
);
},
});
};

View File

@ -8,14 +8,16 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
startAppListening({
actionCreator: socketModelLoadStarted,
effect: (action) => {
const { base_model, model_name, model_type, submodel } = action.payload.data;
const { model_config, submodel_type } = action.payload.data;
const { name, base, type } = model_config;
let message = `Model load started: ${base_model}/${model_type}/${model_name}`;
if (submodel) {
message = message.concat(`/${submodel}`);
const extras: string[] = [base, type];
if (submodel_type) {
extras.push(submodel_type);
}
const message = `Model load started: ${name} (${extras.join(', ')})`;
log.debug(action.payload, message);
},
});
@ -23,14 +25,16 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
startAppListening({
actionCreator: socketModelLoadCompleted,
effect: (action) => {
const { base_model, model_name, model_type, submodel } = action.payload.data;
const { model_config, submodel_type } = action.payload.data;
const { name, base, type } = model_config;
let message = `Model load complete: ${base_model}/${model_type}/${model_name}`;
if (submodel) {
message = message.concat(`/${submodel}`);
const extras: string[] = [base, type];
if (submodel_type) {
extras.push(submodel_type);
}
const message = `Model load complete: ${name} (${extras.join(', ')})`;
log.debug(action.payload, message);
},
});

View File

@ -1,15 +1,14 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { groupBy, map, reduce } from 'lodash-es';
import { groupBy, reduce } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined;
modelConfigs: T[];
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
@ -29,13 +28,12 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
): UseGroupedModelComboboxReturn => {
const { t } = useTranslation();
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg;
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading } = arg;
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
if (!modelEntities) {
if (!modelConfigs) {
return [];
}
const modelEntitiesArray = map(modelEntities.entities);
const groupedModels = groupBy(modelEntitiesArray, 'base');
const groupedModels = groupBy(modelConfigs, 'base');
const _options = reduce(
groupedModels,
(acc, val, label) => {
@ -53,7 +51,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
);
_options.sort((a) => (a.label === base_model ? -1 : 1));
return _options;
}, [getIsDisabled, modelEntities, base_model]);
}, [getIsDisabled, modelConfigs, base_model]);
const value = useMemo(
() =>
@ -67,14 +65,14 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
onChange(null);
return;
}
const model = modelEntities?.entities[v.value];
const model = modelConfigs.find((m) => m.key === v.value);
if (!model) {
onChange(null);
return;
}
onChange(model);
},
[modelEntities?.entities, onChange]
[modelConfigs, onChange]
);
const placeholder = useMemo(() => {

View File

@ -1,13 +1,11 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
type UseModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined;
modelConfigs: T[];
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
@ -25,19 +23,14 @@ type UseModelComboboxReturn = {
export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelComboboxArg<T>): UseModelComboboxReturn => {
const { t } = useTranslation();
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
const options = useMemo<ComboboxOption[]>(() => {
if (!modelEntities) {
return [];
}
return map(modelEntities.entities)
.filter(optionsFilter)
.map((model) => ({
label: model.name,
value: model.key,
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
}));
}, [optionsFilter, getIsDisabled, modelEntities]);
return modelConfigs.filter(optionsFilter).map((model) => ({
label: model.name,
value: model.key,
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
}));
}, [optionsFilter, getIsDisabled, modelConfigs]);
const value = useMemo(
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
@ -50,14 +43,14 @@ export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelCombobox
onChange(null);
return;
}
const model = modelEntities?.entities[v.value];
const model = modelConfigs.find((m) => m.key === v.value);
if (!model) {
onChange(null);
return;
}
onChange(model);
},
[modelEntities?.entities, onChange]
[modelConfigs, onChange]
);
const placeholder = useMemo(() => {

View File

@ -1,15 +1,12 @@
import type { Item } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { filter } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
data: EntityState<T, string> | undefined;
modelConfigs: T[];
isLoading: boolean;
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
@ -28,7 +25,7 @@ const modelFilterDefault = () => true;
const isModelDisabledDefault = () => false;
export const useModelCustomSelect = <T extends AnyModelConfig>({
data,
modelConfigs,
isLoading,
selectedModel,
onChange,
@ -39,30 +36,28 @@ export const useModelCustomSelect = <T extends AnyModelConfig>({
const items: Item[] = useMemo(
() =>
data
? filter(data.entities, modelFilter).map<Item>((m) => ({
label: m.name,
value: m.key,
description: m.description,
group: MODEL_TYPE_SHORT_MAP[m.base],
isDisabled: isModelDisabled(m),
}))
: EMPTY_ARRAY,
[data, isModelDisabled, modelFilter]
modelConfigs.filter(modelFilter).map<Item>((m) => ({
label: m.name,
value: m.key,
description: m.description,
group: MODEL_TYPE_SHORT_MAP[m.base],
isDisabled: isModelDisabled(m),
})),
[modelConfigs, isModelDisabled, modelFilter]
);
const _onChange = useCallback(
(item: Item | null) => {
if (!item || !data) {
if (!item || !modelConfigs) {
return;
}
const model = data.entities[item.value];
const model = modelConfigs.find((m) => m.key === item.value);
if (!model) {
return;
}
onChange(model);
},
[data, onChange]
[modelConfigs, onChange]
);
const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]);

View File

@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { memo, useCallback, useMemo } from 'react';
@ -20,7 +20,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const dispatch = useAppDispatch();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
const [modelConfigs, { isLoading }] = useControlAdapterModels(controlAdapterType);
const _onChange = useCallback(
(modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
@ -43,7 +43,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
);
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
data,
modelConfigs,
isLoading,
selectedModel,
onChange: _onChange,
@ -52,7 +52,13 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
return (
<FormControl isDisabled={!items.length || !isEnabled} isInvalid={!selectedItem || !items.length}>
<CustomSelect selectedItem={selectedItem} placeholder={placeholder} items={items} onChange={onChange} />
<CustomSelect
key={items.length}
selectedItem={selectedItem}
placeholder={placeholder}
items={items}
onChange={onChange}
/>
</FormControl>
);
};

View File

@ -1,17 +1,16 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
import { type ControlAdapterType, isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
import { useCallback, useMemo } from 'react';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { useControlAdapterModels } from './useControlAdapterModels';
export const useAddControlAdapter = (type: ControlAdapterType) => {
const baseModel = useAppSelector((s) => s.generation.model?.base);
const dispatch = useAppDispatch();
const models = useControlAdapterModels(type);
const [models] = useControlAdapterModels(type);
const firstModel: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig | undefined = useMemo(() => {
// prefer to use a model that matches the base model

View File

@ -1,26 +0,0 @@
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
import {
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
} from 'services/api/endpoints/models';
export const useControlAdapterModelQuery = (type: ControlAdapterType) => {
const controlNetModelsQuery = useGetControlNetModelsQuery();
const t2iAdapterModelsQuery = useGetT2IAdapterModelsQuery();
const ipAdapterModelsQuery = useGetIPAdapterModelsQuery();
if (type === 'controlnet') {
return controlNetModelsQuery;
}
if (type === 't2i_adapter') {
return t2iAdapterModelsQuery;
}
if (type === 'ip_adapter') {
return ipAdapterModelsQuery;
}
// Assert that the end of the function is not reachable.
const exhaustiveCheck: never = type;
return exhaustiveCheck;
};

View File

@ -1,31 +1,10 @@
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
import { useMemo } from 'react';
import {
controlNetModelsAdapterSelectors,
ipAdapterModelsAdapterSelectors,
t2iAdapterModelsAdapterSelectors,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
} from 'services/api/endpoints/models';
import { useControlNetModels, useIPAdapterModels, useT2IAdapterModels } from 'services/api/hooks/modelsByType';
export const useControlAdapterModels = (type?: ControlAdapterType) => {
const { data: controlNetModelsData } = useGetControlNetModelsQuery();
const controlNetModels = useMemo(
() => (controlNetModelsData ? controlNetModelsAdapterSelectors.selectAll(controlNetModelsData) : []),
[controlNetModelsData]
);
const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery();
const t2iAdapterModels = useMemo(
() => (t2iAdapterModelsData ? t2iAdapterModelsAdapterSelectors.selectAll(t2iAdapterModelsData) : []),
[t2iAdapterModelsData]
);
const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery();
const ipAdapterModels = useMemo(
() => (ipAdapterModelsData ? ipAdapterModelsAdapterSelectors.selectAll(ipAdapterModelsData) : []),
[ipAdapterModelsData]
);
export const useControlAdapterModels = (type: ControlAdapterType) => {
const controlNetModels = useControlNetModels();
const t2iAdapterModels = useT2IAdapterModels();
const ipAdapterModels = useIPAdapterModels();
if (type === 'controlnet') {
return controlNetModels;
@ -36,5 +15,8 @@ export const useControlAdapterModels = (type?: ControlAdapterType) => {
if (type === 'ip_adapter') {
return ipAdapterModels;
}
return [];
// Assert that the end of the function is not reachable.
const exhaustiveCheck: never = type;
return exhaustiveCheck;
};

View File

@ -93,7 +93,6 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
type: 'depth_anything_image_processor',
model_size: 'small',
resolution: 512,
offload: false,
},
},
hed_image_processor: {

View File

@ -338,6 +338,21 @@ export const controlAdaptersSlice = createSlice({
pendingControlImagesCleared: (state) => {
state.pendingControlImages = [];
},
ipAdaptersReset: (state) => {
selectAllIPAdapters(state).forEach((ca) => {
caAdapter.removeOne(state, ca.id);
});
},
controlNetsReset: (state) => {
selectAllControlNets(state).forEach((ca) => {
caAdapter.removeOne(state, ca.id);
});
},
t2iAdaptersReset: (state) => {
selectAllT2IAdapters(state).forEach((ca) => {
caAdapter.removeOne(state, ca.id);
});
},
},
extraReducers: (builder) => {
builder.addCase(controlAdapterImageProcessed, (state, action) => {
@ -376,6 +391,9 @@ export const {
controlAdapterAutoConfigToggled,
pendingControlImagesCleared,
controlAdapterModelCleared,
ipAdaptersReset,
controlNetsReset,
t2iAdaptersReset,
} = controlAdaptersSlice.actions;
export const isAnyControlAdapterAdded = isAnyOf(controlAdapterAdded, controlAdapterRecalled);

View File

@ -15,6 +15,7 @@ import {
} from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
alwaysShowImageSizeBadgeChanged,
autoAssignBoardOnClickChanged,
setGalleryImageMinimumWidth,
shouldAutoSwitchChanged,
@ -36,6 +37,7 @@ const GallerySettingsPopover = () => {
const galleryImageMinimumWidth = useAppSelector((s) => s.gallery.galleryImageMinimumWidth);
const shouldAutoSwitch = useAppSelector((s) => s.gallery.shouldAutoSwitch);
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
const handleChangeGalleryImageMinimumWidth = useCallback(
(v: number) => {
@ -56,6 +58,11 @@ const GallerySettingsPopover = () => {
[dispatch]
);
const handleChangeAlwaysShowImageSizeBadgeChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => dispatch(alwaysShowImageSizeBadgeChanged(e.target.checked)),
[dispatch]
);
return (
<Popover isLazy>
<PopoverTrigger>
@ -88,6 +95,10 @@ const GallerySettingsPopover = () => {
<FormLabel>{t('gallery.autoAssignBoardOnClick')}</FormLabel>
<Checkbox isChecked={autoAssignBoardOnClick} onChange={handleChangeAutoAssignBoardOnClick} />
</FormControl>
<FormControl>
<FormLabel>{t('gallery.alwaysShowImageSizeBadge')}</FormLabel>
<Checkbox isChecked={alwaysShowImageSizeBadge} onChange={handleChangeAlwaysShowImageSizeBadgeChanged} />
</FormControl>
</FormControlGroup>
<BoardAutoAddSelect />
</Flex>

View File

@ -1,5 +1,5 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, useShiftModifier } from '@invoke-ai/ui-library';
import { Box, Flex, Text, useShiftModifier } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $customStarUI } from 'app/store/nanostores/customStarUI';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@ -22,6 +22,16 @@ const imageIconStyleOverrides: SystemStyleObject = {
bottom: 2,
top: 'auto',
};
const boxSx: SystemStyleObject = {
containerType: 'inline-size',
};
const badgeSx: SystemStyleObject = {
'@container (max-width: 80px)': {
'&': { display: 'none' },
},
};
interface HoverableImageProps {
imageName: string;
index: number;
@ -34,6 +44,7 @@ const GalleryImage = (props: HoverableImageProps) => {
const shift = useShiftModifier();
const { t } = useTranslation();
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
const customStarUi = useStore($customStarUI);
@ -121,7 +132,7 @@ const GalleryImage = (props: HoverableImageProps) => {
}
return (
<Box w="full" h="full" className="gallerygrid-image" data-testid={dataTestId}>
<Box w="full" h="full" className="gallerygrid-image" data-testid={dataTestId} sx={boxSx}>
<Flex
ref={imageContainerRef}
userSelect="none"
@ -145,6 +156,23 @@ const GalleryImage = (props: HoverableImageProps) => {
onMouseOut={handleMouseOut}
>
<>
{(isHovered || alwaysShowImageSizeBadge) && (
<Text
position="absolute"
background="base.900"
color="base.50"
fontSize="sm"
fontWeight="semibold"
bottom={0}
left={0}
opacity={0.7}
px={2}
lineHeight={1.25}
borderTopEndRadius="base"
borderBottomStartRadius="base"
sx={badgeSx}
>{`${imageDTO.width}x${imageDTO.height}`}</Text>
)}
<IAIDndImageIcon onClick={toggleStarredState} icon={starIcon} tooltip={starTooltip} />
{isHovered && shift && (

View File

@ -15,6 +15,7 @@ const initialGalleryState: GalleryState = {
autoAssignBoardOnClick: true,
autoAddBoardId: 'none',
galleryImageMinimumWidth: 90,
alwaysShowImageSizeBadge: false,
selectedBoardId: 'none',
galleryView: 'images',
boardSearchText: '',
@ -71,6 +72,9 @@ export const gallerySlice = createSlice({
state.limit += IMAGE_LIMIT;
}
},
alwaysShowImageSizeBadgeChanged: (state, action: PayloadAction<boolean>) => {
state.alwaysShowImageSizeBadge = action.payload;
},
},
extraReducers: (builder) => {
builder.addMatcher(isAnyBoardDeleted, (state, action) => {
@ -107,6 +111,7 @@ export const {
selectionChanged,
boardSearchTextChanged,
moreImagesLoaded,
alwaysShowImageSizeBadgeChanged,
} = gallerySlice.actions;
const isAnyBoardDeleted = isAnyOf(

View File

@ -19,4 +19,5 @@ export type GalleryState = {
boardSearchText: string;
offset: number;
limit: number;
alwaysShowImageSizeBadge: boolean;
};

View File

@ -7,14 +7,14 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import { useLoRAModels } from 'services/api/hooks/modelsByType';
import type { LoRAModelConfig } from 'services/api/types';
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
const LoRASelect = () => {
const dispatch = useAppDispatch();
const { data, isLoading } = useGetLoRAModelsQuery();
const [modelConfigs, { isLoading }] = useLoRAModels();
const { t } = useTranslation();
const addedLoRAs = useAppSelector(selectAddedLoRAs);
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
@ -37,7 +37,7 @@ const LoRASelect = () => {
);
const { options, onChange } = useGroupedModelCombobox({
modelEntities: data,
modelConfigs,
getIsDisabled,
onChange: _onChange,
});

View File

@ -3,6 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import { cloneDeep } from 'lodash-es';
import type { LoRAModelConfig } from 'services/api/types';
export type LoRA = {
@ -57,10 +58,12 @@ export const loraSlice = createSlice({
}
lora.isEnabled = isEnabled;
},
lorasReset: () => cloneDeep(initialLoraState),
},
});
export const { loraAdded, loraRemoved, loraWeightChanged, loraIsEnabledChanged, loraRecalled } = loraSlice.actions;
export const { loraAdded, loraRemoved, loraWeightChanged, loraIsEnabledChanged, loraRecalled, lorasReset } =
loraSlice.actions;
export const selectLoraSlice = (state: RootState) => state.lora;

View File

@ -225,28 +225,34 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
const control_model = await getProperty(metadataItem, 'control_model');
const key = await getModelKey(control_model, 'controlnet');
const controlNetModel = await fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
const image = zControlField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
const image = zControlField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'image'));
const processedImage = zControlField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'processed_image'));
const control_weight = zControlField.shape.control_weight
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'control_weight'));
.parse(await getProperty(metadataItem, 'control_weight'));
const begin_step_percent = zControlField.shape.begin_step_percent
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'begin_step_percent'));
.parse(await getProperty(metadataItem, 'begin_step_percent'));
const end_step_percent = zControlField.shape.end_step_percent
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'end_step_percent'));
.parse(await getProperty(metadataItem, 'end_step_percent'));
const control_mode = zControlField.shape.control_mode
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'control_mode'));
.parse(await getProperty(metadataItem, 'control_mode'));
const resize_mode = zControlField.shape.resize_mode
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'resize_mode'));
.parse(await getProperty(metadataItem, 'resize_mode'));
const { processorType, processorNode } = buildControlAdapterProcessor(controlNetModel);
@ -260,7 +266,7 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
controlMode: control_mode ?? initialControlNet.controlMode,
resizeMode: resize_mode ?? initialControlNet.resizeMode,
controlImage: image?.image_name ?? null,
processedControlImage: image?.image_name ?? null,
processedControlImage: processedImage?.image_name ?? null,
processorType,
processorNode,
shouldAutoConfig: true,
@ -284,20 +290,30 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
const key = await getModelKey(t2i_adapter_model, 't2i_adapter');
const t2iAdapterModel = await fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
const image = zT2IAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
const weight = zT2IAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight'));
const image = zT2IAdapterField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'image'));
const processedImage = zT2IAdapterField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'processed_image'));
const weight = zT2IAdapterField.shape.weight
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'weight'));
const begin_step_percent = zT2IAdapterField.shape.begin_step_percent
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'begin_step_percent'));
.parse(await getProperty(metadataItem, 'begin_step_percent'));
const end_step_percent = zT2IAdapterField.shape.end_step_percent
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'end_step_percent'));
.parse(await getProperty(metadataItem, 'end_step_percent'));
const resize_mode = zT2IAdapterField.shape.resize_mode
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'resize_mode'));
.parse(await getProperty(metadataItem, 'resize_mode'));
const { processorType, processorNode } = buildControlAdapterProcessor(t2iAdapterModel);
@ -310,7 +326,7 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
resizeMode: resize_mode ?? initialT2IAdapter.resizeMode,
controlImage: image?.image_name ?? null,
processedControlImage: image?.image_name ?? null,
processedControlImage: processedImage?.image_name ?? null,
processorType,
processorNode,
shouldAutoConfig: true,
@ -334,16 +350,22 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
const key = await getModelKey(ip_adapter_model, 'ip_adapter');
const ipAdapterModel = await fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
const image = zIPAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
const weight = zIPAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight'));
const image = zIPAdapterField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'image'));
const weight = zIPAdapterField.shape.weight
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'weight'));
const begin_step_percent = zIPAdapterField.shape.begin_step_percent
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'begin_step_percent'));
.parse(await getProperty(metadataItem, 'begin_step_percent'));
const end_step_percent = zIPAdapterField.shape.end_step_percent
.nullish()
.catch(null)
.parse(getProperty(metadataItem, 'end_step_percent'));
.parse(await getProperty(metadataItem, 'end_step_percent'));
const ipAdapter: IPAdapterConfigMetadata = {
id: uuidv4(),

View File

@ -1,8 +1,13 @@
import { getStore } from 'app/store/nanostores/store';
import { controlAdapterRecalled } from 'features/controlAdapters/store/controlAdaptersSlice';
import {
controlAdapterRecalled,
controlNetsReset,
ipAdaptersReset,
t2iAdaptersReset,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
import type { LoRA } from 'features/lora/store/loraSlice';
import { loraRecalled } from 'features/lora/store/loraSlice';
import { loraRecalled, lorasReset } from 'features/lora/store/loraSlice';
import type {
ControlNetConfigMetadata,
IPAdapterConfigMetadata,
@ -166,7 +171,11 @@ const recallLoRA: MetadataRecallFunc<LoRA> = (lora) => {
};
const recallAllLoRAs: MetadataRecallFunc<LoRA[]> = (loras) => {
if (!loras.length) {
return;
}
const { dispatch } = getStore();
dispatch(lorasReset());
loras.forEach((lora) => {
dispatch(loraRecalled(lora));
});
@ -177,7 +186,11 @@ const recallControlNet: MetadataRecallFunc<ControlNetConfigMetadata> = (controlN
};
const recallControlNets: MetadataRecallFunc<ControlNetConfigMetadata[]> = (controlNets) => {
if (!controlNets.length) {
return;
}
const { dispatch } = getStore();
dispatch(controlNetsReset());
controlNets.forEach((controlNet) => {
dispatch(controlAdapterRecalled(controlNet));
});
@ -188,7 +201,11 @@ const recallT2IAdapter: MetadataRecallFunc<T2IAdapterConfigMetadata> = (t2iAdapt
};
const recallT2IAdapters: MetadataRecallFunc<T2IAdapterConfigMetadata[]> = (t2iAdapters) => {
if (!t2iAdapters.length) {
return;
}
const { dispatch } = getStore();
dispatch(t2iAdaptersReset());
t2iAdapters.forEach((t2iAdapter) => {
dispatch(controlAdapterRecalled(t2iAdapter));
});
@ -199,7 +216,11 @@ const recallIPAdapter: MetadataRecallFunc<IPAdapterConfigMetadata> = (ipAdapter)
};
const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapters) => {
if (!ipAdapters.length) {
return;
}
const { dispatch } = getStore();
dispatch(ipAdaptersReset());
ipAdapters.forEach((ipAdapter) => {
dispatch(controlAdapterRecalled(ipAdapter));
});

View File

@ -1,6 +1,7 @@
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { isNil } from 'lodash-es';
import { useMemo } from 'react';
@ -8,7 +9,7 @@ import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelCo
import { isNonRefinerMainModelConfig } from 'services/api/types';
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision } = config.sd;
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision, width, height } = config.sd;
return {
initialSteps: steps.initial,
@ -16,14 +17,23 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
initialScheduler: scheduler,
initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial,
initialVaePrecision: vaePrecision,
initialWidth: width.initial,
initialHeight: height.initial,
};
});
export const useMainModelDefaultSettings = (modelKey?: string | null) => {
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(modelKey ?? skipToken, isNonRefinerMainModelConfig);
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
useAppSelector(initialStatesSelector);
const {
initialSteps,
initialCfg,
initialScheduler,
initialCfgRescaleMultiplier,
initialVaePrecision,
initialWidth,
initialHeight,
} = useAppSelector(initialStatesSelector);
const defaultSettingsDefaults = useMemo(() => {
return {
@ -51,15 +61,25 @@ export const useMainModelDefaultSettings = (modelKey?: string | null) => {
isEnabled: !isNil(modelConfig?.default_settings?.cfg_rescale_multiplier),
value: modelConfig?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier,
},
width: {
isEnabled: !isNil(modelConfig?.default_settings?.width),
value: modelConfig?.default_settings?.width || initialWidth,
},
height: {
isEnabled: !isNil(modelConfig?.default_settings?.height),
value: modelConfig?.default_settings?.height || initialHeight,
},
};
}, [
modelConfig?.default_settings,
modelConfig,
initialVaePrecision,
initialScheduler,
initialSteps,
initialCfg,
initialScheduler,
initialCfgRescaleMultiplier,
initialVaePrecision,
initialWidth,
initialHeight,
]);
return { defaultSettingsDefaults, isLoading };
return { defaultSettingsDefaults, isLoading, optimalDimension: getOptimalDimension(modelConfig) };
};

View File

@ -1,13 +1,16 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig } from 'app/store/store';
import type { ModelType } from 'services/api/types';
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'>;
type ModelManagerState = {
_version: 1;
selectedModelKey: string | null;
selectedModelMode: 'edit' | 'view';
searchTerm: string;
filteredModelType: string | null;
filteredModelType: FilterableModelType | null;
scanPath: string | undefined;
};
@ -35,7 +38,7 @@ export const modelManagerV2Slice = createSlice({
state.searchTerm = action.payload;
},
setFilteredModelType: (state, action: PayloadAction<string | null>) => {
setFilteredModelType: (state, action: PayloadAction<FilterableModelType | null>) => {
state.filteredModelType = action.payload;
},
setScanPath: (state, action: PayloadAction<string | undefined>) => {

View File

@ -0,0 +1,102 @@
import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useInstallModelMutation, useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
import { HuggingFaceResults } from './HuggingFaceResults';
export const HuggingFaceForm = () => {
const [huggingFaceRepo, setHuggingFaceRepo] = useState('');
const [displayResults, setDisplayResults] = useState(false);
const [errorMessage, setErrorMessage] = useState('');
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery();
const [installModel] = useInstallModelMutation();
const handleInstallModel = useCallback(
(source: string) => {
installModel({ source })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
},
[installModel, dispatch, t]
);
const getModels = useCallback(async () => {
_getHuggingFaceModels(huggingFaceRepo)
.unwrap()
.then((response) => {
if (response.is_diffusers) {
handleInstallModel(huggingFaceRepo);
setDisplayResults(false);
} else if (response.urls?.length === 1 && response.urls[0]) {
handleInstallModel(response.urls[0]);
setDisplayResults(false);
} else {
setDisplayResults(true);
}
})
.catch((error) => {
setErrorMessage(error.data.detail || '');
});
}, [_getHuggingFaceModels, handleInstallModel, huggingFaceRepo]);
const handleSetHuggingFaceRepo: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
setHuggingFaceRepo(e.target.value);
setErrorMessage('');
}, []);
return (
<Flex flexDir="column" height="100%" gap={3}>
<FormControl isInvalid={!!errorMessage.length} w="full" orientation="vertical" flexShrink={0}>
<FormLabel>{t('modelManager.huggingFaceRepoID')}</FormLabel>
<Flex gap={3} alignItems="center" w="full">
<Input
placeholder={t('modelManager.huggingFacePlaceholder')}
value={huggingFaceRepo}
onChange={handleSetHuggingFaceRepo}
/>
<Button
onClick={getModels}
isLoading={isLoading}
isDisabled={huggingFaceRepo.length === 0}
size="sm"
flexShrink={0}
>
{t('modelManager.installRepo')}
</Button>
</Flex>
<FormHelperText>{t('modelManager.huggingFaceHelper')}</FormHelperText>
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
</FormControl>
{data && data.urls && displayResults && <HuggingFaceResults results={data.urls} />}
</Flex>
);
};

View File

@ -0,0 +1,57 @@
import { Flex, IconButton, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
import { useInstallModelMutation } from 'services/api/endpoints/models';
type Props = {
result: string;
};
export const HuggingFaceResultItem = ({ result }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [installModel] = useInstallModelMutation();
const handleInstall = useCallback(() => {
installModel({ source: result })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}, [installModel, result, dispatch, t]);
return (
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
<Flex fontSize="sm" flexDir="column">
<Text fontWeight="semibold">{result.split('/').slice(-1)[0]}</Text>
<Text variant="subtext" noOfLines={1} wordBreak="break-all">
{result}
</Text>
</Flex>
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleInstall} size="sm" />
</Flex>
);
};

View File

@ -0,0 +1,123 @@
import {
Button,
Divider,
Flex,
Heading,
IconButton,
Input,
InputGroup,
InputRightElement,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import type { ChangeEventHandler } from 'react';
import { useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import { useInstallModelMutation } from 'services/api/endpoints/models';
import { HuggingFaceResultItem } from './HuggingFaceResultItem';
type HuggingFaceResultsProps = {
results: string[];
};
export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
const { t } = useTranslation();
const [searchTerm, setSearchTerm] = useState('');
const dispatch = useAppDispatch();
const [installModel] = useInstallModelMutation();
const filteredResults = useMemo(() => {
return results.filter((result) => {
const modelName = result.split('/').slice(-1)[0];
return modelName?.toLowerCase().includes(searchTerm.toLowerCase());
});
}, [results, searchTerm]);
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
setSearchTerm(e.target.value.trim());
}, []);
const clearSearch = useCallback(() => {
setSearchTerm('');
}, []);
const handleAddAll = useCallback(() => {
for (const result of filteredResults) {
installModel({ source: result })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}
}, [filteredResults, installModel, dispatch, t]);
return (
<>
<Divider />
<Flex flexDir="column" gap={3} height="100%">
<Flex justifyContent="space-between" alignItems="center">
<Heading size="sm">{t('modelManager.availableModels')}</Heading>
<Flex alignItems="center" gap={3}>
<Button size="sm" onClick={handleAddAll} isDisabled={results.length === 0} flexShrink={0}>
{t('modelManager.installAll')}
</Button>
<InputGroup w={64} size="xs">
<Input
placeholder={t('modelManager.search')}
value={searchTerm}
data-testid="board-search-input"
onChange={handleSearch}
size="xs"
/>
{searchTerm && (
<InputRightElement h="full" pe={2}>
<IconButton
size="sm"
variant="link"
aria-label={t('boards.clearSearch')}
icon={<PiXBold />}
onClick={clearSearch}
/>
</InputRightElement>
)}
</InputGroup>
</Flex>
</Flex>
<Flex height="100%" layerStyle="third" borderRadius="base" p={3}>
<ScrollableContent>
<Flex flexDir="column" gap={3}>
{filteredResults.map((result) => (
<HuggingFaceResultItem key={result} result={result} />
))}
</Flex>
</ScrollableContent>
</Flex>
</Flex>
</>
);
};

View File

@ -67,19 +67,21 @@ export const InstallModelForm = () => {
<Flex flexDir="column" gap={4}>
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
<FormControl orientation="vertical">
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
<Input {...register('location')} />
<FormLabel>{t('modelManager.urlOrLocalPath')}</FormLabel>
<Flex alignItems="center" gap={3} w="full">
<Input placeholder={t('modelManager.simpleModelPlaceholder')} {...register('location')} />
<Button
onClick={handleSubmit(onSubmit)}
isDisabled={!formState.dirtyFields.location}
isLoading={isLoading}
type="submit"
size="sm"
>
{t('modelManager.install')}
</Button>
</Flex>
<FormHelperText>{t('modelManager.urlOrLocalPathHelper')}</FormHelperText>
</FormControl>
<Button
onClick={handleSubmit(onSubmit)}
isDisabled={!formState.dirtyFields.location}
isLoading={isLoading}
type="submit"
size="sm"
mb={1}
>
{t('modelManager.addModel')}
</Button>
</Flex>
<FormControl>

View File

@ -1,4 +1,4 @@
import { Box, Button, Flex, Text } from '@invoke-ai/ui-library';
import { Box, Button, Flex, Heading } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { addToast } from 'features/system/store/systemSlice';
@ -50,9 +50,9 @@ export const ModelInstallQueue = () => {
}, [data]);
return (
<Flex flexDir="column" p={3} h="full">
<Flex flexDir="column" p={3} h="full" gap={3}>
<Flex justifyContent="space-between" alignItems="center">
<Text>{t('modelManager.importQueue')}</Text>
<Heading size="sm">{t('modelManager.installQueue')}</Heading>
<Button
size="sm"
isDisabled={!pruneAvailable}
@ -62,9 +62,9 @@ export const ModelInstallQueue = () => {
{t('modelManager.prune')}
</Button>
</Flex>
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
<Box layerStyle="first" p={3} borderRadius="base" w="full" h="full">
<ScrollableContent>
<Flex flexDir="column-reverse" gap="2">
<Flex flexDir="column-reverse" gap="2" w="full">
{data?.map((model) => <ModelInstallQueueItem key={model.id} installJob={model} />)}
</Flex>
</ScrollableContent>

View File

@ -1,4 +1,4 @@
import { Badge, Tooltip } from '@invoke-ai/ui-library';
import { Badge } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import type { ModelInstallStatus } from 'services/api/types';
@ -13,13 +13,7 @@ const STATUSES = {
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
};
const ModelInstallQueueBadge = ({
status,
errorReason,
}: {
status?: ModelInstallStatus;
errorReason?: string | null;
}) => {
const ModelInstallQueueBadge = ({ status }: { status?: ModelInstallStatus }) => {
const { t } = useTranslation();
if (!status || !Object.keys(STATUSES).includes(status)) {
@ -27,9 +21,9 @@ const ModelInstallQueueBadge = ({
}
return (
<Tooltip label={errorReason}>
<Badge colorScheme={STATUSES[status].colorScheme}>{t(STATUSES[status].translationKey)}</Badge>
</Tooltip>
<Badge textAlign="center" w="134px" colorScheme={STATUSES[status].colorScheme}>
{t(STATUSES[status].translationKey)}
</Badge>
);
};
export default memo(ModelInstallQueueBadge);

View File

@ -1,4 +1,4 @@
import { Box, Flex, IconButton, Progress, Text, Tooltip } from '@invoke-ai/ui-library';
import { Flex, IconButton, Progress, Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
@ -7,7 +7,7 @@ import { isNil } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { PiXBold } from 'react-icons/pi';
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
import type { ModelInstallJob } from 'services/api/types';
import ModelInstallQueueBadge from './ModelInstallQueueBadge';
@ -16,7 +16,7 @@ type ModelListItemProps = {
};
const formatBytes = (bytes: number) => {
const units = ['b', 'kb', 'mb', 'gb', 'tb'];
const units = ['B', 'KB', 'MB', 'GB', 'TB'];
let i = 0;
@ -33,18 +33,6 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
const [deleteImportModel] = useCancelModelInstallMutation();
const source = useMemo(() => {
if (installJob.source.type === 'hf') {
return installJob.source as HFModelSource;
} else if (installJob.source.type === 'local') {
return installJob.source as LocalModelSource;
} else if (installJob.source.type === 'url') {
return installJob.source as URLModelSource;
} else {
return installJob.source as LocalModelSource;
}
}, [installJob.source]);
const handleDeleteModelImport = useCallback(() => {
deleteImportModel(installJob.id)
.unwrap()
@ -72,18 +60,31 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
});
}, [deleteImportModel, installJob, dispatch]);
const modelName = useMemo(() => {
switch (source.type) {
const sourceLocation = useMemo(() => {
switch (installJob.source.type) {
case 'hf':
return source.repo_id;
return installJob.source.repo_id;
case 'url':
return source.url;
return installJob.source.url;
case 'local':
return source.path.split('\\').slice(-1)[0];
return installJob.source.path;
default:
return '';
return t('common.unknown');
}
}, [source]);
}, [installJob.source]);
const modelName = useMemo(() => {
switch (installJob.source.type) {
case 'hf':
return installJob.source.repo_id;
case 'url':
return installJob.source.url.split('/').slice(-1)[0] ?? t('common.unknown');
case 'local':
return installJob.source.path.split('\\').slice(-1)[0] ?? t('common.unknown');
default:
return t('common.unknown');
}
}, [installJob.source]);
const progressValue = useMemo(() => {
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
@ -97,48 +98,67 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
return (installJob.bytes / installJob.total_bytes) * 100;
}, [installJob.bytes, installJob.total_bytes]);
return (
<Flex gap={3} w="full" alignItems="center">
<Tooltip maxW={600} label={<TooltipLabel name={modelName} source={sourceLocation} installJob={installJob} />}>
<Flex gap={3} w="full" alignItems="center">
<Text w={96} whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
{modelName}
</Text>
<Progress
w="full"
flexGrow={1}
value={progressValue ?? 0}
isIndeterminate={progressValue === null}
aria-label={t('accessibility.invokeProgressBar')}
h={2}
/>
<ModelInstallQueueBadge status={installJob.status} />
</Flex>
</Tooltip>
<IconButton
isDisabled={
installJob.status !== 'downloading' && installJob.status !== 'waiting' && installJob.status !== 'running'
}
size="xs"
tooltip={t('modelManager.cancel')}
aria-label={t('modelManager.cancel')}
icon={<PiXBold />}
onClick={handleDeleteModelImport}
variant="ghost"
/>
</Flex>
);
};
type TooltipLabelProps = {
installJob: ModelInstallJob;
name: string;
source: string;
};
const TooltipLabel = ({ name, source, installJob }: TooltipLabelProps) => {
const progressString = useMemo(() => {
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
if (installJob.status === 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
return '';
}
return `${formatBytes(installJob.bytes)} / ${formatBytes(installJob.total_bytes)}`;
}, [installJob.bytes, installJob.total_bytes, installJob.status]);
return (
<Flex gap="2" w="full" alignItems="center">
<Tooltip label={modelName}>
<Text width="30%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
{modelName}
</Text>
</Tooltip>
<Flex flexDir="column" flex={1}>
<Tooltip label={progressString}>
<Progress
value={progressValue ?? 0}
isIndeterminate={progressValue === null}
aria-label={t('accessibility.invokeProgressBar')}
h={2}
/>
</Tooltip>
<>
<Flex gap={3} justifyContent="space-between">
<Text fontWeight="semibold">{name}</Text>
{progressString && <Text>{progressString}</Text>}
</Flex>
<Box minW="100px" textAlign="center">
<ModelInstallQueueBadge status={installJob.status} errorReason={installJob.error_reason} />
</Box>
<Box minW="20px">
{(installJob.status === 'downloading' ||
installJob.status === 'waiting' ||
installJob.status === 'running') && (
<IconButton
isRound={true}
size="xs"
tooltip={t('modelManager.cancel')}
aria-label={t('modelManager.cancel')}
icon={<PiXBold />}
onClick={handleDeleteModelImport}
/>
)}
</Box>
</Flex>
<Text fontStyle="italic" wordBreak="break-all">
{source}
</Text>
{installJob.error_reason && (
<Text color="error.500">
{t('queue.failed')}: {installJob.error}
</Text>
)}
</>
);
};

View File

@ -1,4 +1,4 @@
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@invoke-ai/ui-library';
import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setScanPath } from 'features/modelManagerV2/store/modelManagerV2Slice';
import type { ChangeEventHandler } from 'react';
@ -17,11 +17,13 @@ export const ScanModelsForm = () => {
const [_scanFolder, { isLoading, data }] = useLazyScanFolderQuery();
const scanFolder = useCallback(async () => {
_scanFolder({ scan_path: scanPath }).catch((error) => {
if (error) {
setErrorMessage(error.data.detail);
}
});
_scanFolder({ scan_path: scanPath })
.unwrap()
.catch((error) => {
if (error) {
setErrorMessage(error.data.detail);
}
});
}, [_scanFolder, scanPath]);
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback(
@ -33,25 +35,23 @@ export const ScanModelsForm = () => {
);
return (
<Flex flexDir="column" height="100%">
<FormControl isInvalid={!!errorMessage.length} w="full">
<Flex flexDir="column" w="full">
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
<Flex direction="column" w="full">
<FormLabel>{t('common.folder')}</FormLabel>
<Input value={scanPath} onChange={handleSetScanPath} />
</Flex>
<Button
onClick={scanFolder}
isLoading={isLoading}
isDisabled={scanPath === undefined || scanPath.length === 0}
>
{t('modelManager.scanFolder')}
</Button>
</Flex>
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
<Flex flexDir="column" height="100%" gap={3}>
<FormControl isInvalid={!!errorMessage.length} w="full" orientation="vertical" flexShrink={0}>
<FormLabel>{t('common.folder')}</FormLabel>
<Flex gap={3} alignItems="center" w="full">
<Input placeholder={t('modelManager.scanPlaceholder')} value={scanPath} onChange={handleSetScanPath} />
<Button
onClick={scanFolder}
isLoading={isLoading}
isDisabled={scanPath === undefined || scanPath.length === 0}
size="sm"
flexShrink={0}
>
{t('modelManager.scanFolder')}
</Button>
</Flex>
<FormHelperText>{t('modelManager.scanFolderHelper')}</FormHelperText>
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
</FormControl>
{data && <ScanModelsResults results={data} />}
</Flex>

View File

@ -1,10 +1,10 @@
import { Badge, Box, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { IoAdd } from 'react-icons/io5';
import { PiPlusBold } from 'react-icons/pi';
import type { ScanFolderResponse } from 'services/api/endpoints/models';
import { useInstallModelMutation } from 'services/api/endpoints/models';
@ -45,7 +45,7 @@ export const ScanModelResultItem = ({ result }: Props) => {
}, [installModel, result, dispatch, t]);
return (
<Flex justifyContent="space-between">
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
<Flex fontSize="sm" flexDir="column">
<Text fontWeight="semibold">{result.path.split('\\').slice(-1)[0]}</Text>
<Text variant="subtext">{result.path}</Text>
@ -54,9 +54,7 @@ export const ScanModelResultItem = ({ result }: Props) => {
{result.is_installed ? (
<Badge>{t('common.installed')}</Badge>
) : (
<Tooltip label={t('modelManager.quickAdd')}>
<IconButton aria-label={t('modelManager.quickAdd')} icon={<IoAdd />} onClick={handleQuickAdd} />
</Tooltip>
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleQuickAdd} size="sm" />
)}
</Box>
</Flex>

View File

@ -80,17 +80,15 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
return (
<>
<Divider mt={4} />
<Flex flexDir="column" gap={2} mt={4} height="100%">
<Divider />
<Flex flexDir="column" gap={3} height="100%">
<Flex justifyContent="space-between" alignItems="center">
<Heading fontSize="md" as="h4">
{t('modelManager.scanResults')}
</Heading>
<Flex alignItems="center" gap="4">
<Button onClick={handleAddAll} isDisabled={filteredResults.length === 0}>
{t('modelManager.addAll')}
<Heading size="sm">{t('modelManager.scanResults')}</Heading>
<Flex alignItems="center" gap={3}>
<Button size="sm" onClick={handleAddAll} isDisabled={filteredResults.length === 0}>
{t('modelManager.installAll')}
</Button>
<InputGroup maxW="300px" size="xs">
<InputGroup w={64} size="xs">
<Input
placeholder={t('modelManager.search')}
value={searchTerm}
@ -107,13 +105,14 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
aria-label={t('boards.clearSearch')}
icon={<PiXBold />}
onClick={clearSearch}
flexShrink={0}
/>
</InputRightElement>
)}
</InputGroup>
</Flex>
</Flex>
<Flex height="100%" layerStyle="third" borderRadius="base" p={4} mt={4} mb={4}>
<Flex height="100%" layerStyle="third" borderRadius="base" p={3}>
<ScrollableContent>
<Flex flexDir="column" gap={3}>
{filteredResults.map((result) => (

View File

@ -1,6 +1,7 @@
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { useTranslation } from 'react-i18next';
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
@ -8,27 +9,27 @@ import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
export const InstallModels = () => {
const { t } = useTranslation();
return (
<Flex layerStyle="first" p={3} borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
<Box w="full" p={2}>
<Heading fontSize="xl">{t('modelManager.addModel')}</Heading>
</Box>
<Box layerStyle="second" borderRadius="base" w="full" h="50%" overflow="hidden">
<Tabs variant="collapse" height="100%">
<TabList>
<Tab>{t('common.simple')}</Tab>
<Tab>{t('modelManager.scan')}</Tab>
</TabList>
<TabPanels p={3} height="100%">
<TabPanel>
<InstallModelForm />
</TabPanel>
<TabPanel height="100%">
<ScanModelsForm />
</TabPanel>
</TabPanels>
</Tabs>
</Box>
<Box layerStyle="second" borderRadius="base" w="full" h="50%">
<Flex layerStyle="first" borderRadius="base" w="full" h="full" flexDir="column" gap={4}>
<Heading fontSize="xl">{t('modelManager.addModel')}</Heading>
<Tabs variant="collapse" height="50%" display="flex" flexDir="column">
<TabList>
<Tab>{t('modelManager.urlOrLocalPath')}</Tab>
<Tab>{t('modelManager.huggingFace')}</Tab>
<Tab>{t('modelManager.scanFolder')}</Tab>
</TabList>
<TabPanels p={3} height="100%">
<TabPanel>
<InstallModelForm />
</TabPanel>
<TabPanel height="100%">
<HuggingFaceForm />
</TabPanel>
<TabPanel height="100%">
<ScanModelsForm />
</TabPanel>
</TabPanels>
</Tabs>
<Box layerStyle="second" borderRadius="base" h="50%">
<ModelInstallQueue />
</Box>
</Flex>

View File

@ -1,122 +1,105 @@
import { Flex, Spinner, Text } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { forEach } from 'lodash-es';
import { memo } from 'react';
import { ALL_BASE_MODELS } from 'services/api/constants';
import { memo, useMemo } from 'react';
import {
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetLoRAModelsQuery,
useGetMainModelsQuery,
useGetT2IAdapterModelsQuery,
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
} from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
useControlNetModels,
useEmbeddingModels,
useIPAdapterModels,
useLoRAModels,
useMainModels,
useT2IAdapterModels,
useVAEModels,
} from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, ModelType } from 'services/api/types';
import { ModelListWrapper } from './ModelListWrapper';
const ModelList = () => {
const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2);
const { filteredMainModels, isLoadingMainModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data, isLoading }) => ({
filteredMainModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingMainModels: isLoading,
}),
});
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(undefined, {
selectFromResult: ({ data, isLoading }) => ({
filteredLoraModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingLoraModels: isLoading,
}),
});
const { filteredTextualInversionModels, isLoadingTextualInversionModels } = useGetTextualInversionModelsQuery(
undefined,
{
selectFromResult: ({ data, isLoading }) => ({
filteredTextualInversionModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingTextualInversionModels: isLoading,
}),
}
const [mainModels, { isLoading: isLoadingMainModels }] = useMainModels();
const filteredMainModels = useMemo(
() => modelsFilter(mainModels, searchTerm, filteredModelType),
[mainModels, searchTerm, filteredModelType]
);
const { filteredControlnetModels, isLoadingControlnetModels } = useGetControlNetModelsQuery(undefined, {
selectFromResult: ({ data, isLoading }) => ({
filteredControlnetModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingControlnetModels: isLoading,
}),
});
const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels();
const filteredLoRAModels = useMemo(
() => modelsFilter(loraModels, searchTerm, filteredModelType),
[loraModels, searchTerm, filteredModelType]
);
const { filteredT2iAdapterModels, isLoadingT2IAdapterModels } = useGetT2IAdapterModelsQuery(undefined, {
selectFromResult: ({ data, isLoading }) => ({
filteredT2iAdapterModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingT2IAdapterModels: isLoading,
}),
});
const [embeddingModels, { isLoading: isLoadingEmbeddingModels }] = useEmbeddingModels();
const filteredEmbeddingModels = useMemo(
() => modelsFilter(embeddingModels, searchTerm, filteredModelType),
[embeddingModels, searchTerm, filteredModelType]
);
const { filteredIpAdapterModels, isLoadingIpAdapterModels } = useGetIPAdapterModelsQuery(undefined, {
selectFromResult: ({ data, isLoading }) => ({
filteredIpAdapterModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingIpAdapterModels: isLoading,
}),
});
const [controlNetModels, { isLoading: isLoadingControlNetModels }] = useControlNetModels();
const filteredControlNetModels = useMemo(
() => modelsFilter(controlNetModels, searchTerm, filteredModelType),
[controlNetModels, searchTerm, filteredModelType]
);
const { filteredVaeModels, isLoadingVaeModels } = useGetVaeModelsQuery(undefined, {
selectFromResult: ({ data, isLoading }) => ({
filteredVaeModels: modelsFilter(data, searchTerm, filteredModelType),
isLoadingVaeModels: isLoading,
}),
});
const [t2iAdapterModels, { isLoading: isLoadingT2IAdapterModels }] = useT2IAdapterModels();
const filteredT2IAdapterModels = useMemo(
() => modelsFilter(t2iAdapterModels, searchTerm, filteredModelType),
[t2iAdapterModels, searchTerm, filteredModelType]
);
const [ipAdapterModels, { isLoading: isLoadingIPAdapterModels }] = useIPAdapterModels();
const filteredIPAdapterModels = useMemo(
() => modelsFilter(ipAdapterModels, searchTerm, filteredModelType),
[ipAdapterModels, searchTerm, filteredModelType]
);
const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels();
const filteredVAEModels = useMemo(
() => modelsFilter(vaeModels, searchTerm, filteredModelType),
[vaeModels, searchTerm, filteredModelType]
);
return (
<ScrollableContent>
<Flex flexDirection="column" w="full" h="full" gap={4}>
{/* Main Model List */}
{isLoadingMainModels && <FetchingModelsLoader loadingMessage="Loading Main..." />}
{isLoadingMainModels && <FetchingModelsLoader loadingMessage="Loading Main Models..." />}
{!isLoadingMainModels && filteredMainModels.length > 0 && (
<ModelListWrapper title="Main" modelList={filteredMainModels} key="main" />
)}
{/* LoRAs List */}
{isLoadingLoraModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
{!isLoadingLoraModels && filteredLoraModels.length > 0 && (
<ModelListWrapper title="LoRAs" modelList={filteredLoraModels} key="loras" />
{isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
{!isLoadingLoRAModels && filteredLoRAModels.length > 0 && (
<ModelListWrapper title="LoRA" modelList={filteredLoRAModels} key="loras" />
)}
{/* TI List */}
{isLoadingTextualInversionModels && <FetchingModelsLoader loadingMessage="Loading Textual Inversions..." />}
{!isLoadingTextualInversionModels && filteredTextualInversionModels.length > 0 && (
<ModelListWrapper
title="Textual Inversions"
modelList={filteredTextualInversionModels}
key="textual-inversions"
/>
{isLoadingEmbeddingModels && <FetchingModelsLoader loadingMessage="Loading Embeddings..." />}
{!isLoadingEmbeddingModels && filteredEmbeddingModels.length > 0 && (
<ModelListWrapper title="Embedding" modelList={filteredEmbeddingModels} key="textual-inversions" />
)}
{/* VAE List */}
{isLoadingVaeModels && <FetchingModelsLoader loadingMessage="Loading VAEs..." />}
{!isLoadingVaeModels && filteredVaeModels.length > 0 && (
<ModelListWrapper title="VAEs" modelList={filteredVaeModels} key="vae" />
{isLoadingVAEModels && <FetchingModelsLoader loadingMessage="Loading VAEs..." />}
{!isLoadingVAEModels && filteredVAEModels.length > 0 && (
<ModelListWrapper title="VAE" modelList={filteredVAEModels} key="vae" />
)}
{/* Controlnet List */}
{isLoadingControlnetModels && <FetchingModelsLoader loadingMessage="Loading Controlnets..." />}
{!isLoadingControlnetModels && filteredControlnetModels.length > 0 && (
<ModelListWrapper title="Controlnets" modelList={filteredControlnetModels} key="controlnets" />
{isLoadingControlNetModels && <FetchingModelsLoader loadingMessage="Loading ControlNets..." />}
{!isLoadingControlNetModels && filteredControlNetModels.length > 0 && (
<ModelListWrapper title="ControlNet" modelList={filteredControlNetModels} key="controlnets" />
)}
{/* IP Adapter List */}
{isLoadingIpAdapterModels && <FetchingModelsLoader loadingMessage="Loading IP Adapters..." />}
{!isLoadingIpAdapterModels && filteredIpAdapterModels.length > 0 && (
<ModelListWrapper title="IP Adapters" modelList={filteredIpAdapterModels} key="ip-adapters" />
{isLoadingIPAdapterModels && <FetchingModelsLoader loadingMessage="Loading IP Adapters..." />}
{!isLoadingIPAdapterModels && filteredIPAdapterModels.length > 0 && (
<ModelListWrapper title="IP Adapter" modelList={filteredIPAdapterModels} key="ip-adapters" />
)}
{/* T2I Adapters List */}
{isLoadingT2IAdapterModels && <FetchingModelsLoader loadingMessage="Loading T2I Adapters..." />}
{!isLoadingT2IAdapterModels && filteredT2iAdapterModels.length > 0 && (
<ModelListWrapper title="T2I Adapters" modelList={filteredT2iAdapterModels} key="t2i-adapters" />
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
<ModelListWrapper title="T2I Adapter" modelList={filteredT2IAdapterModels} key="t2i-adapters" />
)}
</Flex>
</ScrollableContent>
@ -126,25 +109,16 @@ const ModelList = () => {
export default memo(ModelList);
const modelsFilter = <T extends AnyModelConfig>(
data: EntityState<T, string> | undefined,
data: T[],
nameFilter: string,
filteredModelType: string | null
filteredModelType: ModelType | null
): T[] => {
const filteredModels: T[] = [];
forEach(data?.entities, (model) => {
if (!model) {
return;
}
return data.filter((model) => {
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
const matchesType = filteredModelType ? model.type === filteredModelType : true;
if (matchesFilter && matchesType) {
filteredModels.push(model);
}
return matchesFilter && matchesType;
});
return filteredModels;
};
const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => {

View File

@ -1,11 +1,13 @@
import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { IoFilter } from 'react-icons/io5';
import { PiFunnelBold } from 'react-icons/pi';
import { objectKeys } from 'tsafe';
const MODEL_TYPE_LABELS: { [key: string]: string } = {
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = {
main: 'Main',
lora: 'LoRA',
embedding: 'Textual Inversion',
@ -13,7 +15,6 @@ const MODEL_TYPE_LABELS: { [key: string]: string } = {
vae: 'VAE',
t2i_adapter: 'T2I Adapter',
ip_adapter: 'IP Adapter',
clip_vision: 'Clip Vision',
};
export const ModelTypeFilter = () => {
@ -22,7 +23,7 @@ export const ModelTypeFilter = () => {
const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType);
const selectModelType = useCallback(
(option: string) => {
(option: FilterableModelType) => {
dispatch(setFilteredModelType(option));
},
[dispatch]
@ -34,12 +35,12 @@ export const ModelTypeFilter = () => {
return (
<Menu>
<MenuButton as={Button} size="sm" leftIcon={<IoFilter />}>
<MenuButton as={Button} size="sm" leftIcon={<PiFunnelBold />}>
{filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')}
</MenuButton>
<MenuList>
<MenuItem onClick={clearModelType}>{t('modelManager.allModels')}</MenuItem>
{Object.keys(MODEL_TYPE_LABELS).map((option) => (
{objectKeys(MODEL_TYPE_LABELS).map((option) => (
<MenuItem
key={option}
bg={filteredModelType === option ? 'base.700' : 'transparent'}

View File

@ -0,0 +1,81 @@
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings';
type DefaultHeightType = MainModelDefaultSettingsFormData['height'];
type Props = {
control: UseControllerProps<MainModelDefaultSettingsFormData>['control'];
optimalDimension: number;
};
export function DefaultHeight({ control, optimalDimension }: Props) {
const { field } = useController({ control, name: 'height' });
const sliderMin = useAppSelector((s) => s.config.sd.height.sliderMin);
const sliderMax = useAppSelector((s) => s.config.sd.height.sliderMax);
const numberInputMin = useAppSelector((s) => s.config.sd.height.numberInputMin);
const numberInputMax = useAppSelector((s) => s.config.sd.height.numberInputMax);
const coarseStep = useAppSelector((s) => s.config.sd.height.coarseStep);
const fineStep = useAppSelector((s) => s.config.sd.height.fineStep);
const { t } = useTranslation();
const marks = useMemo(() => [sliderMin, optimalDimension, sliderMax], [sliderMin, optimalDimension, sliderMax]);
const onChange = useCallback(
(v: number) => {
const updatedValue = {
...(field.value as DefaultHeightType),
value: v,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => {
return field.value.value;
}, [field.value]);
const isDisabled = useMemo(() => {
return !field.value.isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={2} alignItems="flex-start">
<Flex justifyContent="space-between" w="full">
<InformationalPopover feature="paramHeight">
<FormLabel>{t('parameters.height')}</FormLabel>
</InformationalPopover>
<SettingToggle control={control} name="height" />
</Flex>
<Flex w="full" gap={4}>
<CompositeSlider
value={value}
min={sliderMin}
max={sliderMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
marks={marks}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={numberInputMin}
max={numberInputMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
}

View File

@ -4,12 +4,12 @@ import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useGetModelConfigQuery, useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import { useVAEModels } from 'services/api/hooks/modelsByType';
import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings';
@ -21,18 +21,16 @@ export function DefaultVae(props: UseControllerProps<MainModelDefaultSettingsFor
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const { compatibleOptions } = useGetVaeModelsQuery(undefined, {
selectFromResult: ({ data }) => {
const modelArray = map(data?.entities);
const compatibleOptions = modelArray
.filter((vae) => vae.base === modelData?.base)
.map((vae) => ({ label: vae.name, value: vae.key }));
const [vaeModels] = useVAEModels();
const compatibleOptions = useMemo(() => {
const compatibleOptions = vaeModels
.filter((vae) => vae.base === modelData?.base)
.map((vae) => ({ label: vae.name, value: vae.key }));
const defaultOption = { label: 'Default VAE', value: 'default' };
const defaultOption = { label: 'Default VAE', value: 'default' };
return { compatibleOptions: [defaultOption, ...compatibleOptions] };
},
});
return [defaultOption, ...compatibleOptions];
}, [modelData?.base, vaeModels]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {

View File

@ -0,0 +1,81 @@
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings';
type DefaultWidthType = MainModelDefaultSettingsFormData['width'];
type Props = {
control: UseControllerProps<MainModelDefaultSettingsFormData>['control'];
optimalDimension: number;
};
export function DefaultWidth({ control, optimalDimension }: Props) {
const { field } = useController({ control, name: 'width' });
const sliderMin = useAppSelector((s) => s.config.sd.width.sliderMin);
const sliderMax = useAppSelector((s) => s.config.sd.width.sliderMax);
const numberInputMin = useAppSelector((s) => s.config.sd.width.numberInputMin);
const numberInputMax = useAppSelector((s) => s.config.sd.width.numberInputMax);
const coarseStep = useAppSelector((s) => s.config.sd.width.coarseStep);
const fineStep = useAppSelector((s) => s.config.sd.width.fineStep);
const { t } = useTranslation();
const marks = useMemo(() => [sliderMin, optimalDimension, sliderMax], [sliderMin, optimalDimension, sliderMax]);
const onChange = useCallback(
(v: number) => {
const updatedValue = {
...(field.value as DefaultWidthType),
value: v,
};
field.onChange(updatedValue);
},
[field]
);
const value = useMemo(() => {
return (field.value as DefaultWidthType).value;
}, [field.value]);
const isDisabled = useMemo(() => {
return !(field.value as DefaultWidthType).isEnabled;
}, [field.value]);
return (
<FormControl flexDir="column" gap={2} alignItems="flex-start">
<Flex justifyContent="space-between" w="full">
<InformationalPopover feature="paramWidth">
<FormLabel>{t('parameters.width')}</FormLabel>
</InformationalPopover>
<SettingToggle control={control} name="width" />
</Flex>
<Flex w="full" gap={4}>
<CompositeSlider
value={value}
min={sliderMin}
max={sliderMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
marks={marks}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={numberInputMin}
max={numberInputMax}
step={coarseStep}
fineStep={fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
}

View File

@ -1,6 +1,8 @@
import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings';
import { DefaultHeight } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultHeight';
import { DefaultWidth } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultWidth';
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
@ -25,11 +27,13 @@ export interface FormField<T> {
export type MainModelDefaultSettingsFormData = {
vae: FormField<string>;
vaePrecision: FormField<string>;
vaePrecision: FormField<'fp16' | 'fp32'>;
scheduler: FormField<ParameterScheduler>;
steps: FormField<number>;
cfgScale: FormField<number>;
cfgRescaleMultiplier: FormField<number>;
width: FormField<number>;
height: FormField<number>;
};
export const MainModelDefaultSettings = () => {
@ -37,8 +41,11 @@ export const MainModelDefaultSettings = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { defaultSettingsDefaults, isLoading: isLoadingDefaultSettings } =
useMainModelDefaultSettings(selectedModelKey);
const {
defaultSettingsDefaults,
isLoading: isLoadingDefaultSettings,
optimalDimension,
} = useMainModelDefaultSettings(selectedModelKey);
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
@ -59,6 +66,8 @@ export const MainModelDefaultSettings = () => {
cfg_rescale_multiplier: data.cfgRescaleMultiplier.isEnabled ? data.cfgRescaleMultiplier.value : null,
steps: data.steps.isEnabled ? data.steps.value : null,
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
width: data.width.isEnabled ? data.width.value : null,
height: data.height.isEnabled ? data.height.value : null,
};
updateModel({
@ -139,6 +148,14 @@ export const MainModelDefaultSettings = () => {
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
</Flex>
</Flex>
<Flex gap={8}>
<Flex gap={4} w="full">
<DefaultWidth control={control} optimalDimension={optimalDimension} />
</Flex>
<Flex gap={4} w="full">
<DefaultHeight control={control} optimalDimension={optimalDimension} />
</Flex>
</Flex>
</Flex>
</>
);

View File

@ -90,7 +90,7 @@ export const TriggerPhrases = () => {
size="sm"
type="submit"
onClick={addTriggerPhrase}
isDisabled={Boolean(errors.length)}
isDisabled={!phrase || Boolean(errors.length)}
isLoading={isLoading}
>
{t('common.add')}

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import { useControlNetModels } from 'services/api/hooks/modelsByType';
import type { ControlNetModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -14,7 +14,7 @@ type Props = FieldComponentProps<ControlNetModelFieldInputInstance, ControlNetMo
const ControlNetModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetControlNetModelsQuery();
const [modelConfigs, { isLoading }] = useControlNetModels();
const _onChange = useCallback(
(value: ControlNetModelConfig | null) => {
@ -33,7 +33,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
import type { IPAdapterModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -14,7 +14,7 @@ const IPAdapterModelFieldInputComponent = (
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
const [modelConfigs, { isLoading }] = useIPAdapterModels();
const _onChange = useCallback(
(value: IPAdapterModelConfig | null) => {
@ -33,9 +33,10 @@ const IPAdapterModelFieldInputComponent = (
);
const { options, value, onChange } = useGroupedModelCombobox({
modelEntities: ipAdapterModels,
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import { useLoRAModels } from 'services/api/hooks/modelsByType';
import type { LoRAModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -14,7 +14,7 @@ type Props = FieldComponentProps<LoRAModelFieldInputInstance, LoRAModelFieldInpu
const LoRAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetLoRAModelsQuery();
const [modelConfigs, { isLoading }] = useLoRAModels();
const _onChange = useCallback(
(value: LoRAModelConfig | null) => {
if (!value) {
@ -32,7 +32,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,

View File

@ -5,8 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { NON_SDXL_MAIN_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useNonSDXLMainModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -16,7 +15,7 @@ type Props = FieldComponentProps<MainModelFieldInputInstance, MainModelFieldInpu
const MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetMainModelsQuery(NON_SDXL_MAIN_MODELS);
const [modelConfigs, { isLoading }] = useNonSDXLMainModels();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
@ -33,7 +32,7 @@ const MainModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,

View File

@ -8,8 +8,7 @@ import type {
SDXLRefinerModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useRefinerModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -19,7 +18,7 @@ type Props = FieldComponentProps<SDXLRefinerModelFieldInputInstance, SDXLRefiner
const RefinerModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS);
const [modelConfigs, { isLoading }] = useRefinerModels();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
@ -36,7 +35,7 @@ const RefinerModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,

View File

@ -5,8 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { SDXL_MAIN_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useSDXLModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -16,7 +15,7 @@ type Props = FieldComponentProps<SDXLMainModelFieldInputInstance, SDXLMainModelF
const SDXLMainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetMainModelsQuery(SDXL_MAIN_MODELS);
const [modelConfigs, { isLoading }] = useSDXLModels();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
@ -33,7 +32,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,

View File

@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
import { useT2IAdapterModels } from 'services/api/hooks/modelsByType';
import type { T2IAdapterModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -15,7 +15,7 @@ const T2IAdapterModelFieldInputComponent = (
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
const [modelConfigs, { isLoading }] = useT2IAdapterModels();
const _onChange = useCallback(
(value: T2IAdapterModelConfig | null) => {
@ -34,9 +34,10 @@ const T2IAdapterModelFieldInputComponent = (
);
const { options, value, onChange } = useGroupedModelCombobox({
modelEntities: t2iAdapterModels,
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (

View File

@ -5,7 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { useVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
@ -15,7 +15,7 @@ type Props = FieldComponentProps<VAEModelFieldInputInstance, VAEModelFieldInputT
const VAEModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { data, isLoading } = useGetVaeModelsQuery();
const [modelConfigs, { isLoading }] = useVAEModels();
const _onChange = useCallback(
(value: VAEModelConfig | null) => {
if (!value) {
@ -32,7 +32,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,

View File

@ -1,16 +1,18 @@
import type { RootState } from 'app/store/store';
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type { ControlAdapterProcessorType, ControlNetConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common';
import type {
CollectInvocation,
ControlNetInvocation,
CoreMetadataInvocation,
NonNullableGraph,
S,
} from 'services/api/types';
import { isControlNetModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { CONTROL_NET_COLLECT } from './constants';
import { getModelMetadataField, upsertMetadata } from './metadata';
import { upsertMetadata } from './metadata';
export const addControlNetToLinearGraph = async (
state: RootState,
@ -18,12 +20,14 @@ export const addControlNetToLinearGraph = async (
baseNodeId: string
): Promise<void> => {
const validControlNets = selectValidControlNets(state.controlAdapters).filter(
(ca) => ca.model?.base === state.generation.model?.base
);
({ model, processedControlImage, processorType, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
const doesBaseMatch = model?.base === state.generation.model?.base;
const hasControlImage = (processedControlImage && processorType !== 'none') || controlImage;
// const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
// | MetadataAccumulatorInvocation
// | undefined;
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
}
);
const controlNetMetadata: CoreMetadataInvocation['controlnets'] = [];
@ -43,7 +47,7 @@ export const addControlNetToLinearGraph = async (
},
});
validControlNets.forEach(async (controlNet) => {
for (const controlNet of validControlNets) {
if (!controlNet.model) {
return;
}
@ -70,36 +74,12 @@ export const addControlNetToLinearGraph = async (
resize_mode: resizeMode,
control_model: model,
control_weight: weight,
image: buildControlImage(controlImage, processedControlImage, processorType),
};
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
controlNetNode.image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
controlNetNode.image = {
image_name: controlImage,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
return;
}
graph.nodes[controlNetNode.id] = controlNetNode;
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isControlNetModelConfig);
controlNetMetadata.push({
control_model: getModelMetadataField(modelConfig),
control_weight: weight,
control_mode: controlMode,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
image: controlNetNode.image,
});
controlNetMetadata.push(buildControlNetMetadata(controlNet));
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },
@ -108,7 +88,66 @@ export const addControlNetToLinearGraph = async (
field: 'item',
},
});
});
}
upsertMetadata(graph, { controlnets: controlNetMetadata });
}
};
const buildControlImage = (
controlImage: string | null,
processedControlImage: string | null,
processorType: ControlAdapterProcessorType
): ImageField => {
let image: ImageField | null = null;
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
image = {
image_name: controlImage,
};
}
assert(image, 'ControlNet image is required');
return image;
};
const buildControlNetMetadata = (controlNet: ControlNetConfig): S['ControlNetMetadataField'] => {
const {
controlImage,
processedControlImage,
beginStepPct,
endStepPct,
controlMode,
resizeMode,
model,
processorType,
weight,
} = controlNet;
assert(model, 'ControlNet model is required');
const processed_image =
processedControlImage && processorType !== 'none'
? {
image_name: processedControlImage,
}
: null;
assert(controlImage, 'ControlNet image is required');
return {
control_model: model,
control_weight: weight,
control_mode: controlMode,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
image: {
image_name: controlImage,
},
processed_image,
};
};

View File

@ -1,25 +1,30 @@
import type { RootState } from 'app/store/store';
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common';
import type {
CollectInvocation,
CoreMetadataInvocation,
IPAdapterInvocation,
NonNullableGraph,
S,
} from 'services/api/types';
import { isIPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { IP_ADAPTER_COLLECT } from './constants';
import { getModelMetadataField, upsertMetadata } from './metadata';
import { upsertMetadata } from './metadata';
export const addIPAdapterToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(
(ca) => ca.model?.base === state.generation.model?.base
);
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(({ model, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
const doesBaseMatch = model?.base === state.generation.model?.base;
const hasControlImage = controlImage;
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
});
if (validIPAdapters.length) {
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
@ -39,11 +44,14 @@ export const addIPAdapterToLinearGraph = async (
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
validIPAdapters.forEach(async (ipAdapter) => {
for (const ipAdapter of validIPAdapters) {
if (!ipAdapter.model) {
return;
}
const { id, weight, model, beginStepPct, endStepPct } = ipAdapter;
const { id, weight, model, beginStepPct, endStepPct, controlImage } = ipAdapter;
assert(controlImage, 'IP Adapter image is required');
const ipAdapterNode: IPAdapterInvocation = {
id: `ip_adapter_${id}`,
type: 'ip_adapter',
@ -52,27 +60,14 @@ export const addIPAdapterToLinearGraph = async (
ip_adapter_model: model,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image: {
image_name: controlImage,
},
};
if (ipAdapter.controlImage) {
ipAdapterNode.image = {
image_name: ipAdapter.controlImage,
};
} else {
return;
}
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isIPAdapterModelConfig);
ipAdapterMetdata.push({
weight: weight,
ip_adapter_model: getModelMetadataField(modelConfig),
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image: ipAdapterNode.image,
});
ipAdapterMetdata.push(buildIPAdapterMetadata(ipAdapter));
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
@ -81,8 +76,32 @@ export const addIPAdapterToLinearGraph = async (
field: 'item',
},
});
});
}
upsertMetadata(graph, { ipAdapters: ipAdapterMetdata });
}
};
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
const { controlImage, beginStepPct, endStepPct, model, weight } = ipAdapter;
assert(model, 'IP Adapter model is required');
let image: ImageField | null = null;
if (controlImage) {
image = {
image_name: controlImage,
};
}
assert(image, 'IP Adapter image is required');
return {
ip_adapter_model: model,
weight,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image,
};
};

View File

@ -1,16 +1,18 @@
import type { RootState } from 'app/store/store';
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
type CollectInvocation,
type CoreMetadataInvocation,
isT2IAdapterModelConfig,
type NonNullableGraph,
type T2IAdapterInvocation,
import type { ControlAdapterProcessorType, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import type { ImageField } from 'features/nodes/types/common';
import type {
CollectInvocation,
CoreMetadataInvocation,
NonNullableGraph,
S,
T2IAdapterInvocation,
} from 'services/api/types';
import { assert } from 'tsafe';
import { T2I_ADAPTER_COLLECT } from './constants';
import { getModelMetadataField, upsertMetadata } from './metadata';
import { upsertMetadata } from './metadata';
export const addT2IAdaptersToLinearGraph = async (
state: RootState,
@ -18,7 +20,13 @@ export const addT2IAdaptersToLinearGraph = async (
baseNodeId: string
): Promise<void> => {
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
(ca) => ca.model?.base === state.generation.model?.base
({ model, processedControlImage, processorType, controlImage, isEnabled }) => {
const hasModel = Boolean(model);
const doesBaseMatch = model?.base === state.generation.model?.base;
const hasControlImage = (processedControlImage && processorType !== 'none') || controlImage;
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
}
);
if (validT2IAdapters.length) {
@ -39,7 +47,7 @@ export const addT2IAdaptersToLinearGraph = async (
const t2iAdapterMetadata: CoreMetadataInvocation['t2iAdapters'] = [];
validT2IAdapters.forEach(async (t2iAdapter) => {
for (const t2iAdapter of validT2IAdapters) {
if (!t2iAdapter.model) {
return;
}
@ -64,35 +72,12 @@ export const addT2IAdaptersToLinearGraph = async (
resize_mode: resizeMode,
t2i_adapter_model: model,
weight: weight,
image: buildControlImage(controlImage, processedControlImage, processorType),
};
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
t2iAdapterNode.image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
t2iAdapterNode.image = {
image_name: controlImage,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
return;
}
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
const modelConfig = await fetchModelConfigWithTypeGuard(t2iAdapter.model.key, isT2IAdapterModelConfig);
t2iAdapterMetadata.push({
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
t2i_adapter_model: getModelMetadataField(modelConfig),
weight: weight,
image: t2iAdapterNode.image,
});
t2iAdapterMetadata.push(buildT2IAdapterMetadata(t2iAdapter));
graph.edges.push({
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
@ -101,8 +86,57 @@ export const addT2IAdaptersToLinearGraph = async (
field: 'item',
},
});
});
}
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetadata });
}
};
const buildControlImage = (
controlImage: string | null,
processedControlImage: string | null,
processorType: ControlAdapterProcessorType
): ImageField => {
let image: ImageField | null = null;
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
image = {
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
image = {
image_name: controlImage,
};
}
assert(image, 'T2I Adapter image is required');
return image;
};
const buildT2IAdapterMetadata = (t2iAdapter: T2IAdapterConfig): S['T2IAdapterMetadataField'] => {
const { controlImage, processedControlImage, beginStepPct, endStepPct, resizeMode, model, processorType, weight } =
t2iAdapter;
assert(model, 'T2I Adapter model is required');
const processed_image =
processedControlImage && processorType !== 'none'
? {
image_name: processedControlImage,
}
: null;
assert(controlImage, 'T2I Adapter image is required');
return {
t2i_adapter_model: model,
weight,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
image: {
image_name: controlImage,
},
processed_image,
};
};

View File

@ -1,7 +1,5 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type { ModelMetadataField, NonNullableGraph } from 'services/api/types';
import { isVAEModelConfig } from 'services/api/types';
import type { NonNullableGraph } from 'services/api/types';
import {
CANVAS_IMAGE_TO_IMAGE_GRAPH,
@ -25,7 +23,7 @@ import {
TEXT_TO_IMAGE_GRAPH,
VAE_LOADER,
} from './constants';
import { getModelMetadataField, upsertMetadata } from './metadata';
import { upsertMetadata } from './metadata';
export const addVAEToGraph = async (
state: RootState,
@ -151,8 +149,6 @@ export const addVAEToGraph = async (
}
if (vae) {
const modelConfig = await fetchModelConfigWithTypeGuard(vae.key, isVAEModelConfig);
const vaeMetadata: ModelMetadataField = getModelMetadataField(modelConfig);
upsertMetadata(graph, { vae: vaeMetadata });
upsertMetadata(graph, { vae });
}
};

View File

@ -1,5 +1,6 @@
import type { JSONObject } from 'common/types';
import type { AnyModelConfig, CoreMetadataInvocation, ModelMetadataField, NonNullableGraph } from 'services/api/types';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import type { AnyModelConfig, CoreMetadataInvocation, NonNullableGraph } from 'services/api/types';
import { METADATA } from './constants';
@ -72,7 +73,7 @@ export const setMetadataReceivingNode = (graph: NonNullableGraph, nodeId: string
});
};
export const getModelMetadataField = ({ key, hash, name, base, type }: AnyModelConfig): ModelMetadataField => ({
export const getModelMetadataField = ({ key, hash, name, base, type }: AnyModelConfig): ModelIdentifierField => ({
key,
hash,
name,

View File

@ -8,8 +8,7 @@ import { modelSelected } from 'features/parameters/store/actions';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useMainModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
@ -18,7 +17,7 @@ const ParamMainModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const selectedModel = useAppSelector(selectModel);
const { data, isLoading } = useGetMainModelsQuery(NON_REFINER_BASE_MODELS);
const [modelConfigs, { isLoading }] = useMainModels();
const _onChange = useCallback(
(model: MainModelConfig | null) => {
@ -35,7 +34,7 @@ const ParamMainModelSelect = () => {
);
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
data,
modelConfigs,
isLoading,
selectedModel,
onChange: _onChange,
@ -46,7 +45,13 @@ const ParamMainModelSelect = () => {
<InformationalPopover feature="paramModel">
<FormLabel>{t('modelManager.model')}</FormLabel>
</InformationalPopover>
<CustomSelect selectedItem={selectedItem} placeholder={placeholder} items={items} onChange={onChange} />
<CustomSelect
key={items.length}
selectedItem={selectedItem}
placeholder={placeholder}
items={items}
onChange={onChange}
/>
</FormControl>
);
};

View File

@ -7,7 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import { useVAEModels } from 'services/api/hooks/modelsByType';
import type { VAEModelConfig } from 'services/api/types';
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
@ -19,7 +19,7 @@ const ParamVAEModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { model, vae } = useAppSelector(selector);
const { data, isLoading } = useGetVaeModelsQuery();
const [modelConfigs, { isLoading }] = useVAEModels();
const getIsDisabled = useCallback(
(vae: VAEModelConfig): boolean => {
const isCompatible = model?.base === vae.base;
@ -35,7 +35,7 @@ const ParamVAEModelSelect = () => {
[dispatch]
);
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
modelConfigs,
onChange: _onChange,
selectedModel: vae,
isLoading,

View File

@ -11,13 +11,8 @@ import { t } from 'i18next';
import { flatten, map } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import {
loraModelsAdapterSelectors,
textualInversionModelsAdapterSelectors,
useGetLoRAModelsQuery,
useGetModelConfigQuery,
useGetTextualInversionModelsQuery,
} from 'services/api/endpoints/models';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import { useEmbeddingModels, useLoRAModels } from 'services/api/hooks/modelsByType';
import { isNonRefinerMainModelConfig } from 'services/api/types';
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
@ -33,8 +28,8 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
const { data: mainModelConfig, isLoading: isLoadingMainModelConfig } = useGetModelConfigQuery(
mainModel?.key ?? skipToken
);
const { data: loraModels, isLoading: isLoadingLoRAs } = useGetLoRAModelsQuery();
const { data: tiModels, isLoading: isLoadingTIs } = useGetTextualInversionModelsQuery();
const [loraModels, { isLoading: isLoadingLoRAs }] = useLoRAModels();
const [tiModels, { isLoading: isLoadingTIs }] = useEmbeddingModels();
const _onChange = useCallback<ComboboxOnChange>(
(v) => {
@ -52,8 +47,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
const _options: GroupBase<ComboboxOption>[] = [];
if (tiModels) {
const embeddingOptions = textualInversionModelsAdapterSelectors
.selectAll(tiModels)
const embeddingOptions = tiModels
.filter((ti) => ti.base === mainModelConfig?.base)
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
@ -66,8 +60,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
}
if (loraModels) {
const triggerPhraseOptions = loraModelsAdapterSelectors
.selectAll(loraModels)
const triggerPhraseOptions = loraModels
.filter((lora) => map(addedLoRAs, (l) => l.model.key).includes(lora.key))
.map((lora) => {
if (lora.trigger_phrases) {

View File

@ -7,8 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useRefinerModels } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
const selectModel = createMemoizedSelector(selectSdxlSlice, (sdxl) => sdxl.refinerModel);
@ -19,7 +18,7 @@ const ParamSDXLRefinerModelSelect = () => {
const dispatch = useAppDispatch();
const model = useAppSelector(selectModel);
const { t } = useTranslation();
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS);
const [modelConfigs, { isLoading }] = useRefinerModels();
const _onChange = useCallback(
(model: MainModelConfig | null) => {
if (!model) {
@ -31,7 +30,7 @@ const ParamSDXLRefinerModelSelect = () => {
[dispatch]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useModelCombobox({
modelEntities: data,
modelConfigs,
onChange: _onChange,
selectedModel: model,
isLoading,

View File

@ -1,28 +1,11 @@
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
import type { EntityState } from '@reduxjs/toolkit';
import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import queryString from 'query-string';
import {
ALL_BASE_MODELS,
NON_REFINER_BASE_MODELS,
NON_SDXL_MAIN_MODELS,
REFINER_BASE_MODELS,
SDXL_MAIN_MODELS,
} from 'services/api/constants';
import type { operations, paths } from 'services/api/schema';
import type {
AnyModelConfig,
BaseModelType,
ControlNetModelConfig,
IPAdapterModelConfig,
LoRAModelConfig,
MainModelConfig,
T2IAdapterModelConfig,
TextualInversionModelConfig,
VAEModelConfig,
} from 'services/api/types';
import type { AnyModelConfig } from 'services/api/types';
import type { ApiTagDescription, tagTypes } from '..';
import type { ApiTagDescription } from '..';
import { api, buildV2Url, LIST_TAG } from '..';
export type UpdateModelArg = {
@ -40,8 +23,9 @@ type UpdateModelImageResponse =
paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json'];
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
type GetModelConfigsResponse = NonNullable<
paths['/api/v2/models/']['get']['responses']['200']['content']['application/json']
>;
type DeleteModelArg = {
key: string;
@ -71,74 +55,16 @@ export type ScanFolderResponse =
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
type GetHuggingFaceModelsResponse =
paths['/api/v2/models/hugging_face']['get']['responses']['200']['content']['application/json'];
type GetByAttrsArg = operations['get_model_records_by_attrs']['parameters']['query'];
const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
const modelConfigsAdapter = createEntityAdapter<AnyModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
undefined,
getSelectorsOptions
);
const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions);
const anyModelConfigAdapter = createEntityAdapter<AnyModelConfig, string>({
selectId: (entity) => entity.key,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefined, getSelectorsOptions);
const buildProvidesTags =
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
(result: EntityState<TEntity, string> | undefined) => {
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: tagType,
id,
}))
);
}
return tags;
};
const buildTransformResponse =
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
(response: { models: T[] }) => {
return adapter.setAll(adapter.getInitialState(), response.models);
};
export const modelConfigsAdapterSelectors = modelConfigsAdapter.getSelectors(undefined, getSelectorsOptions);
/**
* Builds an endpoint URL for the models router
@ -159,9 +85,27 @@ export const modelsApi = api.injectEndpoints({
};
},
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertSingleModelConfig(data, dispatch);
});
try {
const { data } = await queryFulfilled;
// Update the individual model query caches
dispatch(modelsApi.util.upsertQueryData('getModelConfig', data.key, data));
const { base, name, type } = data;
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, data));
// Update the list query cache
dispatch(
modelsApi.util.updateQueryData('getModelConfigs', undefined, (draft) => {
modelConfigsAdapter.updateOne(draft, {
id: data.key,
changes: data,
});
})
);
} catch {
// no-op
}
},
}),
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
@ -258,6 +202,13 @@ export const modelsApi = api.injectEndpoints({
};
},
}),
getHuggingFaceModels: build.query<GetHuggingFaceModelsResponse, string>({
query: (hugging_face_repo) => {
return {
url: buildModelsUrl(`hugging_face?hugging_face_repo=${hugging_face_repo}`),
};
},
}),
listModelInstalls: build.query<ListModelInstallsResponse, void>({
query: () => {
return {
@ -284,80 +235,27 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['ModelInstalls'],
}),
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
query: (base_models) => {
const params: ListModelsArg = {
model_type: 'main',
base_models,
};
const query = queryString.stringify(params, { arrayFormat: 'none' });
return buildModelsUrl(`?${query}`);
getModelConfigs: build.query<EntityState<AnyModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl() }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [{ type: 'ModelConfig', id: LIST_TAG }];
if (result) {
const modelTags = result.ids.map((id) => ({ type: 'ModelConfig', id }) as const);
tags.push(...modelTags);
}
return tags;
},
keepUnusedDataFor: 60 * 60 * 1000 * 24, // 1 day (infinite)
transformResponse: (response: GetModelConfigsResponse) => {
return modelConfigsAdapter.setAll(modelConfigsAdapter.getInitialState(), response.models);
},
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'),
transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'),
transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'),
transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
modelConfigsAdapterSelectors.selectAll(data).forEach((modelConfig) => {
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
const { base, name, type } = modelConfig;
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
});
});
},
}),
@ -365,14 +263,8 @@ export const modelsApi = api.injectEndpoints({
});
export const {
useGetModelConfigsQuery,
useGetModelConfigQuery,
useGetMainModelsQuery,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
useDeleteModelsMutation,
useDeleteModelImageMutation,
useUpdateModelMutation,
@ -381,131 +273,8 @@ export const {
useConvertModelMutation,
useSyncModelsMutation,
useLazyScanFolderQuery,
useLazyGetHuggingFaceModelsQuery,
useListModelInstallsQuery,
useCancelModelInstallMutation,
usePruneCompletedModelInstallsMutation,
} = modelsApi;
const upsertModelConfigs = (
modelConfigs: EntityState<AnyModelConfig, string>,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
dispatch: ThunkDispatch<any, any, UnknownAction>
) => {
/**
* Once a list of models of a specific type is received, fetching any of those models individually is a waste of a
* network request. This function takes the received list of models and upserts them into the individual query caches
* for each model type.
*/
// Iterate over all the models and upsert them into the individual query caches for each model type.
anyModelConfigAdapterSelectors.selectAll(modelConfigs).forEach((modelConfig) => {
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
const { base, name, type } = modelConfig;
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
});
};
const upsertSingleModelConfig = (
modelConfig: AnyModelConfig,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
dispatch: ThunkDispatch<any, any, UnknownAction>
) => {
/**
* When a model is updated, the individual query caches for each model type need to be updated, as well as the list
* query caches of models of that type.
*/
// Update the individual model query caches.
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
const { base, name, type } = modelConfig;
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
// Update the list query caches for each model type.
if (modelConfig.type === 'main') {
[ALL_BASE_MODELS, NON_REFINER_BASE_MODELS, SDXL_MAIN_MODELS, NON_SDXL_MAIN_MODELS, REFINER_BASE_MODELS].forEach(
(queryArg) => {
dispatch(
modelsApi.util.updateQueryData('getMainModels', queryArg, (draft) => {
mainModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
}
);
return;
}
if (modelConfig.type === 'controlnet') {
dispatch(
modelsApi.util.updateQueryData('getControlNetModels', undefined, (draft) => {
controlNetModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 'embedding') {
dispatch(
modelsApi.util.updateQueryData('getTextualInversionModels', undefined, (draft) => {
textualInversionModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 'ip_adapter') {
dispatch(
modelsApi.util.updateQueryData('getIPAdapterModels', undefined, (draft) => {
ipAdapterModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 'lora') {
dispatch(
modelsApi.util.updateQueryData('getLoRAModels', undefined, (draft) => {
loraModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 't2i_adapter') {
dispatch(
modelsApi.util.updateQueryData('getT2IAdapterModels', undefined, (draft) => {
t2iAdapterModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
if (modelConfig.type === 'vae') {
dispatch(
modelsApi.util.updateQueryData('getVaeModels', undefined, (draft) => {
vaeModelsAdapter.updateOne(draft, {
id: modelConfig.key,
changes: modelConfig,
});
})
);
return;
}
};

View File

@ -0,0 +1,42 @@
import { EMPTY_ARRAY } from 'app/store/constants';
import { useMemo } from 'react';
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import {
isControlNetModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isNonSDXLMainModelConfig,
isRefinerMainModelModelConfig,
isSDXLMainModelModelConfig,
isT2IAdapterModelConfig,
isTIModelConfig,
isVAEModelConfig,
} from 'services/api/types';
const buildModelsHook =
<T extends AnyModelConfig>(typeGuard: (config: AnyModelConfig) => config is T) =>
() => {
const result = useGetModelConfigsQuery(undefined);
const modelConfigs = useMemo(() => {
if (!result.data) {
return EMPTY_ARRAY;
}
return modelConfigsAdapterSelectors.selectAll(result.data).filter(typeGuard);
}, [result]);
return [modelConfigs, result] as const;
};
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
export const useVAEModels = buildModelsHook(isVAEModelConfig);

View File

@ -1,12 +1,7 @@
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useRefinerModels } from 'services/api/hooks/modelsByType';
export const useIsRefinerAvailable = () => {
const { isRefinerAvailable } = useGetMainModelsQuery(REFINER_BASE_MODELS, {
selectFromResult: ({ data }) => ({
isRefinerAvailable: data ? data.ids.length > 0 : false,
}),
});
const [refinerModels] = useRefinerModels();
return isRefinerAvailable;
return Boolean(refinerModels.length);
};

File diff suppressed because one or more lines are too long

Some files were not shown because too many files have changed in this diff Show More