mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
108 Commits
v4.0.0rc1
...
bug-instal
Author | SHA1 | Date | |
---|---|---|---|
7d46e8430b | |||
b3f8f22998 | |||
6a75c5ba08 | |||
dcbb1ff894 | |||
5751455618 | |||
a2232c2e09 | |||
30da11998b | |||
1f000306f3 | |||
c778d74a42 | |||
97bcd54408 | |||
62d7e38030 | |||
9b6f3bded9 | |||
aec567179d | |||
28bfc1c935 | |||
39f62ac63c | |||
9d0952c2ef | |||
902e26507d | |||
b83427d7ce | |||
7387b0bdc9 | |||
7ea9cac9a3 | |||
ea5bc94b9c | |||
a1743647b7 | |||
a6d64f69e1 | |||
e74e78894f | |||
71a1740740 | |||
b79f2f337e | |||
a0420d1442 | |||
a17021ba0c | |||
faa1ffb06f | |||
8c04eec210 | |||
330e1354b4 | |||
21621eebf0 | |||
c24f2046e7 | |||
297408d67e | |||
0131e7d928 | |||
06ff105a1f | |||
bb8e6bbee6 | |||
328dc99f3a | |||
ef55077e84 | |||
ba3d8af161 | |||
b07b7af710 | |||
19d66d5ec7 | |||
ed20255abf | |||
fed1f983db | |||
a386544a1d | |||
0851de9090 | |||
1bd8e33f8c | |||
e3f29ed320 | |||
3fd824306c | |||
2584a950aa | |||
1adaf63253 | |||
b9f1a4bd65 | |||
731942dbed | |||
4117cea5bf | |||
21617f3bc1 | |||
9fcd67b5c0 | |||
a4be935458 | |||
eb6e6548ed | |||
8287fcf097 | |||
dd475e28ed | |||
24e741e2d1 | |||
e0bf9ce5c6 | |||
c66e8b395e | |||
4c417adc82 | |||
437a413ca3 | |||
4492bedd19 | |||
db12ce95a8 | |||
ee3a1a95ef | |||
4bb5aba70e | |||
cd55c23713 | |||
1d2743af1b | |||
99d2099ccd | |||
b64a693f16 | |||
9d523a3094 | |||
af660163ca | |||
7e4b462fca | |||
4468dd6948 | |||
4f39e248dd | |||
44b3e5d43f | |||
8894a9e48a | |||
c73f58e486 | |||
614fece147 | |||
8ef8082d65 | |||
d93d4afbb7 | |||
01207a2fa5 | |||
d0800c4888 | |||
2a300ecada | |||
90340a39c7 | |||
ee77abb4fe | |||
004bca5c42 | |||
5ad048a161 | |||
6369ccd05e | |||
3a5314f1ca | |||
4c0896e436 | |||
f7cd3cf1f4 | |||
efea1a8a7d | |||
d0d695c020 | |||
2a648da557 | |||
54f1a1f952 | |||
8d2a4db902 | |||
7b393656de | |||
43948e0758 | |||
cc03fcbcb6 | |||
d1e445fa49 | |||
adba8489f2 | |||
d919022ba5 | |||
e076898798 | |||
9f19b766a4 |
28
.github/workflows/frontend-checks.yml
vendored
28
.github/workflows/frontend-checks.yml
vendored
@ -1,7 +1,7 @@
|
|||||||
# Runs frontend code quality checks.
|
# Runs frontend code quality checks.
|
||||||
#
|
#
|
||||||
# Checks for changes to frontend files before running the checks.
|
# Checks for changes to frontend files before running the checks.
|
||||||
# When manually triggered or when called from another workflow, always runs the checks.
|
# If always_run is true, always runs the checks.
|
||||||
|
|
||||||
name: 'frontend checks'
|
name: 'frontend checks'
|
||||||
|
|
||||||
@ -16,7 +16,19 @@ on:
|
|||||||
- 'synchronize'
|
- 'synchronize'
|
||||||
merge_group:
|
merge_group:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
always_run:
|
||||||
|
description: 'Always run the checks'
|
||||||
|
required: true
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
workflow_call:
|
workflow_call:
|
||||||
|
inputs:
|
||||||
|
always_run:
|
||||||
|
description: 'Always run the checks'
|
||||||
|
required: true
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
|
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
@ -30,7 +42,7 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: check for changed frontend files
|
- name: check for changed frontend files
|
||||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
if: ${{ inputs.always_run != true }}
|
||||||
id: changed-files
|
id: changed-files
|
||||||
uses: tj-actions/changed-files@v42
|
uses: tj-actions/changed-files@v42
|
||||||
with:
|
with:
|
||||||
@ -39,30 +51,30 @@ jobs:
|
|||||||
- 'invokeai/frontend/web/**'
|
- 'invokeai/frontend/web/**'
|
||||||
|
|
||||||
- name: install dependencies
|
- name: install dependencies
|
||||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||||
uses: ./.github/actions/install-frontend-deps
|
uses: ./.github/actions/install-frontend-deps
|
||||||
|
|
||||||
- name: tsc
|
- name: tsc
|
||||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||||
run: 'pnpm lint:tsc'
|
run: 'pnpm lint:tsc'
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|
||||||
- name: dpdm
|
- name: dpdm
|
||||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||||
run: 'pnpm lint:dpdm'
|
run: 'pnpm lint:dpdm'
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|
||||||
- name: eslint
|
- name: eslint
|
||||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||||
run: 'pnpm lint:eslint'
|
run: 'pnpm lint:eslint'
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|
||||||
- name: prettier
|
- name: prettier
|
||||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||||
run: 'pnpm lint:prettier'
|
run: 'pnpm lint:prettier'
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|
||||||
- name: knip
|
- name: knip
|
||||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||||
run: 'pnpm lint:knip'
|
run: 'pnpm lint:knip'
|
||||||
shell: bash
|
shell: bash
|
||||||
|
20
.github/workflows/frontend-tests.yml
vendored
20
.github/workflows/frontend-tests.yml
vendored
@ -1,7 +1,7 @@
|
|||||||
# Runs frontend tests.
|
# Runs frontend tests.
|
||||||
#
|
#
|
||||||
# Checks for changes to frontend files before running the tests.
|
# Checks for changes to frontend files before running the tests.
|
||||||
# When manually triggered or called from another workflow, always runs the tests.
|
# If always_run is true, always runs the tests.
|
||||||
|
|
||||||
name: 'frontend tests'
|
name: 'frontend tests'
|
||||||
|
|
||||||
@ -16,7 +16,19 @@ on:
|
|||||||
- 'synchronize'
|
- 'synchronize'
|
||||||
merge_group:
|
merge_group:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
always_run:
|
||||||
|
description: 'Always run the tests'
|
||||||
|
required: true
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
workflow_call:
|
workflow_call:
|
||||||
|
inputs:
|
||||||
|
always_run:
|
||||||
|
description: 'Always run the tests'
|
||||||
|
required: true
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
|
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
@ -30,7 +42,7 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: check for changed frontend files
|
- name: check for changed frontend files
|
||||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
if: ${{ inputs.always_run != true }}
|
||||||
id: changed-files
|
id: changed-files
|
||||||
uses: tj-actions/changed-files@v42
|
uses: tj-actions/changed-files@v42
|
||||||
with:
|
with:
|
||||||
@ -39,10 +51,10 @@ jobs:
|
|||||||
- 'invokeai/frontend/web/**'
|
- 'invokeai/frontend/web/**'
|
||||||
|
|
||||||
- name: install dependencies
|
- name: install dependencies
|
||||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||||
uses: ./.github/actions/install-frontend-deps
|
uses: ./.github/actions/install-frontend-deps
|
||||||
|
|
||||||
- name: vitest
|
- name: vitest
|
||||||
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
|
||||||
run: 'pnpm test:no-watch'
|
run: 'pnpm test:no-watch'
|
||||||
shell: bash
|
shell: bash
|
||||||
|
24
.github/workflows/python-checks.yml
vendored
24
.github/workflows/python-checks.yml
vendored
@ -1,7 +1,7 @@
|
|||||||
# Runs python code quality checks.
|
# Runs python code quality checks.
|
||||||
#
|
#
|
||||||
# Checks for changes to python files before running the checks.
|
# Checks for changes to python files before running the checks.
|
||||||
# When manually triggered or called from another workflow, always runs the tests.
|
# If always_run is true, always runs the checks.
|
||||||
#
|
#
|
||||||
# TODO: Add mypy or pyright to the checks.
|
# TODO: Add mypy or pyright to the checks.
|
||||||
|
|
||||||
@ -18,7 +18,19 @@ on:
|
|||||||
- 'synchronize'
|
- 'synchronize'
|
||||||
merge_group:
|
merge_group:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
always_run:
|
||||||
|
description: 'Always run the checks'
|
||||||
|
required: true
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
workflow_call:
|
workflow_call:
|
||||||
|
inputs:
|
||||||
|
always_run:
|
||||||
|
description: 'Always run the checks'
|
||||||
|
required: true
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
python-checks:
|
python-checks:
|
||||||
@ -29,7 +41,7 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: check for changed python files
|
- name: check for changed python files
|
||||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
if: ${{ inputs.always_run != true }}
|
||||||
id: changed-files
|
id: changed-files
|
||||||
uses: tj-actions/changed-files@v42
|
uses: tj-actions/changed-files@v42
|
||||||
with:
|
with:
|
||||||
@ -41,7 +53,7 @@ jobs:
|
|||||||
- 'tests/**'
|
- 'tests/**'
|
||||||
|
|
||||||
- name: setup python
|
- name: setup python
|
||||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
@ -49,16 +61,16 @@ jobs:
|
|||||||
cache-dependency-path: pyproject.toml
|
cache-dependency-path: pyproject.toml
|
||||||
|
|
||||||
- name: install ruff
|
- name: install ruff
|
||||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||||
run: pip install ruff
|
run: pip install ruff
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|
||||||
- name: ruff check
|
- name: ruff check
|
||||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||||
run: ruff check --output-format=github .
|
run: ruff check --output-format=github .
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|
||||||
- name: ruff format
|
- name: ruff format
|
||||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||||
run: ruff format --check .
|
run: ruff format --check .
|
||||||
shell: bash
|
shell: bash
|
||||||
|
23
.github/workflows/python-tests.yml
vendored
23
.github/workflows/python-tests.yml
vendored
@ -1,7 +1,7 @@
|
|||||||
# Runs python tests on a matrix of python versions and platforms.
|
# Runs python tests on a matrix of python versions and platforms.
|
||||||
#
|
#
|
||||||
# Checks for changes to python files before running the tests.
|
# Checks for changes to python files before running the tests.
|
||||||
# When manually triggered or called from another workflow, always runs the tests.
|
# If always_run is true, always runs the tests.
|
||||||
|
|
||||||
name: 'python tests'
|
name: 'python tests'
|
||||||
|
|
||||||
@ -9,6 +9,7 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- 'main'
|
- 'main'
|
||||||
|
- 'bug-install-job-running-multiple-times'
|
||||||
pull_request:
|
pull_request:
|
||||||
types:
|
types:
|
||||||
- 'ready_for_review'
|
- 'ready_for_review'
|
||||||
@ -16,7 +17,19 @@ on:
|
|||||||
- 'synchronize'
|
- 'synchronize'
|
||||||
merge_group:
|
merge_group:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
always_run:
|
||||||
|
description: 'Always run the tests'
|
||||||
|
required: true
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
workflow_call:
|
workflow_call:
|
||||||
|
inputs:
|
||||||
|
always_run:
|
||||||
|
description: 'Always run the tests'
|
||||||
|
required: true
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||||
@ -63,7 +76,7 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: check for changed python files
|
- name: check for changed python files
|
||||||
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
|
if: ${{ inputs.always_run != true }}
|
||||||
id: changed-files
|
id: changed-files
|
||||||
uses: tj-actions/changed-files@v42
|
uses: tj-actions/changed-files@v42
|
||||||
with:
|
with:
|
||||||
@ -75,7 +88,7 @@ jobs:
|
|||||||
- 'tests/**'
|
- 'tests/**'
|
||||||
|
|
||||||
- name: setup python
|
- name: setup python
|
||||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
@ -83,12 +96,12 @@ jobs:
|
|||||||
cache-dependency-path: pyproject.toml
|
cache-dependency-path: pyproject.toml
|
||||||
|
|
||||||
- name: install dependencies
|
- name: install dependencies
|
||||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||||
env:
|
env:
|
||||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||||
run: >
|
run: >
|
||||||
pip3 install --editable=".[test]"
|
pip3 install --editable=".[test]"
|
||||||
|
|
||||||
- name: run pytest
|
- name: run pytest
|
||||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
|
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||||
run: pytest
|
run: pytest
|
||||||
|
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
@ -30,15 +30,23 @@ jobs:
|
|||||||
|
|
||||||
frontend-checks:
|
frontend-checks:
|
||||||
uses: ./.github/workflows/frontend-checks.yml
|
uses: ./.github/workflows/frontend-checks.yml
|
||||||
|
with:
|
||||||
|
always_run: true
|
||||||
|
|
||||||
frontend-tests:
|
frontend-tests:
|
||||||
uses: ./.github/workflows/frontend-tests.yml
|
uses: ./.github/workflows/frontend-tests.yml
|
||||||
|
with:
|
||||||
|
always_run: true
|
||||||
|
|
||||||
python-checks:
|
python-checks:
|
||||||
uses: ./.github/workflows/python-checks.yml
|
uses: ./.github/workflows/python-checks.yml
|
||||||
|
with:
|
||||||
|
always_run: true
|
||||||
|
|
||||||
python-tests:
|
python-tests:
|
||||||
uses: ./.github/workflows/python-tests.yml
|
uses: ./.github/workflows/python-tests.yml
|
||||||
|
with:
|
||||||
|
always_run: true
|
||||||
|
|
||||||
build:
|
build:
|
||||||
uses: ./.github/workflows/build-installer.yml
|
uses: ./.github/workflows/build-installer.yml
|
||||||
|
133
docs/contributing/frontend/OVERVIEW.md
Normal file
133
docs/contributing/frontend/OVERVIEW.md
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
# Invoke UI
|
||||||
|
|
||||||
|
Invoke's UI is made possible by many contributors and open-source libraries. Thank you!
|
||||||
|
|
||||||
|
## Dev environment
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
1. Install [node] and [pnpm].
|
||||||
|
1. Run `pnpm i` to install all packages.
|
||||||
|
|
||||||
|
#### Run in dev mode
|
||||||
|
|
||||||
|
1. From `invokeai/frontend/web/`, run `pnpm dev`.
|
||||||
|
1. From repo root, run `python scripts/invokeai-web.py`.
|
||||||
|
1. Point your browser to the dev server address, e.g. <http://localhost:5173/>
|
||||||
|
|
||||||
|
### Package scripts
|
||||||
|
|
||||||
|
- `dev`: run the frontend in dev mode, enabling hot reloading
|
||||||
|
- `build`: run all checks (madge, eslint, prettier, tsc) and then build the frontend
|
||||||
|
- `typegen`: generate types from the OpenAPI schema (see [Type generation])
|
||||||
|
- `lint:dpdm`: check circular dependencies
|
||||||
|
- `lint:eslint`: check code quality
|
||||||
|
- `lint:prettier`: check code formatting
|
||||||
|
- `lint:tsc`: check type issues
|
||||||
|
- `lint:knip`: check for unused exports or objects (failures here are just suggestions, not hard fails)
|
||||||
|
- `lint`: run all checks concurrently
|
||||||
|
- `fix`: run `eslint` and `prettier`, fixing fixable issues
|
||||||
|
|
||||||
|
### Type generation
|
||||||
|
|
||||||
|
We use [openapi-typescript] to generate types from the app's OpenAPI schema.
|
||||||
|
|
||||||
|
The generated types are committed to the repo in [schema.ts].
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# from the repo root, start the server
|
||||||
|
python scripts/invokeai-web.py
|
||||||
|
# from invokeai/frontend/web/, run the script
|
||||||
|
pnpm typegen
|
||||||
|
```
|
||||||
|
|
||||||
|
### Localization
|
||||||
|
|
||||||
|
We use [i18next] for localization, but translation to languages other than English happens on our [Weblate] project.
|
||||||
|
|
||||||
|
Only the English source strings should be changed on this repo.
|
||||||
|
|
||||||
|
### VSCode
|
||||||
|
|
||||||
|
#### Example debugger config
|
||||||
|
|
||||||
|
```jsonc
|
||||||
|
{
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"type": "chrome",
|
||||||
|
"request": "launch",
|
||||||
|
"name": "Invoke UI",
|
||||||
|
"url": "http://localhost:5173",
|
||||||
|
"webRoot": "${workspaceFolder}/invokeai/frontend/web"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Remote dev
|
||||||
|
|
||||||
|
We've noticed an intermittent timeout issue with the VSCode remote dev port forwarding.
|
||||||
|
|
||||||
|
We suggest disabling the editor's port forwarding feature and doing it manually via SSH:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
ssh -L 9090:localhost:9090 -L 5173:localhost:5173 user@host
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contributing Guidelines
|
||||||
|
|
||||||
|
Thanks for your interest in contributing to the Invoke Web UI!
|
||||||
|
|
||||||
|
Please follow these guidelines when contributing.
|
||||||
|
|
||||||
|
### Check in before investing your time
|
||||||
|
|
||||||
|
Please check in before you invest your time on anything besides a trivial fix, in case it conflicts with ongoing work or isn't aligned with the vision for the app.
|
||||||
|
|
||||||
|
If a feature request or issue doesn't already exist for the thing you want to work on, please create one.
|
||||||
|
|
||||||
|
Ping `@psychedelicious` on [discord] in the `#frontend-dev` channel or in the feature request / issue you want to work on - we're happy to chat.
|
||||||
|
|
||||||
|
### Code conventions
|
||||||
|
|
||||||
|
- This is a fairly complex app with a deep component tree. Please use memoization (`useCallback`, `useMemo`, `memo`) with enthusiasm.
|
||||||
|
- If you need to add some global, ephemeral state, please use [nanostores] if possible.
|
||||||
|
- Be careful with your redux selectors. If they need to be parameterized, consider creating them inside a `useMemo`.
|
||||||
|
- Feel free to use `lodash` (via `lodash-es`) to make the intent of your code clear.
|
||||||
|
- Please add comments describing the "why", not the "how" (unless it is really arcane).
|
||||||
|
|
||||||
|
### Commit format
|
||||||
|
|
||||||
|
Please use the [conventional commits] spec for the web UI, with a scope of "ui":
|
||||||
|
|
||||||
|
- `chore(ui): bump deps`
|
||||||
|
- `chore(ui): lint`
|
||||||
|
- `feat(ui): add some cool new feature`
|
||||||
|
- `fix(ui): fix some bug`
|
||||||
|
|
||||||
|
### Submitting a PR
|
||||||
|
|
||||||
|
- Ensure your branch is tidy. Use an interactive rebase to clean up the commit history and reword the commit messages if they are not descriptive.
|
||||||
|
- Run `pnpm lint`. Some issues are auto-fixable with `pnpm fix`.
|
||||||
|
- Fill out the PR form when creating the PR.
|
||||||
|
- It doesn't need to be super detailed, but a screenshot or video is nice if you changed something visually.
|
||||||
|
- If a section isn't relevant, delete it. There are no UI tests at this time.
|
||||||
|
|
||||||
|
## Other docs
|
||||||
|
|
||||||
|
- [Workflows - Design and Implementation]
|
||||||
|
- [State Management]
|
||||||
|
|
||||||
|
[node]: https://nodejs.org/en/download/
|
||||||
|
[pnpm]: https://github.com/pnpm/pnpm
|
||||||
|
[discord]: https://discord.gg/ZmtBAhwWhy
|
||||||
|
[i18next]: https://github.com/i18next/react-i18next
|
||||||
|
[Weblate]: https://hosted.weblate.org/engage/invokeai/
|
||||||
|
[openapi-typescript]: https://github.com/drwpow/openapi-typescript
|
||||||
|
[Type generation]: #type-generation
|
||||||
|
[schema.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/services/api/schema.ts
|
||||||
|
[conventional commits]: https://www.conventionalcommits.org/en/v1.0.0/
|
||||||
|
[Workflows - Design and Implementation]: ./WORKFLOWS.md
|
||||||
|
[State Management]: ./STATE_MGMT.md
|
@ -1,40 +1,5 @@
|
|||||||
# Workflows - Design and Implementation
|
# Workflows - Design and Implementation
|
||||||
|
|
||||||
<!-- @import "[TOC]" {cmd="toc" depthFrom=1 depthTo=6 orderedList=false} -->
|
|
||||||
|
|
||||||
<!-- code_chunk_output -->
|
|
||||||
|
|
||||||
- [Workflows - Design and Implementation](#workflows---design-and-implementation)
|
|
||||||
- [Design](#design)
|
|
||||||
- [Linear UI](#linear-ui)
|
|
||||||
- [Workflow Editor](#workflow-editor)
|
|
||||||
- [Workflows](#workflows)
|
|
||||||
- [Workflow -> reactflow state -> InvokeAI graph](#workflow---reactflow-state---invokeai-graph)
|
|
||||||
- [Nodes vs Invocations](#nodes-vs-invocations)
|
|
||||||
- [Workflow Linear View](#workflow-linear-view)
|
|
||||||
- [OpenAPI Schema](#openapi-schema)
|
|
||||||
- [Field Instances and Templates](#field-instances-and-templates)
|
|
||||||
- [Stateful vs Stateless Fields](#stateful-vs-stateless-fields)
|
|
||||||
- [Collection and Polymorphic Fields](#collection-and-polymorphic-fields)
|
|
||||||
- [Implementation](#implementation)
|
|
||||||
- [zod Schemas and Types](#zod-schemas-and-types)
|
|
||||||
- [OpenAPI Schema Parsing](#openapi-schema-parsing)
|
|
||||||
- [Parsing Field Types](#parsing-field-types)
|
|
||||||
- [Primitive Types](#primitive-types)
|
|
||||||
- [Complex Types](#complex-types)
|
|
||||||
- [Collection Types](#collection-types)
|
|
||||||
- [Collection or Scalar Types](#collection-or-scalar-types)
|
|
||||||
- [Optional Fields](#optional-fields)
|
|
||||||
- [Building Field Input Templates](#building-field-input-templates)
|
|
||||||
- [Building Field Output Templates](#building-field-output-templates)
|
|
||||||
- [Managing reactflow State](#managing-reactflow-state)
|
|
||||||
- [Building Nodes and Edges](#building-nodes-and-edges)
|
|
||||||
- [Building a Workflow](#building-a-workflow)
|
|
||||||
- [Loading a Workflow](#loading-a-workflow)
|
|
||||||
- [Workflow Migrations](#workflow-migrations)
|
|
||||||
|
|
||||||
<!-- /code_chunk_output -->
|
|
||||||
|
|
||||||
> This document describes, at a high level, the design and implementation of workflows in the InvokeAI frontend. There are a substantial number of implementation details not included, but which are hopefully clear from the code.
|
> This document describes, at a high level, the design and implementation of workflows in the InvokeAI frontend. There are a substantial number of implementation details not included, but which are hopefully clear from the code.
|
||||||
|
|
||||||
InvokeAI's backend uses graphs, composed of **nodes** and **edges**, to process data and generate images.
|
InvokeAI's backend uses graphs, composed of **nodes** and **edges**, to process data and generate images.
|
||||||
@ -152,13 +117,13 @@ Stateless fields do not store their value in the node, so their field instances
|
|||||||
|
|
||||||
"Custom" fields will always be treated as stateless fields.
|
"Custom" fields will always be treated as stateless fields.
|
||||||
|
|
||||||
##### Collection and Polymorphic Fields
|
##### Collection and Scalar Fields
|
||||||
|
|
||||||
Field types have a name and two flags which may identify it as a **collection** or **polymorphic** field.
|
Field types have a name and two flags which may identify it as a **collection** or **collection or scalar** field.
|
||||||
|
|
||||||
If a field is annotated in python as a list, its field type is parsed and flagged as a collection type (e.g. `list[int]`).
|
If a field is annotated in python as a list, its field type is parsed and flagged as a **collection** type (e.g. `list[int]`).
|
||||||
|
|
||||||
If it is annotated as a union of a type and list, the type will be flagged as a polymorphic type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
|
If it is annotated as a union of a type and list, the type will be flagged as a **collection or scalar** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
|
||||||
|
|
||||||
## Implementation
|
## Implementation
|
||||||
|
|
||||||
@ -338,13 +303,13 @@ Migration logic is in [migrations.ts].
|
|||||||
[reactflow]: https://github.com/xyflow/xyflow 'reactflow'
|
[reactflow]: https://github.com/xyflow/xyflow 'reactflow'
|
||||||
[reactflow-concepts]: https://reactflow.dev/learn/concepts/terms-and-definitions
|
[reactflow-concepts]: https://reactflow.dev/learn/concepts/terms-and-definitions
|
||||||
[reactflow-events]: https://reactflow.dev/api-reference/react-flow#event-handlers
|
[reactflow-events]: https://reactflow.dev/api-reference/react-flow#event-handlers
|
||||||
[buildWorkflow.ts]: ../src/features/nodes/util/workflow/buildWorkflow.ts
|
[buildWorkflow.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts
|
||||||
[nodesSlice.ts]: ../src/features/nodes/store/nodesSlice.ts
|
[nodesSlice.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
|
||||||
[buildLinearTextToImageGraph.ts]: ../src/features/nodes/util/graph/buildLinearTextToImageGraph.ts
|
[buildLinearTextToImageGraph.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts
|
||||||
[buildNodesGraph.ts]: ../src/features/nodes/util/graph/buildNodesGraph.ts
|
[buildNodesGraph.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts
|
||||||
[buildInvocationNode.ts]: ../src/features/nodes/util/node/buildInvocationNode.ts
|
[buildInvocationNode.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts
|
||||||
[validateWorkflow.ts]: ../src/features/nodes/util/workflow/validateWorkflow.ts
|
[validateWorkflow.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts
|
||||||
[migrations.ts]: ../src/features/nodes/util/workflow/migrations.ts
|
[migrations.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts
|
||||||
[parseSchema.ts]: ../src/features/nodes/util/schema/parseSchema.ts
|
[parseSchema.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts
|
||||||
[buildFieldInputTemplate.ts]: ../src/features/nodes/util/schema/buildFieldInputTemplate.ts
|
[buildFieldInputTemplate.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts
|
||||||
[buildFieldOutputTemplate.ts]: ../src/features/nodes/util/schema/buildFieldOutputTemplate.ts
|
[buildFieldOutputTemplate.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldOutputTemplate.ts
|
@ -184,16 +184,26 @@ The provided token will be added as a `Bearer` token to the network requests to
|
|||||||
|
|
||||||
### Model Hashing
|
### Model Hashing
|
||||||
|
|
||||||
Models are hashed during installation with the `BLAKE3` algorithm, providing a stable identifier for models across all platforms.
|
Models are hashed during installation, providing a stable identifier for models across all platforms. The default algorithm is `blake3`, with a multi-threaded implementation.
|
||||||
|
|
||||||
Model hashing is a one-time operation, but it may take a couple minutes to hash a large model collection. You may opt out of model hashing and instead have a random UUID assigned instead:
|
If your models are stored on a spinning hard drive, we suggest using `blake3_single`, the single-threaded implementation. The hashes are the same, but it's much faster on spinning disks.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
InvokeAI:
|
InvokeAI:
|
||||||
Model Install:
|
Model Install:
|
||||||
skip_model_hash: true
|
hashing_algorithm: blake3_single
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Model hashing is a one-time operation, but it may take a couple minutes to hash a large model collection. You may opt out of model hashing entirely by setting the algorithm to `random`.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
InvokeAI:
|
||||||
|
Model Install:
|
||||||
|
hashing_algorithm: random
|
||||||
|
```
|
||||||
|
|
||||||
|
Most common algorithms are supported, like `md5`, `sha256`, and `sha512`. These are typically much, much slower than `blake3`.
|
||||||
|
|
||||||
### Paths
|
### Paths
|
||||||
|
|
||||||
These options set the paths of various directories and files used by
|
These options set the paths of various directories and files used by
|
||||||
|
@ -22,6 +22,24 @@ class MyInvocation(BaseInvocation):
|
|||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The full API is documented below.
|
||||||
|
|
||||||
|
## Invocation Mixins
|
||||||
|
|
||||||
|
Two important mixins are provided to facilitate working with metadata and gallery boards.
|
||||||
|
|
||||||
|
### `WithMetadata`
|
||||||
|
|
||||||
|
Inherit from this class (in addition to `BaseInvocation`) to add a `metadata` input to your node. When you do this, you can access the metadata dict from `self.metadata` in the `invoke()` function.
|
||||||
|
|
||||||
|
The dict will be populated via the node's input, and you can add any metadata you'd like to it. When you call `context.images.save()`, if the metadata dict has any data, it be automatically embedded in the image.
|
||||||
|
|
||||||
|
### `WithBoard`
|
||||||
|
|
||||||
|
Inherit from this class (in addition to `BaseInvocation`) to add a `board` input to your node. This renders as a drop-down to select a board. The user's selection will be accessible from `self.board` in the `invoke()` function.
|
||||||
|
|
||||||
|
When you call `context.images.save()`, if a board was selected, the image will added to that board as it is saved.
|
||||||
|
|
||||||
<!-- prettier-ignore-start -->
|
<!-- prettier-ignore-start -->
|
||||||
::: invokeai.app.services.shared.invocation_context.InvocationContext
|
::: invokeai.app.services.shared.invocation_context.InvocationContext
|
||||||
options:
|
options:
|
||||||
|
@ -11,7 +11,7 @@ from fastapi import Body, Path, Query, Response, UploadFile
|
|||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
@ -29,6 +29,8 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||||
|
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
@ -246,6 +248,40 @@ async def scan_for_models(
|
|||||||
return scan_results
|
return scan_results
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceModels(BaseModel):
|
||||||
|
urls: List[AnyHttpUrl] | None = Field(description="URLs for all checkpoint format models in the metadata")
|
||||||
|
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model")
|
||||||
|
|
||||||
|
|
||||||
|
@model_manager_router.get(
|
||||||
|
"/hugging_face",
|
||||||
|
operation_id="get_hugging_face_models",
|
||||||
|
responses={
|
||||||
|
200: {"description": "Hugging Face repo scanned successfully"},
|
||||||
|
400: {"description": "Invalid hugging face repo"},
|
||||||
|
},
|
||||||
|
status_code=200,
|
||||||
|
response_model=HuggingFaceModels,
|
||||||
|
)
|
||||||
|
async def get_hugging_face_models(
|
||||||
|
hugging_face_repo: str = Query(description="Hugging face repo to search for models", default=None),
|
||||||
|
) -> HuggingFaceModels:
|
||||||
|
try:
|
||||||
|
metadata = HuggingFaceMetadataFetch().from_id(hugging_face_repo)
|
||||||
|
except UnknownMetadataException:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="No HuggingFace repository found",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
|
|
||||||
|
return HuggingFaceModels(
|
||||||
|
urls=metadata.ckpt_urls,
|
||||||
|
is_diffusers=metadata.is_diffusers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.patch(
|
@model_manager_router.patch(
|
||||||
"/i/{key}",
|
"/i/{key}",
|
||||||
operation_id="update_model_record",
|
operation_id="update_model_record",
|
||||||
|
@ -574,7 +574,7 @@ DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
|
|||||||
title="Depth Anything Processor",
|
title="Depth Anything Processor",
|
||||||
tags=["controlnet", "depth", "depth anything"],
|
tags=["controlnet", "depth", "depth anything"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Generates a depth map based on the Depth Anything algorithm"""
|
"""Generates a depth map based on the Depth Anything algorithm"""
|
||||||
@ -583,13 +583,12 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
default="small", description="The size of the depth model to use"
|
default="small", description="The size of the depth model to use"
|
||||||
)
|
)
|
||||||
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
|
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
|
||||||
offload: bool = InputField(default=False)
|
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image):
|
def run_processor(self, image: Image.Image):
|
||||||
depth_anything_detector = DepthAnythingDetector()
|
depth_anything_detector = DepthAnythingDetector()
|
||||||
depth_anything_detector.load_model(model_size=self.model_size)
|
depth_anything_detector.load_model(model_size=self.model_size)
|
||||||
|
|
||||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload)
|
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ from invokeai.app.invocations.model import ModelIdentifierField
|
|||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig, ModelType
|
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, IPAdapterConfig, ModelType
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterField(BaseModel):
|
class IPAdapterField(BaseModel):
|
||||||
@ -89,17 +89,32 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
assert isinstance(ip_adapter_info, IPAdapterConfig)
|
assert isinstance(ip_adapter_info, IPAdapterConfig)
|
||||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||||
image_encoder_models = context.models.search_by_attrs(
|
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
||||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
|
||||||
)
|
|
||||||
assert len(image_encoder_models) == 1
|
|
||||||
return IPAdapterOutput(
|
return IPAdapterOutput(
|
||||||
ip_adapter=IPAdapterField(
|
ip_adapter=IPAdapterField(
|
||||||
image=self.image,
|
image=self.image,
|
||||||
ip_adapter_model=self.ip_adapter_model,
|
ip_adapter_model=self.ip_adapter_model,
|
||||||
image_encoder_model=ModelIdentifierField.from_config(image_encoder_models[0]),
|
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
begin_step_percent=self.begin_step_percent,
|
begin_step_percent=self.begin_step_percent,
|
||||||
end_step_percent=self.end_step_percent,
|
end_step_percent=self.end_step_percent,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
|
||||||
|
found = False
|
||||||
|
while not found:
|
||||||
|
image_encoder_models = context.models.search_by_attrs(
|
||||||
|
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||||
|
)
|
||||||
|
found = len(image_encoder_models) > 0
|
||||||
|
if not found:
|
||||||
|
context.logger.warning(
|
||||||
|
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed."
|
||||||
|
)
|
||||||
|
context.logger.warning("Downloading and installing now. This may take a while.")
|
||||||
|
installer = context._services.model_manager.install
|
||||||
|
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
|
||||||
|
installer.wait_for_job(job, timeout=600) # wait up to 10 minutes - then raise a TimeoutException
|
||||||
|
assert len(image_encoder_models) == 1
|
||||||
|
return image_encoder_models[0]
|
||||||
|
@ -837,14 +837,14 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
latents = context.tensors.load(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
vae_info = context.models.load(self.vae.vae)
|
vae_info = context.models.load(self.vae.vae)
|
||||||
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL))
|
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
|
||||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||||
assert isinstance(vae, torch.nn.Module)
|
assert isinstance(vae, torch.nn.Module)
|
||||||
latents = latents.to(vae.device)
|
latents = latents.to(vae.device)
|
||||||
if self.fp32:
|
if self.fp32:
|
||||||
vae.to(dtype=torch.float32)
|
vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
use_torch_2_0_or_xformers = isinstance(
|
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||||
vae.decoder.mid_block.attentions[0].processor,
|
vae.decoder.mid_block.attentions[0].processor,
|
||||||
(
|
(
|
||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
@ -1018,7 +1018,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
if upcast:
|
if upcast:
|
||||||
vae.to(dtype=torch.float32)
|
vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
use_torch_2_0_or_xformers = isinstance(
|
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||||
vae.decoder.mid_block.attentions[0].processor,
|
vae.decoder.mid_block.attentions[0].processor,
|
||||||
(
|
(
|
||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
|
@ -20,8 +20,8 @@ from invokeai.app.invocations.fields import (
|
|||||||
OutputField,
|
OutputField,
|
||||||
UIType,
|
UIType,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
|
||||||
|
|
||||||
from ...version import __version__
|
from ...version import __version__
|
||||||
|
|
||||||
@ -31,20 +31,10 @@ class MetadataItemField(BaseModel):
|
|||||||
value: Any = Field(description=FieldDescriptions.metadata_item_value)
|
value: Any = Field(description=FieldDescriptions.metadata_item_value)
|
||||||
|
|
||||||
|
|
||||||
class ModelMetadataField(BaseModel):
|
|
||||||
"""Model Metadata Field"""
|
|
||||||
|
|
||||||
key: str
|
|
||||||
hash: str
|
|
||||||
name: str
|
|
||||||
base: BaseModelType
|
|
||||||
type: ModelType
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAMetadataField(BaseModel):
|
class LoRAMetadataField(BaseModel):
|
||||||
"""LoRA Metadata Field"""
|
"""LoRA Metadata Field"""
|
||||||
|
|
||||||
model: ModelMetadataField = Field(description=FieldDescriptions.lora_model)
|
model: ModelIdentifierField = Field(description=FieldDescriptions.lora_model)
|
||||||
weight: float = Field(description=FieldDescriptions.lora_weight)
|
weight: float = Field(description=FieldDescriptions.lora_weight)
|
||||||
|
|
||||||
|
|
||||||
@ -52,19 +42,16 @@ class IPAdapterMetadataField(BaseModel):
|
|||||||
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
|
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
|
||||||
|
|
||||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||||
ip_adapter_model: ModelMetadataField = Field(
|
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
|
||||||
description="The IP-Adapter model.",
|
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
|
||||||
)
|
|
||||||
weight: Union[float, list[float]] = Field(
|
|
||||||
description="The weight given to the IP-Adapter",
|
|
||||||
)
|
|
||||||
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
||||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterMetadataField(BaseModel):
|
class T2IAdapterMetadataField(BaseModel):
|
||||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
image: ImageField = Field(description="The control image.")
|
||||||
t2i_adapter_model: ModelMetadataField = Field(description="The T2I-Adapter model to use.")
|
processed_image: Optional[ImageField] = Field(default=None, description="The control image, after processing.")
|
||||||
|
t2i_adapter_model: ModelIdentifierField = Field(description="The T2I-Adapter model to use.")
|
||||||
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||||
begin_step_percent: float = Field(
|
begin_step_percent: float = Field(
|
||||||
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||||
@ -77,7 +64,8 @@ class T2IAdapterMetadataField(BaseModel):
|
|||||||
|
|
||||||
class ControlNetMetadataField(BaseModel):
|
class ControlNetMetadataField(BaseModel):
|
||||||
image: ImageField = Field(description="The control image")
|
image: ImageField = Field(description="The control image")
|
||||||
control_model: ModelMetadataField = Field(description="The ControlNet model to use")
|
processed_image: Optional[ImageField] = Field(default=None, description="The control image, after processing.")
|
||||||
|
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||||
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
|
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||||
begin_step_percent: float = Field(
|
begin_step_percent: float = Field(
|
||||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||||
@ -178,7 +166,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
|||||||
default=None,
|
default=None,
|
||||||
description="The number of skipped CLIP layers",
|
description="The number of skipped CLIP layers",
|
||||||
)
|
)
|
||||||
model: Optional[ModelMetadataField] = InputField(default=None, description="The main model used for inference")
|
model: Optional[ModelIdentifierField] = InputField(default=None, description="The main model used for inference")
|
||||||
controlnets: Optional[list[ControlNetMetadataField]] = InputField(
|
controlnets: Optional[list[ControlNetMetadataField]] = InputField(
|
||||||
default=None, description="The ControlNets used for inference"
|
default=None, description="The ControlNets used for inference"
|
||||||
)
|
)
|
||||||
@ -197,7 +185,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
|||||||
default=None,
|
default=None,
|
||||||
description="The name of the initial image",
|
description="The name of the initial image",
|
||||||
)
|
)
|
||||||
vae: Optional[ModelMetadataField] = InputField(
|
vae: Optional[ModelIdentifierField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="The VAE used for decoding, if the main model's default was not used",
|
description="The VAE used for decoding, if the main model's default was not used",
|
||||||
)
|
)
|
||||||
@ -228,7 +216,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# SDXL Refiner
|
# SDXL Refiner
|
||||||
refiner_model: Optional[ModelMetadataField] = InputField(
|
refiner_model: Optional[ModelIdentifierField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="The SDXL Refiner model used",
|
description="The SDXL Refiner model used",
|
||||||
)
|
)
|
||||||
|
@ -179,6 +179,8 @@ from pydantic import BaseModel, Field, field_validator
|
|||||||
from pydantic.config import JsonDict
|
from pydantic.config import JsonDict
|
||||||
from pydantic_settings import SettingsConfigDict
|
from pydantic_settings import SettingsConfigDict
|
||||||
|
|
||||||
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||||
|
|
||||||
from .config_base import InvokeAISettings
|
from .config_base import InvokeAISettings
|
||||||
|
|
||||||
INIT_FILE = Path("invokeai.yaml")
|
INIT_FILE = Path("invokeai.yaml")
|
||||||
@ -259,7 +261,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
profile_prefix: **Development**: An optional prefix for profile output files.
|
profile_prefix: **Development**: An optional prefix for profile output files.
|
||||||
profiles_dir: **Development**: Path to profiles output directory.
|
profiles_dir: **Development**: Path to profiles output directory.
|
||||||
version: **CLIArgs**: CLI arg - show InvokeAI version and exit.
|
version: **CLIArgs**: CLI arg - show InvokeAI version and exit.
|
||||||
skip_model_hash: **Model Install**: Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models.
|
hashing_algorithm: **Model Install**: Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.
|
||||||
remote_api_tokens: **Model Install**: List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.
|
remote_api_tokens: **Model Install**: List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.
|
||||||
ram: **Model Cache**: Maximum memory amount used by memory model cache for rapid switching (GB).
|
ram: **Model Cache**: Maximum memory amount used by memory model cache for rapid switching (GB).
|
||||||
vram: **Model Cache**: Amount of VRAM reserved for model storage (GB)
|
vram: **Model Cache**: Amount of VRAM reserved for model storage (GB)
|
||||||
@ -360,7 +362,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory.", json_schema_extra=Categories.Nodes)
|
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory.", json_schema_extra=Categories.Nodes)
|
||||||
|
|
||||||
# MODEL INSTALL
|
# MODEL INSTALL
|
||||||
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models.", json_schema_extra=Categories.ModelInstall)
|
hashing_algorithm : HASHING_ALGORITHMS = Field(default="blake3", description="Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.", json_schema_extra=Categories.ModelInstall)
|
||||||
remote_api_tokens : Optional[list[URLRegexToken]] = Field(
|
remote_api_tokens : Optional[list[URLRegexToken]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.",
|
description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.",
|
||||||
|
@ -12,6 +12,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.util.misc import get_timestamp
|
from invokeai.app.util.misc import get_timestamp
|
||||||
from invokeai.backend.model_manager import AnyModelConfig
|
from invokeai.backend.model_manager import AnyModelConfig
|
||||||
|
from invokeai.backend.model_manager.config import SubModelType
|
||||||
|
|
||||||
|
|
||||||
class EventServiceBase:
|
class EventServiceBase:
|
||||||
@ -80,7 +81,7 @@ class EventServiceBase:
|
|||||||
"graph_execution_state_id": graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
"source_node_id": source_node_id,
|
"source_node_id": source_node_id,
|
||||||
"progress_image": progress_image.model_dump() if progress_image is not None else None,
|
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None,
|
||||||
"step": step,
|
"step": step,
|
||||||
"order": order,
|
"order": order,
|
||||||
"total_steps": total_steps,
|
"total_steps": total_steps,
|
||||||
@ -180,6 +181,7 @@ class EventServiceBase:
|
|||||||
queue_batch_id: str,
|
queue_batch_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
model_config: AnyModelConfig,
|
model_config: AnyModelConfig,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when a model is requested"""
|
"""Emitted when a model is requested"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
@ -189,7 +191,8 @@ class EventServiceBase:
|
|||||||
"queue_item_id": queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
"queue_batch_id": queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
"graph_execution_state_id": graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
"model_config": model_config.model_dump(),
|
"model_config": model_config.model_dump(mode="json"),
|
||||||
|
"submodel_type": submodel_type,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -200,6 +203,7 @@ class EventServiceBase:
|
|||||||
queue_batch_id: str,
|
queue_batch_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
model_config: AnyModelConfig,
|
model_config: AnyModelConfig,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||||
self.__emit_queue_event(
|
self.__emit_queue_event(
|
||||||
@ -209,7 +213,8 @@ class EventServiceBase:
|
|||||||
"queue_item_id": queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
"queue_batch_id": queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
"graph_execution_state_id": graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
"model_config": model_config.model_dump(),
|
"model_config": model_config.model_dump(mode="json"),
|
||||||
|
"submodel_type": submodel_type,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -254,8 +259,8 @@ class EventServiceBase:
|
|||||||
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||||
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||||
},
|
},
|
||||||
"batch_status": batch_status.model_dump(),
|
"batch_status": batch_status.model_dump(mode="json"),
|
||||||
"queue_status": queue_status.model_dump(),
|
"queue_status": queue_status.model_dump(mode="json"),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -405,7 +410,7 @@ class EventServiceBase:
|
|||||||
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
|
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_model_install_cancelled(self, source: str) -> None:
|
def emit_model_install_cancelled(self, source: str, id: int) -> None:
|
||||||
"""
|
"""
|
||||||
Emit when an install job is cancelled.
|
Emit when an install job is cancelled.
|
||||||
|
|
||||||
@ -413,7 +418,7 @@ class EventServiceBase:
|
|||||||
"""
|
"""
|
||||||
self.__emit_model_event(
|
self.__emit_model_event(
|
||||||
event_name="model_install_cancelled",
|
event_name="model_install_cancelled",
|
||||||
payload={"source": source},
|
payload={"source": source, "id": id},
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:
|
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:
|
||||||
|
@ -22,7 +22,6 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
|||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||||
from invokeai.app.util.misc import uuid_string
|
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@ -134,6 +133,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._download_cache.clear()
|
self._download_cache.clear()
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
|
def _put_in_queue(self, job: ModelInstallJob) -> None:
|
||||||
|
print(f'DEBUG: in _put_in_queue(job={job.id})')
|
||||||
|
if self._stop_event.is_set():
|
||||||
|
self.cancel_job(job)
|
||||||
|
else:
|
||||||
|
print(f'DEBUG: putting {job.id} into the install queue')
|
||||||
|
self._install_queue.put(job)
|
||||||
|
|
||||||
def register_path(
|
def register_path(
|
||||||
self,
|
self,
|
||||||
model_path: Union[Path, str],
|
model_path: Union[Path, str],
|
||||||
@ -154,10 +161,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
config = config or {}
|
config = config or {}
|
||||||
|
|
||||||
if self._app_config.skip_model_hash:
|
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm)
|
||||||
config["hash"] = uuid_string()
|
|
||||||
|
|
||||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
|
|
||||||
|
|
||||||
if preferred_name := config.get("name"):
|
if preferred_name := config.get("name"):
|
||||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||||
@ -222,7 +226,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
if isinstance(source, LocalModelSource):
|
if isinstance(source, LocalModelSource):
|
||||||
install_job = self._import_local_model(source, config)
|
install_job = self._import_local_model(source, config)
|
||||||
self._install_queue.put(install_job) # synchronously install
|
self._put_in_queue(install_job) # synchronously install
|
||||||
elif isinstance(source, HFModelSource):
|
elif isinstance(source, HFModelSource):
|
||||||
install_job = self._import_from_hf(source, config)
|
install_job = self._import_from_hf(source, config)
|
||||||
elif isinstance(source, URLModelSource):
|
elif isinstance(source, URLModelSource):
|
||||||
@ -332,7 +336,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
yaml_path.rename(yaml_path.with_suffix(".yaml.bak"))
|
yaml_path.rename(yaml_path.with_suffix(".yaml.bak"))
|
||||||
|
|
||||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||||
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
|
self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}
|
||||||
callback = self._scan_install if install else self._scan_register
|
callback = self._scan_install if install else self._scan_register
|
||||||
search = ModelSearch(on_model_found=callback)
|
search = ModelSearch(on_model_found=callback)
|
||||||
self._models_installed.clear()
|
self._models_installed.clear()
|
||||||
@ -346,7 +350,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||||
model = self.record_store.get_model(key)
|
model = self.record_store.get_model(key)
|
||||||
models_dir = self.app_config.models_path
|
models_dir = self.app_config.models_path
|
||||||
model_path = Path(model.path)
|
model_path = models_dir / Path(model.path) # handle legacy relative model paths
|
||||||
if model_path.is_relative_to(models_dir):
|
if model_path.is_relative_to(models_dir):
|
||||||
self.unconditionally_delete(key)
|
self.unconditionally_delete(key)
|
||||||
else:
|
else:
|
||||||
@ -354,7 +358,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
||||||
model = self.record_store.get_model(key)
|
model = self.record_store.get_model(key)
|
||||||
model_path = Path(model.path)
|
model_path = self.app_config.models_path / model.path
|
||||||
if model_path.is_dir():
|
if model_path.is_dir():
|
||||||
rmtree(model_path)
|
rmtree(model_path)
|
||||||
else:
|
else:
|
||||||
@ -407,10 +411,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
done = True
|
done = True
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
|
print(f'DEBUG: _install_next_item() checking for a job to install')
|
||||||
job = self._install_queue.get(timeout=1)
|
job = self._install_queue.get(timeout=1)
|
||||||
except Empty:
|
except Empty:
|
||||||
continue
|
continue
|
||||||
|
print(f'DEBUG: _install_next_item() got job {job.id}, status={job.status}')
|
||||||
assert job.local_path is not None
|
assert job.local_path is not None
|
||||||
try:
|
try:
|
||||||
if job.cancelled:
|
if job.cancelled:
|
||||||
@ -436,6 +441,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
else:
|
else:
|
||||||
key = self.install_path(job.local_path, job.config_in)
|
key = self.install_path(job.local_path, job.config_in)
|
||||||
job.config_out = self.record_store.get_model(key)
|
job.config_out = self.record_store.get_model(key)
|
||||||
|
print(f'DEBUG: _install_next_item() signaling completion for job={job.id}, status={job.status}')
|
||||||
self._signal_job_completed(job)
|
self._signal_job_completed(job)
|
||||||
|
|
||||||
except InvalidModelConfigException as excp:
|
except InvalidModelConfigException as excp:
|
||||||
@ -496,6 +502,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
for cur_base_model in BaseModelType:
|
for cur_base_model in BaseModelType:
|
||||||
for cur_model_type in ModelType:
|
for cur_model_type in ModelType:
|
||||||
models_dir = self._app_config.models_path / Path(cur_base_model.value, cur_model_type.value)
|
models_dir = self._app_config.models_path / Path(cur_base_model.value, cur_model_type.value)
|
||||||
|
if not models_dir.exists():
|
||||||
|
continue
|
||||||
installed.update(self.scan_directory(models_dir))
|
installed.update(self.scan_directory(models_dir))
|
||||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||||
|
|
||||||
@ -522,7 +530,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
new_path = models_dir / model.base.value / model.type.value / old_path.name
|
new_path = models_dir / model.base.value / model.type.value / old_path.name
|
||||||
|
|
||||||
if old_path == new_path:
|
if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
|
||||||
return model
|
return model
|
||||||
|
|
||||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||||
@ -585,10 +593,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
) -> str:
|
) -> str:
|
||||||
config = config or {}
|
config = config or {}
|
||||||
|
|
||||||
if self._app_config.skip_model_hash:
|
info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm)
|
||||||
config["hash"] = uuid_string()
|
|
||||||
|
|
||||||
info = info or ModelProbe.probe(model_path, config)
|
|
||||||
|
|
||||||
model_path = model_path.resolve()
|
model_path = model_path.resolve()
|
||||||
|
|
||||||
@ -786,14 +791,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
||||||
self._logger.info(f"{download_job.source}: model download complete")
|
self._logger.info(f"{download_job.source}: model download complete")
|
||||||
|
print(f'DEBUG: _download_complete_callback(download_job={download_job.source}')
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache[download_job.source]
|
install_job = self._download_cache.pop(download_job.source, None)
|
||||||
self._download_cache.pop(download_job.source, None)
|
print(f'DEBUG: download_job={download_job.source} / install_job={install_job}')
|
||||||
|
|
||||||
# are there any more active jobs left in this task?
|
# are there any more active jobs left in this task?
|
||||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
if install_job and install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||||
|
print(f'DEBUG: setting job {install_job.id} to DOWNLOADS_DONE')
|
||||||
install_job.status = InstallStatus.DOWNLOADS_DONE
|
install_job.status = InstallStatus.DOWNLOADS_DONE
|
||||||
self._install_queue.put(install_job)
|
print(f'DEBUG: putting {install_job.id} into the install queue')
|
||||||
|
self._put_in_queue(install_job)
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
self._downloads_changed_event.set()
|
self._downloads_changed_event.set()
|
||||||
@ -835,7 +842,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
if all(x.in_terminal_state for x in install_job.download_parts):
|
if all(x.in_terminal_state for x in install_job.download_parts):
|
||||||
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
||||||
self._install_queue.put(install_job)
|
self._put_in_queue(install_job)
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------------
|
||||||
# Internal methods that put events on the event bus
|
# Internal methods that put events on the event bus
|
||||||
@ -892,7 +899,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
|
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
|
||||||
self._logger.info(f"{job.source}: model installation was cancelled")
|
self._logger.info(f"{job.source}: model installation was cancelled")
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_model_install_cancelled(str(job.source))
|
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_fetcher_from_url(url: str):
|
def get_fetcher_from_url(url: str):
|
||||||
|
@ -68,6 +68,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
self._emit_load_event(
|
self._emit_load_event(
|
||||||
context_data=context_data,
|
context_data=context_data,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
|
submodel_type=submodel_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
|
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
|
||||||
@ -82,6 +83,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
self._emit_load_event(
|
self._emit_load_event(
|
||||||
context_data=context_data,
|
context_data=context_data,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
|
submodel_type=submodel_type,
|
||||||
loaded=True,
|
loaded=True,
|
||||||
)
|
)
|
||||||
return loaded_model
|
return loaded_model
|
||||||
@ -91,6 +93,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
context_data: InvocationContextData,
|
context_data: InvocationContextData,
|
||||||
model_config: AnyModelConfig,
|
model_config: AnyModelConfig,
|
||||||
loaded: Optional[bool] = False,
|
loaded: Optional[bool] = False,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self._invoker:
|
if not self._invoker:
|
||||||
return
|
return
|
||||||
@ -102,6 +105,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
queue_batch_id=context_data.queue_item.batch_id,
|
queue_batch_id=context_data.queue_item.batch_id,
|
||||||
graph_execution_state_id=context_data.queue_item.session_id,
|
graph_execution_state_id=context_data.queue_item.session_id,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
|
submodel_type=submodel_type,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._invoker.services.events.emit_model_load_completed(
|
self._invoker.services.events.emit_model_load_completed(
|
||||||
@ -110,4 +114,5 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
queue_batch_id=context_data.queue_item.batch_id,
|
queue_batch_id=context_data.queue_item.batch_id,
|
||||||
graph_execution_state_id=context_data.queue_item.session_id,
|
graph_execution_state_id=context_data.queue_item.session_id,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
|
submodel_type=submodel_type,
|
||||||
)
|
)
|
||||||
|
@ -13,9 +13,11 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
|||||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.backend.util.util import download_with_progress_bar
|
from invokeai.backend.util.util import download_with_progress_bar
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
logger = InvokeAILogger.get_logger(config=config)
|
||||||
|
|
||||||
DEPTH_ANYTHING_MODELS = {
|
DEPTH_ANYTHING_MODELS = {
|
||||||
"large": {
|
"large": {
|
||||||
@ -54,8 +56,9 @@ class DepthAnythingDetector:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.model = None
|
self.model = None
|
||||||
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
||||||
|
self.device = choose_torch_device()
|
||||||
|
|
||||||
def load_model(self, model_size=Literal["large", "base", "small"]):
|
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
|
||||||
DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"])
|
DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"])
|
||||||
if not DEPTH_ANYTHING_MODEL_PATH.exists():
|
if not DEPTH_ANYTHING_MODEL_PATH.exists():
|
||||||
download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH)
|
download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH)
|
||||||
@ -71,8 +74,6 @@ class DepthAnythingDetector:
|
|||||||
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||||
case "large":
|
case "large":
|
||||||
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||||
case _:
|
|
||||||
raise TypeError("Not a supported model")
|
|
||||||
|
|
||||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
@ -80,20 +81,20 @@ class DepthAnythingDetector:
|
|||||||
self.model.to(choose_torch_device())
|
self.model.to(choose_torch_device())
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def to(self, device):
|
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||||
self.model.to(device)
|
if not self.model:
|
||||||
return self
|
logger.warn("DepthAnything model was not loaded. Returning original image")
|
||||||
|
return image
|
||||||
|
|
||||||
def __call__(self, image, resolution=512, offload=False):
|
np_image = np.array(image, dtype=np.uint8)
|
||||||
image = np.array(image, dtype=np.uint8)
|
np_image = np_image[:, :, ::-1] / 255.0
|
||||||
image = image[:, :, ::-1] / 255.0
|
|
||||||
|
|
||||||
image_height, image_width = image.shape[:2]
|
image_height, image_width = np_image.shape[:2]
|
||||||
image = transform({"image": image})["image"]
|
np_image = transform({"image": np_image})["image"]
|
||||||
image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
|
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device())
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
depth = self.model(image)
|
depth = self.model(tensor_image)
|
||||||
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
|
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
|
||||||
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
||||||
|
|
||||||
@ -103,7 +104,4 @@ class DepthAnythingDetector:
|
|||||||
new_height = int(image_height * (resolution / image_width))
|
new_height = int(image_height * (resolution / image_width))
|
||||||
depth_map = depth_map.resize((resolution, new_height))
|
depth_map = depth_map.resize((resolution, new_height))
|
||||||
|
|
||||||
if offload:
|
|
||||||
del self.model
|
|
||||||
|
|
||||||
return depth_map
|
return depth_map
|
||||||
|
@ -11,17 +11,6 @@ def check_invokeai_root(config: InvokeAIAppConfig):
|
|||||||
try:
|
try:
|
||||||
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
|
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
|
||||||
assert config.models_path.exists(), f"{config.models_path} not found"
|
assert config.models_path.exists(), f"{config.models_path} not found"
|
||||||
if not config.ignore_missing_core_models:
|
|
||||||
for model in [
|
|
||||||
"CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
|
||||||
"bert-base-uncased",
|
|
||||||
"clip-vit-large-patch14",
|
|
||||||
"sd-vae-ft-mse",
|
|
||||||
"stable-diffusion-2-clip",
|
|
||||||
"stable-diffusion-safety-checker",
|
|
||||||
]:
|
|
||||||
path = config.models_path / f"core/convert/{model}"
|
|
||||||
assert path.exists(), f"{path} is missing"
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print()
|
print()
|
||||||
print(f"An exception has occurred: {str(e)}")
|
print(f"An exception has occurred: {str(e)}")
|
||||||
@ -32,10 +21,5 @@ def check_invokeai_root(config: InvokeAIAppConfig):
|
|||||||
print(
|
print(
|
||||||
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
|
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
|
||||||
)
|
)
|
||||||
print(
|
|
||||||
'** (To skip this check completely, add "--ignore_missing_core_models" to your CLI args. Not installing '
|
|
||||||
"these core models will prevent the loading of some or all .safetensors and .ckpt files. However, you can "
|
|
||||||
"always come back and install these core models in the future.)"
|
|
||||||
)
|
|
||||||
input("Press any key to continue...")
|
input("Press any key to continue...")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
@ -25,20 +25,20 @@ import npyscreen
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from diffusers import AutoencoderKL, ModelMixin
|
from diffusers import ModelMixin
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from huggingface_hub import HfFolder
|
from huggingface_hub import HfFolder
|
||||||
from huggingface_hub import login as hf_hub_login
|
from huggingface_hub import login as hf_hub_login
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pydantic.error_wrappers import ValidationError
|
from pydantic.error_wrappers import ValidationError
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
|
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
|
||||||
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
||||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
from invokeai.backend.model_manager import ModelType
|
||||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.frontend.install.model_install import addModelsForm
|
from invokeai.frontend.install.model_install import addModelsForm
|
||||||
@ -210,51 +210,15 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
|||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
def download_conversion_models():
|
def download_safety_checker():
|
||||||
target_dir = config.models_path / "core/convert"
|
target_dir = config.models_path / "core/convert"
|
||||||
kwargs = {} # for future use
|
kwargs = {} # for future use
|
||||||
try:
|
try:
|
||||||
logger.info("Downloading core tokenizers and text encoders")
|
|
||||||
|
|
||||||
# bert
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
||||||
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
|
|
||||||
bert.save_pretrained(target_dir / "bert-base-uncased", safe_serialization=True)
|
|
||||||
|
|
||||||
# sd-1
|
|
||||||
repo_id = "openai/clip-vit-large-patch14"
|
|
||||||
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / "clip-vit-large-patch14")
|
|
||||||
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / "clip-vit-large-patch14")
|
|
||||||
|
|
||||||
# sd-2
|
|
||||||
repo_id = "stabilityai/stable-diffusion-2"
|
|
||||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
|
|
||||||
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "tokenizer", safe_serialization=True)
|
|
||||||
|
|
||||||
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
|
|
||||||
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "text_encoder", safe_serialization=True)
|
|
||||||
|
|
||||||
# sd-xl - tokenizer_2
|
|
||||||
repo_id = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
|
||||||
_, model_name = repo_id.split("/")
|
|
||||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
|
|
||||||
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
|
|
||||||
|
|
||||||
pipeline = CLIPTextConfig.from_pretrained(repo_id, **kwargs)
|
|
||||||
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
|
|
||||||
|
|
||||||
# VAE
|
|
||||||
logger.info("Downloading stable diffusion VAE")
|
|
||||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", **kwargs)
|
|
||||||
vae.save_pretrained(target_dir / "sd-vae-ft-mse", safe_serialization=True)
|
|
||||||
|
|
||||||
# safety checking
|
# safety checking
|
||||||
logger.info("Downloading safety checker")
|
logger.info("Downloading safety checker")
|
||||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
pipeline = AutoFeatureExtractor.from_pretrained(repo_id, **kwargs)
|
pipeline = AutoFeatureExtractor.from_pretrained(repo_id, **kwargs)
|
||||||
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
||||||
|
|
||||||
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id, **kwargs)
|
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id, **kwargs)
|
||||||
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@ -307,7 +271,7 @@ def download_lama():
|
|||||||
def download_support_models() -> None:
|
def download_support_models() -> None:
|
||||||
download_realesrgan()
|
download_realesrgan()
|
||||||
download_lama()
|
download_lama()
|
||||||
download_conversion_models()
|
download_safety_checker()
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
@ -744,12 +708,7 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
|||||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||||
|
|
||||||
dest = root / "models"
|
dest = root / "models"
|
||||||
for model_base in BaseModelType:
|
dest.mkdir(parents=True, exist_ok=True)
|
||||||
for model_type in ModelType:
|
|
||||||
path = dest / model_base.value / model_type.value
|
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
|
||||||
path = dest / "core"
|
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
|
@ -1,12 +1,4 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||||
"""
|
|
||||||
Fast hashing of diffusers and checkpoint-style models.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from invokeai.backend.model_managre.model_hash import FastModelHash
|
|
||||||
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
|
||||||
'a8e693a126ea5b831c96064dc569956f'
|
|
||||||
"""
|
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
@ -15,9 +7,9 @@ from typing import Callable, Literal, Optional, Union
|
|||||||
|
|
||||||
from blake3 import blake3
|
from blake3 import blake3
|
||||||
|
|
||||||
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
ALGORITHM = Literal[
|
HASHING_ALGORITHMS = Literal[
|
||||||
"md5",
|
"md5",
|
||||||
"sha1",
|
"sha1",
|
||||||
"sha224",
|
"sha224",
|
||||||
@ -33,12 +25,15 @@ ALGORITHM = Literal[
|
|||||||
"shake_128",
|
"shake_128",
|
||||||
"shake_256",
|
"shake_256",
|
||||||
"blake3",
|
"blake3",
|
||||||
|
"blake3_single",
|
||||||
|
"random",
|
||||||
]
|
]
|
||||||
|
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
|
||||||
|
|
||||||
|
|
||||||
class ModelHash:
|
class ModelHash:
|
||||||
"""
|
"""
|
||||||
Creates a hash of a model using a specified algorithm.
|
Creates a hash of a model using a specified algorithm. The hash is prefixed by the algorithm used.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
|
||||||
@ -53,20 +48,29 @@ class ModelHash:
|
|||||||
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
|
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
|
||||||
that directory hashes are never weaker than the file hashes.
|
that directory hashes are never weaker than the file hashes.
|
||||||
|
|
||||||
|
A convenience algorithm choice of "random" is also available, which returns a random string. This is not a hash.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
```py
|
```py
|
||||||
# BLAKE3 hash
|
# BLAKE3 hash
|
||||||
ModelHash().hash("path/to/some/model.safetensors")
|
ModelHash().hash("path/to/some/model.safetensors") # "blake3:ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"
|
||||||
# MD5
|
# MD5
|
||||||
ModelHash("md5").hash("path/to/model/dir/")
|
ModelHash("md5").hash("path/to/model/dir/") # "md5:a0cd925fc063f98dbf029eee315060c3"
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
|
def __init__(
|
||||||
|
self, algorithm: HASHING_ALGORITHMS = "blake3", file_filter: Optional[Callable[[str], bool]] = None
|
||||||
|
) -> None:
|
||||||
|
self.algorithm: HASHING_ALGORITHMS = algorithm
|
||||||
if algorithm == "blake3":
|
if algorithm == "blake3":
|
||||||
self._hash_file = self._blake3
|
self._hash_file = self._blake3
|
||||||
|
elif algorithm == "blake3_single":
|
||||||
|
self._hash_file = self._blake3_single
|
||||||
elif algorithm in hashlib.algorithms_available:
|
elif algorithm in hashlib.algorithms_available:
|
||||||
self._hash_file = self._get_hashlib(algorithm)
|
self._hash_file = self._get_hashlib(algorithm)
|
||||||
|
elif algorithm == "random":
|
||||||
|
self._hash_file = self._random
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Algorithm {algorithm} not available")
|
raise ValueError(f"Algorithm {algorithm} not available")
|
||||||
|
|
||||||
@ -87,10 +91,12 @@ class ModelHash:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
|
# blake3_single is a single-threaded version of blake3, prefix should still be "blake3:"
|
||||||
|
prefix = self._get_prefix(self.algorithm)
|
||||||
if model_path.is_file():
|
if model_path.is_file():
|
||||||
return self._hash_file(model_path)
|
return prefix + self._hash_file(model_path)
|
||||||
elif model_path.is_dir():
|
elif model_path.is_dir():
|
||||||
return self._hash_dir(model_path)
|
return prefix + self._hash_dir(model_path)
|
||||||
else:
|
else:
|
||||||
raise OSError(f"Not a valid file or directory: {model_path}")
|
raise OSError(f"Not a valid file or directory: {model_path}")
|
||||||
|
|
||||||
@ -114,6 +120,7 @@ class ModelHash:
|
|||||||
composite_hasher = blake3()
|
composite_hasher = blake3()
|
||||||
for h in component_hashes:
|
for h in component_hashes:
|
||||||
composite_hasher.update(h.encode("utf-8"))
|
composite_hasher.update(h.encode("utf-8"))
|
||||||
|
|
||||||
return composite_hasher.hexdigest()
|
return composite_hasher.hexdigest()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -137,7 +144,7 @@ class ModelHash:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _blake3(file_path: Path) -> str:
|
def _blake3(file_path: Path) -> str:
|
||||||
"""Hashes a file using BLAKE3
|
"""Hashes a file using BLAKE3, using parallelized and memory-mapped I/O to avoid reading the entire file into memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: Path to the file to hash
|
file_path: Path to the file to hash
|
||||||
@ -150,7 +157,21 @@ class ModelHash:
|
|||||||
return file_hasher.hexdigest()
|
return file_hasher.hexdigest()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
|
def _blake3_single(file_path: Path) -> str:
|
||||||
|
"""Hashes a file using BLAKE3, without parallelism. Suitable for spinning hard drives.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to hash
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hexdigest of the hash of the file
|
||||||
|
"""
|
||||||
|
file_hasher = blake3()
|
||||||
|
file_hasher.update_mmap(file_path)
|
||||||
|
return file_hasher.hexdigest()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_hashlib(algorithm: HASHING_ALGORITHMS) -> Callable[[Path], str]:
|
||||||
"""Factory function that returns a function to hash a file with the given algorithm.
|
"""Factory function that returns a function to hash a file with the given algorithm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -172,6 +193,13 @@ class ModelHash:
|
|||||||
|
|
||||||
return hashlib_hasher
|
return hashlib_hasher
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _random(_file_path: Path) -> str:
|
||||||
|
"""Returns a random string. This is not a hash.
|
||||||
|
|
||||||
|
The string is a UUID, hashed with BLAKE3 to ensure that it is unique."""
|
||||||
|
return blake3(uuid_string().encode()).hexdigest()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _default_file_filter(file_path: str) -> bool:
|
def _default_file_filter(file_path: str) -> bool:
|
||||||
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
|
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
|
||||||
@ -183,3 +211,9 @@ class ModelHash:
|
|||||||
True if the file matches the given extensions, otherwise False
|
True if the file matches the given extensions, otherwise False
|
||||||
"""
|
"""
|
||||||
return file_path.endswith(MODEL_FILE_EXTENSIONS)
|
return file_path.endswith(MODEL_FILE_EXTENSIONS)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_prefix(algorithm: HASHING_ALGORITHMS) -> str:
|
||||||
|
"""Return the prefix for the given algorithm, e.g. \"blake3:\" or \"md5:\"."""
|
||||||
|
# blake3_single is a single-threaded version of blake3, prefix should still be "blake3:"
|
||||||
|
return "blake3:" if algorithm == "blake3_single" else f"{algorithm}:"
|
@ -131,13 +131,20 @@ class ModelSourceType(str, Enum):
|
|||||||
HFRepoID = "hf_repo_id"
|
HFRepoID = "hf_repo_id"
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
|
||||||
|
|
||||||
|
|
||||||
class MainModelDefaultSettings(BaseModel):
|
class MainModelDefaultSettings(BaseModel):
|
||||||
vae: str | None
|
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
|
||||||
vae_precision: str | None
|
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
|
||||||
scheduler: SCHEDULER_NAME_VALUES | None
|
scheduler: SCHEDULER_NAME_VALUES | None = Field(default=None, description="Default scheduler for this model")
|
||||||
steps: int | None
|
steps: int | None = Field(default=None, gt=0, description="Default number of steps for this model")
|
||||||
cfg_scale: float | None
|
cfg_scale: float | None = Field(default=None, ge=1, description="Default CFG Scale for this model")
|
||||||
cfg_rescale_multiplier: float | None
|
cfg_rescale_multiplier: float | None = Field(
|
||||||
|
default=None, ge=0, lt=1, description="Default CFG Rescale Multiplier for this model"
|
||||||
|
)
|
||||||
|
width: int | None = Field(default=None, multiple_of=8, ge=64, description="Default width for this model")
|
||||||
|
height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model")
|
||||||
|
|
||||||
|
|
||||||
class ControlAdapterDefaultSettings(BaseModel):
|
class ControlAdapterDefaultSettings(BaseModel):
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -3,9 +3,6 @@
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
|
||||||
from safetensors.torch import load_file as safetensors_load_file
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@ -37,27 +34,25 @@ class ControlNetLoader(GenericDiffusersLoader):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
|
||||||
raise Exception(f"ControlNet conversion not supported for model type: {config.base}")
|
|
||||||
else:
|
|
||||||
assert isinstance(config, CheckpointConfigBase)
|
assert isinstance(config, CheckpointConfigBase)
|
||||||
config_file = config.config_path
|
config_file = config.config_path
|
||||||
|
|
||||||
if model_path.suffix == ".safetensors":
|
image_size = (
|
||||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
512
|
||||||
else:
|
if config.base == BaseModelType.StableDiffusion1
|
||||||
checkpoint = torch.load(model_path, map_location="cpu")
|
else 768
|
||||||
|
if config.base == BaseModelType.StableDiffusion2
|
||||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
else 1024
|
||||||
if "state_dict" in checkpoint:
|
)
|
||||||
checkpoint = checkpoint["state_dict"]
|
|
||||||
|
|
||||||
|
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||||
|
with open(self._app_config.root_path / config_file, "r") as config_stream:
|
||||||
convert_controlnet_to_diffusers(
|
convert_controlnet_to_diffusers(
|
||||||
model_path,
|
model_path,
|
||||||
output_path,
|
output_path,
|
||||||
original_config_file=self._app_config.root_path / config_file,
|
original_config_file=config_stream,
|
||||||
image_size=512,
|
image_size=image_size,
|
||||||
scan_needed=True,
|
precision=self._torch_dtype,
|
||||||
from_safetensors=model_path.suffix == ".safetensors",
|
from_safetensors=model_path.suffix == ".safetensors",
|
||||||
)
|
)
|
||||||
return output_path
|
return output_path
|
||||||
|
@ -4,9 +4,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@ -14,7 +11,7 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelRepoVariant,
|
ModelRepoVariant,
|
||||||
ModelType,
|
ModelType,
|
||||||
ModelVariantType,
|
SchedulerPredictionType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
|
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
|
||||||
@ -68,27 +65,31 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
|||||||
|
|
||||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||||
assert isinstance(config, MainCheckpointConfig)
|
assert isinstance(config, MainCheckpointConfig)
|
||||||
variant = config.variant
|
|
||||||
base = config.base
|
base = config.base
|
||||||
pipeline_class = (
|
|
||||||
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
|
|
||||||
)
|
|
||||||
|
|
||||||
config_file = config.config_path
|
config_file = config.config_path
|
||||||
|
prediction_type = config.prediction_type.value
|
||||||
|
upcast_attention = config.upcast_attention
|
||||||
|
image_size = (
|
||||||
|
1024
|
||||||
|
if base == BaseModelType.StableDiffusionXL
|
||||||
|
else 768
|
||||||
|
if config.prediction_type == SchedulerPredictionType.VPrediction and base == BaseModelType.StableDiffusion2
|
||||||
|
else 512
|
||||||
|
)
|
||||||
|
|
||||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||||
convert_ckpt_to_diffusers(
|
convert_ckpt_to_diffusers(
|
||||||
model_path,
|
model_path,
|
||||||
output_path,
|
output_path,
|
||||||
model_type=self.model_base_to_model_type[base],
|
model_type=self.model_base_to_model_type[base],
|
||||||
model_version=base,
|
|
||||||
model_variant=variant,
|
|
||||||
original_config_file=self._app_config.root_path / config_file,
|
original_config_file=self._app_config.root_path / config_file,
|
||||||
extract_ema=True,
|
extract_ema=True,
|
||||||
scan_needed=True,
|
|
||||||
pipeline_class=pipeline_class,
|
|
||||||
from_safetensors=model_path.suffix == ".safetensors",
|
from_safetensors=model_path.suffix == ".safetensors",
|
||||||
precision=self._torch_dtype,
|
precision=self._torch_dtype,
|
||||||
|
prediction_type=prediction_type,
|
||||||
|
image_size=image_size,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
load_safety_checker=False,
|
load_safety_checker=False,
|
||||||
)
|
)
|
||||||
return output_path
|
return output_path
|
||||||
|
@ -57,12 +57,12 @@ class VAELoader(GenericDiffusersLoader):
|
|||||||
|
|
||||||
ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
|
ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
|
||||||
assert isinstance(ckpt_config, DictConfig)
|
assert isinstance(ckpt_config, DictConfig)
|
||||||
|
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||||
vae_model = convert_ldm_vae_to_diffusers(
|
vae_model = convert_ldm_vae_to_diffusers(
|
||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint,
|
||||||
vae_config=ckpt_config,
|
vae_config=ckpt_config,
|
||||||
image_size=512,
|
image_size=512,
|
||||||
|
precision=self._torch_dtype,
|
||||||
)
|
)
|
||||||
vae_model.to(self._torch_dtype) # set precision appropriately
|
|
||||||
vae_model.save_pretrained(output_path, safe_serialization=True)
|
vae_model.save_pretrained(output_path, safe_serialization=True)
|
||||||
return output_path
|
return output_path
|
||||||
|
@ -90,8 +90,35 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# diffusers models have a `model_index.json` or `config.json` file
|
||||||
|
is_diffusers = any(str(f.url).endswith(("model_index.json", "config.json")) for f in files)
|
||||||
|
|
||||||
|
# These URLs will be exposed to the user - I think these are the only file types we fully support
|
||||||
|
ckpt_urls = (
|
||||||
|
None
|
||||||
|
if is_diffusers
|
||||||
|
else [
|
||||||
|
f.url
|
||||||
|
for f in files
|
||||||
|
if str(f.url).endswith(
|
||||||
|
(
|
||||||
|
".safetensors",
|
||||||
|
".bin",
|
||||||
|
".pth",
|
||||||
|
".pt",
|
||||||
|
".ckpt",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return HuggingFaceMetadata(
|
return HuggingFaceMetadata(
|
||||||
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
|
id=model_info.id,
|
||||||
|
name=name,
|
||||||
|
files=files,
|
||||||
|
api_response=json.dumps(model_info.__dict__, default=str),
|
||||||
|
is_diffusers=is_diffusers,
|
||||||
|
ckpt_urls=ckpt_urls,
|
||||||
)
|
)
|
||||||
|
|
||||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||||
|
@ -84,6 +84,10 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
|
|||||||
type: Literal["huggingface"] = "huggingface"
|
type: Literal["huggingface"] = "huggingface"
|
||||||
id: str = Field(description="The HF model id")
|
id: str = Field(description="The HF model id")
|
||||||
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
|
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
|
||||||
|
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model", default=False)
|
||||||
|
ckpt_urls: Optional[List[AnyHttpUrl]] = Field(
|
||||||
|
description="URLs for all checkpoint format models in the metadata", default=None
|
||||||
|
)
|
||||||
|
|
||||||
def download_urls(
|
def download_urls(
|
||||||
self,
|
self,
|
||||||
|
@ -9,6 +9,7 @@ from picklescan.scanner import scan_file_path
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||||
from invokeai.backend.util.util import SilenceWarnings
|
from invokeai.backend.util.util import SilenceWarnings
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
@ -16,6 +17,7 @@ from .config import (
|
|||||||
BaseModelType,
|
BaseModelType,
|
||||||
ControlAdapterDefaultSettings,
|
ControlAdapterDefaultSettings,
|
||||||
InvalidModelConfigException,
|
InvalidModelConfigException,
|
||||||
|
MainModelDefaultSettings,
|
||||||
ModelConfigFactory,
|
ModelConfigFactory,
|
||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelRepoVariant,
|
ModelRepoVariant,
|
||||||
@ -24,7 +26,6 @@ from .config import (
|
|||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
)
|
)
|
||||||
from .hash import ModelHash
|
|
||||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||||
|
|
||||||
CkptType = Dict[str, Any]
|
CkptType = Dict[str, Any]
|
||||||
@ -113,9 +114,7 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def probe(
|
def probe(
|
||||||
cls,
|
cls, model_path: Path, fields: Optional[Dict[str, Any]] = None, hash_algo: HASHING_ALGORITHMS = "blake3"
|
||||||
model_path: Path,
|
|
||||||
fields: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Probe the model at model_path and return its configuration record.
|
Probe the model at model_path and return its configuration record.
|
||||||
@ -133,7 +132,8 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||||
model_info = None
|
model_info = None
|
||||||
model_type = None
|
model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
|
||||||
|
if not model_type:
|
||||||
if format_type is ModelFormat.Diffusers:
|
if format_type is ModelFormat.Diffusers:
|
||||||
model_type = cls.get_model_type_from_folder(model_path)
|
model_type = cls.get_model_type_from_folder(model_path)
|
||||||
else:
|
else:
|
||||||
@ -157,16 +157,18 @@ class ModelProbe(object):
|
|||||||
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
|
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
|
||||||
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
||||||
fields["description"] = (
|
fields["description"] = (
|
||||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
|
||||||
)
|
)
|
||||||
fields["format"] = fields.get("format") or probe.get_format()
|
fields["format"] = fields.get("format") or probe.get_format()
|
||||||
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||||
|
|
||||||
fields["default_settings"] = (
|
fields["default_settings"] = fields.get("default_settings")
|
||||||
fields.get("default_settings") or probe.get_default_settings(fields["name"])
|
|
||||||
if isinstance(probe, ControlAdapterProbe)
|
if not fields["default_settings"]:
|
||||||
else None
|
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter}:
|
||||||
)
|
fields["default_settings"] = get_default_settings_controlnet_t2i_adapter(fields["name"])
|
||||||
|
elif fields["type"] is ModelType.Main:
|
||||||
|
fields["default_settings"] = get_default_settings_main(fields["base"])
|
||||||
|
|
||||||
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
||||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||||
@ -318,7 +320,7 @@ class ModelProbe(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType:
|
def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType:
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
||||||
cls._scan_model(model_path.name, model_path)
|
cls._scan_model(model_path.name, model_path)
|
||||||
model = torch.load(model_path)
|
model = torch.load(model_path)
|
||||||
assert isinstance(model, dict)
|
assert isinstance(model, dict)
|
||||||
@ -338,12 +340,8 @@ class ModelProbe(object):
|
|||||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||||
|
|
||||||
|
|
||||||
class ControlAdapterProbe(ProbeBase):
|
# Probing utilities
|
||||||
"""Adds `get_default_settings` for ControlNet and T2IAdapter probes"""
|
MODEL_NAME_TO_PREPROCESSOR = {
|
||||||
|
|
||||||
# TODO(psyche): It would be nice to get these from the invocations, but that creates circular dependencies.
|
|
||||||
# "canny": CannyImageProcessorInvocation.get_type()
|
|
||||||
MODEL_NAME_TO_PREPROCESSOR = {
|
|
||||||
"canny": "canny_image_processor",
|
"canny": "canny_image_processor",
|
||||||
"mlsd": "mlsd_image_processor",
|
"mlsd": "mlsd_image_processor",
|
||||||
"depth": "depth_anything_image_processor",
|
"depth": "depth_anything_image_processor",
|
||||||
@ -360,16 +358,25 @@ class ControlAdapterProbe(ProbeBase):
|
|||||||
"pidi": "pidi_image_processor",
|
"pidi": "pidi_image_processor",
|
||||||
"zoe": "zoe_depth_image_processor",
|
"zoe": "zoe_depth_image_processor",
|
||||||
"color": "color_map_image_processor",
|
"color": "color_map_image_processor",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_default_settings(cls, model_name: str) -> Optional[ControlAdapterDefaultSettings]:
|
def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
|
||||||
for k, v in cls.MODEL_NAME_TO_PREPROCESSOR.items():
|
for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
|
||||||
if k in model_name:
|
if k in model_name:
|
||||||
return ControlAdapterDefaultSettings(preprocessor=v)
|
return ControlAdapterDefaultSettings(preprocessor=v)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
|
||||||
|
if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
|
||||||
|
return MainModelDefaultSettings(width=512, height=512)
|
||||||
|
elif model_base is BaseModelType.StableDiffusionXL:
|
||||||
|
return MainModelDefaultSettings(width=1024, height=1024)
|
||||||
|
# We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models.
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# ##################################################3
|
# ##################################################3
|
||||||
# Checkpoint probing
|
# Checkpoint probing
|
||||||
# ##################################################3
|
# ##################################################3
|
||||||
@ -493,7 +500,7 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
|||||||
raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type")
|
raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type")
|
||||||
|
|
||||||
|
|
||||||
class ControlNetCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe):
|
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||||
"""Class for probing controlnets."""
|
"""Class for probing controlnets."""
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
@ -521,7 +528,7 @@ class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe):
|
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -659,7 +666,7 @@ class ONNXFolderProbe(PipelineFolderProbe):
|
|||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFolderProbe(FolderProbeBase, ControlAdapterProbe):
|
class ControlNetFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
config_file = self.model_path / "config.json"
|
config_file = self.model_path / "config.json"
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
@ -733,7 +740,7 @@ class CLIPVisionFolderProbe(FolderProbeBase):
|
|||||||
return BaseModelType.Any
|
return BaseModelType.Any
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterFolderProbe(FolderProbeBase, ControlAdapterProbe):
|
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
config_file = self.model_path / "config.json"
|
config_file = self.model_path / "config.json"
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
|
@ -5,6 +5,7 @@ from typing import Callable, List, Union
|
|||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||||
|
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
|
|
||||||
|
|
||||||
@ -26,7 +27,7 @@ def _conv_forward_asymmetric(self, input, weight, bias):
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
|
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
|
||||||
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
|
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
|
||||||
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
|
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
|
||||||
try:
|
try:
|
||||||
|
@ -62,40 +62,72 @@ sd-1/main/trinart_stable_diffusion_v2:
|
|||||||
recommended: False
|
recommended: False
|
||||||
sd-1/controlnet/qrcode_monster:
|
sd-1/controlnet/qrcode_monster:
|
||||||
source: monster-labs/control_v1p_sd15_qrcode_monster
|
source: monster-labs/control_v1p_sd15_qrcode_monster
|
||||||
|
description: Controlnet model that generates scannable creative QR codes
|
||||||
subfolder: v2
|
subfolder: v2
|
||||||
sd-1/controlnet/canny:
|
sd-1/controlnet/canny:
|
||||||
|
description: Controlnet weights trained on sd-1.5 with canny conditioning.
|
||||||
source: lllyasviel/control_v11p_sd15_canny
|
source: lllyasviel/control_v11p_sd15_canny
|
||||||
recommended: True
|
recommended: True
|
||||||
sd-1/controlnet/inpaint:
|
sd-1/controlnet/inpaint:
|
||||||
source: lllyasviel/control_v11p_sd15_inpaint
|
source: lllyasviel/control_v11p_sd15_inpaint
|
||||||
|
description: Controlnet weights trained on sd-1.5 with canny conditioning, inpaint version
|
||||||
sd-1/controlnet/mlsd:
|
sd-1/controlnet/mlsd:
|
||||||
|
description: Controlnet weights trained on sd-1.5 with canny conditioning, MLSD version
|
||||||
source: lllyasviel/control_v11p_sd15_mlsd
|
source: lllyasviel/control_v11p_sd15_mlsd
|
||||||
sd-1/controlnet/depth:
|
sd-1/controlnet/depth:
|
||||||
|
description: Controlnet weights trained on sd-1.5 with depth conditioning
|
||||||
source: lllyasviel/control_v11f1p_sd15_depth
|
source: lllyasviel/control_v11f1p_sd15_depth
|
||||||
recommended: True
|
recommended: True
|
||||||
sd-1/controlnet/normal_bae:
|
sd-1/controlnet/normal_bae:
|
||||||
|
description: Controlnet weights trained on sd-1.5 with normalbae image conditioning
|
||||||
source: lllyasviel/control_v11p_sd15_normalbae
|
source: lllyasviel/control_v11p_sd15_normalbae
|
||||||
sd-1/controlnet/seg:
|
sd-1/controlnet/seg:
|
||||||
|
description: Controlnet weights trained on sd-1.5 with seg image conditioning
|
||||||
source: lllyasviel/control_v11p_sd15_seg
|
source: lllyasviel/control_v11p_sd15_seg
|
||||||
sd-1/controlnet/lineart:
|
sd-1/controlnet/lineart:
|
||||||
|
description: Controlnet weights trained on sd-1.5 with lineart image conditioning
|
||||||
source: lllyasviel/control_v11p_sd15_lineart
|
source: lllyasviel/control_v11p_sd15_lineart
|
||||||
recommended: True
|
recommended: True
|
||||||
sd-1/controlnet/lineart_anime:
|
sd-1/controlnet/lineart_anime:
|
||||||
|
description: Controlnet weights trained on sd-1.5 with anime image conditioning
|
||||||
source: lllyasviel/control_v11p_sd15s2_lineart_anime
|
source: lllyasviel/control_v11p_sd15s2_lineart_anime
|
||||||
sd-1/controlnet/openpose:
|
sd-1/controlnet/openpose:
|
||||||
|
description: Controlnet weights trained on sd-1.5 with openpose image conditioning
|
||||||
source: lllyasviel/control_v11p_sd15_openpose
|
source: lllyasviel/control_v11p_sd15_openpose
|
||||||
recommended: True
|
recommended: True
|
||||||
sd-1/controlnet/scribble:
|
sd-1/controlnet/scribble:
|
||||||
source: lllyasviel/control_v11p_sd15_scribble
|
source: lllyasviel/control_v11p_sd15_scribble
|
||||||
|
description: Controlnet weights trained on sd-1.5 with scribble image conditioning
|
||||||
recommended: False
|
recommended: False
|
||||||
sd-1/controlnet/softedge:
|
sd-1/controlnet/softedge:
|
||||||
source: lllyasviel/control_v11p_sd15_softedge
|
source: lllyasviel/control_v11p_sd15_softedge
|
||||||
|
description: Controlnet weights trained on sd-1.5 with soft edge conditioning
|
||||||
sd-1/controlnet/shuffle:
|
sd-1/controlnet/shuffle:
|
||||||
source: lllyasviel/control_v11e_sd15_shuffle
|
source: lllyasviel/control_v11e_sd15_shuffle
|
||||||
|
description: Controlnet weights trained on sd-1.5 with shuffle image conditioning
|
||||||
sd-1/controlnet/tile:
|
sd-1/controlnet/tile:
|
||||||
source: lllyasviel/control_v11f1e_sd15_tile
|
source: lllyasviel/control_v11f1e_sd15_tile
|
||||||
|
description: Controlnet weights trained on sd-1.5 with tiled image conditioning
|
||||||
sd-1/controlnet/ip2p:
|
sd-1/controlnet/ip2p:
|
||||||
source: lllyasviel/control_v11e_sd15_ip2p
|
source: lllyasviel/control_v11e_sd15_ip2p
|
||||||
|
description: Controlnet weights trained on sd-1.5 with ip2p conditioning.
|
||||||
|
sdxl/controlnet/canny-sdxl:
|
||||||
|
description: Controlnet weights trained on sdxl-1.0 with canny conditioning.
|
||||||
|
source: diffusers/controlnet-canny-sdxl-1.0
|
||||||
|
recommended: True
|
||||||
|
sdxl/controlnet/depth-sdxl:
|
||||||
|
description: Controlnet weights trained on sdxl-1.0 with depth conditioning.
|
||||||
|
source: diffusers/controlnet-depth-sdxl-1.0
|
||||||
|
recommended: True
|
||||||
|
sdxl/controlnet/softedge-dexined-sdxl:
|
||||||
|
description: Controlnet weights trained on sdxl-1.0 with dexined soft edge preprocessing.
|
||||||
|
source: SargeZT/controlnet-sd-xl-1.0-softedge-dexined
|
||||||
|
sdxl/controlnet/depth-16bit-zoe-sdxl:
|
||||||
|
description: Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).
|
||||||
|
source: SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe
|
||||||
|
sdxl/controlnet/depth-zoe-sdxl:
|
||||||
|
description: Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).
|
||||||
|
source: diffusers/controlnet-zoe-depth-sdxl-1.0
|
||||||
sd-1/t2i_adapter/canny-sd15:
|
sd-1/t2i_adapter/canny-sd15:
|
||||||
source: TencentARC/t2iadapter_canny_sd15v2
|
source: TencentARC/t2iadapter_canny_sd15v2
|
||||||
sd-1/t2i_adapter/sketch-sd15:
|
sd-1/t2i_adapter/sketch-sd15:
|
||||||
|
@ -608,8 +608,9 @@ def main() -> None:
|
|||||||
config.parse_args(invoke_args)
|
config.parse_args(invoke_args)
|
||||||
logger = InvokeAILogger().get_logger(config=config)
|
logger = InvokeAILogger().get_logger(config=config)
|
||||||
|
|
||||||
if not config.model_conf_path.exists():
|
if not config.models_path.exists():
|
||||||
logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.")
|
logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.")
|
||||||
|
sys.argv = ["invokeai_configure", "--yes", "--skip-sd-weights"]
|
||||||
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
||||||
|
|
||||||
invokeai_configure()
|
invokeai_configure()
|
||||||
|
@ -1,150 +1,3 @@
|
|||||||
# Invoke UI
|
# Invoke UI
|
||||||
|
|
||||||
<!-- @import "[TOC]" {cmd="toc" depthFrom=2 depthTo=3 orderedList=false} -->
|
<https://invoke-ai.github.io/InvokeAI/contributing/frontend/OVERVIEW/>
|
||||||
|
|
||||||
<!-- code_chunk_output -->
|
|
||||||
|
|
||||||
- [Dev environment](#dev-environment)
|
|
||||||
- [Setup](#setup)
|
|
||||||
- [Package scripts](#package-scripts)
|
|
||||||
- [Type generation](#type-generation)
|
|
||||||
- [Localization](#localization)
|
|
||||||
- [VSCode](#vscode)
|
|
||||||
- [Contributing](#contributing)
|
|
||||||
- [Check in before investing your time](#check-in-before-investing-your-time)
|
|
||||||
- [Commit format](#commit-format)
|
|
||||||
- [Submitting a PR](#submitting-a-pr)
|
|
||||||
- [Other docs](#other-docs)
|
|
||||||
|
|
||||||
<!-- /code_chunk_output -->
|
|
||||||
|
|
||||||
Invoke's UI is made possible by many contributors and open-source libraries. Thank you!
|
|
||||||
|
|
||||||
## Dev environment
|
|
||||||
|
|
||||||
### Setup
|
|
||||||
|
|
||||||
1. Install [node] and [pnpm].
|
|
||||||
1. Run `pnpm i` to install all packages.
|
|
||||||
|
|
||||||
#### Run in dev mode
|
|
||||||
|
|
||||||
1. From `invokeai/frontend/web/`, run `pnpm dev`.
|
|
||||||
1. From repo root, run `python scripts/invokeai-web.py`.
|
|
||||||
1. Point your browser to the dev server address, e.g. <http://localhost:5173/>
|
|
||||||
|
|
||||||
### Package scripts
|
|
||||||
|
|
||||||
- `dev`: run the frontend in dev mode, enabling hot reloading
|
|
||||||
- `build`: run all checks (madge, eslint, prettier, tsc) and then build the frontend
|
|
||||||
- `typegen`: generate types from the OpenAPI schema (see [Type generation])
|
|
||||||
- `lint:madge`: check frontend for circular dependencies
|
|
||||||
- `lint:eslint`: check frontend for code quality
|
|
||||||
- `lint:prettier`: check frontend for code formatting
|
|
||||||
- `lint:tsc`: check frontend for type issues
|
|
||||||
- `lint`: run all checks concurrently
|
|
||||||
- `fix`: run `eslint` and `prettier`, fixing fixable issues
|
|
||||||
|
|
||||||
### Type generation
|
|
||||||
|
|
||||||
We use [openapi-typescript] to generate types from the app's OpenAPI schema.
|
|
||||||
|
|
||||||
The generated types are committed to the repo in [schema.ts].
|
|
||||||
|
|
||||||
```sh
|
|
||||||
# from the repo root, start the server
|
|
||||||
python scripts/invokeai-web.py
|
|
||||||
# from invokeai/frontend/web/, run the script
|
|
||||||
pnpm typegen
|
|
||||||
```
|
|
||||||
|
|
||||||
### Localization
|
|
||||||
|
|
||||||
We use [i18next] for localization, but translation to languages other than English happens on our [Weblate] project.
|
|
||||||
|
|
||||||
Only the English source strings should be changed on this repo.
|
|
||||||
|
|
||||||
### VSCode
|
|
||||||
|
|
||||||
#### Example debugger config
|
|
||||||
|
|
||||||
```jsonc
|
|
||||||
{
|
|
||||||
"version": "0.2.0",
|
|
||||||
"configurations": [
|
|
||||||
{
|
|
||||||
"type": "chrome",
|
|
||||||
"request": "launch",
|
|
||||||
"name": "Invoke UI",
|
|
||||||
"url": "http://localhost:5173",
|
|
||||||
"webRoot": "${workspaceFolder}/invokeai/frontend/web",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Remote dev
|
|
||||||
|
|
||||||
We've noticed an intermittent timeout issue with the VSCode remote dev port forwarding.
|
|
||||||
|
|
||||||
We suggest disabling the editor's port forwarding feature and doing it manually via SSH:
|
|
||||||
|
|
||||||
```sh
|
|
||||||
ssh -L 9090:localhost:9090 -L 5173:localhost:5173 user@host
|
|
||||||
```
|
|
||||||
|
|
||||||
## Contributing Guidelines
|
|
||||||
|
|
||||||
Thanks for your interest in contributing to the Invoke Web UI!
|
|
||||||
|
|
||||||
Please follow these guidelines when contributing.
|
|
||||||
|
|
||||||
### Check in before investing your time
|
|
||||||
|
|
||||||
Please check in before you invest your time on anything besides a trivial fix, in case it conflicts with ongoing work or isn't aligned with the vision for the app.
|
|
||||||
|
|
||||||
If a feature request or issue doesn't already exist for the thing you want to work on, please create one.
|
|
||||||
|
|
||||||
Ping `@psychedelicious` on [discord] in the `#frontend-dev` channel or in the feature request / issue you want to work on - we're happy chat.
|
|
||||||
|
|
||||||
### Code conventions
|
|
||||||
|
|
||||||
- This is a fairly complex app with a deep component tree. Please use memoization (`useCallback`, `useMemo`, `memo`) with enthusiasm.
|
|
||||||
- If you need to add some global, ephemeral state, please use [nanostores] if possible.
|
|
||||||
- Be careful with your redux selectors. If they need to be parameterized, consider creating them inside a `useMemo`.
|
|
||||||
- Feel free to use `lodash` (via `lodash-es`) to make the intent of your code clear.
|
|
||||||
- Please add comments describing the "why", not the "how" (unless it is really arcane).
|
|
||||||
|
|
||||||
### Commit format
|
|
||||||
|
|
||||||
Please use the [conventional commits] spec for the web UI, with a scope of "ui":
|
|
||||||
|
|
||||||
- `chore(ui): bump deps`
|
|
||||||
- `chore(ui): lint`
|
|
||||||
- `feat(ui): add some cool new feature`
|
|
||||||
- `fix(ui): fix some bug`
|
|
||||||
|
|
||||||
### Submitting a PR
|
|
||||||
|
|
||||||
- Ensure your branch is tidy. Use an interactive rebase to clean up the commit history and reword the commit messages if they are not descriptive.
|
|
||||||
- Run `pnpm lint`. Some issues are auto-fixable with `pnpm fix`.
|
|
||||||
- Fill out the PR form when creating the PR.
|
|
||||||
- It doesn't need to be super detailed, but a screenshot or video is nice if you changed something visually.
|
|
||||||
- If a section isn't relevant, delete it. There are no UI tests at this time.
|
|
||||||
|
|
||||||
## Other docs
|
|
||||||
|
|
||||||
- [Workflows - Design and Implementation]
|
|
||||||
- [State Management]
|
|
||||||
|
|
||||||
[node]: https://nodejs.org/en/download/
|
|
||||||
[pnpm]: https://github.com/pnpm/pnpm
|
|
||||||
[discord]: https://discord.gg/ZmtBAhwWhy
|
|
||||||
[i18next]: https://github.com/i18next/react-i18next
|
|
||||||
[Weblate]: https://hosted.weblate.org/engage/invokeai/
|
|
||||||
[openapi-typescript]: https://github.com/drwpow/openapi-typescript
|
|
||||||
[Type generation]: #type-generation
|
|
||||||
[schema.ts]: ../src/services/api/schema.ts
|
|
||||||
[conventional commits]: https://www.conventionalcommits.org/en/v1.0.0/
|
|
||||||
[Workflows - Design and Implementation]: ./docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md
|
|
||||||
[State Management]: ./docs/STATE_MGMT.md
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
88
invokeai/frontend/web/scripts/clean_translations.py
Normal file
88
invokeai/frontend/web/scripts/clean_translations.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
# Cleans translations by removing unused keys
|
||||||
|
# Usage: python clean_translations.py
|
||||||
|
# Note: Must be run from invokeai/frontend/web/scripts directory
|
||||||
|
#
|
||||||
|
# After running the script, open `en.json` and check for empty objects (`{}`) and remove them manually.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import TypeAlias, Union
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
RecursiveDict: TypeAlias = dict[str, Union["RecursiveDict", str]]
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationCleaner:
|
||||||
|
file_cache: dict[str, str] = {}
|
||||||
|
|
||||||
|
def _get_keys(self, obj: RecursiveDict, current_path: str = "", keys: list[str] | None = None):
|
||||||
|
if keys is None:
|
||||||
|
keys = []
|
||||||
|
for key in obj:
|
||||||
|
new_path = f"{current_path}.{key}" if current_path else key
|
||||||
|
next_ = obj[key]
|
||||||
|
if isinstance(next_, dict):
|
||||||
|
self._get_keys(next_, new_path, keys)
|
||||||
|
elif "_" in key:
|
||||||
|
# This typically means its a pluralized key
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
keys.append(new_path)
|
||||||
|
return keys
|
||||||
|
|
||||||
|
def _search_codebase(self, key: str):
|
||||||
|
for root, _dirs, files in os.walk("../src"):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith(".ts") or file.endswith(".tsx"):
|
||||||
|
full_path = os.path.join(root, file)
|
||||||
|
if full_path in self.file_cache:
|
||||||
|
content = self.file_cache[full_path]
|
||||||
|
else:
|
||||||
|
with open(full_path, "r") as f:
|
||||||
|
content = f.read()
|
||||||
|
self.file_cache[full_path] = content
|
||||||
|
|
||||||
|
# match the whole key, surrounding by quotes
|
||||||
|
if re.search(r"['\"`]" + re.escape(key) + r"['\"`]", self.file_cache[full_path]):
|
||||||
|
return True
|
||||||
|
# math the stem of the key, with quotes at the end
|
||||||
|
if re.search(re.escape(key.split(".")[-1]) + r"['\"`]", self.file_cache[full_path]):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _remove_key(self, obj: RecursiveDict, key: str):
|
||||||
|
path = key.split(".")
|
||||||
|
last_key = path[-1]
|
||||||
|
for k in path[:-1]:
|
||||||
|
obj = obj[k]
|
||||||
|
del obj[last_key]
|
||||||
|
|
||||||
|
def clean(self, obj: RecursiveDict) -> RecursiveDict:
|
||||||
|
keys = self._get_keys(obj)
|
||||||
|
pbar = tqdm(keys, desc="Checking keys")
|
||||||
|
for key in pbar:
|
||||||
|
if not self._search_codebase(key):
|
||||||
|
self._remove_key(obj, key)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
try:
|
||||||
|
with open("../public/locales/en.json", "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"Unable to find en.json file - must be run from invokeai/frontend/web/scripts directory"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
cleaner = TranslationCleaner()
|
||||||
|
cleaned_data = cleaner.clean(data)
|
||||||
|
|
||||||
|
with open("../public/locales/en.json", "w") as f:
|
||||||
|
json.dump(cleaned_data, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -1,10 +1,10 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
|
import type { AppDispatch, RootState } from 'app/store/store';
|
||||||
|
import type { JSONObject } from 'common/types';
|
||||||
import {
|
import {
|
||||||
controlAdapterModelCleared,
|
controlAdapterModelCleared,
|
||||||
selectAllControlNets,
|
selectControlAdapterAll,
|
||||||
selectAllIPAdapters,
|
|
||||||
selectAllT2IAdapters,
|
|
||||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||||
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
|
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
|
||||||
@ -12,39 +12,56 @@ import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features
|
|||||||
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||||
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { forEach, some } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models';
|
import type { Logger } from 'roarr';
|
||||||
import type { TypeGuardFor } from 'services/api/types';
|
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
import { isNonRefinerMainModelConfig, isRefinerMainModelModelConfig, isVAEModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
|
export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
predicate: (action): action is TypeGuardFor<typeof modelsApi.endpoints.getMainModels.matchFulfilled> =>
|
predicate: modelsApi.endpoints.getModelConfigs.matchFulfilled,
|
||||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
|
||||||
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||||
const log = logger('models');
|
const log = logger('models');
|
||||||
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`);
|
log.info({ models: action.payload.entities }, `Models loaded (${action.payload.ids.length})`);
|
||||||
|
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
const currentModel = state.generation.model;
|
const models = modelConfigsAdapterSelectors.selectAll(action.payload);
|
||||||
const models = mainModelsAdapterSelectors.selectAll(action.payload);
|
|
||||||
|
|
||||||
if (models.length === 0) {
|
handleMainModels(models, state, dispatch, log);
|
||||||
|
handleRefinerModels(models, state, dispatch, log);
|
||||||
|
handleVAEModels(models, state, dispatch, log);
|
||||||
|
handleLoRAModels(models, state, dispatch, log);
|
||||||
|
handleControlAdapterModels(models, state, dispatch, log);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
type ModelHandler = (
|
||||||
|
models: AnyModelConfig[],
|
||||||
|
state: RootState,
|
||||||
|
dispatch: AppDispatch,
|
||||||
|
log: Logger<JSONObject>
|
||||||
|
) => undefined;
|
||||||
|
|
||||||
|
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
|
||||||
|
const currentModel = state.generation.model;
|
||||||
|
const mainModels = models.filter(isNonRefinerMainModelConfig);
|
||||||
|
if (mainModels.length === 0) {
|
||||||
// No models loaded at all
|
// No models loaded at all
|
||||||
dispatch(modelChanged(null));
|
dispatch(modelChanged(null));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
|
const isCurrentMainModelAvailable = currentModel ? mainModels.some((m) => m.key === currentModel.key) : false;
|
||||||
|
if (isCurrentMainModelAvailable) {
|
||||||
if (isCurrentModelAvailable) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultModel = state.config.sd.defaultModel;
|
const defaultModel = state.config.sd.defaultModel;
|
||||||
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
|
const defaultModelInList = defaultModel ? mainModels.find((m) => m.key === defaultModel) : false;
|
||||||
|
|
||||||
if (defaultModelInList) {
|
if (defaultModelInList) {
|
||||||
const result = zParameterModel.safeParse(defaultModelInList);
|
const result = zParameterModel.safeParse(defaultModelInList);
|
||||||
@ -66,7 +83,7 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = zParameterModel.safeParse(models[0]);
|
const result = zParameterModel.safeParse(mainModels[0]);
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
log.error({ error: result.error.format() }, 'Failed to parse main model');
|
log.error({ error: result.error.format() }, 'Failed to parse main model');
|
||||||
@ -74,54 +91,43 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
|||||||
}
|
}
|
||||||
|
|
||||||
dispatch(modelChanged(result.data, currentModel));
|
dispatch(modelChanged(result.data, currentModel));
|
||||||
},
|
};
|
||||||
});
|
|
||||||
startAppListening({
|
|
||||||
predicate: (action): action is TypeGuardFor<typeof modelsApi.endpoints.getMainModels.matchFulfilled> =>
|
|
||||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) && action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
|
||||||
effect: async (action, { getState, dispatch }) => {
|
|
||||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
|
||||||
const log = logger('models');
|
|
||||||
log.info({ models: action.payload.entities }, `SDXL Refiner models loaded (${action.payload.ids.length})`);
|
|
||||||
|
|
||||||
const currentModel = getState().sdxl.refinerModel;
|
|
||||||
const models = mainModelsAdapterSelectors.selectAll(action.payload);
|
|
||||||
|
|
||||||
|
const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||||
|
const currentRefinerModel = state.sdxl.refinerModel;
|
||||||
|
const refinerModels = models.filter(isRefinerMainModelModelConfig);
|
||||||
if (models.length === 0) {
|
if (models.length === 0) {
|
||||||
// No models loaded at all
|
// No models loaded at all
|
||||||
dispatch(refinerModelChanged(null));
|
dispatch(refinerModelChanged(null));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
|
const isCurrentRefinerModelAvailable = currentRefinerModel
|
||||||
|
? refinerModels.some((m) => m.key === currentRefinerModel.key)
|
||||||
|
: false;
|
||||||
|
|
||||||
if (!isCurrentModelAvailable) {
|
if (!isCurrentRefinerModelAvailable) {
|
||||||
dispatch(refinerModelChanged(null));
|
dispatch(refinerModelChanged(null));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
},
|
};
|
||||||
});
|
|
||||||
startAppListening({
|
|
||||||
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
|
|
||||||
effect: async (action, { getState, dispatch }) => {
|
|
||||||
// VAEs loaded, need to reset the VAE is it's no longer available
|
|
||||||
const log = logger('models');
|
|
||||||
log.info({ models: action.payload.entities }, `VAEs loaded (${action.payload.ids.length})`);
|
|
||||||
|
|
||||||
const currentVae = getState().generation.vae;
|
const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
|
||||||
|
const currentVae = state.generation.vae;
|
||||||
|
|
||||||
if (currentVae === null) {
|
if (currentVae === null) {
|
||||||
// null is a valid VAE! it means "use the default with the main model"
|
// null is a valid VAE! it means "use the default with the main model"
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
const vaeModels = models.filter(isVAEModelConfig);
|
||||||
|
|
||||||
const isCurrentVAEAvailable = some(action.payload.entities, (m) => m?.key === currentVae?.key);
|
const isCurrentVAEAvailable = vaeModels.some((m) => m.key === currentVae.key);
|
||||||
|
|
||||||
if (isCurrentVAEAvailable) {
|
if (isCurrentVAEAvailable) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const firstModel = vaeModelsAdapterSelectors.selectAll(action.payload)[0];
|
const firstModel = vaeModels[0];
|
||||||
|
|
||||||
if (!firstModel) {
|
if (!firstModel) {
|
||||||
// No custom VAEs loaded at all; use the default
|
// No custom VAEs loaded at all; use the default
|
||||||
@ -137,19 +143,13 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
|||||||
}
|
}
|
||||||
|
|
||||||
dispatch(vaeSelected(result.data));
|
dispatch(vaeSelected(result.data));
|
||||||
},
|
};
|
||||||
});
|
|
||||||
startAppListening({
|
|
||||||
matcher: modelsApi.endpoints.getLoRAModels.matchFulfilled,
|
|
||||||
effect: async (action, { getState, dispatch }) => {
|
|
||||||
// LoRA models loaded - need to remove missing LoRAs from state
|
|
||||||
const log = logger('models');
|
|
||||||
log.info({ models: action.payload.entities }, `LoRAs loaded (${action.payload.ids.length})`);
|
|
||||||
|
|
||||||
const loras = getState().lora.loras;
|
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||||
|
const loras = state.lora.loras;
|
||||||
|
|
||||||
forEach(loras, (lora, id) => {
|
forEach(loras, (lora, id) => {
|
||||||
const isLoRAAvailable = some(action.payload.entities, (m) => m?.key === lora?.model.key);
|
const isLoRAAvailable = models.some((m) => m.key === lora.model.key);
|
||||||
|
|
||||||
if (isLoRAAvailable) {
|
if (isLoRAAvailable) {
|
||||||
return;
|
return;
|
||||||
@ -157,17 +157,11 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
|||||||
|
|
||||||
dispatch(loraRemoved(id));
|
dispatch(loraRemoved(id));
|
||||||
});
|
});
|
||||||
},
|
};
|
||||||
});
|
|
||||||
startAppListening({
|
|
||||||
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
|
|
||||||
effect: async (action, { getState, dispatch }) => {
|
|
||||||
// ControlNet models loaded - need to remove missing ControlNets from state
|
|
||||||
const log = logger('models');
|
|
||||||
log.info({ models: action.payload.entities }, `ControlNet models loaded (${action.payload.ids.length})`);
|
|
||||||
|
|
||||||
selectAllControlNets(getState().controlAdapters).forEach((ca) => {
|
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||||
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
|
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
|
||||||
|
const isModelAvailable = models.some((m) => m.key === ca.model?.key);
|
||||||
|
|
||||||
if (isModelAvailable) {
|
if (isModelAvailable) {
|
||||||
return;
|
return;
|
||||||
@ -175,49 +169,4 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
|||||||
|
|
||||||
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
||||||
});
|
});
|
||||||
},
|
|
||||||
});
|
|
||||||
startAppListening({
|
|
||||||
matcher: modelsApi.endpoints.getT2IAdapterModels.matchFulfilled,
|
|
||||||
effect: async (action, { getState, dispatch }) => {
|
|
||||||
// ControlNet models loaded - need to remove missing ControlNets from state
|
|
||||||
const log = logger('models');
|
|
||||||
log.info({ models: action.payload.entities }, `T2I Adapter models loaded (${action.payload.ids.length})`);
|
|
||||||
|
|
||||||
selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => {
|
|
||||||
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
|
|
||||||
|
|
||||||
if (isModelAvailable) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
startAppListening({
|
|
||||||
matcher: modelsApi.endpoints.getIPAdapterModels.matchFulfilled,
|
|
||||||
effect: async (action, { getState, dispatch }) => {
|
|
||||||
// ControlNet models loaded - need to remove missing ControlNets from state
|
|
||||||
const log = logger('models');
|
|
||||||
log.info({ models: action.payload.entities }, `IP Adapter models loaded (${action.payload.ids.length})`);
|
|
||||||
|
|
||||||
selectAllIPAdapters(getState().controlAdapters).forEach((ca) => {
|
|
||||||
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
|
|
||||||
|
|
||||||
if (isModelAvailable) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
startAppListening({
|
|
||||||
matcher: modelsApi.endpoints.getTextualInversionModels.matchFulfilled,
|
|
||||||
effect: async (action) => {
|
|
||||||
const log = logger('models');
|
|
||||||
log.info({ models: action.payload.entities }, `Embeddings loaded (${action.payload.ids.length})`);
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
|
@ -1,26 +1,29 @@
|
|||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||||
import {
|
import {
|
||||||
|
heightChanged,
|
||||||
setCfgRescaleMultiplier,
|
setCfgRescaleMultiplier,
|
||||||
setCfgScale,
|
setCfgScale,
|
||||||
setScheduler,
|
setScheduler,
|
||||||
setSteps,
|
setSteps,
|
||||||
vaePrecisionChanged,
|
vaePrecisionChanged,
|
||||||
vaeSelected,
|
vaeSelected,
|
||||||
|
widthChanged,
|
||||||
} from 'features/parameters/store/generationSlice';
|
} from 'features/parameters/store/generationSlice';
|
||||||
import {
|
import {
|
||||||
isParameterCFGRescaleMultiplier,
|
isParameterCFGRescaleMultiplier,
|
||||||
isParameterCFGScale,
|
isParameterCFGScale,
|
||||||
|
isParameterHeight,
|
||||||
isParameterPrecision,
|
isParameterPrecision,
|
||||||
isParameterScheduler,
|
isParameterScheduler,
|
||||||
isParameterSteps,
|
isParameterSteps,
|
||||||
|
isParameterWidth,
|
||||||
zParameterVAEModel,
|
zParameterVAEModel,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { map } from 'lodash-es';
|
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||||
import { modelsApi } from 'services/api/endpoints/models';
|
|
||||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
|
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
|
||||||
@ -35,14 +38,19 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
|
const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate());
|
||||||
|
const data = await request.unwrap();
|
||||||
|
request.unsubscribe();
|
||||||
|
const models = modelConfigsAdapterSelectors.selectAll(data);
|
||||||
|
|
||||||
|
const modelConfig = models.find((model) => model.key === currentModel.key);
|
||||||
|
|
||||||
if (!modelConfig) {
|
if (!modelConfig) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isNonRefinerMainModelConfig(modelConfig) && modelConfig.default_settings) {
|
if (isNonRefinerMainModelConfig(modelConfig) && modelConfig.default_settings) {
|
||||||
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } =
|
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler, width, height } =
|
||||||
modelConfig.default_settings;
|
modelConfig.default_settings;
|
||||||
|
|
||||||
if (vae) {
|
if (vae) {
|
||||||
@ -51,11 +59,8 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
|||||||
if (vae === 'default') {
|
if (vae === 'default') {
|
||||||
dispatch(vaeSelected(null));
|
dispatch(vaeSelected(null));
|
||||||
} else {
|
} else {
|
||||||
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
|
const vaeModel = models.find((model) => model.key === vae);
|
||||||
const vaeArray = map(data?.entities);
|
const result = zParameterVAEModel.safeParse(vaeModel);
|
||||||
const validVae = vaeArray.find((model) => model.key === vae);
|
|
||||||
|
|
||||||
const result = zParameterVAEModel.safeParse(validVae);
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -93,6 +98,18 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (width) {
|
||||||
|
if (isParameterWidth(width)) {
|
||||||
|
dispatch(widthChanged(width));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (height) {
|
||||||
|
if (isParameterHeight(height)) {
|
||||||
|
dispatch(heightChanged(height));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
|
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -4,6 +4,7 @@ import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
|||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
import { atom } from 'nanostores';
|
import { atom } from 'nanostores';
|
||||||
import { api } from 'services/api';
|
import { api } from 'services/api';
|
||||||
|
import { modelsApi } from 'services/api/endpoints/models';
|
||||||
import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
|
import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
|
||||||
import { socketConnected } from 'services/events/actions';
|
import { socketConnected } from 'services/events/actions';
|
||||||
|
|
||||||
@ -29,6 +30,11 @@ export const addSocketConnectedEventListener = (startAppListening: AppStartListe
|
|||||||
|
|
||||||
// Bail on the recovery logic if this is the first connection - we don't need to recover anything
|
// Bail on the recovery logic if this is the first connection - we don't need to recover anything
|
||||||
if ($isFirstConnection.get()) {
|
if ($isFirstConnection.get()) {
|
||||||
|
// Populate the model configs on first connection. This query cache has a 24hr timeout, so we can immediately
|
||||||
|
// unsubscribe.
|
||||||
|
const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate());
|
||||||
|
request.unsubscribe();
|
||||||
|
|
||||||
$isFirstConnection.set(false);
|
$isFirstConnection.set(false);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
|
|||||||
import { api } from 'services/api';
|
import { api } from 'services/api';
|
||||||
import { modelsApi } from 'services/api/endpoints/models';
|
import { modelsApi } from 'services/api/endpoints/models';
|
||||||
import {
|
import {
|
||||||
|
socketModelInstallCancelled,
|
||||||
socketModelInstallCompleted,
|
socketModelInstallCompleted,
|
||||||
socketModelInstallDownloading,
|
socketModelInstallDownloading,
|
||||||
socketModelInstallError,
|
socketModelInstallError,
|
||||||
@ -63,4 +64,21 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
|||||||
);
|
);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketModelInstallCancelled,
|
||||||
|
effect: (action, { dispatch }) => {
|
||||||
|
const { id } = action.payload.data;
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
|
if (modelImport) {
|
||||||
|
modelImport.status = 'cancelled';
|
||||||
|
}
|
||||||
|
return draft;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
};
|
};
|
||||||
|
@ -8,14 +8,16 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: socketModelLoadStarted,
|
actionCreator: socketModelLoadStarted,
|
||||||
effect: (action) => {
|
effect: (action) => {
|
||||||
const { base_model, model_name, model_type, submodel } = action.payload.data;
|
const { model_config, submodel_type } = action.payload.data;
|
||||||
|
const { name, base, type } = model_config;
|
||||||
|
|
||||||
let message = `Model load started: ${base_model}/${model_type}/${model_name}`;
|
const extras: string[] = [base, type];
|
||||||
|
if (submodel_type) {
|
||||||
if (submodel) {
|
extras.push(submodel_type);
|
||||||
message = message.concat(`/${submodel}`);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const message = `Model load started: ${name} (${extras.join(', ')})`;
|
||||||
|
|
||||||
log.debug(action.payload, message);
|
log.debug(action.payload, message);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
@ -23,14 +25,16 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: socketModelLoadCompleted,
|
actionCreator: socketModelLoadCompleted,
|
||||||
effect: (action) => {
|
effect: (action) => {
|
||||||
const { base_model, model_name, model_type, submodel } = action.payload.data;
|
const { model_config, submodel_type } = action.payload.data;
|
||||||
|
const { name, base, type } = model_config;
|
||||||
|
|
||||||
let message = `Model load complete: ${base_model}/${model_type}/${model_name}`;
|
const extras: string[] = [base, type];
|
||||||
|
if (submodel_type) {
|
||||||
if (submodel) {
|
extras.push(submodel_type);
|
||||||
message = message.concat(`/${submodel}`);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const message = `Model load complete: ${name} (${extras.join(', ')})`;
|
||||||
|
|
||||||
log.debug(action.payload, message);
|
log.debug(action.payload, message);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -1,15 +1,14 @@
|
|||||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
import type { EntityState } from '@reduxjs/toolkit';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import type { GroupBase } from 'chakra-react-select';
|
import type { GroupBase } from 'chakra-react-select';
|
||||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import { groupBy, map, reduce } from 'lodash-es';
|
import { groupBy, reduce } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
||||||
modelEntities: EntityState<T, string> | undefined;
|
modelConfigs: T[];
|
||||||
selectedModel?: ModelIdentifierField | null;
|
selectedModel?: ModelIdentifierField | null;
|
||||||
onChange: (value: T | null) => void;
|
onChange: (value: T | null) => void;
|
||||||
getIsDisabled?: (model: T) => boolean;
|
getIsDisabled?: (model: T) => boolean;
|
||||||
@ -29,13 +28,12 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
|||||||
): UseGroupedModelComboboxReturn => {
|
): UseGroupedModelComboboxReturn => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
|
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
|
||||||
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg;
|
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading } = arg;
|
||||||
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
|
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
|
||||||
if (!modelEntities) {
|
if (!modelConfigs) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
const modelEntitiesArray = map(modelEntities.entities);
|
const groupedModels = groupBy(modelConfigs, 'base');
|
||||||
const groupedModels = groupBy(modelEntitiesArray, 'base');
|
|
||||||
const _options = reduce(
|
const _options = reduce(
|
||||||
groupedModels,
|
groupedModels,
|
||||||
(acc, val, label) => {
|
(acc, val, label) => {
|
||||||
@ -53,7 +51,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
|||||||
);
|
);
|
||||||
_options.sort((a) => (a.label === base_model ? -1 : 1));
|
_options.sort((a) => (a.label === base_model ? -1 : 1));
|
||||||
return _options;
|
return _options;
|
||||||
}, [getIsDisabled, modelEntities, base_model]);
|
}, [getIsDisabled, modelConfigs, base_model]);
|
||||||
|
|
||||||
const value = useMemo(
|
const value = useMemo(
|
||||||
() =>
|
() =>
|
||||||
@ -67,14 +65,14 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
|
|||||||
onChange(null);
|
onChange(null);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const model = modelEntities?.entities[v.value];
|
const model = modelConfigs.find((m) => m.key === v.value);
|
||||||
if (!model) {
|
if (!model) {
|
||||||
onChange(null);
|
onChange(null);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
onChange(model);
|
onChange(model);
|
||||||
},
|
},
|
||||||
[modelEntities?.entities, onChange]
|
[modelConfigs, onChange]
|
||||||
);
|
);
|
||||||
|
|
||||||
const placeholder = useMemo(() => {
|
const placeholder = useMemo(() => {
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
import type { EntityState } from '@reduxjs/toolkit';
|
|
||||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import { map } from 'lodash-es';
|
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
type UseModelComboboxArg<T extends AnyModelConfig> = {
|
type UseModelComboboxArg<T extends AnyModelConfig> = {
|
||||||
modelEntities: EntityState<T, string> | undefined;
|
modelConfigs: T[];
|
||||||
selectedModel?: ModelIdentifierField | null;
|
selectedModel?: ModelIdentifierField | null;
|
||||||
onChange: (value: T | null) => void;
|
onChange: (value: T | null) => void;
|
||||||
getIsDisabled?: (model: T) => boolean;
|
getIsDisabled?: (model: T) => boolean;
|
||||||
@ -25,19 +23,14 @@ type UseModelComboboxReturn = {
|
|||||||
|
|
||||||
export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelComboboxArg<T>): UseModelComboboxReturn => {
|
export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelComboboxArg<T>): UseModelComboboxReturn => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
|
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
|
||||||
const options = useMemo<ComboboxOption[]>(() => {
|
const options = useMemo<ComboboxOption[]>(() => {
|
||||||
if (!modelEntities) {
|
return modelConfigs.filter(optionsFilter).map((model) => ({
|
||||||
return [];
|
|
||||||
}
|
|
||||||
return map(modelEntities.entities)
|
|
||||||
.filter(optionsFilter)
|
|
||||||
.map((model) => ({
|
|
||||||
label: model.name,
|
label: model.name,
|
||||||
value: model.key,
|
value: model.key,
|
||||||
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
|
||||||
}));
|
}));
|
||||||
}, [optionsFilter, getIsDisabled, modelEntities]);
|
}, [optionsFilter, getIsDisabled, modelConfigs]);
|
||||||
|
|
||||||
const value = useMemo(
|
const value = useMemo(
|
||||||
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
|
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
|
||||||
@ -50,14 +43,14 @@ export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelCombobox
|
|||||||
onChange(null);
|
onChange(null);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const model = modelEntities?.entities[v.value];
|
const model = modelConfigs.find((m) => m.key === v.value);
|
||||||
if (!model) {
|
if (!model) {
|
||||||
onChange(null);
|
onChange(null);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
onChange(model);
|
onChange(model);
|
||||||
},
|
},
|
||||||
[modelEntities?.entities, onChange]
|
[modelConfigs, onChange]
|
||||||
);
|
);
|
||||||
|
|
||||||
const placeholder = useMemo(() => {
|
const placeholder = useMemo(() => {
|
||||||
|
@ -1,15 +1,12 @@
|
|||||||
import type { Item } from '@invoke-ai/ui-library';
|
import type { Item } from '@invoke-ai/ui-library';
|
||||||
import type { EntityState } from '@reduxjs/toolkit';
|
|
||||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
|
||||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||||
import { filter } from 'lodash-es';
|
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
|
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
|
||||||
data: EntityState<T, string> | undefined;
|
modelConfigs: T[];
|
||||||
isLoading: boolean;
|
isLoading: boolean;
|
||||||
selectedModel?: ModelIdentifierField | null;
|
selectedModel?: ModelIdentifierField | null;
|
||||||
onChange: (value: T | null) => void;
|
onChange: (value: T | null) => void;
|
||||||
@ -28,7 +25,7 @@ const modelFilterDefault = () => true;
|
|||||||
const isModelDisabledDefault = () => false;
|
const isModelDisabledDefault = () => false;
|
||||||
|
|
||||||
export const useModelCustomSelect = <T extends AnyModelConfig>({
|
export const useModelCustomSelect = <T extends AnyModelConfig>({
|
||||||
data,
|
modelConfigs,
|
||||||
isLoading,
|
isLoading,
|
||||||
selectedModel,
|
selectedModel,
|
||||||
onChange,
|
onChange,
|
||||||
@ -39,30 +36,28 @@ export const useModelCustomSelect = <T extends AnyModelConfig>({
|
|||||||
|
|
||||||
const items: Item[] = useMemo(
|
const items: Item[] = useMemo(
|
||||||
() =>
|
() =>
|
||||||
data
|
modelConfigs.filter(modelFilter).map<Item>((m) => ({
|
||||||
? filter(data.entities, modelFilter).map<Item>((m) => ({
|
|
||||||
label: m.name,
|
label: m.name,
|
||||||
value: m.key,
|
value: m.key,
|
||||||
description: m.description,
|
description: m.description,
|
||||||
group: MODEL_TYPE_SHORT_MAP[m.base],
|
group: MODEL_TYPE_SHORT_MAP[m.base],
|
||||||
isDisabled: isModelDisabled(m),
|
isDisabled: isModelDisabled(m),
|
||||||
}))
|
})),
|
||||||
: EMPTY_ARRAY,
|
[modelConfigs, isModelDisabled, modelFilter]
|
||||||
[data, isModelDisabled, modelFilter]
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(item: Item | null) => {
|
(item: Item | null) => {
|
||||||
if (!item || !data) {
|
if (!item || !modelConfigs) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const model = data.entities[item.value];
|
const model = modelConfigs.find((m) => m.key === item.value);
|
||||||
if (!model) {
|
if (!model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
onChange(model);
|
onChange(model);
|
||||||
},
|
},
|
||||||
[data, onChange]
|
[modelConfigs, onChange]
|
||||||
);
|
);
|
||||||
|
|
||||||
const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]);
|
const selectedItem = useMemo(() => items.find((o) => o.value === selectedModel?.key) ?? null, [selectedModel, items]);
|
||||||
|
@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
|
import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect';
|
||||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||||
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
||||||
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
|
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
|
||||||
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
||||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
@ -20,7 +20,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||||
|
|
||||||
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
|
const [modelConfigs, { isLoading }] = useControlAdapterModels(controlAdapterType);
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
|
(modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
|
||||||
@ -43,7 +43,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
||||||
data,
|
modelConfigs,
|
||||||
isLoading,
|
isLoading,
|
||||||
selectedModel,
|
selectedModel,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
@ -52,7 +52,13 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<FormControl isDisabled={!items.length || !isEnabled} isInvalid={!selectedItem || !items.length}>
|
<FormControl isDisabled={!items.length || !isEnabled} isInvalid={!selectedItem || !items.length}>
|
||||||
<CustomSelect selectedItem={selectedItem} placeholder={placeholder} items={items} onChange={onChange} />
|
<CustomSelect
|
||||||
|
key={items.length}
|
||||||
|
selectedItem={selectedItem}
|
||||||
|
placeholder={placeholder}
|
||||||
|
items={items}
|
||||||
|
onChange={onChange}
|
||||||
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,17 +1,16 @@
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
|
||||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||||
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { type ControlAdapterType, isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
import { type ControlAdapterType, isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import { useControlAdapterModels } from './useControlAdapterModels';
|
|
||||||
|
|
||||||
export const useAddControlAdapter = (type: ControlAdapterType) => {
|
export const useAddControlAdapter = (type: ControlAdapterType) => {
|
||||||
const baseModel = useAppSelector((s) => s.generation.model?.base);
|
const baseModel = useAppSelector((s) => s.generation.model?.base);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const models = useControlAdapterModels(type);
|
const [models] = useControlAdapterModels(type);
|
||||||
|
|
||||||
const firstModel: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig | undefined = useMemo(() => {
|
const firstModel: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig | undefined = useMemo(() => {
|
||||||
// prefer to use a model that matches the base model
|
// prefer to use a model that matches the base model
|
||||||
|
@ -1,26 +0,0 @@
|
|||||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
|
||||||
import {
|
|
||||||
useGetControlNetModelsQuery,
|
|
||||||
useGetIPAdapterModelsQuery,
|
|
||||||
useGetT2IAdapterModelsQuery,
|
|
||||||
} from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
export const useControlAdapterModelQuery = (type: ControlAdapterType) => {
|
|
||||||
const controlNetModelsQuery = useGetControlNetModelsQuery();
|
|
||||||
const t2iAdapterModelsQuery = useGetT2IAdapterModelsQuery();
|
|
||||||
const ipAdapterModelsQuery = useGetIPAdapterModelsQuery();
|
|
||||||
|
|
||||||
if (type === 'controlnet') {
|
|
||||||
return controlNetModelsQuery;
|
|
||||||
}
|
|
||||||
if (type === 't2i_adapter') {
|
|
||||||
return t2iAdapterModelsQuery;
|
|
||||||
}
|
|
||||||
if (type === 'ip_adapter') {
|
|
||||||
return ipAdapterModelsQuery;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert that the end of the function is not reachable.
|
|
||||||
const exhaustiveCheck: never = type;
|
|
||||||
return exhaustiveCheck;
|
|
||||||
};
|
|
@ -1,31 +1,10 @@
|
|||||||
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||||
import { useMemo } from 'react';
|
import { useControlNetModels, useIPAdapterModels, useT2IAdapterModels } from 'services/api/hooks/modelsByType';
|
||||||
import {
|
|
||||||
controlNetModelsAdapterSelectors,
|
|
||||||
ipAdapterModelsAdapterSelectors,
|
|
||||||
t2iAdapterModelsAdapterSelectors,
|
|
||||||
useGetControlNetModelsQuery,
|
|
||||||
useGetIPAdapterModelsQuery,
|
|
||||||
useGetT2IAdapterModelsQuery,
|
|
||||||
} from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
export const useControlAdapterModels = (type?: ControlAdapterType) => {
|
export const useControlAdapterModels = (type: ControlAdapterType) => {
|
||||||
const { data: controlNetModelsData } = useGetControlNetModelsQuery();
|
const controlNetModels = useControlNetModels();
|
||||||
const controlNetModels = useMemo(
|
const t2iAdapterModels = useT2IAdapterModels();
|
||||||
() => (controlNetModelsData ? controlNetModelsAdapterSelectors.selectAll(controlNetModelsData) : []),
|
const ipAdapterModels = useIPAdapterModels();
|
||||||
[controlNetModelsData]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery();
|
|
||||||
const t2iAdapterModels = useMemo(
|
|
||||||
() => (t2iAdapterModelsData ? t2iAdapterModelsAdapterSelectors.selectAll(t2iAdapterModelsData) : []),
|
|
||||||
[t2iAdapterModelsData]
|
|
||||||
);
|
|
||||||
const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery();
|
|
||||||
const ipAdapterModels = useMemo(
|
|
||||||
() => (ipAdapterModelsData ? ipAdapterModelsAdapterSelectors.selectAll(ipAdapterModelsData) : []),
|
|
||||||
[ipAdapterModelsData]
|
|
||||||
);
|
|
||||||
|
|
||||||
if (type === 'controlnet') {
|
if (type === 'controlnet') {
|
||||||
return controlNetModels;
|
return controlNetModels;
|
||||||
@ -36,5 +15,8 @@ export const useControlAdapterModels = (type?: ControlAdapterType) => {
|
|||||||
if (type === 'ip_adapter') {
|
if (type === 'ip_adapter') {
|
||||||
return ipAdapterModels;
|
return ipAdapterModels;
|
||||||
}
|
}
|
||||||
return [];
|
|
||||||
|
// Assert that the end of the function is not reachable.
|
||||||
|
const exhaustiveCheck: never = type;
|
||||||
|
return exhaustiveCheck;
|
||||||
};
|
};
|
||||||
|
@ -93,7 +93,6 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
|
|||||||
type: 'depth_anything_image_processor',
|
type: 'depth_anything_image_processor',
|
||||||
model_size: 'small',
|
model_size: 'small',
|
||||||
resolution: 512,
|
resolution: 512,
|
||||||
offload: false,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
hed_image_processor: {
|
hed_image_processor: {
|
||||||
|
@ -338,6 +338,21 @@ export const controlAdaptersSlice = createSlice({
|
|||||||
pendingControlImagesCleared: (state) => {
|
pendingControlImagesCleared: (state) => {
|
||||||
state.pendingControlImages = [];
|
state.pendingControlImages = [];
|
||||||
},
|
},
|
||||||
|
ipAdaptersReset: (state) => {
|
||||||
|
selectAllIPAdapters(state).forEach((ca) => {
|
||||||
|
caAdapter.removeOne(state, ca.id);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
controlNetsReset: (state) => {
|
||||||
|
selectAllControlNets(state).forEach((ca) => {
|
||||||
|
caAdapter.removeOne(state, ca.id);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
t2iAdaptersReset: (state) => {
|
||||||
|
selectAllT2IAdapters(state).forEach((ca) => {
|
||||||
|
caAdapter.removeOne(state, ca.id);
|
||||||
|
});
|
||||||
|
},
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(controlAdapterImageProcessed, (state, action) => {
|
builder.addCase(controlAdapterImageProcessed, (state, action) => {
|
||||||
@ -376,6 +391,9 @@ export const {
|
|||||||
controlAdapterAutoConfigToggled,
|
controlAdapterAutoConfigToggled,
|
||||||
pendingControlImagesCleared,
|
pendingControlImagesCleared,
|
||||||
controlAdapterModelCleared,
|
controlAdapterModelCleared,
|
||||||
|
ipAdaptersReset,
|
||||||
|
controlNetsReset,
|
||||||
|
t2iAdaptersReset,
|
||||||
} = controlAdaptersSlice.actions;
|
} = controlAdaptersSlice.actions;
|
||||||
|
|
||||||
export const isAnyControlAdapterAdded = isAnyOf(controlAdapterAdded, controlAdapterRecalled);
|
export const isAnyControlAdapterAdded = isAnyOf(controlAdapterAdded, controlAdapterRecalled);
|
||||||
|
@ -15,6 +15,7 @@ import {
|
|||||||
} from '@invoke-ai/ui-library';
|
} from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import {
|
import {
|
||||||
|
alwaysShowImageSizeBadgeChanged,
|
||||||
autoAssignBoardOnClickChanged,
|
autoAssignBoardOnClickChanged,
|
||||||
setGalleryImageMinimumWidth,
|
setGalleryImageMinimumWidth,
|
||||||
shouldAutoSwitchChanged,
|
shouldAutoSwitchChanged,
|
||||||
@ -36,6 +37,7 @@ const GallerySettingsPopover = () => {
|
|||||||
const galleryImageMinimumWidth = useAppSelector((s) => s.gallery.galleryImageMinimumWidth);
|
const galleryImageMinimumWidth = useAppSelector((s) => s.gallery.galleryImageMinimumWidth);
|
||||||
const shouldAutoSwitch = useAppSelector((s) => s.gallery.shouldAutoSwitch);
|
const shouldAutoSwitch = useAppSelector((s) => s.gallery.shouldAutoSwitch);
|
||||||
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
|
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
|
||||||
|
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
|
||||||
|
|
||||||
const handleChangeGalleryImageMinimumWidth = useCallback(
|
const handleChangeGalleryImageMinimumWidth = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -56,6 +58,11 @@ const GallerySettingsPopover = () => {
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const handleChangeAlwaysShowImageSizeBadgeChanged = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => dispatch(alwaysShowImageSizeBadgeChanged(e.target.checked)),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Popover isLazy>
|
<Popover isLazy>
|
||||||
<PopoverTrigger>
|
<PopoverTrigger>
|
||||||
@ -88,6 +95,10 @@ const GallerySettingsPopover = () => {
|
|||||||
<FormLabel>{t('gallery.autoAssignBoardOnClick')}</FormLabel>
|
<FormLabel>{t('gallery.autoAssignBoardOnClick')}</FormLabel>
|
||||||
<Checkbox isChecked={autoAssignBoardOnClick} onChange={handleChangeAutoAssignBoardOnClick} />
|
<Checkbox isChecked={autoAssignBoardOnClick} onChange={handleChangeAutoAssignBoardOnClick} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
|
<FormControl>
|
||||||
|
<FormLabel>{t('gallery.alwaysShowImageSizeBadge')}</FormLabel>
|
||||||
|
<Checkbox isChecked={alwaysShowImageSizeBadge} onChange={handleChangeAlwaysShowImageSizeBadgeChanged} />
|
||||||
|
</FormControl>
|
||||||
</FormControlGroup>
|
</FormControlGroup>
|
||||||
<BoardAutoAddSelect />
|
<BoardAutoAddSelect />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||||
import { Box, Flex, useShiftModifier } from '@invoke-ai/ui-library';
|
import { Box, Flex, Text, useShiftModifier } from '@invoke-ai/ui-library';
|
||||||
import { useStore } from '@nanostores/react';
|
import { useStore } from '@nanostores/react';
|
||||||
import { $customStarUI } from 'app/store/nanostores/customStarUI';
|
import { $customStarUI } from 'app/store/nanostores/customStarUI';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
@ -22,6 +22,16 @@ const imageIconStyleOverrides: SystemStyleObject = {
|
|||||||
bottom: 2,
|
bottom: 2,
|
||||||
top: 'auto',
|
top: 'auto',
|
||||||
};
|
};
|
||||||
|
const boxSx: SystemStyleObject = {
|
||||||
|
containerType: 'inline-size',
|
||||||
|
};
|
||||||
|
|
||||||
|
const badgeSx: SystemStyleObject = {
|
||||||
|
'@container (max-width: 80px)': {
|
||||||
|
'&': { display: 'none' },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
interface HoverableImageProps {
|
interface HoverableImageProps {
|
||||||
imageName: string;
|
imageName: string;
|
||||||
index: number;
|
index: number;
|
||||||
@ -34,6 +44,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
|||||||
const shift = useShiftModifier();
|
const shift = useShiftModifier();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||||
|
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
|
||||||
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
|
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
|
||||||
|
|
||||||
const customStarUi = useStore($customStarUI);
|
const customStarUi = useStore($customStarUI);
|
||||||
@ -121,7 +132,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box w="full" h="full" className="gallerygrid-image" data-testid={dataTestId}>
|
<Box w="full" h="full" className="gallerygrid-image" data-testid={dataTestId} sx={boxSx}>
|
||||||
<Flex
|
<Flex
|
||||||
ref={imageContainerRef}
|
ref={imageContainerRef}
|
||||||
userSelect="none"
|
userSelect="none"
|
||||||
@ -145,6 +156,23 @@ const GalleryImage = (props: HoverableImageProps) => {
|
|||||||
onMouseOut={handleMouseOut}
|
onMouseOut={handleMouseOut}
|
||||||
>
|
>
|
||||||
<>
|
<>
|
||||||
|
{(isHovered || alwaysShowImageSizeBadge) && (
|
||||||
|
<Text
|
||||||
|
position="absolute"
|
||||||
|
background="base.900"
|
||||||
|
color="base.50"
|
||||||
|
fontSize="sm"
|
||||||
|
fontWeight="semibold"
|
||||||
|
bottom={0}
|
||||||
|
left={0}
|
||||||
|
opacity={0.7}
|
||||||
|
px={2}
|
||||||
|
lineHeight={1.25}
|
||||||
|
borderTopEndRadius="base"
|
||||||
|
borderBottomStartRadius="base"
|
||||||
|
sx={badgeSx}
|
||||||
|
>{`${imageDTO.width}x${imageDTO.height}`}</Text>
|
||||||
|
)}
|
||||||
<IAIDndImageIcon onClick={toggleStarredState} icon={starIcon} tooltip={starTooltip} />
|
<IAIDndImageIcon onClick={toggleStarredState} icon={starIcon} tooltip={starTooltip} />
|
||||||
|
|
||||||
{isHovered && shift && (
|
{isHovered && shift && (
|
||||||
|
@ -15,6 +15,7 @@ const initialGalleryState: GalleryState = {
|
|||||||
autoAssignBoardOnClick: true,
|
autoAssignBoardOnClick: true,
|
||||||
autoAddBoardId: 'none',
|
autoAddBoardId: 'none',
|
||||||
galleryImageMinimumWidth: 90,
|
galleryImageMinimumWidth: 90,
|
||||||
|
alwaysShowImageSizeBadge: false,
|
||||||
selectedBoardId: 'none',
|
selectedBoardId: 'none',
|
||||||
galleryView: 'images',
|
galleryView: 'images',
|
||||||
boardSearchText: '',
|
boardSearchText: '',
|
||||||
@ -71,6 +72,9 @@ export const gallerySlice = createSlice({
|
|||||||
state.limit += IMAGE_LIMIT;
|
state.limit += IMAGE_LIMIT;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
alwaysShowImageSizeBadgeChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.alwaysShowImageSizeBadge = action.payload;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addMatcher(isAnyBoardDeleted, (state, action) => {
|
builder.addMatcher(isAnyBoardDeleted, (state, action) => {
|
||||||
@ -107,6 +111,7 @@ export const {
|
|||||||
selectionChanged,
|
selectionChanged,
|
||||||
boardSearchTextChanged,
|
boardSearchTextChanged,
|
||||||
moreImagesLoaded,
|
moreImagesLoaded,
|
||||||
|
alwaysShowImageSizeBadgeChanged,
|
||||||
} = gallerySlice.actions;
|
} = gallerySlice.actions;
|
||||||
|
|
||||||
const isAnyBoardDeleted = isAnyOf(
|
const isAnyBoardDeleted = isAnyOf(
|
||||||
|
@ -19,4 +19,5 @@ export type GalleryState = {
|
|||||||
boardSearchText: string;
|
boardSearchText: string;
|
||||||
offset: number;
|
offset: number;
|
||||||
limit: number;
|
limit: number;
|
||||||
|
alwaysShowImageSizeBadge: boolean;
|
||||||
};
|
};
|
||||||
|
@ -7,14 +7,14 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
|||||||
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
|
import { loraAdded, selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
import { useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||||
import type { LoRAModelConfig } from 'services/api/types';
|
import type { LoRAModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
|
const selectAddedLoRAs = createMemoizedSelector(selectLoraSlice, (lora) => lora.loras);
|
||||||
|
|
||||||
const LoRASelect = () => {
|
const LoRASelect = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data, isLoading } = useGetLoRAModelsQuery();
|
const [modelConfigs, { isLoading }] = useLoRAModels();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const addedLoRAs = useAppSelector(selectAddedLoRAs);
|
const addedLoRAs = useAppSelector(selectAddedLoRAs);
|
||||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||||
@ -37,7 +37,7 @@ const LoRASelect = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { options, onChange } = useGroupedModelCombobox({
|
const { options, onChange } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
getIsDisabled,
|
getIsDisabled,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
});
|
});
|
||||||
|
@ -3,6 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
|
|||||||
import type { PersistConfig, RootState } from 'app/store/store';
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { cloneDeep } from 'lodash-es';
|
||||||
import type { LoRAModelConfig } from 'services/api/types';
|
import type { LoRAModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
export type LoRA = {
|
export type LoRA = {
|
||||||
@ -57,10 +58,12 @@ export const loraSlice = createSlice({
|
|||||||
}
|
}
|
||||||
lora.isEnabled = isEnabled;
|
lora.isEnabled = isEnabled;
|
||||||
},
|
},
|
||||||
|
lorasReset: () => cloneDeep(initialLoraState),
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
export const { loraAdded, loraRemoved, loraWeightChanged, loraIsEnabledChanged, loraRecalled } = loraSlice.actions;
|
export const { loraAdded, loraRemoved, loraWeightChanged, loraIsEnabledChanged, loraRecalled, lorasReset } =
|
||||||
|
loraSlice.actions;
|
||||||
|
|
||||||
export const selectLoraSlice = (state: RootState) => state.lora;
|
export const selectLoraSlice = (state: RootState) => state.lora;
|
||||||
|
|
||||||
|
@ -225,28 +225,34 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
|
|||||||
const control_model = await getProperty(metadataItem, 'control_model');
|
const control_model = await getProperty(metadataItem, 'control_model');
|
||||||
const key = await getModelKey(control_model, 'controlnet');
|
const key = await getModelKey(control_model, 'controlnet');
|
||||||
const controlNetModel = await fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
|
const controlNetModel = await fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
|
||||||
|
const image = zControlField.shape.image
|
||||||
const image = zControlField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
|
.nullish()
|
||||||
|
.catch(null)
|
||||||
|
.parse(await getProperty(metadataItem, 'image'));
|
||||||
|
const processedImage = zControlField.shape.image
|
||||||
|
.nullish()
|
||||||
|
.catch(null)
|
||||||
|
.parse(await getProperty(metadataItem, 'processed_image'));
|
||||||
const control_weight = zControlField.shape.control_weight
|
const control_weight = zControlField.shape.control_weight
|
||||||
.nullish()
|
.nullish()
|
||||||
.catch(null)
|
.catch(null)
|
||||||
.parse(getProperty(metadataItem, 'control_weight'));
|
.parse(await getProperty(metadataItem, 'control_weight'));
|
||||||
const begin_step_percent = zControlField.shape.begin_step_percent
|
const begin_step_percent = zControlField.shape.begin_step_percent
|
||||||
.nullish()
|
.nullish()
|
||||||
.catch(null)
|
.catch(null)
|
||||||
.parse(getProperty(metadataItem, 'begin_step_percent'));
|
.parse(await getProperty(metadataItem, 'begin_step_percent'));
|
||||||
const end_step_percent = zControlField.shape.end_step_percent
|
const end_step_percent = zControlField.shape.end_step_percent
|
||||||
.nullish()
|
.nullish()
|
||||||
.catch(null)
|
.catch(null)
|
||||||
.parse(getProperty(metadataItem, 'end_step_percent'));
|
.parse(await getProperty(metadataItem, 'end_step_percent'));
|
||||||
const control_mode = zControlField.shape.control_mode
|
const control_mode = zControlField.shape.control_mode
|
||||||
.nullish()
|
.nullish()
|
||||||
.catch(null)
|
.catch(null)
|
||||||
.parse(getProperty(metadataItem, 'control_mode'));
|
.parse(await getProperty(metadataItem, 'control_mode'));
|
||||||
const resize_mode = zControlField.shape.resize_mode
|
const resize_mode = zControlField.shape.resize_mode
|
||||||
.nullish()
|
.nullish()
|
||||||
.catch(null)
|
.catch(null)
|
||||||
.parse(getProperty(metadataItem, 'resize_mode'));
|
.parse(await getProperty(metadataItem, 'resize_mode'));
|
||||||
|
|
||||||
const { processorType, processorNode } = buildControlAdapterProcessor(controlNetModel);
|
const { processorType, processorNode } = buildControlAdapterProcessor(controlNetModel);
|
||||||
|
|
||||||
@ -260,7 +266,7 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
|
|||||||
controlMode: control_mode ?? initialControlNet.controlMode,
|
controlMode: control_mode ?? initialControlNet.controlMode,
|
||||||
resizeMode: resize_mode ?? initialControlNet.resizeMode,
|
resizeMode: resize_mode ?? initialControlNet.resizeMode,
|
||||||
controlImage: image?.image_name ?? null,
|
controlImage: image?.image_name ?? null,
|
||||||
processedControlImage: image?.image_name ?? null,
|
processedControlImage: processedImage?.image_name ?? null,
|
||||||
processorType,
|
processorType,
|
||||||
processorNode,
|
processorNode,
|
||||||
shouldAutoConfig: true,
|
shouldAutoConfig: true,
|
||||||
@ -284,20 +290,30 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
|
|||||||
const key = await getModelKey(t2i_adapter_model, 't2i_adapter');
|
const key = await getModelKey(t2i_adapter_model, 't2i_adapter');
|
||||||
const t2iAdapterModel = await fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
|
const t2iAdapterModel = await fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
|
||||||
|
|
||||||
const image = zT2IAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
|
const image = zT2IAdapterField.shape.image
|
||||||
const weight = zT2IAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight'));
|
.nullish()
|
||||||
|
.catch(null)
|
||||||
|
.parse(await getProperty(metadataItem, 'image'));
|
||||||
|
const processedImage = zT2IAdapterField.shape.image
|
||||||
|
.nullish()
|
||||||
|
.catch(null)
|
||||||
|
.parse(await getProperty(metadataItem, 'processed_image'));
|
||||||
|
const weight = zT2IAdapterField.shape.weight
|
||||||
|
.nullish()
|
||||||
|
.catch(null)
|
||||||
|
.parse(await getProperty(metadataItem, 'weight'));
|
||||||
const begin_step_percent = zT2IAdapterField.shape.begin_step_percent
|
const begin_step_percent = zT2IAdapterField.shape.begin_step_percent
|
||||||
.nullish()
|
.nullish()
|
||||||
.catch(null)
|
.catch(null)
|
||||||
.parse(getProperty(metadataItem, 'begin_step_percent'));
|
.parse(await getProperty(metadataItem, 'begin_step_percent'));
|
||||||
const end_step_percent = zT2IAdapterField.shape.end_step_percent
|
const end_step_percent = zT2IAdapterField.shape.end_step_percent
|
||||||
.nullish()
|
.nullish()
|
||||||
.catch(null)
|
.catch(null)
|
||||||
.parse(getProperty(metadataItem, 'end_step_percent'));
|
.parse(await getProperty(metadataItem, 'end_step_percent'));
|
||||||
const resize_mode = zT2IAdapterField.shape.resize_mode
|
const resize_mode = zT2IAdapterField.shape.resize_mode
|
||||||
.nullish()
|
.nullish()
|
||||||
.catch(null)
|
.catch(null)
|
||||||
.parse(getProperty(metadataItem, 'resize_mode'));
|
.parse(await getProperty(metadataItem, 'resize_mode'));
|
||||||
|
|
||||||
const { processorType, processorNode } = buildControlAdapterProcessor(t2iAdapterModel);
|
const { processorType, processorNode } = buildControlAdapterProcessor(t2iAdapterModel);
|
||||||
|
|
||||||
@ -310,7 +326,7 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
|
|||||||
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
|
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
|
||||||
resizeMode: resize_mode ?? initialT2IAdapter.resizeMode,
|
resizeMode: resize_mode ?? initialT2IAdapter.resizeMode,
|
||||||
controlImage: image?.image_name ?? null,
|
controlImage: image?.image_name ?? null,
|
||||||
processedControlImage: image?.image_name ?? null,
|
processedControlImage: processedImage?.image_name ?? null,
|
||||||
processorType,
|
processorType,
|
||||||
processorNode,
|
processorNode,
|
||||||
shouldAutoConfig: true,
|
shouldAutoConfig: true,
|
||||||
@ -334,16 +350,22 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
|||||||
const key = await getModelKey(ip_adapter_model, 'ip_adapter');
|
const key = await getModelKey(ip_adapter_model, 'ip_adapter');
|
||||||
const ipAdapterModel = await fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
|
const ipAdapterModel = await fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
|
||||||
|
|
||||||
const image = zIPAdapterField.shape.image.nullish().catch(null).parse(getProperty(metadataItem, 'image'));
|
const image = zIPAdapterField.shape.image
|
||||||
const weight = zIPAdapterField.shape.weight.nullish().catch(null).parse(getProperty(metadataItem, 'weight'));
|
.nullish()
|
||||||
|
.catch(null)
|
||||||
|
.parse(await getProperty(metadataItem, 'image'));
|
||||||
|
const weight = zIPAdapterField.shape.weight
|
||||||
|
.nullish()
|
||||||
|
.catch(null)
|
||||||
|
.parse(await getProperty(metadataItem, 'weight'));
|
||||||
const begin_step_percent = zIPAdapterField.shape.begin_step_percent
|
const begin_step_percent = zIPAdapterField.shape.begin_step_percent
|
||||||
.nullish()
|
.nullish()
|
||||||
.catch(null)
|
.catch(null)
|
||||||
.parse(getProperty(metadataItem, 'begin_step_percent'));
|
.parse(await getProperty(metadataItem, 'begin_step_percent'));
|
||||||
const end_step_percent = zIPAdapterField.shape.end_step_percent
|
const end_step_percent = zIPAdapterField.shape.end_step_percent
|
||||||
.nullish()
|
.nullish()
|
||||||
.catch(null)
|
.catch(null)
|
||||||
.parse(getProperty(metadataItem, 'end_step_percent'));
|
.parse(await getProperty(metadataItem, 'end_step_percent'));
|
||||||
|
|
||||||
const ipAdapter: IPAdapterConfigMetadata = {
|
const ipAdapter: IPAdapterConfigMetadata = {
|
||||||
id: uuidv4(),
|
id: uuidv4(),
|
||||||
|
@ -1,8 +1,13 @@
|
|||||||
import { getStore } from 'app/store/nanostores/store';
|
import { getStore } from 'app/store/nanostores/store';
|
||||||
import { controlAdapterRecalled } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import {
|
||||||
|
controlAdapterRecalled,
|
||||||
|
controlNetsReset,
|
||||||
|
ipAdaptersReset,
|
||||||
|
t2iAdaptersReset,
|
||||||
|
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
|
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
|
||||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||||
import { loraRecalled } from 'features/lora/store/loraSlice';
|
import { loraRecalled, lorasReset } from 'features/lora/store/loraSlice';
|
||||||
import type {
|
import type {
|
||||||
ControlNetConfigMetadata,
|
ControlNetConfigMetadata,
|
||||||
IPAdapterConfigMetadata,
|
IPAdapterConfigMetadata,
|
||||||
@ -166,7 +171,11 @@ const recallLoRA: MetadataRecallFunc<LoRA> = (lora) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const recallAllLoRAs: MetadataRecallFunc<LoRA[]> = (loras) => {
|
const recallAllLoRAs: MetadataRecallFunc<LoRA[]> = (loras) => {
|
||||||
|
if (!loras.length) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const { dispatch } = getStore();
|
const { dispatch } = getStore();
|
||||||
|
dispatch(lorasReset());
|
||||||
loras.forEach((lora) => {
|
loras.forEach((lora) => {
|
||||||
dispatch(loraRecalled(lora));
|
dispatch(loraRecalled(lora));
|
||||||
});
|
});
|
||||||
@ -177,7 +186,11 @@ const recallControlNet: MetadataRecallFunc<ControlNetConfigMetadata> = (controlN
|
|||||||
};
|
};
|
||||||
|
|
||||||
const recallControlNets: MetadataRecallFunc<ControlNetConfigMetadata[]> = (controlNets) => {
|
const recallControlNets: MetadataRecallFunc<ControlNetConfigMetadata[]> = (controlNets) => {
|
||||||
|
if (!controlNets.length) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const { dispatch } = getStore();
|
const { dispatch } = getStore();
|
||||||
|
dispatch(controlNetsReset());
|
||||||
controlNets.forEach((controlNet) => {
|
controlNets.forEach((controlNet) => {
|
||||||
dispatch(controlAdapterRecalled(controlNet));
|
dispatch(controlAdapterRecalled(controlNet));
|
||||||
});
|
});
|
||||||
@ -188,7 +201,11 @@ const recallT2IAdapter: MetadataRecallFunc<T2IAdapterConfigMetadata> = (t2iAdapt
|
|||||||
};
|
};
|
||||||
|
|
||||||
const recallT2IAdapters: MetadataRecallFunc<T2IAdapterConfigMetadata[]> = (t2iAdapters) => {
|
const recallT2IAdapters: MetadataRecallFunc<T2IAdapterConfigMetadata[]> = (t2iAdapters) => {
|
||||||
|
if (!t2iAdapters.length) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const { dispatch } = getStore();
|
const { dispatch } = getStore();
|
||||||
|
dispatch(t2iAdaptersReset());
|
||||||
t2iAdapters.forEach((t2iAdapter) => {
|
t2iAdapters.forEach((t2iAdapter) => {
|
||||||
dispatch(controlAdapterRecalled(t2iAdapter));
|
dispatch(controlAdapterRecalled(t2iAdapter));
|
||||||
});
|
});
|
||||||
@ -199,7 +216,11 @@ const recallIPAdapter: MetadataRecallFunc<IPAdapterConfigMetadata> = (ipAdapter)
|
|||||||
};
|
};
|
||||||
|
|
||||||
const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapters) => {
|
const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapters) => {
|
||||||
|
if (!ipAdapters.length) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const { dispatch } = getStore();
|
const { dispatch } = getStore();
|
||||||
|
dispatch(ipAdaptersReset());
|
||||||
ipAdapters.forEach((ipAdapter) => {
|
ipAdapters.forEach((ipAdapter) => {
|
||||||
dispatch(controlAdapterRecalled(ipAdapter));
|
dispatch(controlAdapterRecalled(ipAdapter));
|
||||||
});
|
});
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||||
import { isNil } from 'lodash-es';
|
import { isNil } from 'lodash-es';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
@ -8,7 +9,7 @@ import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelCo
|
|||||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
|
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
|
||||||
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision } = config.sd;
|
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision, width, height } = config.sd;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
initialSteps: steps.initial,
|
initialSteps: steps.initial,
|
||||||
@ -16,14 +17,23 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
|
|||||||
initialScheduler: scheduler,
|
initialScheduler: scheduler,
|
||||||
initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial,
|
initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial,
|
||||||
initialVaePrecision: vaePrecision,
|
initialVaePrecision: vaePrecision,
|
||||||
|
initialWidth: width.initial,
|
||||||
|
initialHeight: height.initial,
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
export const useMainModelDefaultSettings = (modelKey?: string | null) => {
|
export const useMainModelDefaultSettings = (modelKey?: string | null) => {
|
||||||
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(modelKey ?? skipToken, isNonRefinerMainModelConfig);
|
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(modelKey ?? skipToken, isNonRefinerMainModelConfig);
|
||||||
|
|
||||||
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
|
const {
|
||||||
useAppSelector(initialStatesSelector);
|
initialSteps,
|
||||||
|
initialCfg,
|
||||||
|
initialScheduler,
|
||||||
|
initialCfgRescaleMultiplier,
|
||||||
|
initialVaePrecision,
|
||||||
|
initialWidth,
|
||||||
|
initialHeight,
|
||||||
|
} = useAppSelector(initialStatesSelector);
|
||||||
|
|
||||||
const defaultSettingsDefaults = useMemo(() => {
|
const defaultSettingsDefaults = useMemo(() => {
|
||||||
return {
|
return {
|
||||||
@ -51,15 +61,25 @@ export const useMainModelDefaultSettings = (modelKey?: string | null) => {
|
|||||||
isEnabled: !isNil(modelConfig?.default_settings?.cfg_rescale_multiplier),
|
isEnabled: !isNil(modelConfig?.default_settings?.cfg_rescale_multiplier),
|
||||||
value: modelConfig?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier,
|
value: modelConfig?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier,
|
||||||
},
|
},
|
||||||
|
width: {
|
||||||
|
isEnabled: !isNil(modelConfig?.default_settings?.width),
|
||||||
|
value: modelConfig?.default_settings?.width || initialWidth,
|
||||||
|
},
|
||||||
|
height: {
|
||||||
|
isEnabled: !isNil(modelConfig?.default_settings?.height),
|
||||||
|
value: modelConfig?.default_settings?.height || initialHeight,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
}, [
|
}, [
|
||||||
modelConfig?.default_settings,
|
modelConfig,
|
||||||
|
initialVaePrecision,
|
||||||
|
initialScheduler,
|
||||||
initialSteps,
|
initialSteps,
|
||||||
initialCfg,
|
initialCfg,
|
||||||
initialScheduler,
|
|
||||||
initialCfgRescaleMultiplier,
|
initialCfgRescaleMultiplier,
|
||||||
initialVaePrecision,
|
initialWidth,
|
||||||
|
initialHeight,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
return { defaultSettingsDefaults, isLoading };
|
return { defaultSettingsDefaults, isLoading, optimalDimension: getOptimalDimension(modelConfig) };
|
||||||
};
|
};
|
||||||
|
@ -1,13 +1,16 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import type { PersistConfig } from 'app/store/store';
|
import type { PersistConfig } from 'app/store/store';
|
||||||
|
import type { ModelType } from 'services/api/types';
|
||||||
|
|
||||||
|
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'>;
|
||||||
|
|
||||||
type ModelManagerState = {
|
type ModelManagerState = {
|
||||||
_version: 1;
|
_version: 1;
|
||||||
selectedModelKey: string | null;
|
selectedModelKey: string | null;
|
||||||
selectedModelMode: 'edit' | 'view';
|
selectedModelMode: 'edit' | 'view';
|
||||||
searchTerm: string;
|
searchTerm: string;
|
||||||
filteredModelType: string | null;
|
filteredModelType: FilterableModelType | null;
|
||||||
scanPath: string | undefined;
|
scanPath: string | undefined;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -35,7 +38,7 @@ export const modelManagerV2Slice = createSlice({
|
|||||||
state.searchTerm = action.payload;
|
state.searchTerm = action.payload;
|
||||||
},
|
},
|
||||||
|
|
||||||
setFilteredModelType: (state, action: PayloadAction<string | null>) => {
|
setFilteredModelType: (state, action: PayloadAction<FilterableModelType | null>) => {
|
||||||
state.filteredModelType = action.payload;
|
state.filteredModelType = action.payload;
|
||||||
},
|
},
|
||||||
setScanPath: (state, action: PayloadAction<string | undefined>) => {
|
setScanPath: (state, action: PayloadAction<string | undefined>) => {
|
||||||
|
@ -0,0 +1,102 @@
|
|||||||
|
import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
|
import type { ChangeEventHandler } from 'react';
|
||||||
|
import { useCallback, useState } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useInstallModelMutation, useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
import { HuggingFaceResults } from './HuggingFaceResults';
|
||||||
|
|
||||||
|
export const HuggingFaceForm = () => {
|
||||||
|
const [huggingFaceRepo, setHuggingFaceRepo] = useState('');
|
||||||
|
const [displayResults, setDisplayResults] = useState(false);
|
||||||
|
const [errorMessage, setErrorMessage] = useState('');
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery();
|
||||||
|
const [installModel] = useInstallModelMutation();
|
||||||
|
|
||||||
|
const handleInstallModel = useCallback(
|
||||||
|
(source: string) => {
|
||||||
|
installModel({ source })
|
||||||
|
.unwrap()
|
||||||
|
.then((_) => {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('toast.modelAddedSimple'),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
if (error) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: `${error.data.detail} `,
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[installModel, dispatch, t]
|
||||||
|
);
|
||||||
|
|
||||||
|
const getModels = useCallback(async () => {
|
||||||
|
_getHuggingFaceModels(huggingFaceRepo)
|
||||||
|
.unwrap()
|
||||||
|
.then((response) => {
|
||||||
|
if (response.is_diffusers) {
|
||||||
|
handleInstallModel(huggingFaceRepo);
|
||||||
|
setDisplayResults(false);
|
||||||
|
} else if (response.urls?.length === 1 && response.urls[0]) {
|
||||||
|
handleInstallModel(response.urls[0]);
|
||||||
|
setDisplayResults(false);
|
||||||
|
} else {
|
||||||
|
setDisplayResults(true);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
setErrorMessage(error.data.detail || '');
|
||||||
|
});
|
||||||
|
}, [_getHuggingFaceModels, handleInstallModel, huggingFaceRepo]);
|
||||||
|
|
||||||
|
const handleSetHuggingFaceRepo: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||||
|
setHuggingFaceRepo(e.target.value);
|
||||||
|
setErrorMessage('');
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex flexDir="column" height="100%" gap={3}>
|
||||||
|
<FormControl isInvalid={!!errorMessage.length} w="full" orientation="vertical" flexShrink={0}>
|
||||||
|
<FormLabel>{t('modelManager.huggingFaceRepoID')}</FormLabel>
|
||||||
|
<Flex gap={3} alignItems="center" w="full">
|
||||||
|
<Input
|
||||||
|
placeholder={t('modelManager.huggingFacePlaceholder')}
|
||||||
|
value={huggingFaceRepo}
|
||||||
|
onChange={handleSetHuggingFaceRepo}
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
onClick={getModels}
|
||||||
|
isLoading={isLoading}
|
||||||
|
isDisabled={huggingFaceRepo.length === 0}
|
||||||
|
size="sm"
|
||||||
|
flexShrink={0}
|
||||||
|
>
|
||||||
|
{t('modelManager.installRepo')}
|
||||||
|
</Button>
|
||||||
|
</Flex>
|
||||||
|
<FormHelperText>{t('modelManager.huggingFaceHelper')}</FormHelperText>
|
||||||
|
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
|
||||||
|
</FormControl>
|
||||||
|
{data && data.urls && displayResults && <HuggingFaceResults results={data.urls} />}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
@ -0,0 +1,57 @@
|
|||||||
|
import { Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { PiPlusBold } from 'react-icons/pi';
|
||||||
|
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
result: string;
|
||||||
|
};
|
||||||
|
export const HuggingFaceResultItem = ({ result }: Props) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const [installModel] = useInstallModelMutation();
|
||||||
|
|
||||||
|
const handleInstall = useCallback(() => {
|
||||||
|
installModel({ source: result })
|
||||||
|
.unwrap()
|
||||||
|
.then((_) => {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('toast.modelAddedSimple'),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
if (error) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: `${error.data.detail} `,
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}, [installModel, result, dispatch, t]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
|
||||||
|
<Flex fontSize="sm" flexDir="column">
|
||||||
|
<Text fontWeight="semibold">{result.split('/').slice(-1)[0]}</Text>
|
||||||
|
<Text variant="subtext" noOfLines={1} wordBreak="break-all">
|
||||||
|
{result}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleInstall} size="sm" />
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
@ -0,0 +1,123 @@
|
|||||||
|
import {
|
||||||
|
Button,
|
||||||
|
Divider,
|
||||||
|
Flex,
|
||||||
|
Heading,
|
||||||
|
IconButton,
|
||||||
|
Input,
|
||||||
|
InputGroup,
|
||||||
|
InputRightElement,
|
||||||
|
} from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
|
import type { ChangeEventHandler } from 'react';
|
||||||
|
import { useCallback, useMemo, useState } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { PiXBold } from 'react-icons/pi';
|
||||||
|
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
import { HuggingFaceResultItem } from './HuggingFaceResultItem';
|
||||||
|
|
||||||
|
type HuggingFaceResultsProps = {
|
||||||
|
results: string[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const [searchTerm, setSearchTerm] = useState('');
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const [installModel] = useInstallModelMutation();
|
||||||
|
|
||||||
|
const filteredResults = useMemo(() => {
|
||||||
|
return results.filter((result) => {
|
||||||
|
const modelName = result.split('/').slice(-1)[0];
|
||||||
|
return modelName?.toLowerCase().includes(searchTerm.toLowerCase());
|
||||||
|
});
|
||||||
|
}, [results, searchTerm]);
|
||||||
|
|
||||||
|
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||||
|
setSearchTerm(e.target.value.trim());
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const clearSearch = useCallback(() => {
|
||||||
|
setSearchTerm('');
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handleAddAll = useCallback(() => {
|
||||||
|
for (const result of filteredResults) {
|
||||||
|
installModel({ source: result })
|
||||||
|
.unwrap()
|
||||||
|
.then((_) => {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('toast.modelAddedSimple'),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
if (error) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: `${error.data.detail} `,
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [filteredResults, installModel, dispatch, t]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Divider />
|
||||||
|
<Flex flexDir="column" gap={3} height="100%">
|
||||||
|
<Flex justifyContent="space-between" alignItems="center">
|
||||||
|
<Heading size="sm">{t('modelManager.availableModels')}</Heading>
|
||||||
|
<Flex alignItems="center" gap={3}>
|
||||||
|
<Button size="sm" onClick={handleAddAll} isDisabled={results.length === 0} flexShrink={0}>
|
||||||
|
{t('modelManager.installAll')}
|
||||||
|
</Button>
|
||||||
|
<InputGroup w={64} size="xs">
|
||||||
|
<Input
|
||||||
|
placeholder={t('modelManager.search')}
|
||||||
|
value={searchTerm}
|
||||||
|
data-testid="board-search-input"
|
||||||
|
onChange={handleSearch}
|
||||||
|
size="xs"
|
||||||
|
/>
|
||||||
|
|
||||||
|
{searchTerm && (
|
||||||
|
<InputRightElement h="full" pe={2}>
|
||||||
|
<IconButton
|
||||||
|
size="sm"
|
||||||
|
variant="link"
|
||||||
|
aria-label={t('boards.clearSearch')}
|
||||||
|
icon={<PiXBold />}
|
||||||
|
onClick={clearSearch}
|
||||||
|
/>
|
||||||
|
</InputRightElement>
|
||||||
|
)}
|
||||||
|
</InputGroup>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
<Flex height="100%" layerStyle="third" borderRadius="base" p={3}>
|
||||||
|
<ScrollableContent>
|
||||||
|
<Flex flexDir="column" gap={3}>
|
||||||
|
{filteredResults.map((result) => (
|
||||||
|
<HuggingFaceResultItem key={result} result={result} />
|
||||||
|
))}
|
||||||
|
</Flex>
|
||||||
|
</ScrollableContent>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
@ -67,20 +67,22 @@ export const InstallModelForm = () => {
|
|||||||
<Flex flexDir="column" gap={4}>
|
<Flex flexDir="column" gap={4}>
|
||||||
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
|
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
|
||||||
<FormControl orientation="vertical">
|
<FormControl orientation="vertical">
|
||||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
<FormLabel>{t('modelManager.urlOrLocalPath')}</FormLabel>
|
||||||
<Input {...register('location')} />
|
<Flex alignItems="center" gap={3} w="full">
|
||||||
</FormControl>
|
<Input placeholder={t('modelManager.simpleModelPlaceholder')} {...register('location')} />
|
||||||
<Button
|
<Button
|
||||||
onClick={handleSubmit(onSubmit)}
|
onClick={handleSubmit(onSubmit)}
|
||||||
isDisabled={!formState.dirtyFields.location}
|
isDisabled={!formState.dirtyFields.location}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
type="submit"
|
type="submit"
|
||||||
size="sm"
|
size="sm"
|
||||||
mb={1}
|
|
||||||
>
|
>
|
||||||
{t('modelManager.addModel')}
|
{t('modelManager.install')}
|
||||||
</Button>
|
</Button>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
<FormHelperText>{t('modelManager.urlOrLocalPathHelper')}</FormHelperText>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
|
||||||
<FormControl>
|
<FormControl>
|
||||||
<Flex flexDir="column" gap={2}>
|
<Flex flexDir="column" gap={2}>
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Box, Button, Flex, Text } from '@invoke-ai/ui-library';
|
import { Box, Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
@ -50,9 +50,9 @@ export const ModelInstallQueue = () => {
|
|||||||
}, [data]);
|
}, [data]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex flexDir="column" p={3} h="full">
|
<Flex flexDir="column" p={3} h="full" gap={3}>
|
||||||
<Flex justifyContent="space-between" alignItems="center">
|
<Flex justifyContent="space-between" alignItems="center">
|
||||||
<Text>{t('modelManager.importQueue')}</Text>
|
<Heading size="sm">{t('modelManager.installQueue')}</Heading>
|
||||||
<Button
|
<Button
|
||||||
size="sm"
|
size="sm"
|
||||||
isDisabled={!pruneAvailable}
|
isDisabled={!pruneAvailable}
|
||||||
@ -62,9 +62,9 @@ export const ModelInstallQueue = () => {
|
|||||||
{t('modelManager.prune')}
|
{t('modelManager.prune')}
|
||||||
</Button>
|
</Button>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
|
<Box layerStyle="first" p={3} borderRadius="base" w="full" h="full">
|
||||||
<ScrollableContent>
|
<ScrollableContent>
|
||||||
<Flex flexDir="column-reverse" gap="2">
|
<Flex flexDir="column-reverse" gap="2" w="full">
|
||||||
{data?.map((model) => <ModelInstallQueueItem key={model.id} installJob={model} />)}
|
{data?.map((model) => <ModelInstallQueueItem key={model.id} installJob={model} />)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</ScrollableContent>
|
</ScrollableContent>
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Badge, Tooltip } from '@invoke-ai/ui-library';
|
import { Badge } from '@invoke-ai/ui-library';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { ModelInstallStatus } from 'services/api/types';
|
import type { ModelInstallStatus } from 'services/api/types';
|
||||||
@ -13,13 +13,7 @@ const STATUSES = {
|
|||||||
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
|
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
|
||||||
};
|
};
|
||||||
|
|
||||||
const ModelInstallQueueBadge = ({
|
const ModelInstallQueueBadge = ({ status }: { status?: ModelInstallStatus }) => {
|
||||||
status,
|
|
||||||
errorReason,
|
|
||||||
}: {
|
|
||||||
status?: ModelInstallStatus;
|
|
||||||
errorReason?: string | null;
|
|
||||||
}) => {
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
if (!status || !Object.keys(STATUSES).includes(status)) {
|
if (!status || !Object.keys(STATUSES).includes(status)) {
|
||||||
@ -27,9 +21,9 @@ const ModelInstallQueueBadge = ({
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Tooltip label={errorReason}>
|
<Badge textAlign="center" w="134px" colorScheme={STATUSES[status].colorScheme}>
|
||||||
<Badge colorScheme={STATUSES[status].colorScheme}>{t(STATUSES[status].translationKey)}</Badge>
|
{t(STATUSES[status].translationKey)}
|
||||||
</Tooltip>
|
</Badge>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
export default memo(ModelInstallQueueBadge);
|
export default memo(ModelInstallQueueBadge);
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Box, Flex, IconButton, Progress, Text, Tooltip } from '@invoke-ai/ui-library';
|
import { Flex, IconButton, Progress, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
@ -7,7 +7,7 @@ import { isNil } from 'lodash-es';
|
|||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { PiXBold } from 'react-icons/pi';
|
import { PiXBold } from 'react-icons/pi';
|
||||||
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
|
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
|
||||||
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
|
import type { ModelInstallJob } from 'services/api/types';
|
||||||
|
|
||||||
import ModelInstallQueueBadge from './ModelInstallQueueBadge';
|
import ModelInstallQueueBadge from './ModelInstallQueueBadge';
|
||||||
|
|
||||||
@ -16,7 +16,7 @@ type ModelListItemProps = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const formatBytes = (bytes: number) => {
|
const formatBytes = (bytes: number) => {
|
||||||
const units = ['b', 'kb', 'mb', 'gb', 'tb'];
|
const units = ['B', 'KB', 'MB', 'GB', 'TB'];
|
||||||
|
|
||||||
let i = 0;
|
let i = 0;
|
||||||
|
|
||||||
@ -33,18 +33,6 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
|||||||
|
|
||||||
const [deleteImportModel] = useCancelModelInstallMutation();
|
const [deleteImportModel] = useCancelModelInstallMutation();
|
||||||
|
|
||||||
const source = useMemo(() => {
|
|
||||||
if (installJob.source.type === 'hf') {
|
|
||||||
return installJob.source as HFModelSource;
|
|
||||||
} else if (installJob.source.type === 'local') {
|
|
||||||
return installJob.source as LocalModelSource;
|
|
||||||
} else if (installJob.source.type === 'url') {
|
|
||||||
return installJob.source as URLModelSource;
|
|
||||||
} else {
|
|
||||||
return installJob.source as LocalModelSource;
|
|
||||||
}
|
|
||||||
}, [installJob.source]);
|
|
||||||
|
|
||||||
const handleDeleteModelImport = useCallback(() => {
|
const handleDeleteModelImport = useCallback(() => {
|
||||||
deleteImportModel(installJob.id)
|
deleteImportModel(installJob.id)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@ -72,18 +60,31 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
|||||||
});
|
});
|
||||||
}, [deleteImportModel, installJob, dispatch]);
|
}, [deleteImportModel, installJob, dispatch]);
|
||||||
|
|
||||||
const modelName = useMemo(() => {
|
const sourceLocation = useMemo(() => {
|
||||||
switch (source.type) {
|
switch (installJob.source.type) {
|
||||||
case 'hf':
|
case 'hf':
|
||||||
return source.repo_id;
|
return installJob.source.repo_id;
|
||||||
case 'url':
|
case 'url':
|
||||||
return source.url;
|
return installJob.source.url;
|
||||||
case 'local':
|
case 'local':
|
||||||
return source.path.split('\\').slice(-1)[0];
|
return installJob.source.path;
|
||||||
default:
|
default:
|
||||||
return '';
|
return t('common.unknown');
|
||||||
}
|
}
|
||||||
}, [source]);
|
}, [installJob.source]);
|
||||||
|
|
||||||
|
const modelName = useMemo(() => {
|
||||||
|
switch (installJob.source.type) {
|
||||||
|
case 'hf':
|
||||||
|
return installJob.source.repo_id;
|
||||||
|
case 'url':
|
||||||
|
return installJob.source.url.split('/').slice(-1)[0] ?? t('common.unknown');
|
||||||
|
case 'local':
|
||||||
|
return installJob.source.path.split('\\').slice(-1)[0] ?? t('common.unknown');
|
||||||
|
default:
|
||||||
|
return t('common.unknown');
|
||||||
|
}
|
||||||
|
}, [installJob.source]);
|
||||||
|
|
||||||
const progressValue = useMemo(() => {
|
const progressValue = useMemo(() => {
|
||||||
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
|
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
|
||||||
@ -97,48 +98,67 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
|||||||
return (installJob.bytes / installJob.total_bytes) * 100;
|
return (installJob.bytes / installJob.total_bytes) * 100;
|
||||||
}, [installJob.bytes, installJob.total_bytes]);
|
}, [installJob.bytes, installJob.total_bytes]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex gap={3} w="full" alignItems="center">
|
||||||
|
<Tooltip maxW={600} label={<TooltipLabel name={modelName} source={sourceLocation} installJob={installJob} />}>
|
||||||
|
<Flex gap={3} w="full" alignItems="center">
|
||||||
|
<Text w={96} whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
|
||||||
|
{modelName}
|
||||||
|
</Text>
|
||||||
|
<Progress
|
||||||
|
w="full"
|
||||||
|
flexGrow={1}
|
||||||
|
value={progressValue ?? 0}
|
||||||
|
isIndeterminate={progressValue === null}
|
||||||
|
aria-label={t('accessibility.invokeProgressBar')}
|
||||||
|
h={2}
|
||||||
|
/>
|
||||||
|
<ModelInstallQueueBadge status={installJob.status} />
|
||||||
|
</Flex>
|
||||||
|
</Tooltip>
|
||||||
|
<IconButton
|
||||||
|
isDisabled={
|
||||||
|
installJob.status !== 'downloading' && installJob.status !== 'waiting' && installJob.status !== 'running'
|
||||||
|
}
|
||||||
|
size="xs"
|
||||||
|
tooltip={t('modelManager.cancel')}
|
||||||
|
aria-label={t('modelManager.cancel')}
|
||||||
|
icon={<PiXBold />}
|
||||||
|
onClick={handleDeleteModelImport}
|
||||||
|
variant="ghost"
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
type TooltipLabelProps = {
|
||||||
|
installJob: ModelInstallJob;
|
||||||
|
name: string;
|
||||||
|
source: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
const TooltipLabel = ({ name, source, installJob }: TooltipLabelProps) => {
|
||||||
const progressString = useMemo(() => {
|
const progressString = useMemo(() => {
|
||||||
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
|
if (installJob.status === 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
|
||||||
return '';
|
return '';
|
||||||
}
|
}
|
||||||
return `${formatBytes(installJob.bytes)} / ${formatBytes(installJob.total_bytes)}`;
|
return `${formatBytes(installJob.bytes)} / ${formatBytes(installJob.total_bytes)}`;
|
||||||
}, [installJob.bytes, installJob.total_bytes, installJob.status]);
|
}, [installJob.bytes, installJob.total_bytes, installJob.status]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex gap="2" w="full" alignItems="center">
|
<>
|
||||||
<Tooltip label={modelName}>
|
<Flex gap={3} justifyContent="space-between">
|
||||||
<Text width="30%" whiteSpace="nowrap" overflow="hidden" textOverflow="ellipsis">
|
<Text fontWeight="semibold">{name}</Text>
|
||||||
{modelName}
|
{progressString && <Text>{progressString}</Text>}
|
||||||
|
</Flex>
|
||||||
|
<Text fontStyle="italic" wordBreak="break-all">
|
||||||
|
{source}
|
||||||
|
</Text>
|
||||||
|
{installJob.error_reason && (
|
||||||
|
<Text color="error.500">
|
||||||
|
{t('queue.failed')}: {installJob.error}
|
||||||
</Text>
|
</Text>
|
||||||
</Tooltip>
|
|
||||||
<Flex flexDir="column" flex={1}>
|
|
||||||
<Tooltip label={progressString}>
|
|
||||||
<Progress
|
|
||||||
value={progressValue ?? 0}
|
|
||||||
isIndeterminate={progressValue === null}
|
|
||||||
aria-label={t('accessibility.invokeProgressBar')}
|
|
||||||
h={2}
|
|
||||||
/>
|
|
||||||
</Tooltip>
|
|
||||||
</Flex>
|
|
||||||
<Box minW="100px" textAlign="center">
|
|
||||||
<ModelInstallQueueBadge status={installJob.status} errorReason={installJob.error_reason} />
|
|
||||||
</Box>
|
|
||||||
|
|
||||||
<Box minW="20px">
|
|
||||||
{(installJob.status === 'downloading' ||
|
|
||||||
installJob.status === 'waiting' ||
|
|
||||||
installJob.status === 'running') && (
|
|
||||||
<IconButton
|
|
||||||
isRound={true}
|
|
||||||
size="xs"
|
|
||||||
tooltip={t('modelManager.cancel')}
|
|
||||||
aria-label={t('modelManager.cancel')}
|
|
||||||
icon={<PiXBold />}
|
|
||||||
onClick={handleDeleteModelImport}
|
|
||||||
/>
|
|
||||||
)}
|
)}
|
||||||
</Box>
|
</>
|
||||||
</Flex>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@invoke-ai/ui-library';
|
import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { setScanPath } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
import { setScanPath } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import type { ChangeEventHandler } from 'react';
|
import type { ChangeEventHandler } from 'react';
|
||||||
@ -17,7 +17,9 @@ export const ScanModelsForm = () => {
|
|||||||
const [_scanFolder, { isLoading, data }] = useLazyScanFolderQuery();
|
const [_scanFolder, { isLoading, data }] = useLazyScanFolderQuery();
|
||||||
|
|
||||||
const scanFolder = useCallback(async () => {
|
const scanFolder = useCallback(async () => {
|
||||||
_scanFolder({ scan_path: scanPath }).catch((error) => {
|
_scanFolder({ scan_path: scanPath })
|
||||||
|
.unwrap()
|
||||||
|
.catch((error) => {
|
||||||
if (error) {
|
if (error) {
|
||||||
setErrorMessage(error.data.detail);
|
setErrorMessage(error.data.detail);
|
||||||
}
|
}
|
||||||
@ -33,25 +35,23 @@ export const ScanModelsForm = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex flexDir="column" height="100%">
|
<Flex flexDir="column" height="100%" gap={3}>
|
||||||
<FormControl isInvalid={!!errorMessage.length} w="full">
|
<FormControl isInvalid={!!errorMessage.length} w="full" orientation="vertical" flexShrink={0}>
|
||||||
<Flex flexDir="column" w="full">
|
|
||||||
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
|
|
||||||
<Flex direction="column" w="full">
|
|
||||||
<FormLabel>{t('common.folder')}</FormLabel>
|
<FormLabel>{t('common.folder')}</FormLabel>
|
||||||
<Input value={scanPath} onChange={handleSetScanPath} />
|
<Flex gap={3} alignItems="center" w="full">
|
||||||
</Flex>
|
<Input placeholder={t('modelManager.scanPlaceholder')} value={scanPath} onChange={handleSetScanPath} />
|
||||||
|
|
||||||
<Button
|
<Button
|
||||||
onClick={scanFolder}
|
onClick={scanFolder}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
isDisabled={scanPath === undefined || scanPath.length === 0}
|
isDisabled={scanPath === undefined || scanPath.length === 0}
|
||||||
|
size="sm"
|
||||||
|
flexShrink={0}
|
||||||
>
|
>
|
||||||
{t('modelManager.scanFolder')}
|
{t('modelManager.scanFolder')}
|
||||||
</Button>
|
</Button>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
<FormHelperText>{t('modelManager.scanFolderHelper')}</FormHelperText>
|
||||||
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
|
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
|
||||||
</Flex>
|
|
||||||
</FormControl>
|
</FormControl>
|
||||||
{data && <ScanModelsResults results={data} />}
|
{data && <ScanModelsResults results={data} />}
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import { Badge, Box, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
|
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { IoAdd } from 'react-icons/io5';
|
import { PiPlusBold } from 'react-icons/pi';
|
||||||
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
||||||
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ export const ScanModelResultItem = ({ result }: Props) => {
|
|||||||
}, [installModel, result, dispatch, t]);
|
}, [installModel, result, dispatch, t]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex justifyContent="space-between">
|
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
|
||||||
<Flex fontSize="sm" flexDir="column">
|
<Flex fontSize="sm" flexDir="column">
|
||||||
<Text fontWeight="semibold">{result.path.split('\\').slice(-1)[0]}</Text>
|
<Text fontWeight="semibold">{result.path.split('\\').slice(-1)[0]}</Text>
|
||||||
<Text variant="subtext">{result.path}</Text>
|
<Text variant="subtext">{result.path}</Text>
|
||||||
@ -54,9 +54,7 @@ export const ScanModelResultItem = ({ result }: Props) => {
|
|||||||
{result.is_installed ? (
|
{result.is_installed ? (
|
||||||
<Badge>{t('common.installed')}</Badge>
|
<Badge>{t('common.installed')}</Badge>
|
||||||
) : (
|
) : (
|
||||||
<Tooltip label={t('modelManager.quickAdd')}>
|
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleQuickAdd} size="sm" />
|
||||||
<IconButton aria-label={t('modelManager.quickAdd')} icon={<IoAdd />} onClick={handleQuickAdd} />
|
|
||||||
</Tooltip>
|
|
||||||
)}
|
)}
|
||||||
</Box>
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -80,17 +80,15 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Divider mt={4} />
|
<Divider />
|
||||||
<Flex flexDir="column" gap={2} mt={4} height="100%">
|
<Flex flexDir="column" gap={3} height="100%">
|
||||||
<Flex justifyContent="space-between" alignItems="center">
|
<Flex justifyContent="space-between" alignItems="center">
|
||||||
<Heading fontSize="md" as="h4">
|
<Heading size="sm">{t('modelManager.scanResults')}</Heading>
|
||||||
{t('modelManager.scanResults')}
|
<Flex alignItems="center" gap={3}>
|
||||||
</Heading>
|
<Button size="sm" onClick={handleAddAll} isDisabled={filteredResults.length === 0}>
|
||||||
<Flex alignItems="center" gap="4">
|
{t('modelManager.installAll')}
|
||||||
<Button onClick={handleAddAll} isDisabled={filteredResults.length === 0}>
|
|
||||||
{t('modelManager.addAll')}
|
|
||||||
</Button>
|
</Button>
|
||||||
<InputGroup maxW="300px" size="xs">
|
<InputGroup w={64} size="xs">
|
||||||
<Input
|
<Input
|
||||||
placeholder={t('modelManager.search')}
|
placeholder={t('modelManager.search')}
|
||||||
value={searchTerm}
|
value={searchTerm}
|
||||||
@ -107,13 +105,14 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
|||||||
aria-label={t('boards.clearSearch')}
|
aria-label={t('boards.clearSearch')}
|
||||||
icon={<PiXBold />}
|
icon={<PiXBold />}
|
||||||
onClick={clearSearch}
|
onClick={clearSearch}
|
||||||
|
flexShrink={0}
|
||||||
/>
|
/>
|
||||||
</InputRightElement>
|
</InputRightElement>
|
||||||
)}
|
)}
|
||||||
</InputGroup>
|
</InputGroup>
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Flex height="100%" layerStyle="third" borderRadius="base" p={4} mt={4} mb={4}>
|
<Flex height="100%" layerStyle="third" borderRadius="base" p={3}>
|
||||||
<ScrollableContent>
|
<ScrollableContent>
|
||||||
<Flex flexDir="column" gap={3}>
|
<Flex flexDir="column" gap={3}>
|
||||||
{filteredResults.map((result) => (
|
{filteredResults.map((result) => (
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
|
||||||
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
|
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
|
||||||
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
|
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
|
||||||
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
|
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
|
||||||
@ -8,27 +9,27 @@ import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
|
|||||||
export const InstallModels = () => {
|
export const InstallModels = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
return (
|
return (
|
||||||
<Flex layerStyle="first" p={3} borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
|
<Flex layerStyle="first" borderRadius="base" w="full" h="full" flexDir="column" gap={4}>
|
||||||
<Box w="full" p={2}>
|
|
||||||
<Heading fontSize="xl">{t('modelManager.addModel')}</Heading>
|
<Heading fontSize="xl">{t('modelManager.addModel')}</Heading>
|
||||||
</Box>
|
<Tabs variant="collapse" height="50%" display="flex" flexDir="column">
|
||||||
<Box layerStyle="second" borderRadius="base" w="full" h="50%" overflow="hidden">
|
|
||||||
<Tabs variant="collapse" height="100%">
|
|
||||||
<TabList>
|
<TabList>
|
||||||
<Tab>{t('common.simple')}</Tab>
|
<Tab>{t('modelManager.urlOrLocalPath')}</Tab>
|
||||||
<Tab>{t('modelManager.scan')}</Tab>
|
<Tab>{t('modelManager.huggingFace')}</Tab>
|
||||||
|
<Tab>{t('modelManager.scanFolder')}</Tab>
|
||||||
</TabList>
|
</TabList>
|
||||||
<TabPanels p={3} height="100%">
|
<TabPanels p={3} height="100%">
|
||||||
<TabPanel>
|
<TabPanel>
|
||||||
<InstallModelForm />
|
<InstallModelForm />
|
||||||
</TabPanel>
|
</TabPanel>
|
||||||
|
<TabPanel height="100%">
|
||||||
|
<HuggingFaceForm />
|
||||||
|
</TabPanel>
|
||||||
<TabPanel height="100%">
|
<TabPanel height="100%">
|
||||||
<ScanModelsForm />
|
<ScanModelsForm />
|
||||||
</TabPanel>
|
</TabPanel>
|
||||||
</TabPanels>
|
</TabPanels>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
</Box>
|
<Box layerStyle="second" borderRadius="base" h="50%">
|
||||||
<Box layerStyle="second" borderRadius="base" w="full" h="50%">
|
|
||||||
<ModelInstallQueue />
|
<ModelInstallQueue />
|
||||||
</Box>
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -1,122 +1,105 @@
|
|||||||
import { Flex, Spinner, Text } from '@invoke-ai/ui-library';
|
import { Flex, Spinner, Text } from '@invoke-ai/ui-library';
|
||||||
import type { EntityState } from '@reduxjs/toolkit';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
import { forEach } from 'lodash-es';
|
import { memo, useMemo } from 'react';
|
||||||
import { memo } from 'react';
|
|
||||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
|
||||||
import {
|
import {
|
||||||
useGetControlNetModelsQuery,
|
useControlNetModels,
|
||||||
useGetIPAdapterModelsQuery,
|
useEmbeddingModels,
|
||||||
useGetLoRAModelsQuery,
|
useIPAdapterModels,
|
||||||
useGetMainModelsQuery,
|
useLoRAModels,
|
||||||
useGetT2IAdapterModelsQuery,
|
useMainModels,
|
||||||
useGetTextualInversionModelsQuery,
|
useT2IAdapterModels,
|
||||||
useGetVaeModelsQuery,
|
useVAEModels,
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/hooks/modelsByType';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
import type { AnyModelConfig, ModelType } from 'services/api/types';
|
||||||
|
|
||||||
import { ModelListWrapper } from './ModelListWrapper';
|
import { ModelListWrapper } from './ModelListWrapper';
|
||||||
|
|
||||||
const ModelList = () => {
|
const ModelList = () => {
|
||||||
const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2);
|
const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2);
|
||||||
|
|
||||||
const { filteredMainModels, isLoadingMainModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
const [mainModels, { isLoading: isLoadingMainModels }] = useMainModels();
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
const filteredMainModels = useMemo(
|
||||||
filteredMainModels: modelsFilter(data, searchTerm, filteredModelType),
|
() => modelsFilter(mainModels, searchTerm, filteredModelType),
|
||||||
isLoadingMainModels: isLoading,
|
[mainModels, searchTerm, filteredModelType]
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(undefined, {
|
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
|
||||||
filteredLoraModels: modelsFilter(data, searchTerm, filteredModelType),
|
|
||||||
isLoadingLoraModels: isLoading,
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
const { filteredTextualInversionModels, isLoadingTextualInversionModels } = useGetTextualInversionModelsQuery(
|
|
||||||
undefined,
|
|
||||||
{
|
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
|
||||||
filteredTextualInversionModels: modelsFilter(data, searchTerm, filteredModelType),
|
|
||||||
isLoadingTextualInversionModels: isLoading,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const { filteredControlnetModels, isLoadingControlnetModels } = useGetControlNetModelsQuery(undefined, {
|
const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels();
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
const filteredLoRAModels = useMemo(
|
||||||
filteredControlnetModels: modelsFilter(data, searchTerm, filteredModelType),
|
() => modelsFilter(loraModels, searchTerm, filteredModelType),
|
||||||
isLoadingControlnetModels: isLoading,
|
[loraModels, searchTerm, filteredModelType]
|
||||||
}),
|
);
|
||||||
});
|
|
||||||
|
|
||||||
const { filteredT2iAdapterModels, isLoadingT2IAdapterModels } = useGetT2IAdapterModelsQuery(undefined, {
|
const [embeddingModels, { isLoading: isLoadingEmbeddingModels }] = useEmbeddingModels();
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
const filteredEmbeddingModels = useMemo(
|
||||||
filteredT2iAdapterModels: modelsFilter(data, searchTerm, filteredModelType),
|
() => modelsFilter(embeddingModels, searchTerm, filteredModelType),
|
||||||
isLoadingT2IAdapterModels: isLoading,
|
[embeddingModels, searchTerm, filteredModelType]
|
||||||
}),
|
);
|
||||||
});
|
|
||||||
|
|
||||||
const { filteredIpAdapterModels, isLoadingIpAdapterModels } = useGetIPAdapterModelsQuery(undefined, {
|
const [controlNetModels, { isLoading: isLoadingControlNetModels }] = useControlNetModels();
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
const filteredControlNetModels = useMemo(
|
||||||
filteredIpAdapterModels: modelsFilter(data, searchTerm, filteredModelType),
|
() => modelsFilter(controlNetModels, searchTerm, filteredModelType),
|
||||||
isLoadingIpAdapterModels: isLoading,
|
[controlNetModels, searchTerm, filteredModelType]
|
||||||
}),
|
);
|
||||||
});
|
|
||||||
|
|
||||||
const { filteredVaeModels, isLoadingVaeModels } = useGetVaeModelsQuery(undefined, {
|
const [t2iAdapterModels, { isLoading: isLoadingT2IAdapterModels }] = useT2IAdapterModels();
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
const filteredT2IAdapterModels = useMemo(
|
||||||
filteredVaeModels: modelsFilter(data, searchTerm, filteredModelType),
|
() => modelsFilter(t2iAdapterModels, searchTerm, filteredModelType),
|
||||||
isLoadingVaeModels: isLoading,
|
[t2iAdapterModels, searchTerm, filteredModelType]
|
||||||
}),
|
);
|
||||||
});
|
|
||||||
|
const [ipAdapterModels, { isLoading: isLoadingIPAdapterModels }] = useIPAdapterModels();
|
||||||
|
const filteredIPAdapterModels = useMemo(
|
||||||
|
() => modelsFilter(ipAdapterModels, searchTerm, filteredModelType),
|
||||||
|
[ipAdapterModels, searchTerm, filteredModelType]
|
||||||
|
);
|
||||||
|
|
||||||
|
const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels();
|
||||||
|
const filteredVAEModels = useMemo(
|
||||||
|
() => modelsFilter(vaeModels, searchTerm, filteredModelType),
|
||||||
|
[vaeModels, searchTerm, filteredModelType]
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ScrollableContent>
|
<ScrollableContent>
|
||||||
<Flex flexDirection="column" w="full" h="full" gap={4}>
|
<Flex flexDirection="column" w="full" h="full" gap={4}>
|
||||||
{/* Main Model List */}
|
{/* Main Model List */}
|
||||||
{isLoadingMainModels && <FetchingModelsLoader loadingMessage="Loading Main..." />}
|
{isLoadingMainModels && <FetchingModelsLoader loadingMessage="Loading Main Models..." />}
|
||||||
{!isLoadingMainModels && filteredMainModels.length > 0 && (
|
{!isLoadingMainModels && filteredMainModels.length > 0 && (
|
||||||
<ModelListWrapper title="Main" modelList={filteredMainModels} key="main" />
|
<ModelListWrapper title="Main" modelList={filteredMainModels} key="main" />
|
||||||
)}
|
)}
|
||||||
{/* LoRAs List */}
|
{/* LoRAs List */}
|
||||||
{isLoadingLoraModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
{isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
||||||
{!isLoadingLoraModels && filteredLoraModels.length > 0 && (
|
{!isLoadingLoRAModels && filteredLoRAModels.length > 0 && (
|
||||||
<ModelListWrapper title="LoRAs" modelList={filteredLoraModels} key="loras" />
|
<ModelListWrapper title="LoRA" modelList={filteredLoRAModels} key="loras" />
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* TI List */}
|
{/* TI List */}
|
||||||
{isLoadingTextualInversionModels && <FetchingModelsLoader loadingMessage="Loading Textual Inversions..." />}
|
{isLoadingEmbeddingModels && <FetchingModelsLoader loadingMessage="Loading Embeddings..." />}
|
||||||
{!isLoadingTextualInversionModels && filteredTextualInversionModels.length > 0 && (
|
{!isLoadingEmbeddingModels && filteredEmbeddingModels.length > 0 && (
|
||||||
<ModelListWrapper
|
<ModelListWrapper title="Embedding" modelList={filteredEmbeddingModels} key="textual-inversions" />
|
||||||
title="Textual Inversions"
|
|
||||||
modelList={filteredTextualInversionModels}
|
|
||||||
key="textual-inversions"
|
|
||||||
/>
|
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* VAE List */}
|
{/* VAE List */}
|
||||||
{isLoadingVaeModels && <FetchingModelsLoader loadingMessage="Loading VAEs..." />}
|
{isLoadingVAEModels && <FetchingModelsLoader loadingMessage="Loading VAEs..." />}
|
||||||
{!isLoadingVaeModels && filteredVaeModels.length > 0 && (
|
{!isLoadingVAEModels && filteredVAEModels.length > 0 && (
|
||||||
<ModelListWrapper title="VAEs" modelList={filteredVaeModels} key="vae" />
|
<ModelListWrapper title="VAE" modelList={filteredVAEModels} key="vae" />
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* Controlnet List */}
|
{/* Controlnet List */}
|
||||||
{isLoadingControlnetModels && <FetchingModelsLoader loadingMessage="Loading Controlnets..." />}
|
{isLoadingControlNetModels && <FetchingModelsLoader loadingMessage="Loading ControlNets..." />}
|
||||||
{!isLoadingControlnetModels && filteredControlnetModels.length > 0 && (
|
{!isLoadingControlNetModels && filteredControlNetModels.length > 0 && (
|
||||||
<ModelListWrapper title="Controlnets" modelList={filteredControlnetModels} key="controlnets" />
|
<ModelListWrapper title="ControlNet" modelList={filteredControlNetModels} key="controlnets" />
|
||||||
)}
|
)}
|
||||||
{/* IP Adapter List */}
|
{/* IP Adapter List */}
|
||||||
{isLoadingIpAdapterModels && <FetchingModelsLoader loadingMessage="Loading IP Adapters..." />}
|
{isLoadingIPAdapterModels && <FetchingModelsLoader loadingMessage="Loading IP Adapters..." />}
|
||||||
{!isLoadingIpAdapterModels && filteredIpAdapterModels.length > 0 && (
|
{!isLoadingIPAdapterModels && filteredIPAdapterModels.length > 0 && (
|
||||||
<ModelListWrapper title="IP Adapters" modelList={filteredIpAdapterModels} key="ip-adapters" />
|
<ModelListWrapper title="IP Adapter" modelList={filteredIPAdapterModels} key="ip-adapters" />
|
||||||
)}
|
)}
|
||||||
{/* T2I Adapters List */}
|
{/* T2I Adapters List */}
|
||||||
{isLoadingT2IAdapterModels && <FetchingModelsLoader loadingMessage="Loading T2I Adapters..." />}
|
{isLoadingT2IAdapterModels && <FetchingModelsLoader loadingMessage="Loading T2I Adapters..." />}
|
||||||
{!isLoadingT2IAdapterModels && filteredT2iAdapterModels.length > 0 && (
|
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
|
||||||
<ModelListWrapper title="T2I Adapters" modelList={filteredT2iAdapterModels} key="t2i-adapters" />
|
<ModelListWrapper title="T2I Adapter" modelList={filteredT2IAdapterModels} key="t2i-adapters" />
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</ScrollableContent>
|
</ScrollableContent>
|
||||||
@ -126,25 +109,16 @@ const ModelList = () => {
|
|||||||
export default memo(ModelList);
|
export default memo(ModelList);
|
||||||
|
|
||||||
const modelsFilter = <T extends AnyModelConfig>(
|
const modelsFilter = <T extends AnyModelConfig>(
|
||||||
data: EntityState<T, string> | undefined,
|
data: T[],
|
||||||
nameFilter: string,
|
nameFilter: string,
|
||||||
filteredModelType: string | null
|
filteredModelType: ModelType | null
|
||||||
): T[] => {
|
): T[] => {
|
||||||
const filteredModels: T[] = [];
|
return data.filter((model) => {
|
||||||
|
|
||||||
forEach(data?.entities, (model) => {
|
|
||||||
if (!model) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
|
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
|
||||||
const matchesType = filteredModelType ? model.type === filteredModelType : true;
|
const matchesType = filteredModelType ? model.type === filteredModelType : true;
|
||||||
|
|
||||||
if (matchesFilter && matchesType) {
|
return matchesFilter && matchesType;
|
||||||
filteredModels.push(model);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
return filteredModels;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => {
|
const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => {
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { IoFilter } from 'react-icons/io5';
|
import { PiFunnelBold } from 'react-icons/pi';
|
||||||
|
import { objectKeys } from 'tsafe';
|
||||||
|
|
||||||
const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = {
|
||||||
main: 'Main',
|
main: 'Main',
|
||||||
lora: 'LoRA',
|
lora: 'LoRA',
|
||||||
embedding: 'Textual Inversion',
|
embedding: 'Textual Inversion',
|
||||||
@ -13,7 +15,6 @@ const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
|||||||
vae: 'VAE',
|
vae: 'VAE',
|
||||||
t2i_adapter: 'T2I Adapter',
|
t2i_adapter: 'T2I Adapter',
|
||||||
ip_adapter: 'IP Adapter',
|
ip_adapter: 'IP Adapter',
|
||||||
clip_vision: 'Clip Vision',
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export const ModelTypeFilter = () => {
|
export const ModelTypeFilter = () => {
|
||||||
@ -22,7 +23,7 @@ export const ModelTypeFilter = () => {
|
|||||||
const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType);
|
const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType);
|
||||||
|
|
||||||
const selectModelType = useCallback(
|
const selectModelType = useCallback(
|
||||||
(option: string) => {
|
(option: FilterableModelType) => {
|
||||||
dispatch(setFilteredModelType(option));
|
dispatch(setFilteredModelType(option));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
@ -34,12 +35,12 @@ export const ModelTypeFilter = () => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Menu>
|
<Menu>
|
||||||
<MenuButton as={Button} size="sm" leftIcon={<IoFilter />}>
|
<MenuButton as={Button} size="sm" leftIcon={<PiFunnelBold />}>
|
||||||
{filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')}
|
{filteredModelType ? MODEL_TYPE_LABELS[filteredModelType] : t('modelManager.allModels')}
|
||||||
</MenuButton>
|
</MenuButton>
|
||||||
<MenuList>
|
<MenuList>
|
||||||
<MenuItem onClick={clearModelType}>{t('modelManager.allModels')}</MenuItem>
|
<MenuItem onClick={clearModelType}>{t('modelManager.allModels')}</MenuItem>
|
||||||
{Object.keys(MODEL_TYPE_LABELS).map((option) => (
|
{objectKeys(MODEL_TYPE_LABELS).map((option) => (
|
||||||
<MenuItem
|
<MenuItem
|
||||||
key={option}
|
key={option}
|
||||||
bg={filteredModelType === option ? 'base.700' : 'transparent'}
|
bg={filteredModelType === option ? 'base.700' : 'transparent'}
|
||||||
|
@ -0,0 +1,81 @@
|
|||||||
|
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
|
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings';
|
||||||
|
|
||||||
|
type DefaultHeightType = MainModelDefaultSettingsFormData['height'];
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
control: UseControllerProps<MainModelDefaultSettingsFormData>['control'];
|
||||||
|
optimalDimension: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function DefaultHeight({ control, optimalDimension }: Props) {
|
||||||
|
const { field } = useController({ control, name: 'height' });
|
||||||
|
const sliderMin = useAppSelector((s) => s.config.sd.height.sliderMin);
|
||||||
|
const sliderMax = useAppSelector((s) => s.config.sd.height.sliderMax);
|
||||||
|
const numberInputMin = useAppSelector((s) => s.config.sd.height.numberInputMin);
|
||||||
|
const numberInputMax = useAppSelector((s) => s.config.sd.height.numberInputMax);
|
||||||
|
const coarseStep = useAppSelector((s) => s.config.sd.height.coarseStep);
|
||||||
|
const fineStep = useAppSelector((s) => s.config.sd.height.fineStep);
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const marks = useMemo(() => [sliderMin, optimalDimension, sliderMax], [sliderMin, optimalDimension, sliderMax]);
|
||||||
|
|
||||||
|
const onChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
const updatedValue = {
|
||||||
|
...(field.value as DefaultHeightType),
|
||||||
|
value: v,
|
||||||
|
};
|
||||||
|
field.onChange(updatedValue);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
|
||||||
|
const value = useMemo(() => {
|
||||||
|
return field.value.value;
|
||||||
|
}, [field.value]);
|
||||||
|
|
||||||
|
const isDisabled = useMemo(() => {
|
||||||
|
return !field.value.isEnabled;
|
||||||
|
}, [field.value]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl flexDir="column" gap={2} alignItems="flex-start">
|
||||||
|
<Flex justifyContent="space-between" w="full">
|
||||||
|
<InformationalPopover feature="paramHeight">
|
||||||
|
<FormLabel>{t('parameters.height')}</FormLabel>
|
||||||
|
</InformationalPopover>
|
||||||
|
<SettingToggle control={control} name="height" />
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
<Flex w="full" gap={4}>
|
||||||
|
<CompositeSlider
|
||||||
|
value={value}
|
||||||
|
min={sliderMin}
|
||||||
|
max={sliderMax}
|
||||||
|
step={coarseStep}
|
||||||
|
fineStep={fineStep}
|
||||||
|
onChange={onChange}
|
||||||
|
marks={marks}
|
||||||
|
isDisabled={isDisabled}
|
||||||
|
/>
|
||||||
|
<CompositeNumberInput
|
||||||
|
value={value}
|
||||||
|
min={numberInputMin}
|
||||||
|
max={numberInputMax}
|
||||||
|
step={coarseStep}
|
||||||
|
fineStep={fineStep}
|
||||||
|
onChange={onChange}
|
||||||
|
isDisabled={isDisabled}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
}
|
@ -4,12 +4,12 @@ import { skipToken } from '@reduxjs/toolkit/query';
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||||
import { map } from 'lodash-es';
|
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetModelConfigQuery, useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||||
|
import { useVAEModels } from 'services/api/hooks/modelsByType';
|
||||||
|
|
||||||
import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings';
|
import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings';
|
||||||
|
|
||||||
@ -21,18 +21,16 @@ export function DefaultVae(props: UseControllerProps<MainModelDefaultSettingsFor
|
|||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
const { data: modelData } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||||
|
|
||||||
const { compatibleOptions } = useGetVaeModelsQuery(undefined, {
|
const [vaeModels] = useVAEModels();
|
||||||
selectFromResult: ({ data }) => {
|
const compatibleOptions = useMemo(() => {
|
||||||
const modelArray = map(data?.entities);
|
const compatibleOptions = vaeModels
|
||||||
const compatibleOptions = modelArray
|
|
||||||
.filter((vae) => vae.base === modelData?.base)
|
.filter((vae) => vae.base === modelData?.base)
|
||||||
.map((vae) => ({ label: vae.name, value: vae.key }));
|
.map((vae) => ({ label: vae.name, value: vae.key }));
|
||||||
|
|
||||||
const defaultOption = { label: 'Default VAE', value: 'default' };
|
const defaultOption = { label: 'Default VAE', value: 'default' };
|
||||||
|
|
||||||
return { compatibleOptions: [defaultOption, ...compatibleOptions] };
|
return [defaultOption, ...compatibleOptions];
|
||||||
},
|
}, [modelData?.base, vaeModels]);
|
||||||
});
|
|
||||||
|
|
||||||
const onChange = useCallback<ComboboxOnChange>(
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
(v) => {
|
(v) => {
|
||||||
|
@ -0,0 +1,81 @@
|
|||||||
|
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
|
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings';
|
||||||
|
|
||||||
|
type DefaultWidthType = MainModelDefaultSettingsFormData['width'];
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
control: UseControllerProps<MainModelDefaultSettingsFormData>['control'];
|
||||||
|
optimalDimension: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function DefaultWidth({ control, optimalDimension }: Props) {
|
||||||
|
const { field } = useController({ control, name: 'width' });
|
||||||
|
const sliderMin = useAppSelector((s) => s.config.sd.width.sliderMin);
|
||||||
|
const sliderMax = useAppSelector((s) => s.config.sd.width.sliderMax);
|
||||||
|
const numberInputMin = useAppSelector((s) => s.config.sd.width.numberInputMin);
|
||||||
|
const numberInputMax = useAppSelector((s) => s.config.sd.width.numberInputMax);
|
||||||
|
const coarseStep = useAppSelector((s) => s.config.sd.width.coarseStep);
|
||||||
|
const fineStep = useAppSelector((s) => s.config.sd.width.fineStep);
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const marks = useMemo(() => [sliderMin, optimalDimension, sliderMax], [sliderMin, optimalDimension, sliderMax]);
|
||||||
|
|
||||||
|
const onChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
const updatedValue = {
|
||||||
|
...(field.value as DefaultWidthType),
|
||||||
|
value: v,
|
||||||
|
};
|
||||||
|
field.onChange(updatedValue);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
|
||||||
|
const value = useMemo(() => {
|
||||||
|
return (field.value as DefaultWidthType).value;
|
||||||
|
}, [field.value]);
|
||||||
|
|
||||||
|
const isDisabled = useMemo(() => {
|
||||||
|
return !(field.value as DefaultWidthType).isEnabled;
|
||||||
|
}, [field.value]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl flexDir="column" gap={2} alignItems="flex-start">
|
||||||
|
<Flex justifyContent="space-between" w="full">
|
||||||
|
<InformationalPopover feature="paramWidth">
|
||||||
|
<FormLabel>{t('parameters.width')}</FormLabel>
|
||||||
|
</InformationalPopover>
|
||||||
|
<SettingToggle control={control} name="width" />
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
<Flex w="full" gap={4}>
|
||||||
|
<CompositeSlider
|
||||||
|
value={value}
|
||||||
|
min={sliderMin}
|
||||||
|
max={sliderMax}
|
||||||
|
step={coarseStep}
|
||||||
|
fineStep={fineStep}
|
||||||
|
onChange={onChange}
|
||||||
|
marks={marks}
|
||||||
|
isDisabled={isDisabled}
|
||||||
|
/>
|
||||||
|
<CompositeNumberInput
|
||||||
|
value={value}
|
||||||
|
min={numberInputMin}
|
||||||
|
max={numberInputMax}
|
||||||
|
step={coarseStep}
|
||||||
|
fineStep={fineStep}
|
||||||
|
onChange={onChange}
|
||||||
|
isDisabled={isDisabled}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
}
|
@ -1,6 +1,8 @@
|
|||||||
import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
|
import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings';
|
import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings';
|
||||||
|
import { DefaultHeight } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultHeight';
|
||||||
|
import { DefaultWidth } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultWidth';
|
||||||
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
@ -25,11 +27,13 @@ export interface FormField<T> {
|
|||||||
|
|
||||||
export type MainModelDefaultSettingsFormData = {
|
export type MainModelDefaultSettingsFormData = {
|
||||||
vae: FormField<string>;
|
vae: FormField<string>;
|
||||||
vaePrecision: FormField<string>;
|
vaePrecision: FormField<'fp16' | 'fp32'>;
|
||||||
scheduler: FormField<ParameterScheduler>;
|
scheduler: FormField<ParameterScheduler>;
|
||||||
steps: FormField<number>;
|
steps: FormField<number>;
|
||||||
cfgScale: FormField<number>;
|
cfgScale: FormField<number>;
|
||||||
cfgRescaleMultiplier: FormField<number>;
|
cfgRescaleMultiplier: FormField<number>;
|
||||||
|
width: FormField<number>;
|
||||||
|
height: FormField<number>;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const MainModelDefaultSettings = () => {
|
export const MainModelDefaultSettings = () => {
|
||||||
@ -37,8 +41,11 @@ export const MainModelDefaultSettings = () => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { defaultSettingsDefaults, isLoading: isLoadingDefaultSettings } =
|
const {
|
||||||
useMainModelDefaultSettings(selectedModelKey);
|
defaultSettingsDefaults,
|
||||||
|
isLoading: isLoadingDefaultSettings,
|
||||||
|
optimalDimension,
|
||||||
|
} = useMainModelDefaultSettings(selectedModelKey);
|
||||||
|
|
||||||
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
|
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
|
||||||
|
|
||||||
@ -59,6 +66,8 @@ export const MainModelDefaultSettings = () => {
|
|||||||
cfg_rescale_multiplier: data.cfgRescaleMultiplier.isEnabled ? data.cfgRescaleMultiplier.value : null,
|
cfg_rescale_multiplier: data.cfgRescaleMultiplier.isEnabled ? data.cfgRescaleMultiplier.value : null,
|
||||||
steps: data.steps.isEnabled ? data.steps.value : null,
|
steps: data.steps.isEnabled ? data.steps.value : null,
|
||||||
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
|
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
|
||||||
|
width: data.width.isEnabled ? data.width.value : null,
|
||||||
|
height: data.height.isEnabled ? data.height.value : null,
|
||||||
};
|
};
|
||||||
|
|
||||||
updateModel({
|
updateModel({
|
||||||
@ -139,6 +148,14 @@ export const MainModelDefaultSettings = () => {
|
|||||||
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
|
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
<Flex gap={8}>
|
||||||
|
<Flex gap={4} w="full">
|
||||||
|
<DefaultWidth control={control} optimalDimension={optimalDimension} />
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4} w="full">
|
||||||
|
<DefaultHeight control={control} optimalDimension={optimalDimension} />
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
|
@ -90,7 +90,7 @@ export const TriggerPhrases = () => {
|
|||||||
size="sm"
|
size="sm"
|
||||||
type="submit"
|
type="submit"
|
||||||
onClick={addTriggerPhrase}
|
onClick={addTriggerPhrase}
|
||||||
isDisabled={Boolean(errors.length)}
|
isDisabled={!phrase || Boolean(errors.length)}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
>
|
>
|
||||||
{t('common.add')}
|
{t('common.add')}
|
||||||
|
@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
|||||||
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
import { useControlNetModels } from 'services/api/hooks/modelsByType';
|
||||||
import type { ControlNetModelConfig } from 'services/api/types';
|
import type { ControlNetModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
@ -14,7 +14,7 @@ type Props = FieldComponentProps<ControlNetModelFieldInputInstance, ControlNetMo
|
|||||||
const ControlNetModelFieldInputComponent = (props: Props) => {
|
const ControlNetModelFieldInputComponent = (props: Props) => {
|
||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data, isLoading } = useGetControlNetModelsQuery();
|
const [modelConfigs, { isLoading }] = useControlNetModels();
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: ControlNetModelConfig | null) => {
|
(value: ControlNetModelConfig | null) => {
|
||||||
@ -33,7 +33,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: field.value,
|
selectedModel: field.value,
|
||||||
isLoading,
|
isLoading,
|
||||||
|
@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
|||||||
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
|
import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
|
||||||
import type { IPAdapterModelConfig } from 'services/api/types';
|
import type { IPAdapterModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
@ -14,7 +14,7 @@ const IPAdapterModelFieldInputComponent = (
|
|||||||
) => {
|
) => {
|
||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
|
const [modelConfigs, { isLoading }] = useIPAdapterModels();
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: IPAdapterModelConfig | null) => {
|
(value: IPAdapterModelConfig | null) => {
|
||||||
@ -33,9 +33,10 @@ const IPAdapterModelFieldInputComponent = (
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { options, value, onChange } = useGroupedModelCombobox({
|
const { options, value, onChange } = useGroupedModelCombobox({
|
||||||
modelEntities: ipAdapterModels,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: field.value,
|
selectedModel: field.value,
|
||||||
|
isLoading,
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
|||||||
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
import { useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||||
import type { LoRAModelConfig } from 'services/api/types';
|
import type { LoRAModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
@ -14,7 +14,7 @@ type Props = FieldComponentProps<LoRAModelFieldInputInstance, LoRAModelFieldInpu
|
|||||||
const LoRAModelFieldInputComponent = (props: Props) => {
|
const LoRAModelFieldInputComponent = (props: Props) => {
|
||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data, isLoading } = useGetLoRAModelsQuery();
|
const [modelConfigs, { isLoading }] = useLoRAModels();
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: LoRAModelConfig | null) => {
|
(value: LoRAModelConfig | null) => {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
@ -32,7 +32,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: field.value,
|
selectedModel: field.value,
|
||||||
isLoading,
|
isLoading,
|
||||||
|
@ -5,8 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
|
|||||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { NON_SDXL_MAIN_MODELS } from 'services/api/constants';
|
import { useNonSDXLMainModels } from 'services/api/hooks/modelsByType';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
import type { MainModelConfig } from 'services/api/types';
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
@ -16,7 +15,7 @@ type Props = FieldComponentProps<MainModelFieldInputInstance, MainModelFieldInpu
|
|||||||
const MainModelFieldInputComponent = (props: Props) => {
|
const MainModelFieldInputComponent = (props: Props) => {
|
||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data, isLoading } = useGetMainModelsQuery(NON_SDXL_MAIN_MODELS);
|
const [modelConfigs, { isLoading }] = useNonSDXLMainModels();
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: MainModelConfig | null) => {
|
(value: MainModelConfig | null) => {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
@ -33,7 +32,7 @@ const MainModelFieldInputComponent = (props: Props) => {
|
|||||||
[dispatch, field.name, nodeId]
|
[dispatch, field.name, nodeId]
|
||||||
);
|
);
|
||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
isLoading,
|
isLoading,
|
||||||
selectedModel: field.value,
|
selectedModel: field.value,
|
||||||
|
@ -8,8 +8,7 @@ import type {
|
|||||||
SDXLRefinerModelFieldInputTemplate,
|
SDXLRefinerModelFieldInputTemplate,
|
||||||
} from 'features/nodes/types/field';
|
} from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
import { useRefinerModels } from 'services/api/hooks/modelsByType';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
import type { MainModelConfig } from 'services/api/types';
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
@ -19,7 +18,7 @@ type Props = FieldComponentProps<SDXLRefinerModelFieldInputInstance, SDXLRefiner
|
|||||||
const RefinerModelFieldInputComponent = (props: Props) => {
|
const RefinerModelFieldInputComponent = (props: Props) => {
|
||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS);
|
const [modelConfigs, { isLoading }] = useRefinerModels();
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: MainModelConfig | null) => {
|
(value: MainModelConfig | null) => {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
@ -36,7 +35,7 @@ const RefinerModelFieldInputComponent = (props: Props) => {
|
|||||||
[dispatch, field.name, nodeId]
|
[dispatch, field.name, nodeId]
|
||||||
);
|
);
|
||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
isLoading,
|
isLoading,
|
||||||
selectedModel: field.value,
|
selectedModel: field.value,
|
||||||
|
@ -5,8 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
|
|||||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { SDXL_MAIN_MODELS } from 'services/api/constants';
|
import { useSDXLModels } from 'services/api/hooks/modelsByType';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
import type { MainModelConfig } from 'services/api/types';
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
@ -16,7 +15,7 @@ type Props = FieldComponentProps<SDXLMainModelFieldInputInstance, SDXLMainModelF
|
|||||||
const SDXLMainModelFieldInputComponent = (props: Props) => {
|
const SDXLMainModelFieldInputComponent = (props: Props) => {
|
||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data, isLoading } = useGetMainModelsQuery(SDXL_MAIN_MODELS);
|
const [modelConfigs, { isLoading }] = useSDXLModels();
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: MainModelConfig | null) => {
|
(value: MainModelConfig | null) => {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
@ -33,7 +32,7 @@ const SDXLMainModelFieldInputComponent = (props: Props) => {
|
|||||||
[dispatch, field.name, nodeId]
|
[dispatch, field.name, nodeId]
|
||||||
);
|
);
|
||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
isLoading,
|
isLoading,
|
||||||
selectedModel: field.value,
|
selectedModel: field.value,
|
||||||
|
@ -4,7 +4,7 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
|||||||
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
|
import { useT2IAdapterModels } from 'services/api/hooks/modelsByType';
|
||||||
import type { T2IAdapterModelConfig } from 'services/api/types';
|
import type { T2IAdapterModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
@ -15,7 +15,7 @@ const T2IAdapterModelFieldInputComponent = (
|
|||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
|
const [modelConfigs, { isLoading }] = useT2IAdapterModels();
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: T2IAdapterModelConfig | null) => {
|
(value: T2IAdapterModelConfig | null) => {
|
||||||
@ -34,9 +34,10 @@ const T2IAdapterModelFieldInputComponent = (
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { options, value, onChange } = useGroupedModelCombobox({
|
const { options, value, onChange } = useGroupedModelCombobox({
|
||||||
modelEntities: t2iAdapterModels,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: field.value,
|
selectedModel: field.value,
|
||||||
|
isLoading,
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -5,7 +5,7 @@ import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncMod
|
|||||||
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
import { useVAEModels } from 'services/api/hooks/modelsByType';
|
||||||
import type { VAEModelConfig } from 'services/api/types';
|
import type { VAEModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import type { FieldComponentProps } from './types';
|
import type { FieldComponentProps } from './types';
|
||||||
@ -15,7 +15,7 @@ type Props = FieldComponentProps<VAEModelFieldInputInstance, VAEModelFieldInputT
|
|||||||
const VAEModelFieldInputComponent = (props: Props) => {
|
const VAEModelFieldInputComponent = (props: Props) => {
|
||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { data, isLoading } = useGetVaeModelsQuery();
|
const [modelConfigs, { isLoading }] = useVAEModels();
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(value: VAEModelConfig | null) => {
|
(value: VAEModelConfig | null) => {
|
||||||
if (!value) {
|
if (!value) {
|
||||||
@ -32,7 +32,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
|
|||||||
[dispatch, field.name, nodeId]
|
[dispatch, field.name, nodeId]
|
||||||
);
|
);
|
||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: field.value,
|
selectedModel: field.value,
|
||||||
isLoading,
|
isLoading,
|
||||||
|
@ -1,16 +1,18 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import type { ControlAdapterProcessorType, ControlNetConfig } from 'features/controlAdapters/store/types';
|
||||||
|
import type { ImageField } from 'features/nodes/types/common';
|
||||||
import type {
|
import type {
|
||||||
CollectInvocation,
|
CollectInvocation,
|
||||||
ControlNetInvocation,
|
ControlNetInvocation,
|
||||||
CoreMetadataInvocation,
|
CoreMetadataInvocation,
|
||||||
NonNullableGraph,
|
NonNullableGraph,
|
||||||
|
S,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import { isControlNetModelConfig } from 'services/api/types';
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
import { CONTROL_NET_COLLECT } from './constants';
|
import { CONTROL_NET_COLLECT } from './constants';
|
||||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
import { upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addControlNetToLinearGraph = async (
|
export const addControlNetToLinearGraph = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
@ -18,12 +20,14 @@ export const addControlNetToLinearGraph = async (
|
|||||||
baseNodeId: string
|
baseNodeId: string
|
||||||
): Promise<void> => {
|
): Promise<void> => {
|
||||||
const validControlNets = selectValidControlNets(state.controlAdapters).filter(
|
const validControlNets = selectValidControlNets(state.controlAdapters).filter(
|
||||||
(ca) => ca.model?.base === state.generation.model?.base
|
({ model, processedControlImage, processorType, controlImage, isEnabled }) => {
|
||||||
);
|
const hasModel = Boolean(model);
|
||||||
|
const doesBaseMatch = model?.base === state.generation.model?.base;
|
||||||
|
const hasControlImage = (processedControlImage && processorType !== 'none') || controlImage;
|
||||||
|
|
||||||
// const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
|
||||||
// | MetadataAccumulatorInvocation
|
}
|
||||||
// | undefined;
|
);
|
||||||
|
|
||||||
const controlNetMetadata: CoreMetadataInvocation['controlnets'] = [];
|
const controlNetMetadata: CoreMetadataInvocation['controlnets'] = [];
|
||||||
|
|
||||||
@ -43,7 +47,7 @@ export const addControlNetToLinearGraph = async (
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
validControlNets.forEach(async (controlNet) => {
|
for (const controlNet of validControlNets) {
|
||||||
if (!controlNet.model) {
|
if (!controlNet.model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -70,36 +74,12 @@ export const addControlNetToLinearGraph = async (
|
|||||||
resize_mode: resizeMode,
|
resize_mode: resizeMode,
|
||||||
control_model: model,
|
control_model: model,
|
||||||
control_weight: weight,
|
control_weight: weight,
|
||||||
|
image: buildControlImage(controlImage, processedControlImage, processorType),
|
||||||
};
|
};
|
||||||
|
|
||||||
if (processedControlImage && processorType !== 'none') {
|
graph.nodes[controlNetNode.id] = controlNetNode;
|
||||||
// We've already processed the image in the app, so we can just use the processed image
|
|
||||||
controlNetNode.image = {
|
|
||||||
image_name: processedControlImage,
|
|
||||||
};
|
|
||||||
} else if (controlImage) {
|
|
||||||
// The control image is preprocessed
|
|
||||||
controlNetNode.image = {
|
|
||||||
image_name: controlImage,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
|
controlNetMetadata.push(buildControlNetMetadata(controlNet));
|
||||||
|
|
||||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isControlNetModelConfig);
|
|
||||||
|
|
||||||
controlNetMetadata.push({
|
|
||||||
control_model: getModelMetadataField(modelConfig),
|
|
||||||
control_weight: weight,
|
|
||||||
control_mode: controlMode,
|
|
||||||
begin_step_percent: beginStepPct,
|
|
||||||
end_step_percent: endStepPct,
|
|
||||||
resize_mode: resizeMode,
|
|
||||||
image: controlNetNode.image,
|
|
||||||
});
|
|
||||||
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: controlNetNode.id, field: 'control' },
|
source: { node_id: controlNetNode.id, field: 'control' },
|
||||||
@ -108,7 +88,66 @@ export const addControlNetToLinearGraph = async (
|
|||||||
field: 'item',
|
field: 'item',
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
});
|
}
|
||||||
upsertMetadata(graph, { controlnets: controlNetMetadata });
|
upsertMetadata(graph, { controlnets: controlNetMetadata });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildControlImage = (
|
||||||
|
controlImage: string | null,
|
||||||
|
processedControlImage: string | null,
|
||||||
|
processorType: ControlAdapterProcessorType
|
||||||
|
): ImageField => {
|
||||||
|
let image: ImageField | null = null;
|
||||||
|
if (processedControlImage && processorType !== 'none') {
|
||||||
|
// We've already processed the image in the app, so we can just use the processed image
|
||||||
|
image = {
|
||||||
|
image_name: processedControlImage,
|
||||||
|
};
|
||||||
|
} else if (controlImage) {
|
||||||
|
// The control image is preprocessed
|
||||||
|
image = {
|
||||||
|
image_name: controlImage,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
assert(image, 'ControlNet image is required');
|
||||||
|
return image;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildControlNetMetadata = (controlNet: ControlNetConfig): S['ControlNetMetadataField'] => {
|
||||||
|
const {
|
||||||
|
controlImage,
|
||||||
|
processedControlImage,
|
||||||
|
beginStepPct,
|
||||||
|
endStepPct,
|
||||||
|
controlMode,
|
||||||
|
resizeMode,
|
||||||
|
model,
|
||||||
|
processorType,
|
||||||
|
weight,
|
||||||
|
} = controlNet;
|
||||||
|
|
||||||
|
assert(model, 'ControlNet model is required');
|
||||||
|
|
||||||
|
const processed_image =
|
||||||
|
processedControlImage && processorType !== 'none'
|
||||||
|
? {
|
||||||
|
image_name: processedControlImage,
|
||||||
|
}
|
||||||
|
: null;
|
||||||
|
|
||||||
|
assert(controlImage, 'ControlNet image is required');
|
||||||
|
|
||||||
|
return {
|
||||||
|
control_model: model,
|
||||||
|
control_weight: weight,
|
||||||
|
control_mode: controlMode,
|
||||||
|
begin_step_percent: beginStepPct,
|
||||||
|
end_step_percent: endStepPct,
|
||||||
|
resize_mode: resizeMode,
|
||||||
|
image: {
|
||||||
|
image_name: controlImage,
|
||||||
|
},
|
||||||
|
processed_image,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
@ -1,25 +1,30 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
|
||||||
|
import type { ImageField } from 'features/nodes/types/common';
|
||||||
import type {
|
import type {
|
||||||
CollectInvocation,
|
CollectInvocation,
|
||||||
CoreMetadataInvocation,
|
CoreMetadataInvocation,
|
||||||
IPAdapterInvocation,
|
IPAdapterInvocation,
|
||||||
NonNullableGraph,
|
NonNullableGraph,
|
||||||
|
S,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import { isIPAdapterModelConfig } from 'services/api/types';
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
import { IP_ADAPTER_COLLECT } from './constants';
|
import { IP_ADAPTER_COLLECT } from './constants';
|
||||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
import { upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addIPAdapterToLinearGraph = async (
|
export const addIPAdapterToLinearGraph = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
graph: NonNullableGraph,
|
graph: NonNullableGraph,
|
||||||
baseNodeId: string
|
baseNodeId: string
|
||||||
): Promise<void> => {
|
): Promise<void> => {
|
||||||
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(
|
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(({ model, controlImage, isEnabled }) => {
|
||||||
(ca) => ca.model?.base === state.generation.model?.base
|
const hasModel = Boolean(model);
|
||||||
);
|
const doesBaseMatch = model?.base === state.generation.model?.base;
|
||||||
|
const hasControlImage = controlImage;
|
||||||
|
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
|
||||||
|
});
|
||||||
|
|
||||||
if (validIPAdapters.length) {
|
if (validIPAdapters.length) {
|
||||||
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
|
// Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect
|
||||||
@ -39,11 +44,14 @@ export const addIPAdapterToLinearGraph = async (
|
|||||||
|
|
||||||
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
|
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
|
||||||
|
|
||||||
validIPAdapters.forEach(async (ipAdapter) => {
|
for (const ipAdapter of validIPAdapters) {
|
||||||
if (!ipAdapter.model) {
|
if (!ipAdapter.model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const { id, weight, model, beginStepPct, endStepPct } = ipAdapter;
|
const { id, weight, model, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
||||||
|
|
||||||
|
assert(controlImage, 'IP Adapter image is required');
|
||||||
|
|
||||||
const ipAdapterNode: IPAdapterInvocation = {
|
const ipAdapterNode: IPAdapterInvocation = {
|
||||||
id: `ip_adapter_${id}`,
|
id: `ip_adapter_${id}`,
|
||||||
type: 'ip_adapter',
|
type: 'ip_adapter',
|
||||||
@ -52,27 +60,14 @@ export const addIPAdapterToLinearGraph = async (
|
|||||||
ip_adapter_model: model,
|
ip_adapter_model: model,
|
||||||
begin_step_percent: beginStepPct,
|
begin_step_percent: beginStepPct,
|
||||||
end_step_percent: endStepPct,
|
end_step_percent: endStepPct,
|
||||||
|
image: {
|
||||||
|
image_name: controlImage,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
if (ipAdapter.controlImage) {
|
|
||||||
ipAdapterNode.image = {
|
|
||||||
image_name: ipAdapter.controlImage,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
|
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
|
||||||
|
|
||||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isIPAdapterModelConfig);
|
ipAdapterMetdata.push(buildIPAdapterMetadata(ipAdapter));
|
||||||
|
|
||||||
ipAdapterMetdata.push({
|
|
||||||
weight: weight,
|
|
||||||
ip_adapter_model: getModelMetadataField(modelConfig),
|
|
||||||
begin_step_percent: beginStepPct,
|
|
||||||
end_step_percent: endStepPct,
|
|
||||||
image: ipAdapterNode.image,
|
|
||||||
});
|
|
||||||
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
||||||
@ -81,8 +76,32 @@ export const addIPAdapterToLinearGraph = async (
|
|||||||
field: 'item',
|
field: 'item',
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
});
|
}
|
||||||
|
|
||||||
upsertMetadata(graph, { ipAdapters: ipAdapterMetdata });
|
upsertMetadata(graph, { ipAdapters: ipAdapterMetdata });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
|
||||||
|
const { controlImage, beginStepPct, endStepPct, model, weight } = ipAdapter;
|
||||||
|
|
||||||
|
assert(model, 'IP Adapter model is required');
|
||||||
|
|
||||||
|
let image: ImageField | null = null;
|
||||||
|
|
||||||
|
if (controlImage) {
|
||||||
|
image = {
|
||||||
|
image_name: controlImage,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(image, 'IP Adapter image is required');
|
||||||
|
|
||||||
|
return {
|
||||||
|
ip_adapter_model: model,
|
||||||
|
weight,
|
||||||
|
begin_step_percent: beginStepPct,
|
||||||
|
end_step_percent: endStepPct,
|
||||||
|
image,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
@ -1,16 +1,18 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import type { ControlAdapterProcessorType, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||||
import {
|
import type { ImageField } from 'features/nodes/types/common';
|
||||||
type CollectInvocation,
|
import type {
|
||||||
type CoreMetadataInvocation,
|
CollectInvocation,
|
||||||
isT2IAdapterModelConfig,
|
CoreMetadataInvocation,
|
||||||
type NonNullableGraph,
|
NonNullableGraph,
|
||||||
type T2IAdapterInvocation,
|
S,
|
||||||
|
T2IAdapterInvocation,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
import { T2I_ADAPTER_COLLECT } from './constants';
|
import { T2I_ADAPTER_COLLECT } from './constants';
|
||||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
import { upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addT2IAdaptersToLinearGraph = async (
|
export const addT2IAdaptersToLinearGraph = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
@ -18,7 +20,13 @@ export const addT2IAdaptersToLinearGraph = async (
|
|||||||
baseNodeId: string
|
baseNodeId: string
|
||||||
): Promise<void> => {
|
): Promise<void> => {
|
||||||
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
|
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
|
||||||
(ca) => ca.model?.base === state.generation.model?.base
|
({ model, processedControlImage, processorType, controlImage, isEnabled }) => {
|
||||||
|
const hasModel = Boolean(model);
|
||||||
|
const doesBaseMatch = model?.base === state.generation.model?.base;
|
||||||
|
const hasControlImage = (processedControlImage && processorType !== 'none') || controlImage;
|
||||||
|
|
||||||
|
return isEnabled && hasModel && doesBaseMatch && hasControlImage;
|
||||||
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
if (validT2IAdapters.length) {
|
if (validT2IAdapters.length) {
|
||||||
@ -39,7 +47,7 @@ export const addT2IAdaptersToLinearGraph = async (
|
|||||||
|
|
||||||
const t2iAdapterMetadata: CoreMetadataInvocation['t2iAdapters'] = [];
|
const t2iAdapterMetadata: CoreMetadataInvocation['t2iAdapters'] = [];
|
||||||
|
|
||||||
validT2IAdapters.forEach(async (t2iAdapter) => {
|
for (const t2iAdapter of validT2IAdapters) {
|
||||||
if (!t2iAdapter.model) {
|
if (!t2iAdapter.model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -64,35 +72,12 @@ export const addT2IAdaptersToLinearGraph = async (
|
|||||||
resize_mode: resizeMode,
|
resize_mode: resizeMode,
|
||||||
t2i_adapter_model: model,
|
t2i_adapter_model: model,
|
||||||
weight: weight,
|
weight: weight,
|
||||||
|
image: buildControlImage(controlImage, processedControlImage, processorType),
|
||||||
};
|
};
|
||||||
|
|
||||||
if (processedControlImage && processorType !== 'none') {
|
|
||||||
// We've already processed the image in the app, so we can just use the processed image
|
|
||||||
t2iAdapterNode.image = {
|
|
||||||
image_name: processedControlImage,
|
|
||||||
};
|
|
||||||
} else if (controlImage) {
|
|
||||||
// The control image is preprocessed
|
|
||||||
t2iAdapterNode.image = {
|
|
||||||
image_name: controlImage,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
|
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
|
||||||
|
|
||||||
const modelConfig = await fetchModelConfigWithTypeGuard(t2iAdapter.model.key, isT2IAdapterModelConfig);
|
t2iAdapterMetadata.push(buildT2IAdapterMetadata(t2iAdapter));
|
||||||
|
|
||||||
t2iAdapterMetadata.push({
|
|
||||||
begin_step_percent: beginStepPct,
|
|
||||||
end_step_percent: endStepPct,
|
|
||||||
resize_mode: resizeMode,
|
|
||||||
t2i_adapter_model: getModelMetadataField(modelConfig),
|
|
||||||
weight: weight,
|
|
||||||
image: t2iAdapterNode.image,
|
|
||||||
});
|
|
||||||
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
|
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
|
||||||
@ -101,8 +86,57 @@ export const addT2IAdaptersToLinearGraph = async (
|
|||||||
field: 'item',
|
field: 'item',
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
});
|
}
|
||||||
|
|
||||||
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetadata });
|
upsertMetadata(graph, { t2iAdapters: t2iAdapterMetadata });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildControlImage = (
|
||||||
|
controlImage: string | null,
|
||||||
|
processedControlImage: string | null,
|
||||||
|
processorType: ControlAdapterProcessorType
|
||||||
|
): ImageField => {
|
||||||
|
let image: ImageField | null = null;
|
||||||
|
if (processedControlImage && processorType !== 'none') {
|
||||||
|
// We've already processed the image in the app, so we can just use the processed image
|
||||||
|
image = {
|
||||||
|
image_name: processedControlImage,
|
||||||
|
};
|
||||||
|
} else if (controlImage) {
|
||||||
|
// The control image is preprocessed
|
||||||
|
image = {
|
||||||
|
image_name: controlImage,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
assert(image, 'T2I Adapter image is required');
|
||||||
|
return image;
|
||||||
|
};
|
||||||
|
|
||||||
|
const buildT2IAdapterMetadata = (t2iAdapter: T2IAdapterConfig): S['T2IAdapterMetadataField'] => {
|
||||||
|
const { controlImage, processedControlImage, beginStepPct, endStepPct, resizeMode, model, processorType, weight } =
|
||||||
|
t2iAdapter;
|
||||||
|
|
||||||
|
assert(model, 'T2I Adapter model is required');
|
||||||
|
|
||||||
|
const processed_image =
|
||||||
|
processedControlImage && processorType !== 'none'
|
||||||
|
? {
|
||||||
|
image_name: processedControlImage,
|
||||||
|
}
|
||||||
|
: null;
|
||||||
|
|
||||||
|
assert(controlImage, 'T2I Adapter image is required');
|
||||||
|
|
||||||
|
return {
|
||||||
|
t2i_adapter_model: model,
|
||||||
|
weight,
|
||||||
|
begin_step_percent: beginStepPct,
|
||||||
|
end_step_percent: endStepPct,
|
||||||
|
resize_mode: resizeMode,
|
||||||
|
image: {
|
||||||
|
image_name: controlImage,
|
||||||
|
},
|
||||||
|
processed_image,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
import type { NonNullableGraph } from 'services/api/types';
|
||||||
import type { ModelMetadataField, NonNullableGraph } from 'services/api/types';
|
|
||||||
import { isVAEModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
import {
|
import {
|
||||||
CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||||
@ -25,7 +23,7 @@ import {
|
|||||||
TEXT_TO_IMAGE_GRAPH,
|
TEXT_TO_IMAGE_GRAPH,
|
||||||
VAE_LOADER,
|
VAE_LOADER,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
import { upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addVAEToGraph = async (
|
export const addVAEToGraph = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
@ -151,8 +149,6 @@ export const addVAEToGraph = async (
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (vae) {
|
if (vae) {
|
||||||
const modelConfig = await fetchModelConfigWithTypeGuard(vae.key, isVAEModelConfig);
|
upsertMetadata(graph, { vae });
|
||||||
const vaeMetadata: ModelMetadataField = getModelMetadataField(modelConfig);
|
|
||||||
upsertMetadata(graph, { vae: vaeMetadata });
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import type { JSONObject } from 'common/types';
|
import type { JSONObject } from 'common/types';
|
||||||
import type { AnyModelConfig, CoreMetadataInvocation, ModelMetadataField, NonNullableGraph } from 'services/api/types';
|
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||||
|
import type { AnyModelConfig, CoreMetadataInvocation, NonNullableGraph } from 'services/api/types';
|
||||||
|
|
||||||
import { METADATA } from './constants';
|
import { METADATA } from './constants';
|
||||||
|
|
||||||
@ -72,7 +73,7 @@ export const setMetadataReceivingNode = (graph: NonNullableGraph, nodeId: string
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
export const getModelMetadataField = ({ key, hash, name, base, type }: AnyModelConfig): ModelMetadataField => ({
|
export const getModelMetadataField = ({ key, hash, name, base, type }: AnyModelConfig): ModelIdentifierField => ({
|
||||||
key,
|
key,
|
||||||
hash,
|
hash,
|
||||||
name,
|
name,
|
||||||
|
@ -8,8 +8,7 @@ import { modelSelected } from 'features/parameters/store/actions';
|
|||||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
import { useMainModels } from 'services/api/hooks/modelsByType';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
import type { MainModelConfig } from 'services/api/types';
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
|
const selectModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
|
||||||
@ -18,7 +17,7 @@ const ParamMainModelSelect = () => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const selectedModel = useAppSelector(selectModel);
|
const selectedModel = useAppSelector(selectModel);
|
||||||
const { data, isLoading } = useGetMainModelsQuery(NON_REFINER_BASE_MODELS);
|
const [modelConfigs, { isLoading }] = useMainModels();
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(model: MainModelConfig | null) => {
|
(model: MainModelConfig | null) => {
|
||||||
@ -35,7 +34,7 @@ const ParamMainModelSelect = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
||||||
data,
|
modelConfigs,
|
||||||
isLoading,
|
isLoading,
|
||||||
selectedModel,
|
selectedModel,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
@ -46,7 +45,13 @@ const ParamMainModelSelect = () => {
|
|||||||
<InformationalPopover feature="paramModel">
|
<InformationalPopover feature="paramModel">
|
||||||
<FormLabel>{t('modelManager.model')}</FormLabel>
|
<FormLabel>{t('modelManager.model')}</FormLabel>
|
||||||
</InformationalPopover>
|
</InformationalPopover>
|
||||||
<CustomSelect selectedItem={selectedItem} placeholder={placeholder} items={items} onChange={onChange} />
|
<CustomSelect
|
||||||
|
key={items.length}
|
||||||
|
selectedItem={selectedItem}
|
||||||
|
placeholder={placeholder}
|
||||||
|
items={items}
|
||||||
|
onChange={onChange}
|
||||||
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -7,7 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
|
|||||||
import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice';
|
import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
import { useVAEModels } from 'services/api/hooks/modelsByType';
|
||||||
import type { VAEModelConfig } from 'services/api/types';
|
import type { VAEModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
|
const selector = createMemoizedSelector(selectGenerationSlice, (generation) => {
|
||||||
@ -19,7 +19,7 @@ const ParamVAEModelSelect = () => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { model, vae } = useAppSelector(selector);
|
const { model, vae } = useAppSelector(selector);
|
||||||
const { data, isLoading } = useGetVaeModelsQuery();
|
const [modelConfigs, { isLoading }] = useVAEModels();
|
||||||
const getIsDisabled = useCallback(
|
const getIsDisabled = useCallback(
|
||||||
(vae: VAEModelConfig): boolean => {
|
(vae: VAEModelConfig): boolean => {
|
||||||
const isCompatible = model?.base === vae.base;
|
const isCompatible = model?.base === vae.base;
|
||||||
@ -35,7 +35,7 @@ const ParamVAEModelSelect = () => {
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: vae,
|
selectedModel: vae,
|
||||||
isLoading,
|
isLoading,
|
||||||
|
@ -11,13 +11,8 @@ import { t } from 'i18next';
|
|||||||
import { flatten, map } from 'lodash-es';
|
import { flatten, map } from 'lodash-es';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import {
|
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||||
loraModelsAdapterSelectors,
|
import { useEmbeddingModels, useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||||
textualInversionModelsAdapterSelectors,
|
|
||||||
useGetLoRAModelsQuery,
|
|
||||||
useGetModelConfigQuery,
|
|
||||||
useGetTextualInversionModelsQuery,
|
|
||||||
} from 'services/api/endpoints/models';
|
|
||||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
|
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
|
||||||
@ -33,8 +28,8 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
|||||||
const { data: mainModelConfig, isLoading: isLoadingMainModelConfig } = useGetModelConfigQuery(
|
const { data: mainModelConfig, isLoading: isLoadingMainModelConfig } = useGetModelConfigQuery(
|
||||||
mainModel?.key ?? skipToken
|
mainModel?.key ?? skipToken
|
||||||
);
|
);
|
||||||
const { data: loraModels, isLoading: isLoadingLoRAs } = useGetLoRAModelsQuery();
|
const [loraModels, { isLoading: isLoadingLoRAs }] = useLoRAModels();
|
||||||
const { data: tiModels, isLoading: isLoadingTIs } = useGetTextualInversionModelsQuery();
|
const [tiModels, { isLoading: isLoadingTIs }] = useEmbeddingModels();
|
||||||
|
|
||||||
const _onChange = useCallback<ComboboxOnChange>(
|
const _onChange = useCallback<ComboboxOnChange>(
|
||||||
(v) => {
|
(v) => {
|
||||||
@ -52,8 +47,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
|||||||
const _options: GroupBase<ComboboxOption>[] = [];
|
const _options: GroupBase<ComboboxOption>[] = [];
|
||||||
|
|
||||||
if (tiModels) {
|
if (tiModels) {
|
||||||
const embeddingOptions = textualInversionModelsAdapterSelectors
|
const embeddingOptions = tiModels
|
||||||
.selectAll(tiModels)
|
|
||||||
.filter((ti) => ti.base === mainModelConfig?.base)
|
.filter((ti) => ti.base === mainModelConfig?.base)
|
||||||
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
|
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
|
||||||
|
|
||||||
@ -66,8 +60,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (loraModels) {
|
if (loraModels) {
|
||||||
const triggerPhraseOptions = loraModelsAdapterSelectors
|
const triggerPhraseOptions = loraModels
|
||||||
.selectAll(loraModels)
|
|
||||||
.filter((lora) => map(addedLoRAs, (l) => l.model.key).includes(lora.key))
|
.filter((lora) => map(addedLoRAs, (l) => l.model.key).includes(lora.key))
|
||||||
.map((lora) => {
|
.map((lora) => {
|
||||||
if (lora.trigger_phrases) {
|
if (lora.trigger_phrases) {
|
||||||
|
@ -7,8 +7,7 @@ import { zModelIdentifierField } from 'features/nodes/types/common';
|
|||||||
import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice';
|
import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
import { useRefinerModels } from 'services/api/hooks/modelsByType';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
import type { MainModelConfig } from 'services/api/types';
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
const selectModel = createMemoizedSelector(selectSdxlSlice, (sdxl) => sdxl.refinerModel);
|
const selectModel = createMemoizedSelector(selectSdxlSlice, (sdxl) => sdxl.refinerModel);
|
||||||
@ -19,7 +18,7 @@ const ParamSDXLRefinerModelSelect = () => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const model = useAppSelector(selectModel);
|
const model = useAppSelector(selectModel);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { data, isLoading } = useGetMainModelsQuery(REFINER_BASE_MODELS);
|
const [modelConfigs, { isLoading }] = useRefinerModels();
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(model: MainModelConfig | null) => {
|
(model: MainModelConfig | null) => {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
@ -31,7 +30,7 @@ const ParamSDXLRefinerModelSelect = () => {
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useModelCombobox({
|
const { options, value, onChange, placeholder, noOptionsMessage } = useModelCombobox({
|
||||||
modelEntities: data,
|
modelConfigs,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: model,
|
selectedModel: model,
|
||||||
isLoading,
|
isLoading,
|
||||||
|
@ -1,28 +1,11 @@
|
|||||||
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
|
import type { EntityState } from '@reduxjs/toolkit';
|
||||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||||
import queryString from 'query-string';
|
import queryString from 'query-string';
|
||||||
import {
|
|
||||||
ALL_BASE_MODELS,
|
|
||||||
NON_REFINER_BASE_MODELS,
|
|
||||||
NON_SDXL_MAIN_MODELS,
|
|
||||||
REFINER_BASE_MODELS,
|
|
||||||
SDXL_MAIN_MODELS,
|
|
||||||
} from 'services/api/constants';
|
|
||||||
import type { operations, paths } from 'services/api/schema';
|
import type { operations, paths } from 'services/api/schema';
|
||||||
import type {
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
AnyModelConfig,
|
|
||||||
BaseModelType,
|
|
||||||
ControlNetModelConfig,
|
|
||||||
IPAdapterModelConfig,
|
|
||||||
LoRAModelConfig,
|
|
||||||
MainModelConfig,
|
|
||||||
T2IAdapterModelConfig,
|
|
||||||
TextualInversionModelConfig,
|
|
||||||
VAEModelConfig,
|
|
||||||
} from 'services/api/types';
|
|
||||||
|
|
||||||
import type { ApiTagDescription, tagTypes } from '..';
|
import type { ApiTagDescription } from '..';
|
||||||
import { api, buildV2Url, LIST_TAG } from '..';
|
import { api, buildV2Url, LIST_TAG } from '..';
|
||||||
|
|
||||||
export type UpdateModelArg = {
|
export type UpdateModelArg = {
|
||||||
@ -40,8 +23,9 @@ type UpdateModelImageResponse =
|
|||||||
paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||||
|
type GetModelConfigsResponse = NonNullable<
|
||||||
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
|
paths['/api/v2/models/']['get']['responses']['200']['content']['application/json']
|
||||||
|
>;
|
||||||
|
|
||||||
type DeleteModelArg = {
|
type DeleteModelArg = {
|
||||||
key: string;
|
key: string;
|
||||||
@ -71,74 +55,16 @@ export type ScanFolderResponse =
|
|||||||
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
||||||
type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
|
type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
|
||||||
|
|
||||||
|
type GetHuggingFaceModelsResponse =
|
||||||
|
paths['/api/v2/models/hugging_face']['get']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type GetByAttrsArg = operations['get_model_records_by_attrs']['parameters']['query'];
|
type GetByAttrsArg = operations['get_model_records_by_attrs']['parameters']['query'];
|
||||||
|
|
||||||
const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
|
const modelConfigsAdapter = createEntityAdapter<AnyModelConfig, string>({
|
||||||
selectId: (entity) => entity.key,
|
selectId: (entity) => entity.key,
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
});
|
});
|
||||||
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
export const modelConfigsAdapterSelectors = modelConfigsAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||||
const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
|
|
||||||
selectId: (entity) => entity.key,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
||||||
const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
|
|
||||||
selectId: (entity) => entity.key,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
||||||
const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfig, string>({
|
|
||||||
selectId: (entity) => entity.key,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
||||||
const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfig, string>({
|
|
||||||
selectId: (entity) => entity.key,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
||||||
const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfig, string>({
|
|
||||||
selectId: (entity) => entity.key,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
|
|
||||||
undefined,
|
|
||||||
getSelectorsOptions
|
|
||||||
);
|
|
||||||
const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
|
|
||||||
selectId: (entity) => entity.key,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
||||||
|
|
||||||
const anyModelConfigAdapter = createEntityAdapter<AnyModelConfig, string>({
|
|
||||||
selectId: (entity) => entity.key,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
||||||
|
|
||||||
const buildProvidesTags =
|
|
||||||
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
|
|
||||||
(result: EntityState<TEntity, string> | undefined) => {
|
|
||||||
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
|
|
||||||
if (result) {
|
|
||||||
tags.push(
|
|
||||||
...result.ids.map((id) => ({
|
|
||||||
type: tagType,
|
|
||||||
id,
|
|
||||||
}))
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return tags;
|
|
||||||
};
|
|
||||||
|
|
||||||
const buildTransformResponse =
|
|
||||||
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
|
|
||||||
(response: { models: T[] }) => {
|
|
||||||
return adapter.setAll(adapter.getInitialState(), response.models);
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an endpoint URL for the models router
|
* Builds an endpoint URL for the models router
|
||||||
@ -159,9 +85,27 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
};
|
};
|
||||||
},
|
},
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||||
queryFulfilled.then(({ data }) => {
|
try {
|
||||||
upsertSingleModelConfig(data, dispatch);
|
const { data } = await queryFulfilled;
|
||||||
|
|
||||||
|
// Update the individual model query caches
|
||||||
|
dispatch(modelsApi.util.upsertQueryData('getModelConfig', data.key, data));
|
||||||
|
|
||||||
|
const { base, name, type } = data;
|
||||||
|
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, data));
|
||||||
|
|
||||||
|
// Update the list query cache
|
||||||
|
dispatch(
|
||||||
|
modelsApi.util.updateQueryData('getModelConfigs', undefined, (draft) => {
|
||||||
|
modelConfigsAdapter.updateOne(draft, {
|
||||||
|
id: data.key,
|
||||||
|
changes: data,
|
||||||
});
|
});
|
||||||
|
})
|
||||||
|
);
|
||||||
|
} catch {
|
||||||
|
// no-op
|
||||||
|
}
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
|
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
|
||||||
@ -258,6 +202,13 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
};
|
};
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
|
getHuggingFaceModels: build.query<GetHuggingFaceModelsResponse, string>({
|
||||||
|
query: (hugging_face_repo) => {
|
||||||
|
return {
|
||||||
|
url: buildModelsUrl(`hugging_face?hugging_face_repo=${hugging_face_repo}`),
|
||||||
|
};
|
||||||
|
},
|
||||||
|
}),
|
||||||
listModelInstalls: build.query<ListModelInstallsResponse, void>({
|
listModelInstalls: build.query<ListModelInstallsResponse, void>({
|
||||||
query: () => {
|
query: () => {
|
||||||
return {
|
return {
|
||||||
@ -284,80 +235,27 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['ModelInstalls'],
|
invalidatesTags: ['ModelInstalls'],
|
||||||
}),
|
}),
|
||||||
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
getModelConfigs: build.query<EntityState<AnyModelConfig, string>, void>({
|
||||||
query: (base_models) => {
|
query: () => ({ url: buildModelsUrl() }),
|
||||||
const params: ListModelsArg = {
|
providesTags: (result) => {
|
||||||
model_type: 'main',
|
const tags: ApiTagDescription[] = [{ type: 'ModelConfig', id: LIST_TAG }];
|
||||||
base_models,
|
if (result) {
|
||||||
};
|
const modelTags = result.ids.map((id) => ({ type: 'ModelConfig', id }) as const);
|
||||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
tags.push(...modelTags);
|
||||||
return buildModelsUrl(`?${query}`);
|
}
|
||||||
|
return tags;
|
||||||
|
},
|
||||||
|
keepUnusedDataFor: 60 * 60 * 1000 * 24, // 1 day (infinite)
|
||||||
|
transformResponse: (response: GetModelConfigsResponse) => {
|
||||||
|
return modelConfigsAdapter.setAll(modelConfigsAdapter.getInitialState(), response.models);
|
||||||
},
|
},
|
||||||
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
|
|
||||||
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
|
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||||
queryFulfilled.then(({ data }) => {
|
queryFulfilled.then(({ data }) => {
|
||||||
upsertModelConfigs(data, dispatch);
|
modelConfigsAdapterSelectors.selectAll(data).forEach((modelConfig) => {
|
||||||
|
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
|
||||||
|
const { base, name, type } = modelConfig;
|
||||||
|
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
|
||||||
});
|
});
|
||||||
},
|
|
||||||
}),
|
|
||||||
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
|
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
|
|
||||||
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
|
|
||||||
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
|
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
||||||
queryFulfilled.then(({ data }) => {
|
|
||||||
upsertModelConfigs(data, dispatch);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
|
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
|
|
||||||
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
|
|
||||||
transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter),
|
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
||||||
queryFulfilled.then(({ data }) => {
|
|
||||||
upsertModelConfigs(data, dispatch);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({
|
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
|
|
||||||
providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'),
|
|
||||||
transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter),
|
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
||||||
queryFulfilled.then(({ data }) => {
|
|
||||||
upsertModelConfigs(data, dispatch);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({
|
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
|
|
||||||
providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'),
|
|
||||||
transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter),
|
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
||||||
queryFulfilled.then(({ data }) => {
|
|
||||||
upsertModelConfigs(data, dispatch);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({
|
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
|
|
||||||
providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'),
|
|
||||||
transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter),
|
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
||||||
queryFulfilled.then(({ data }) => {
|
|
||||||
upsertModelConfigs(data, dispatch);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({
|
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
|
|
||||||
providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
|
|
||||||
transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
|
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
||||||
queryFulfilled.then(({ data }) => {
|
|
||||||
upsertModelConfigs(data, dispatch);
|
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
@ -365,14 +263,8 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
});
|
});
|
||||||
|
|
||||||
export const {
|
export const {
|
||||||
|
useGetModelConfigsQuery,
|
||||||
useGetModelConfigQuery,
|
useGetModelConfigQuery,
|
||||||
useGetMainModelsQuery,
|
|
||||||
useGetControlNetModelsQuery,
|
|
||||||
useGetIPAdapterModelsQuery,
|
|
||||||
useGetT2IAdapterModelsQuery,
|
|
||||||
useGetLoRAModelsQuery,
|
|
||||||
useGetTextualInversionModelsQuery,
|
|
||||||
useGetVaeModelsQuery,
|
|
||||||
useDeleteModelsMutation,
|
useDeleteModelsMutation,
|
||||||
useDeleteModelImageMutation,
|
useDeleteModelImageMutation,
|
||||||
useUpdateModelMutation,
|
useUpdateModelMutation,
|
||||||
@ -381,131 +273,8 @@ export const {
|
|||||||
useConvertModelMutation,
|
useConvertModelMutation,
|
||||||
useSyncModelsMutation,
|
useSyncModelsMutation,
|
||||||
useLazyScanFolderQuery,
|
useLazyScanFolderQuery,
|
||||||
|
useLazyGetHuggingFaceModelsQuery,
|
||||||
useListModelInstallsQuery,
|
useListModelInstallsQuery,
|
||||||
useCancelModelInstallMutation,
|
useCancelModelInstallMutation,
|
||||||
usePruneCompletedModelInstallsMutation,
|
usePruneCompletedModelInstallsMutation,
|
||||||
} = modelsApi;
|
} = modelsApi;
|
||||||
|
|
||||||
const upsertModelConfigs = (
|
|
||||||
modelConfigs: EntityState<AnyModelConfig, string>,
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
||||||
dispatch: ThunkDispatch<any, any, UnknownAction>
|
|
||||||
) => {
|
|
||||||
/**
|
|
||||||
* Once a list of models of a specific type is received, fetching any of those models individually is a waste of a
|
|
||||||
* network request. This function takes the received list of models and upserts them into the individual query caches
|
|
||||||
* for each model type.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Iterate over all the models and upsert them into the individual query caches for each model type.
|
|
||||||
anyModelConfigAdapterSelectors.selectAll(modelConfigs).forEach((modelConfig) => {
|
|
||||||
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
|
|
||||||
const { base, name, type } = modelConfig;
|
|
||||||
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
const upsertSingleModelConfig = (
|
|
||||||
modelConfig: AnyModelConfig,
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
||||||
dispatch: ThunkDispatch<any, any, UnknownAction>
|
|
||||||
) => {
|
|
||||||
/**
|
|
||||||
* When a model is updated, the individual query caches for each model type need to be updated, as well as the list
|
|
||||||
* query caches of models of that type.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Update the individual model query caches.
|
|
||||||
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
|
|
||||||
const { base, name, type } = modelConfig;
|
|
||||||
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
|
|
||||||
|
|
||||||
// Update the list query caches for each model type.
|
|
||||||
if (modelConfig.type === 'main') {
|
|
||||||
[ALL_BASE_MODELS, NON_REFINER_BASE_MODELS, SDXL_MAIN_MODELS, NON_SDXL_MAIN_MODELS, REFINER_BASE_MODELS].forEach(
|
|
||||||
(queryArg) => {
|
|
||||||
dispatch(
|
|
||||||
modelsApi.util.updateQueryData('getMainModels', queryArg, (draft) => {
|
|
||||||
mainModelsAdapter.updateOne(draft, {
|
|
||||||
id: modelConfig.key,
|
|
||||||
changes: modelConfig,
|
|
||||||
});
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (modelConfig.type === 'controlnet') {
|
|
||||||
dispatch(
|
|
||||||
modelsApi.util.updateQueryData('getControlNetModels', undefined, (draft) => {
|
|
||||||
controlNetModelsAdapter.updateOne(draft, {
|
|
||||||
id: modelConfig.key,
|
|
||||||
changes: modelConfig,
|
|
||||||
});
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (modelConfig.type === 'embedding') {
|
|
||||||
dispatch(
|
|
||||||
modelsApi.util.updateQueryData('getTextualInversionModels', undefined, (draft) => {
|
|
||||||
textualInversionModelsAdapter.updateOne(draft, {
|
|
||||||
id: modelConfig.key,
|
|
||||||
changes: modelConfig,
|
|
||||||
});
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (modelConfig.type === 'ip_adapter') {
|
|
||||||
dispatch(
|
|
||||||
modelsApi.util.updateQueryData('getIPAdapterModels', undefined, (draft) => {
|
|
||||||
ipAdapterModelsAdapter.updateOne(draft, {
|
|
||||||
id: modelConfig.key,
|
|
||||||
changes: modelConfig,
|
|
||||||
});
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (modelConfig.type === 'lora') {
|
|
||||||
dispatch(
|
|
||||||
modelsApi.util.updateQueryData('getLoRAModels', undefined, (draft) => {
|
|
||||||
loraModelsAdapter.updateOne(draft, {
|
|
||||||
id: modelConfig.key,
|
|
||||||
changes: modelConfig,
|
|
||||||
});
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (modelConfig.type === 't2i_adapter') {
|
|
||||||
dispatch(
|
|
||||||
modelsApi.util.updateQueryData('getT2IAdapterModels', undefined, (draft) => {
|
|
||||||
t2iAdapterModelsAdapter.updateOne(draft, {
|
|
||||||
id: modelConfig.key,
|
|
||||||
changes: modelConfig,
|
|
||||||
});
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (modelConfig.type === 'vae') {
|
|
||||||
dispatch(
|
|
||||||
modelsApi.util.updateQueryData('getVaeModels', undefined, (draft) => {
|
|
||||||
vaeModelsAdapter.updateOne(draft, {
|
|
||||||
id: modelConfig.key,
|
|
||||||
changes: modelConfig,
|
|
||||||
});
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
42
invokeai/frontend/web/src/services/api/hooks/modelsByType.ts
Normal file
42
invokeai/frontend/web/src/services/api/hooks/modelsByType.ts
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
import {
|
||||||
|
isControlNetModelConfig,
|
||||||
|
isIPAdapterModelConfig,
|
||||||
|
isLoRAModelConfig,
|
||||||
|
isNonRefinerMainModelConfig,
|
||||||
|
isNonSDXLMainModelConfig,
|
||||||
|
isRefinerMainModelModelConfig,
|
||||||
|
isSDXLMainModelModelConfig,
|
||||||
|
isT2IAdapterModelConfig,
|
||||||
|
isTIModelConfig,
|
||||||
|
isVAEModelConfig,
|
||||||
|
} from 'services/api/types';
|
||||||
|
|
||||||
|
const buildModelsHook =
|
||||||
|
<T extends AnyModelConfig>(typeGuard: (config: AnyModelConfig) => config is T) =>
|
||||||
|
() => {
|
||||||
|
const result = useGetModelConfigsQuery(undefined);
|
||||||
|
const modelConfigs = useMemo(() => {
|
||||||
|
if (!result.data) {
|
||||||
|
return EMPTY_ARRAY;
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelConfigsAdapterSelectors.selectAll(result.data).filter(typeGuard);
|
||||||
|
}, [result]);
|
||||||
|
|
||||||
|
return [modelConfigs, result] as const;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
||||||
|
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
||||||
|
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
||||||
|
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
||||||
|
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||||
|
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
||||||
|
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
|
||||||
|
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
|
||||||
|
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
|
||||||
|
export const useVAEModels = buildModelsHook(isVAEModelConfig);
|
@ -1,12 +1,7 @@
|
|||||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
import { useRefinerModels } from 'services/api/hooks/modelsByType';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
export const useIsRefinerAvailable = () => {
|
export const useIsRefinerAvailable = () => {
|
||||||
const { isRefinerAvailable } = useGetMainModelsQuery(REFINER_BASE_MODELS, {
|
const [refinerModels] = useRefinerModels();
|
||||||
selectFromResult: ({ data }) => ({
|
|
||||||
isRefinerAvailable: data ? data.ids.length > 0 : false,
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
return isRefinerAvailable;
|
return Boolean(refinerModels.length);
|
||||||
};
|
};
|
||||||
|
File diff suppressed because one or more lines are too long
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user