Compare commits

..

5 Commits

253 changed files with 6711 additions and 6797 deletions

View File

@ -1,7 +1,7 @@
# Runs frontend code quality checks.
#
# Checks for changes to frontend files before running the checks.
# If always_run is true, always runs the checks.
# When manually triggered or when called from another workflow, always runs the checks.
name: 'frontend checks'
@ -16,19 +16,7 @@ on:
- 'synchronize'
merge_group:
workflow_dispatch:
inputs:
always_run:
description: 'Always run the checks'
required: true
type: boolean
default: true
workflow_call:
inputs:
always_run:
description: 'Always run the checks'
required: true
type: boolean
default: true
defaults:
run:
@ -42,7 +30,7 @@ jobs:
- uses: actions/checkout@v4
- name: check for changed frontend files
if: ${{ inputs.always_run != true }}
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
@ -51,30 +39,30 @@ jobs:
- 'invokeai/frontend/web/**'
- name: install dependencies
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: ./.github/actions/install-frontend-deps
- name: tsc
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:tsc'
shell: bash
- name: dpdm
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:dpdm'
shell: bash
- name: eslint
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:eslint'
shell: bash
- name: prettier
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:prettier'
shell: bash
- name: knip
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm lint:knip'
shell: bash

View File

@ -1,7 +1,7 @@
# Runs frontend tests.
#
# Checks for changes to frontend files before running the tests.
# If always_run is true, always runs the tests.
# When manually triggered or called from another workflow, always runs the tests.
name: 'frontend tests'
@ -16,19 +16,7 @@ on:
- 'synchronize'
merge_group:
workflow_dispatch:
inputs:
always_run:
description: 'Always run the tests'
required: true
type: boolean
default: true
workflow_call:
inputs:
always_run:
description: 'Always run the tests'
required: true
type: boolean
default: true
defaults:
run:
@ -42,7 +30,7 @@ jobs:
- uses: actions/checkout@v4
- name: check for changed frontend files
if: ${{ inputs.always_run != true }}
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
@ -51,10 +39,10 @@ jobs:
- 'invokeai/frontend/web/**'
- name: install dependencies
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: ./.github/actions/install-frontend-deps
- name: vitest
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || inputs.always_run == true }}
if: ${{ steps.changed-files.outputs.frontend_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: 'pnpm test:no-watch'
shell: bash

View File

@ -1,7 +1,7 @@
# Runs python code quality checks.
#
# Checks for changes to python files before running the checks.
# If always_run is true, always runs the checks.
# When manually triggered or called from another workflow, always runs the tests.
#
# TODO: Add mypy or pyright to the checks.
@ -18,19 +18,7 @@ on:
- 'synchronize'
merge_group:
workflow_dispatch:
inputs:
always_run:
description: 'Always run the checks'
required: true
type: boolean
default: true
workflow_call:
inputs:
always_run:
description: 'Always run the checks'
required: true
type: boolean
default: true
jobs:
python-checks:
@ -41,7 +29,7 @@ jobs:
uses: actions/checkout@v4
- name: check for changed python files
if: ${{ inputs.always_run != true }}
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
@ -53,7 +41,7 @@ jobs:
- 'tests/**'
- name: setup python
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: actions/setup-python@v5
with:
python-version: '3.10'
@ -61,16 +49,16 @@ jobs:
cache-dependency-path: pyproject.toml
- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: pip install ruff
shell: bash
- name: ruff check
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: ruff check --output-format=github .
shell: bash
- name: ruff format
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: ruff format --check .
shell: bash

View File

@ -1,7 +1,7 @@
# Runs python tests on a matrix of python versions and platforms.
#
# Checks for changes to python files before running the tests.
# If always_run is true, always runs the tests.
# When manually triggered or called from another workflow, always runs the tests.
name: 'python tests'
@ -9,7 +9,6 @@ on:
push:
branches:
- 'main'
- 'bug-install-job-running-multiple-times'
pull_request:
types:
- 'ready_for_review'
@ -17,19 +16,7 @@ on:
- 'synchronize'
merge_group:
workflow_dispatch:
inputs:
always_run:
description: 'Always run the tests'
required: true
type: boolean
default: true
workflow_call:
inputs:
always_run:
description: 'Always run the tests'
required: true
type: boolean
default: true
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
@ -76,7 +63,7 @@ jobs:
uses: actions/checkout@v4
- name: check for changed python files
if: ${{ inputs.always_run != true }}
if: ${{ github.event_name != 'workflow_dispatch' && github.event_name != 'workflow_call' }}
id: changed-files
uses: tj-actions/changed-files@v42
with:
@ -88,7 +75,7 @@ jobs:
- 'tests/**'
- name: setup python
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
@ -96,12 +83,12 @@ jobs:
cache-dependency-path: pyproject.toml
- name: install dependencies
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
pip3 install --editable=".[test]"
- name: run pytest
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || github.event_name == 'workflow_dispatch' || github.event_name == 'workflow_call' }}
run: pytest

View File

@ -30,23 +30,15 @@ jobs:
frontend-checks:
uses: ./.github/workflows/frontend-checks.yml
with:
always_run: true
frontend-tests:
uses: ./.github/workflows/frontend-tests.yml
with:
always_run: true
python-checks:
uses: ./.github/workflows/python-checks.yml
with:
always_run: true
python-tests:
uses: ./.github/workflows/python-tests.yml
with:
always_run: true
build:
uses: ./.github/workflows/build-installer.yml
@ -66,8 +58,6 @@ jobs:
environment:
name: testpypi
url: https://test.pypi.org/p/invokeai
permissions:
id-token: write
steps:
- name: download distribution from build job
uses: actions/download-artifact@v4
@ -95,8 +85,6 @@ jobs:
environment:
name: pypi
url: https://pypi.org/p/invokeai
permissions:
id-token: write
steps:
- name: download distribution from build job
uses: actions/download-artifact@v4

View File

@ -6,18 +6,16 @@ default: help
help:
@echo Developer commands:
@echo
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
@echo "test Run the unit tests."
@echo "update-config-docstring Update the app's config docstring so mkdocs can autogenerate it correctly."
@echo "frontend-install Install the pnpm modules needed for the front end"
@echo "frontend-build Build the frontend in order to run on localhost:9090"
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
@echo "test" Run the unit tests.
@echo "frontend-install" Install the pnpm modules needed for the front end
@echo "frontend-build Build the frontend in order to run on localhost:9090"
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
# Runs ruff, fixing any safely-fixable errors and formatting
ruff:
@ -42,10 +40,6 @@ mypy-all:
test:
pytest ./tests
# Update config docstring
update-config-docstring:
python scripts/update_config_docstring.py
# Install the pnpm modules needed for the front end
frontend-install:
rm -rf invokeai/frontend/web/node_modules
@ -59,9 +53,6 @@ frontend-build:
frontend-dev:
cd invokeai/frontend/web && pnpm dev
frontend-typegen:
cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen
# Installer zip file
installer-zip:
cd installer && ./create_installer.sh

View File

@ -16,6 +16,11 @@ model. These are the:
information. It is also responsible for managing the InvokeAI
`models` directory and its contents.
* _ModelMetadataStore_ and _ModelMetaDataFetch_ Backend modules that
are able to retrieve metadata from online model repositories,
transform them into Pydantic models, and cache them to the InvokeAI
SQL database.
* _DownloadQueueServiceBase_
A multithreaded downloader responsible
for downloading models from a remote source to disk. The download
@ -378,13 +383,16 @@ functionality:
* Downloading a model from an arbitrary URL and installing it in
`models_dir`.
* Special handling for Civitai model URLs which allow the user to
paste in a model page's URL or download link
* Special handling for HuggingFace repo_ids to recursively download
the contents of the repository, paying attention to alternative
variants such as fp16.
* Saving tags and other metadata about the model into the invokeai database
when fetching from a repo that provides that type of information,
(currently only HuggingFace).
(currently only Civitai and HuggingFace).
### Initializing the installer
@ -428,6 +436,7 @@ required parameters:
| `app_config` | InvokeAIAppConfig | InvokeAI app configuration object |
| `record_store` | ModelRecordServiceBase | Config record storage database |
| `download_queue` | DownloadQueueServiceBase | Download queue object |
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
Once initialized, the installer will provide the following methods:
@ -571,7 +580,33 @@ The `AnyHttpUrl` class can be imported from `pydantic.networks`.
Ordinarily, no metadata is retrieved from these sources. However,
there is special-case code in the installer that looks for HuggingFace
and fetches the corresponding model metadata from the corresponding repo.
and Civitai URLs and fetches the corresponding model metadata from
the corresponding repo.
#### CivitaiModelSource
This is used for a model that is hosted by the Civitai web site.
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `version_id` | int | None | The ID of the particular version of the desired model. |
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
Civitai has two model IDs, both of which are integers. The `model_id`
corresponds to a collection of model versions that may different in
arbitrary ways, such as derivation from different checkpoint training
steps, SFW vs NSFW generation, pruned vs non-pruned, etc. The
`version_id` points to a specific version. Please use the latter.
Some Civitai models require an access token to download. These can be
generated from the Civitai profile page of a logged-in
account. Somewhat annoyingly, if you fail to provide the access token
when downloading a model that needs it, Civitai generates a redirect
to a login page rather than a 403 Forbidden error. The installer
attempts to catch this event and issue an informative error
message. Otherwise you will get an "unrecognized model suffix" error
when the model prober tries to identify the type of the HTML login
page.
#### HFModelSource
@ -1218,9 +1253,9 @@ queue and have not yet reached a terminal state.
The modules found under `invokeai.backend.model_manager.metadata`
provide a straightforward API for fetching model metadatda from online
repositories. Currently only HuggingFace is supported. However, the
modules are easily extended for additional repos, provided that they
have defined APIs for metadata access.
repositories. Currently two repositories are supported: HuggingFace
and Civitai. However, the modules are easily extended for additional
repos, provided that they have defined APIs for metadata access.
Metadata comprises any descriptive information that is not essential
for getting the model to run. For example "author" is metadata, while
@ -1232,16 +1267,37 @@ model's config, as defined in `invokeai.backend.model_manager.config`.
```
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
CivitaiMetadataFetch,
CivitaiMetadata
ModelMetadataStore,
)
# to access the initialized sql database
from invokeai.app.api.dependencies import ApiDependencies
hf = HuggingFaceMetadataFetch()
civitai = CivitaiMetadataFetch()
# fetch the metadata
model_metadata = hf.from_id("<repo_id>")
model_metadata = civitai.from_url("https://civitai.com/models/215796")
assert isinstance(model_metadata, HuggingFaceMetadata)
# get some common metadata fields
author = model_metadata.author
tags = model_metadata.tags
# get some Civitai-specific fields
assert isinstance(model_metadata, CivitaiMetadata)
trained_words = model_metadata.trained_words
base_model = model_metadata.base_model_trained_on
thumbnail = model_metadata.thumbnail_url
# cache the metadata to the database using the key corresponding to
# an existing model config record in the `model_config` table
sql_cache = ModelMetadataStore(ApiDependencies.invoker.services.db)
sql_cache.add_metadata('fb237ace520b6716adc98bcb16e8462c', model_metadata)
# now we can search the database by tag, author or model name
# matches will contain a list of model keys that match the search
matches = sql_cache.search_by_tag({"tool", "turbo"})
```
### Structure of the Metadata objects
@ -1278,14 +1334,52 @@ This descends from `ModelMetadataBase` and adds the following fields:
| `last_modified`| datetime | Date of last commit of this model to the repo |
| `files` | List[Path] | List of the files in the model repo |
#### `CivitaiMetadata`
This descends from `ModelMetadataBase` and adds the following fields:
| **Field Name** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `type` | Literal["civitai"] | Used for the discriminated union of metadata classes|
| `id` | int | Civitai model id |
| `version_name` | str | Name of this version of the model (distinct from model name) |
| `version_id` | int | Civitai model version id (distinct from model id) |
| `created` | datetime | Date this version of the model was created |
| `updated` | datetime | Date this version of the model was last updated |
| `published` | datetime | Date this version of the model was published to Civitai |
| `description` | str | Model description. Quite verbose and contains HTML tags |
| `version_description` | str | Model version description, usually describes changes to the model |
| `nsfw` | bool | Whether the model tends to generate NSFW content |
| `restrictions` | LicenseRestrictions | An object that describes what is and isn't allowed with this model |
| `trained_words`| Set[str] | Trigger words for this model, if any |
| `download_url` | AnyHttpUrl | URL for downloading this version of the model |
| `base_model_trained_on` | str | Name of the model that this version was trained on |
| `thumbnail_url` | AnyHttpUrl | URL to access a representative thumbnail image of the model's output |
| `weight_min` | int | For LoRA sliders, the minimum suggested weight to apply |
| `weight_max` | int | For LoRA sliders, the maximum suggested weight to apply |
Note that `weight_min` and `weight_max` are not currently populated
and take the default values of (-1.0, +2.0). The issue is that these
values aren't part of the structured data but appear in the text
description. Some regular expression or LLM coding may be able to
extract these values.
Also be aware that `base_model_trained_on` is free text and doesn't
correspond to our `ModelType` enum.
`CivitaiMetadata` also defines some convenience properties relating to
licensing restrictions: `credit_required`, `allow_commercial_use`,
`allow_derivatives` and `allow_different_license`.
#### `AnyModelRepoMetadata`
This is a discriminated Union of `HuggingFaceMetadata`.
This is a discriminated Union of `CivitaiMetadata` and
`HuggingFaceMetadata`.
### Fetching Metadata from Online Repos
The `HuggingFaceMetadataFetch` class will
retrieve metadata from its corresponding repository and return
The `HuggingFaceMetadataFetch` and `CivitaiMetadataFetch` classes will
retrieve metadata from their corresponding repositories and return
`AnyModelRepoMetadata` objects. Their base class
`ModelMetadataFetchBase` is an abstract class that defines two
methods: `from_url()` and `from_id()`. The former accepts the type of
@ -1303,17 +1397,96 @@ provide a `requests.Session` argument. This allows you to customize
the low-level HTTP fetch requests and is used, for instance, in the
testing suite to avoid hitting the internet.
The HuggingFace fetcher subclass add additional repo-specific fetching methods:
The HuggingFace and Civitai fetcher subclasses add additional
repo-specific fetching methods:
#### HuggingFaceMetadataFetch
This overrides its base class `from_json()` method to return a
`HuggingFaceMetadata` object directly.
#### CivitaiMetadataFetch
This adds the following methods:
`from_civitai_modelid()` This takes the ID of a model, finds the
default version of the model, and then retrieves the metadata for
that version, returning a `CivitaiMetadata` object directly.
`from_civitai_versionid()` This takes the ID of a model version and
retrieves its metadata. Functionally equivalent to `from_id()`, the
only difference is that it returna a `CivitaiMetadata` object rather
than an `AnyModelRepoMetadata`.
### Metadata Storage
The `ModelConfigBase` stores this response in the `source_api_response` field
as a JSON blob.
The `ModelMetadataStore` provides a simple facility to store model
metadata in the `invokeai.db` database. The data is stored as a JSON
blob, with a few common fields (`name`, `author`, `tags`) broken out
to be searchable.
When a metadata object is saved to the database, it is identified
using the model key, _and this key must correspond to an existing
model key in the model_config table_. There is a foreign key integrity
constraint between the `model_config.id` field and the
`model_metadata.id` field such that if you attempt to save metadata
under an unknown key, the attempt will result in an
`UnknownModelException`. Likewise, when a model is deleted from
`model_config`, the deletion of the corresponding metadata record will
be triggered.
Tags are stored in a normalized fashion in the tables `model_tags` and
`tags`. Triggers keep the tag table in sync with the `model_metadata`
table.
To create the storage object, initialize it with the InvokeAI
`SqliteDatabase` object. This is often done this way:
```
from invokeai.app.api.dependencies import ApiDependencies
metadata_store = ModelMetadataStore(ApiDependencies.invoker.services.db)
```
You can then access the storage with the following methods:
#### `add_metadata(key, metadata)`
Add the metadata using a previously-defined model key.
There is currently no `delete_metadata()` method. The metadata will
persist until the matching config is deleted from the `model_config`
table.
#### `get_metadata(key) -> AnyModelRepoMetadata`
Retrieve the metadata corresponding to the model key.
#### `update_metadata(key, new_metadata)`
Update an existing metadata record with new metadata.
#### `search_by_tag(tags: Set[str]) -> Set[str]`
Given a set of tags, find models that are tagged with them. If
multiple tags are provided then a matching model must be tagged with
*all* the tags in the set. This method returns a set of model keys and
is intended to be used in conjunction with the `ModelRecordService`:
```
model_config_store = ApiDependencies.invoker.services.model_records
matches = metadata_store.search_by_tag({'license:other'})
models = [model_config_store.get(x) for x in matches]
```
#### `search_by_name(name: str) -> Set[str]
Find all model metadata records that have the given name and return a
set of keys to the corresponding model config objects.
#### `search_by_author(author: str) -> Set[str]
Find all model metadata records that have the given author and return
a set of keys to the corresponding model config objects.
***

View File

@ -1,133 +0,0 @@
# Invoke UI
Invoke's UI is made possible by many contributors and open-source libraries. Thank you!
## Dev environment
### Setup
1. Install [node] and [pnpm].
1. Run `pnpm i` to install all packages.
#### Run in dev mode
1. From `invokeai/frontend/web/`, run `pnpm dev`.
1. From repo root, run `python scripts/invokeai-web.py`.
1. Point your browser to the dev server address, e.g. <http://localhost:5173/>
### Package scripts
- `dev`: run the frontend in dev mode, enabling hot reloading
- `build`: run all checks (madge, eslint, prettier, tsc) and then build the frontend
- `typegen`: generate types from the OpenAPI schema (see [Type generation])
- `lint:dpdm`: check circular dependencies
- `lint:eslint`: check code quality
- `lint:prettier`: check code formatting
- `lint:tsc`: check type issues
- `lint:knip`: check for unused exports or objects (failures here are just suggestions, not hard fails)
- `lint`: run all checks concurrently
- `fix`: run `eslint` and `prettier`, fixing fixable issues
### Type generation
We use [openapi-typescript] to generate types from the app's OpenAPI schema.
The generated types are committed to the repo in [schema.ts].
```sh
# from the repo root, start the server
python scripts/invokeai-web.py
# from invokeai/frontend/web/, run the script
pnpm typegen
```
### Localization
We use [i18next] for localization, but translation to languages other than English happens on our [Weblate] project.
Only the English source strings should be changed on this repo.
### VSCode
#### Example debugger config
```jsonc
{
"version": "0.2.0",
"configurations": [
{
"type": "chrome",
"request": "launch",
"name": "Invoke UI",
"url": "http://localhost:5173",
"webRoot": "${workspaceFolder}/invokeai/frontend/web"
}
]
}
```
#### Remote dev
We've noticed an intermittent timeout issue with the VSCode remote dev port forwarding.
We suggest disabling the editor's port forwarding feature and doing it manually via SSH:
```sh
ssh -L 9090:localhost:9090 -L 5173:localhost:5173 user@host
```
## Contributing Guidelines
Thanks for your interest in contributing to the Invoke Web UI!
Please follow these guidelines when contributing.
### Check in before investing your time
Please check in before you invest your time on anything besides a trivial fix, in case it conflicts with ongoing work or isn't aligned with the vision for the app.
If a feature request or issue doesn't already exist for the thing you want to work on, please create one.
Ping `@psychedelicious` on [discord] in the `#frontend-dev` channel or in the feature request / issue you want to work on - we're happy to chat.
### Code conventions
- This is a fairly complex app with a deep component tree. Please use memoization (`useCallback`, `useMemo`, `memo`) with enthusiasm.
- If you need to add some global, ephemeral state, please use [nanostores] if possible.
- Be careful with your redux selectors. If they need to be parameterized, consider creating them inside a `useMemo`.
- Feel free to use `lodash` (via `lodash-es`) to make the intent of your code clear.
- Please add comments describing the "why", not the "how" (unless it is really arcane).
### Commit format
Please use the [conventional commits] spec for the web UI, with a scope of "ui":
- `chore(ui): bump deps`
- `chore(ui): lint`
- `feat(ui): add some cool new feature`
- `fix(ui): fix some bug`
### Submitting a PR
- Ensure your branch is tidy. Use an interactive rebase to clean up the commit history and reword the commit messages if they are not descriptive.
- Run `pnpm lint`. Some issues are auto-fixable with `pnpm fix`.
- Fill out the PR form when creating the PR.
- It doesn't need to be super detailed, but a screenshot or video is nice if you changed something visually.
- If a section isn't relevant, delete it. There are no UI tests at this time.
## Other docs
- [Workflows - Design and Implementation]
- [State Management]
[node]: https://nodejs.org/en/download/
[pnpm]: https://github.com/pnpm/pnpm
[discord]: https://discord.gg/ZmtBAhwWhy
[i18next]: https://github.com/i18next/react-i18next
[Weblate]: https://hosted.weblate.org/engage/invokeai/
[openapi-typescript]: https://github.com/drwpow/openapi-typescript
[Type generation]: #type-generation
[schema.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/services/api/schema.ts
[conventional commits]: https://www.conventionalcommits.org/en/v1.0.0/
[Workflows - Design and Implementation]: ./WORKFLOWS.md
[State Management]: ./STATE_MGMT.md

View File

@ -31,18 +31,18 @@ be referred to as ROOT.
To find its root directory, InvokeAI uses the following recipe:
1. It first looks for the argument `--root <path>` on the command line
it was launched from, and uses the indicated path if present.
it was launched from, and uses the indicated path if present.
2. Next it looks for the environment variable INVOKEAI_ROOT, and uses
the directory path found there if present.
the directory path found there if present.
3. If neither of these are present, then InvokeAI looks for the
folder containing the `.venv` Python virtual environment directory for
the currently active environment. This directory is checked for files
expected inside the InvokeAI root before it is used.
folder containing the `.venv` Python virtual environment directory for
the currently active environment. This directory is checked for files
expected inside the InvokeAI root before it is used.
4. Finally, InvokeAI looks for a directory in the current user's home
directory named `invokeai`.
directory named `invokeai`.
#### Reading the InvokeAI Configuration File
@ -149,75 +149,104 @@ usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS
## The Configuration Settings
The config is managed by the `InvokeAIAppConfig` class, which is a pydantic model. The below docs are autogenerated from the class.
The configuration settings are divided into several distinct
groups in `invokeia.yaml`:
When editing your `invokeai.yaml` file, you'll need to put settings under their appropriate group. The group for each setting is denoted in the table below.
### Web Server
Following the table are additional explanations for certain settings.
| Setting | Default Value | Description |
|---------------------|---------------|----------------------------------------------------------------------------------------------------------------------------|
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
| `port` | `9090` | Network port number that the web server will listen on |
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
| `ssl_certfile` | null | Path to an SSL certificate file, used to enable HTTPS. |
| `ssl_keyfile` | null | Path to an SSL keyfile, if the key is not included in the certificate file. |
<!-- prettier-ignore-start -->
::: invokeai.app.services.config.config_default.InvokeAIAppConfig
options:
heading_level: 3
members: false
<!-- prettier-ignore-end -->
The documentation for InvokeAI's API can be accessed by browsing to the following URL: [http://localhost:9090/docs].
### Model Marketplace API Keys
### Features
Some model marketplaces require an API key to download models. You can provide a URL pattern and appropriate token in your `invokeai.yaml` file to provide that API key.
These configuration settings allow you to enable and disable various InvokeAI features:
The pattern can be any valid regex (you may need to surround the pattern with quotes):
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `esrgan` | `true` | Activate the ESRGAN upscaling options|
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
```yaml
InvokeAI:
Model Install:
remote_api_tokens:
# Any URL containing `models.com` will automatically use `your_models_com_token`
- url_regex: models.com
token: your_models_com_token
# Any URL matching this contrived regex will use `some_other_token`
- url_regex: '^[a-z]{3}whatever.*\.com$'
token: some_other_token
```
### Generation
The provided token will be added as a `Bearer` token to the network requests to download the model files. As far as we know, this works for all model marketplaces that require authorization.
These options tune InvokeAI's memory and performance characteristics.
### Model Hashing
| Setting | Default Value | Description |
|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss |
| `attention_type` | `auto` | Select the type of attention to use. One of `auto`,`normal`,`xformers`,`sliced`, or `torch-sdp` |
| `attention_slice_size` | `auto` | When "sliced" attention is selected, set the slice size. One of `auto`, `balanced`, `max` or the integers 1-8|
| `force_tiled_decode` | `false` | Force the VAE step to decode in tiles, reducing memory consumption at the cost of performance |
Models are hashed during installation, providing a stable identifier for models across all platforms. The default algorithm is `blake3`, with a multi-threaded implementation.
### Device
If your models are stored on a spinning hard drive, we suggest using `blake3_single`, the single-threaded implementation. The hashes are the same, but it's much faster on spinning disks.
These options configure the generation execution device.
```yaml
InvokeAI:
Model Install:
hashing_algorithm: blake3_single
```
| Setting | Default Value | Description |
|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `device` | `auto` | Preferred execution device. One of `auto`, `cpu`, `cuda`, `cuda:1`, `mps`. `auto` will choose the device depending on the hardware platform and the installed torch capabilities. |
| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system |
Model hashing is a one-time operation, but it may take a couple minutes to hash a large model collection. You may opt out of model hashing entirely by setting the algorithm to `random`.
```yaml
InvokeAI:
Model Install:
hashing_algorithm: random
```
Most common algorithms are supported, like `md5`, `sha256`, and `sha512`. These are typically much, much slower than `blake3`.
### Paths
These options set the paths of various directories and files used by
InvokeAI. Relative paths are interpreted relative to the root directory, so
if root is `/home/fred/invokeai` and the path is
InvokeAI. Relative paths are interpreted relative to INVOKEAI_ROOT, so
if INVOKEAI_ROOT is `/home/fred/invokeai` and the path is
`autoimport/main`, then the corresponding directory will be located at
`/home/fred/invokeai/autoimport/main`.
Note that the autoimport directory will be searched recursively,
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory |
| `lora_dir` | `autoimport/lora` | At startup time, read and import any LoRA/LyCORIS models found in this directory |
| `embedding_dir` | `autoimport/embedding` | At startup time, read and import any textual inversion (embedding) models found in this directory |
| `controlnet_dir` | `autoimport/controlnet` | At startup time, read and import any ControlNet models found in this directory |
| `conf_path` | `configs/models.yaml` | Location of the `models.yaml` model configuration file |
| `models_dir` | `models` | Location of the directory containing models installed by InvokeAI's model manager |
| `legacy_conf_dir` | `configs/stable-diffusion` | Location of the directory containing the .yaml configuration files for legacy checkpoint models |
| `db_dir` | `databases` | Location of the directory containing InvokeAI's image, schema and session database |
| `outdir` | `outputs` | Location of the directory in which the gallery of generated and uploaded images will be stored |
| `use_memory_db` | `false` | Keep database information in memory rather than on disk; this will not preserve image gallery information across restarts |
Note that the autoimport directories will be searched recursively,
allowing you to organize the models into folders and subfolders in any
way you wish.
way you wish. In addition, while we have split up autoimport
directories by the type of model they contain, this isn't
necessary. You can combine different model types in the same folder
and InvokeAI will figure out what they are. So you can easily use just
one autoimport directory by commenting out the unneeded paths:
```
Paths:
autoimport_dir: autoimport
# lora_dir: null
# embedding_dir: null
# controlnet_dir: null
```
### Logging
These settings control the information, warning, and debugging
messages printed to the console log while InvokeAI is running:
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `log_handlers` | `console` | This controls where log messages are sent, and can be a list of one or more destinations. Values include `console`, `file`, `syslog` and `http`. These are described in more detail below |
| `log_format` | `color` | This controls the formatting of the log messages. Values are `plain`, `color`, `legacy` and `syslog` |
| `log_level` | `debug` | This filters messages according to the level of severity and can be one of `debug`, `info`, `warning`, `error` and `critical`. For example, setting to `warning` will display all messages at the warning level or higher, but won't display "debug" or "info" messages |
Several different log handler destinations are available, and multiple destinations are supported by providing a list:
```
@ -227,9 +256,9 @@ Several different log handler destinations are available, and multiple destinati
- file=/var/log/invokeai.log
```
- `console` is the default. It prints log messages to the command-line window from which InvokeAI was launched.
* `console` is the default. It prints log messages to the command-line window from which InvokeAI was launched.
- `syslog` is only available on Linux and Macintosh systems. It uses
* `syslog` is only available on Linux and Macintosh systems. It uses
the operating system's "syslog" facility to write log file entries
locally or to a remote logging machine. `syslog` offers a variety
of configuration options:
@ -242,7 +271,7 @@ Several different log handler destinations are available, and multiple destinati
- Log to LAN-connected server "fredserver" using the facility LOG_USER and datagram packets.
```
- `http` can be used to log to a remote web server. The server must be
* `http` can be used to log to a remote web server. The server must be
properly configured to receive and act on log messages. The option
accepts the URL to the web server, and a `method` argument
indicating whether the message should be submitted using the GET or
@ -254,7 +283,7 @@ Several different log handler destinations are available, and multiple destinati
The `log_format` option provides several alternative formats:
- `color` - default format providing time, date and a message, using text colors to distinguish different log severities
- `plain` - same as above, but monochrome text only
- `syslog` - the log level and error message only, allowing the syslog system to attach the time and date
- `legacy` - a format similar to the one used by the legacy 2.3 InvokeAI releases.
* `color` - default format providing time, date and a message, using text colors to distinguish different log severities
* `plain` - same as above, but monochrome text only
* `syslog` - the log level and error message only, allowing the syslog system to attach the time and date
* `legacy` - a format similar to the one used by the legacy 2.3 InvokeAI releases.

View File

@ -1,35 +0,0 @@
---
title: Database
---
# Invoke's SQLite Database
Invoke uses a SQLite database to store image, workflow, model, and execution data.
We take great care to ensure your data is safe, by utilizing transactions and a database migration system.
Even so, when testing an prerelease version of the app, we strongly suggest either backing up your database or using an in-memory database. This ensures any prelease hiccups or databases schema changes will not cause problems for your data.
## Database Backup
Backing up your database is very simple. Invoke's data is stored in an `$INVOKEAI_ROOT` directory - where your `invoke.sh`/`invoke.bat` and `invokeai.yaml` files live.
To back up your database, copy the `invokeai.db` file from `$INVOKEAI_ROOT/databases/invokeai.db` to somewhere safe.
If anything comes up during prelease testing, you can simply copy your backup back into `$INVOKEAI_ROOT/databases/`.
## In-Memory Database
SQLite can run on an in-memory database. Your existing database is untouched when this mode is enabled, but your existing data won't be accessible.
This is very useful for testing, as there is no chance of a database change modifying your "physical" database.
To run Invoke with a memory database, edit your `invokeai.yaml` file, and add `use_memory_db: true` to the `Paths:` stanza:
```yaml
InvokeAI:
Development:
use_memory_db: true
```
Delete this line (or set it to `false`) to use your main database.

View File

@ -22,24 +22,6 @@ class MyInvocation(BaseInvocation):
...
```
The full API is documented below.
## Invocation Mixins
Two important mixins are provided to facilitate working with metadata and gallery boards.
### `WithMetadata`
Inherit from this class (in addition to `BaseInvocation`) to add a `metadata` input to your node. When you do this, you can access the metadata dict from `self.metadata` in the `invoke()` function.
The dict will be populated via the node's input, and you can add any metadata you'd like to it. When you call `context.images.save()`, if the metadata dict has any data, it be automatically embedded in the image.
### `WithBoard`
Inherit from this class (in addition to `BaseInvocation`) to add a `board` input to your node. This renders as a drop-down to select a board. The user's selection will be accessible from `self.board` in the `invoke()` function.
When you call `context.images.save()`, if a board was selected, the image will added to that board as it is saved.
<!-- prettier-ignore-start -->
::: invokeai.app.services.shared.invocation_context.InvocationContext
options:

View File

@ -25,7 +25,6 @@ from ..services.invocation_cache.invocation_cache_memory import MemoryInvocation
from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker
from ..services.model_images.model_images_default import ModelImageFileStorageDisk
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
@ -72,8 +71,6 @@ class ApiDependencies:
image_files = DiskImageFileStorage(f"{output_folder}/images")
model_images_folder = config.models_path
db = init_db(config=config, logger=logger, image_files=image_files)
configuration = config
@ -95,7 +92,6 @@ class ApiDependencies:
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
download_queue_service = DownloadQueueService(event_bus=events)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db),
@ -122,7 +118,6 @@ class ApiDependencies:
images=images,
invocation_cache=invocation_cache,
logger=logger,
model_images=model_images_service,
model_manager=model_manager,
download_queue=download_queue_service,
names=names,

View File

@ -1,17 +1,13 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records."""
import io
import pathlib
import shutil
import traceback
from typing import Any, Dict, List, Optional
from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
@ -29,17 +25,12 @@ from invokeai.backend.model_manager.config import (
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
# images are immutable; set a high max-age
IMAGE_MAX_AGE = 31536000
class ModelsList(BaseModel):
"""Return list of configs."""
@ -114,9 +105,6 @@ async def list_model_records(
found_models.extend(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
)
for model in found_models:
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key)
model.cover_image = cover_image
return ModelsList(models=found_models)
@ -160,8 +148,6 @@ async def get_model_record(
record_store = ApiDependencies.invoker.services.model_manager.store
try:
config: AnyModelConfig = record_store.get_model(key)
cover_image = ApiDependencies.invoker.services.model_images.get_url(key)
config.cover_image = cover_image
return config
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@ -248,40 +234,6 @@ async def scan_for_models(
return scan_results
class HuggingFaceModels(BaseModel):
urls: List[AnyHttpUrl] | None = Field(description="URLs for all checkpoint format models in the metadata")
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model")
@model_manager_router.get(
"/hugging_face",
operation_id="get_hugging_face_models",
responses={
200: {"description": "Hugging Face repo scanned successfully"},
400: {"description": "Invalid hugging face repo"},
},
status_code=200,
response_model=HuggingFaceModels,
)
async def get_hugging_face_models(
hugging_face_repo: str = Query(description="Hugging face repo to search for models", default=None),
) -> HuggingFaceModels:
try:
metadata = HuggingFaceMetadataFetch().from_id(hugging_face_repo)
except UnknownMetadataException:
raise HTTPException(
status_code=400,
detail="No HuggingFace repository found",
)
assert isinstance(metadata, ModelMetadataWithFiles)
return HuggingFaceModels(
urls=metadata.ckpt_urls,
is_diffusers=metadata.is_diffusers,
)
@model_manager_router.patch(
"/i/{key}",
operation_id="update_model_record",
@ -314,75 +266,6 @@ async def update_model_record(
return model_response
@model_manager_router.get(
"/i/{key}/image",
operation_id="get_model_image",
responses={
200: {
"description": "The model image was fetched successfully",
},
400: {"description": "Bad request"},
404: {"description": "The model image could not be found"},
},
status_code=200,
)
async def get_model_image(
key: str = Path(description="The name of model image file to get"),
) -> FileResponse:
"""Gets an image file that previews the model"""
try:
path = ApiDependencies.invoker.services.model_images.get_path(key)
response = FileResponse(
path,
media_type="image/png",
filename=key + ".png",
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
raise HTTPException(status_code=404)
@model_manager_router.patch(
"/i/{key}/image",
operation_id="update_model_image",
responses={
200: {
"description": "The model image was updated successfully",
},
400: {"description": "Bad request"},
},
status_code=200,
)
async def update_model_image(
key: Annotated[str, Path(description="Unique key of model")],
image: UploadFile,
) -> None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
logger = ApiDependencies.invoker.services.logger
model_images = ApiDependencies.invoker.services.model_images
try:
model_images.save(pil_image, key)
logger.info(f"Updated image for model: {key}")
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return
@model_manager_router.delete(
"/i/{key}",
operation_id="delete_model",
@ -413,29 +296,6 @@ async def delete_model(
raise HTTPException(status_code=404, detail=str(e))
@model_manager_router.delete(
"/i/{key}/image",
operation_id="delete_model_image",
responses={
204: {"description": "Model image deleted successfully"},
404: {"description": "Model image not found"},
},
status_code=204,
)
async def delete_model_image(
key: str = Path(description="Unique key of model image to remove from model_images directory."),
) -> None:
logger = ApiDependencies.invoker.services.logger
model_images = ApiDependencies.invoker.services.model_images
try:
model_images.delete(key)
logger.info(f"Deleted model image: {key}")
return
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
# @model_manager_router.post(
# "/i/",
# operation_id="add_model_record",
@ -679,7 +539,7 @@ async def convert_model(
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
# loading the model will convert it into a cached diffusers file
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
# Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key)

View File

@ -2,11 +2,12 @@
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
import sys
from contextlib import asynccontextmanager
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.version.invokeai_version import __version__
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
from .services.config import InvokeAIAppConfig
app_config = InvokeAIAppConfig.get_config()
@ -19,7 +20,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
import asyncio
import mimetypes
import socket
from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path
from typing import Any
@ -40,7 +40,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
@ -60,7 +59,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
BaseInvocation,
UIConfigBase,
)
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
if is_mps_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
@ -158,19 +156,17 @@ def custom_openapi() -> dict[str, Any]:
openapi_schema["components"]["schemas"][schema_key] = output_schema
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
# Some models don't end up in the schemas as standalone definitions
additional_schemas = models_json_schema(
# Add Node Editor UI helper schemas
ui_config_schemas = models_json_schema(
[
(UIConfigBase, "serialization"),
(InputFieldJSONSchemaExtra, "serialization"),
(OutputFieldJSONSchemaExtra, "serialization"),
(ModelIdentifierField, "serialization"),
(ProgressImage, "serialization"),
],
ref_template="#/components/schemas/{model}",
)
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema_json
for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:

View File

@ -20,7 +20,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
from invokeai.backend.util.devices import torch_dtype
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import CLIPField
from .model import ClipField
# unconditioned: Optional[torch.Tensor]
@ -46,7 +46,7 @@ class CompelInvocation(BaseInvocation):
description=FieldDescriptions.compel_prompt,
ui_component=UIComponent.Textarea,
)
clip: CLIPField = InputField(
clip: ClipField = InputField(
title="CLIP",
description=FieldDescriptions.clip,
input=Input.Connection,
@ -54,16 +54,16 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer)
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
@ -127,16 +127,16 @@ class SDXLPromptInvocationBase:
def run_clip_compel(
self,
context: InvocationContext,
clip_field: CLIPField,
clip_field: ClipField,
prompt: str,
get_pooled: bool,
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
tokenizer_info = context.models.load(clip_field.tokenizer)
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder)
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
@ -163,7 +163,7 @@ class SDXLPromptInvocationBase:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight)
@ -253,8 +253,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
crop_left: int = InputField(default=0, description="")
target_width: int = InputField(default=1024, description="")
target_height: int = InputField(default=1024, description="")
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -340,7 +340,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
crop_top: int = InputField(default=0, description="")
crop_left: int = InputField(default=0, description="")
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -370,10 +370,10 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
@invocation_output("clip_skip_output")
class CLIPSkipInvocationOutput(BaseInvocationOutput):
"""CLIP skip node output"""
class ClipSkipInvocationOutput(BaseInvocationOutput):
"""Clip skip node output"""
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation(
@ -383,15 +383,15 @@ class CLIPSkipInvocationOutput(BaseInvocationOutput):
category="conditioning",
version="1.0.0",
)
class CLIPSkipInvocation(BaseInvocation):
class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model."""
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
def invoke(self, context: InvocationContext) -> CLIPSkipInvocationOutput:
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers
return CLIPSkipInvocationOutput(
return ClipSkipInvocationOutput(
clip=self.clip,
)

View File

@ -31,11 +31,9 @@ from invokeai.app.invocations.fields import (
Input,
InputField,
OutputField,
UIType,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
@ -53,9 +51,15 @@ CONTROLNET_RESIZE_VALUES = Literal[
]
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
key: str = Field(description="Model config record key for the ControlNet model")
class ControlField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
@ -91,9 +95,7 @@ class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel
)
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
)
@ -574,7 +576,7 @@ DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
title="Depth Anything Processor",
tags=["controlnet", "depth", "depth anything"],
category="controlnet",
version="1.0.1",
version="1.0.0",
)
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a depth map based on the Depth Anything algorithm"""
@ -583,12 +585,13 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
default="small", description="The size of the depth model to use"
)
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
offload: bool = InputField(default=False)
def run_processor(self, image: Image.Image):
depth_anything_detector = DepthAnythingDetector()
depth_anything_detector.load_model(model_size=self.model_size)
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload)
return processed_image

View File

@ -39,15 +39,13 @@ class UIType(str, Enum, metaclass=MetaEnum):
"""
# region Model Field Types
MainModel = "MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
VAEModel = "VAEModelField"
VaeModel = "VAEModelField"
LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField"
# endregion
# region Misc Field Types
@ -88,6 +86,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
StringPolymorphic = "DEPRECATED_StringPolymorphic"
MainModel = "DEPRECATED_MainModel"
UNet = "DEPRECATED_UNet"
Vae = "DEPRECATED_Vae"
CLIP = "DEPRECATED_CLIP"
@ -229,7 +228,7 @@ class ConditioningField(BaseModel):
# endregion
class MetadataField(RootModel[dict[str, Any]]):
class MetadataField(RootModel):
"""
Pydantic model for metadata with custom root of type dict[str, Any].
Metadata is stored without a strict schema.

View File

@ -10,18 +10,26 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, IPAdapterConfig, ModelType
from invokeai.backend.model_manager.config import BaseModelType, ModelType
# LS: Consider moving these two classes into model.py
class IPAdapterModelField(BaseModel):
key: str = Field(description="Key to the IP-Adapter model")
class CLIPVisionModelField(BaseModel):
key: str = Field(description="Key to the CLIP Vision image encoder model")
class IPAdapterField(BaseModel):
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
@ -54,12 +62,8 @@ class IPAdapterInvocation(BaseInvocation):
# Inputs
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.",
title="IP-Adapter Model",
input=Input.Direct,
ui_order=-1,
ui_type=UIType.IPAdapterModel,
ip_adapter_model: IPAdapterModelField = InputField(
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
)
weight: Union[float, List[float]] = InputField(
@ -86,35 +90,20 @@ class IPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, IPAdapterConfig)
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
assert len(image_encoder_models) == 1
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
image_encoder_model=image_encoder_model,
weight=self.weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
),
)
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
found = False
while not found:
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
found = len(image_encoder_models) > 0
if not found:
context.logger.warning(
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed."
)
context.logger.warning("Downloading and installing now. This may take a while.")
installer = context._services.model_manager.install
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
installer.wait_for_job(job, timeout=600) # wait up to 10 minutes - then raise a TimeoutException
assert len(image_encoder_models) == 1
return image_encoder_models[0]

View File

@ -26,7 +26,6 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
from PIL import Image, ImageFilter
from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import (
@ -66,6 +65,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
T2IAdapterData,
image_resized_to_grid_as_tensor,
)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device
from .baseinvocation import (
@ -75,7 +75,7 @@ from .baseinvocation import (
invocation_output,
)
from .controlnet_image_processors import ControlField
from .model import ModelIdentifierField, UNetField, VAEField
from .model import ModelInfo, UNetField, VaeField
if choose_torch_device() == torch.device("mps"):
from torch import mps
@ -118,7 +118,7 @@ class SchedulerInvocation(BaseInvocation):
class CreateDenoiseMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
@ -153,7 +153,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
)
if image_tensor is not None:
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(**self.vae.vae.model_dump())
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
@ -244,12 +244,12 @@ class CreateGradientMaskInvocation(BaseInvocation):
def get_scheduler(
context: InvocationContext,
scheduler_info: ModelIdentifierField,
scheduler_info: ModelInfo,
scheduler_name: str,
seed: int,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info)
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@ -383,6 +383,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
text_embeddings=c,
guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
postprocessing_settings=PostprocessingSettings(
threshold=0.0, # threshold,
warmup=0.2, # warmup,
h_symmetry_time_pct=None, # h_symmetry_time_pct,
v_symmetry_time_pct=None, # v_symmetry_time_pct,
),
)
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
@ -455,7 +461,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# and if weight is None, populate with default 1.0?
controlnet_data = []
for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
# control_models.append(control_model)
control_image_field = control_info.image
@ -517,10 +523,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.models.load(single_ip_adapter.ip_adapter_model)
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
@ -531,7 +538,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
single_ipa_images, image_encoder_model
@ -571,8 +577,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
image = context.images.get_pil(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
@ -677,7 +683,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
if self.denoise_mask.masked_latents_name is not None:
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
else:
masked_latents = torch.where(mask < 0.5, 0.0, latents)
masked_latents = None
return 1 - mask, masked_latents, self.denoise_mask.gradient
@ -725,13 +731,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
yield (lora_info.model, lora.weight)
del lora_info
return
unet_info = context.models.load(self.unet.unet)
unet_info = context.models.load(**self.unet.unet.model_dump())
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
@ -825,7 +830,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VAEField = InputField(
vae: VaeField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@ -836,15 +841,15 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
vae_info = context.models.load(**self.vae.vae.model_dump())
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module)
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
use_torch_2_0_or_xformers = isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
@ -1003,7 +1008,7 @@ class ImageToLatentsInvocation(BaseInvocation):
image: ImageField = InputField(
description="The image to encode",
)
vae: VAEField = InputField(
vae: VaeField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@ -1018,7 +1023,7 @@ class ImageToLatentsInvocation(BaseInvocation):
if upcast:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
use_torch_2_0_or_xformers = isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
@ -1059,7 +1064,7 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(**self.vae.vae.model_dump())
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:

View File

@ -8,10 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import (
CONTROLNET_MODE_VALUES,
CONTROLNET_RESIZE_VALUES,
)
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
@ -20,7 +17,9 @@ from invokeai.app.invocations.fields import (
OutputField,
UIType,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from ...version import __version__
@ -34,7 +33,7 @@ class MetadataItemField(BaseModel):
class LoRAMetadataField(BaseModel):
"""LoRA Metadata Field"""
model: ModelIdentifierField = Field(description=FieldDescriptions.lora_model)
model: LoRAModelField = Field(description=FieldDescriptions.lora_model)
weight: float = Field(description=FieldDescriptions.lora_weight)
@ -42,41 +41,16 @@ class IPAdapterMetadataField(BaseModel):
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
ip_adapter_model: IPAdapterModelField = Field(
description="The IP-Adapter model.",
)
weight: Union[float, list[float]] = Field(
description="The weight given to the IP-Adapter",
)
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
class T2IAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The control image.")
processed_image: Optional[ImageField] = Field(default=None, description="The control image, after processing.")
t2i_adapter_model: ModelIdentifierField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
class ControlNetMetadataField(BaseModel):
image: ImageField = Field(description="The control image")
processed_image: Optional[ImageField] = Field(default=None, description="The control image, after processing.")
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@invocation_output("metadata_item_output")
class MetadataItemOutput(BaseInvocationOutput):
"""Metadata Item Output"""
@ -166,14 +140,14 @@ class CoreMetadataInvocation(BaseInvocation):
default=None,
description="The number of skipped CLIP layers",
)
model: Optional[ModelIdentifierField] = InputField(default=None, description="The main model used for inference")
controlnets: Optional[list[ControlNetMetadataField]] = InputField(
model: Optional[MainModelField] = InputField(default=None, description="The main model used for inference")
controlnets: Optional[list[ControlField]] = InputField(
default=None, description="The ControlNets used for inference"
)
ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField(
default=None, description="The IP Adapters used for inference"
)
t2iAdapters: Optional[list[T2IAdapterMetadataField]] = InputField(
t2iAdapters: Optional[list[T2IAdapterField]] = InputField(
default=None, description="The IP Adapters used for inference"
)
loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference")
@ -185,7 +159,7 @@ class CoreMetadataInvocation(BaseInvocation):
default=None,
description="The name of the initial image",
)
vae: Optional[ModelIdentifierField] = InputField(
vae: Optional[VAEModelField] = InputField(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
@ -216,7 +190,7 @@ class CoreMetadataInvocation(BaseInvocation):
)
# SDXL Refiner
refiner_model: Optional[ModelIdentifierField] = InputField(
refiner_model: Optional[MainModelField] = InputField(
default=None,
description="The SDXL Refiner model used",
)
@ -248,9 +222,10 @@ class CoreMetadataInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> MetadataOutput:
"""Collects and outputs a CoreMetadata object"""
as_dict = self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
as_dict["app_version"] = __version__
return MetadataOutput(metadata=MetadataField.model_validate(as_dict))
return MetadataOutput(
metadata=MetadataField.model_validate(
self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
)
)
model_config = ConfigDict(extra="allow")

View File

@ -3,11 +3,11 @@ from typing import List, Optional
from pydantic import BaseModel, Field
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from ...backend.model_manager import SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -16,52 +16,33 @@ from .baseinvocation import (
)
class ModelIdentifierField(BaseModel):
key: str = Field(description="The model's unique key")
hash: str = Field(description="The model's BLAKE3 hash")
name: str = Field(description="The model's name")
base: BaseModelType = Field(description="The model's base model type")
type: ModelType = Field(description="The model's type")
submodel_type: Optional[SubModelType] = Field(
description="The submodel to load, if this is a main model", default=None
)
@classmethod
def from_config(
cls, config: "AnyModelConfig", submodel_type: Optional[SubModelType] = None
) -> "ModelIdentifierField":
return cls(
key=config.key,
hash=config.hash,
name=config.name,
base=config.base,
type=config.type,
submodel_type=submodel_type,
)
class ModelInfo(BaseModel):
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
class LoRAField(BaseModel):
lora: ModelIdentifierField = Field(description="Info to load lora model")
weight: float = Field(description="Weight to apply to lora model")
class LoraInfo(ModelInfo):
weight: float = Field(description="Lora's weight which to use when apply to model")
class UNetField(BaseModel):
unet: ModelIdentifierField = Field(description="Info to load unet submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
class CLIPField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
class ClipField(BaseModel):
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@ -76,14 +57,14 @@ class UNetOutput(BaseInvocationOutput):
class VAEOutput(BaseInvocationOutput):
"""Base class for invocations that output a VAE field"""
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation_output("clip_output")
class CLIPOutput(BaseInvocationOutput):
"""Base class for invocations that output a CLIP field"""
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
@invocation_output("model_loader_output")
@ -93,6 +74,18 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
pass
class MainModelField(BaseModel):
"""Main model field"""
key: str = Field(description="Model key")
class LoRAModelField(BaseModel):
"""LoRA model field"""
key: str = Field(description="LoRA model key")
@invocation(
"main_model_loader",
title="Main Model",
@ -103,44 +96,62 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
)
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
# TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
# TODO: not found exceptions
if not context.models.exists(self.model.key):
raise Exception(f"Unknown model {self.model.key}")
key = self.model.key
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
# TODO: not found exceptions
if not context.models.exists(key):
raise Exception(f"Unknown model {key}")
return ModelLoaderOutput(
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
vae=VAEField(vae=vae),
unet=UNetField(
unet=ModelInfo(
key=key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=key,
submodel_type=SubModelType.VAE,
),
),
)
@invocation_output("lora_loader_output")
class LoRALoaderOutput(BaseInvocationOutput):
class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
class LoRALoaderInvocation(BaseInvocation):
class LoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
)
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
default=None,
@ -148,41 +159,46 @@ class LoRALoaderInvocation(BaseInvocation):
input=Input.Connection,
title="UNet",
)
clip: Optional[CLIPField] = InputField(
clip: Optional[ClipField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP",
)
def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise Exception(f"Unkown lora: {lora_key}!")
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'LoRA "{lora_key}" already applied to unet')
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'LoRA "{lora_key}" already applied to clip')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip')
output = LoRALoaderOutput()
output = LoraLoaderOutput()
if self.unet is not None:
output.unet = self.unet.model_copy(deep=True)
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoRAField(
lora=self.lora,
LoraInfo(
key=lora_key,
submodel_type=None,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = self.clip.model_copy(deep=True)
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoRAField(
lora=self.lora,
LoraInfo(
key=lora_key,
submodel_type=None,
weight=self.weight,
)
)
@ -191,12 +207,12 @@ class LoRALoaderInvocation(BaseInvocation):
@invocation_output("sdxl_lora_loader_output")
class SDXLLoRALoaderOutput(BaseInvocationOutput):
class SDXLLoraLoaderOutput(BaseInvocationOutput):
"""SDXL LoRA Loader Output"""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
clip2: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
@invocation(
@ -206,12 +222,10 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
category="model",
version="1.0.1",
)
class SDXLLoRALoaderInvocation(BaseInvocation):
class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
)
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
default=None,
@ -219,59 +233,65 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
input=Input.Connection,
title="UNet",
)
clip: Optional[CLIPField] = InputField(
clip: Optional[ClipField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP 1",
)
clip2: Optional[CLIPField] = InputField(
clip2: Optional[ClipField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP 2",
)
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise Exception(f"Unknown lora: {lora_key}!")
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'LoRA "{lora_key}" already applied to unet')
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
raise Exception(f'Lora "{lora_key}" already applied to unet')
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'LoRA "{lora_key}" already applied to clip')
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip')
if self.clip2 is not None and any(lora.lora.key == lora_key for lora in self.clip2.loras):
raise Exception(f'LoRA "{lora_key}" already applied to clip2')
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_key}" already applied to clip2')
output = SDXLLoRALoaderOutput()
output = SDXLLoraLoaderOutput()
if self.unet is not None:
output.unet = self.unet.model_copy(deep=True)
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoRAField(
lora=self.lora,
LoraInfo(
key=lora_key,
submodel_type=None,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = self.clip.model_copy(deep=True)
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoRAField(
lora=self.lora,
LoraInfo(
key=lora_key,
submodel_type=None,
weight=self.weight,
)
)
if self.clip2 is not None:
output.clip2 = self.clip2.model_copy(deep=True)
output.clip2 = copy.deepcopy(self.clip2)
output.clip2.loras.append(
LoRAField(
lora=self.lora,
LoraInfo(
key=lora_key,
submodel_type=None,
weight=self.weight,
)
)
@ -279,12 +299,20 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
return output
class VAEModelField(BaseModel):
"""Vae model field"""
key: str = Field(description="Model's key")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
class VAELoaderInvocation(BaseInvocation):
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel
vae_model: VAEModelField = InputField(
description=FieldDescriptions.vae_model,
input=Input.Direct,
title="VAE",
)
def invoke(self, context: InvocationContext) -> VAEOutput:
@ -293,7 +321,7 @@ class VAELoaderInvocation(BaseInvocation):
if not context.models.exists(key):
raise Exception(f"Unkown vae: {key}!")
return VAEOutput(vae=VAEField(vae=self.vae_model))
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
@invocation_output("seamless_output")
@ -301,7 +329,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
"""Modified Seamless Model output"""
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
vae: Optional[VAEField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
vae: Optional[VaeField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
@invocation(
@ -320,7 +348,7 @@ class SeamlessModeInvocation(BaseInvocation):
input=Input.Connection,
title="UNet",
)
vae: Optional[VAEField] = InputField(
vae: Optional[VaeField] = InputField(
default=None,
description=FieldDescriptions.vae_model,
input=Input.Connection,

View File

@ -8,7 +8,7 @@ from .baseinvocation import (
invocation,
invocation_output,
)
from .model import CLIPField, ModelIdentifierField, UNetField, VAEField
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
@invocation_output("sdxl_model_loader_output")
@ -16,9 +16,9 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL base model loader output"""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation_output("sdxl_refiner_model_loader_output")
@ -26,15 +26,15 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output"""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1")
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
)
# TODO: precision?
@ -46,19 +46,48 @@ class SDXLModelLoaderInvocation(BaseInvocation):
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return SDXLModelLoaderOutput(
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
vae=VAEField(vae=vae),
unet=UNetField(
unet=ModelInfo(
key=model_key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.VAE,
),
),
)
@ -72,8 +101,10 @@ class SDXLModelLoaderInvocation(BaseInvocation):
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
model: MainModelField = InputField(
description=FieldDescriptions.sdxl_refiner_model,
input=Input.Direct,
ui_type=UIType.SDXLRefinerModel,
)
# TODO: precision?
@ -84,14 +115,34 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return SDXLRefinerModelLoaderOutput(
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
vae=VAEField(vae=vae),
unet=UNetField(
unet=ModelInfo(
key=model_key,
submodel_type=SubModelType.UNet,
),
scheduler=ModelInfo(
key=model_key,
submodel_type=SubModelType.Scheduler,
),
loras=[],
),
clip2=ClipField(
tokenizer=ModelInfo(
key=model_key,
submodel_type=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
key=model_key,
submodel_type=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
key=model_key,
submodel_type=SubModelType.VAE,
),
),
)

View File

@ -9,15 +9,18 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
class T2IAdapterModelField(BaseModel):
key: str = Field(description="Model record key for the T2I-Adapter model")
class T2IAdapterField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: ModelIdentifierField = Field(description="The T2I-Adapter model to use.")
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
@ -52,12 +55,11 @@ class T2IAdapterInvocation(BaseInvocation):
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
t2i_adapter_model: ModelIdentifierField = InputField(
t2i_adapter_model: T2IAdapterModelField = InputField(
description="The T2I-Adapter model.",
title="T2I-Adapter Model",
input=Input.Direct,
ui_order=-1,
ui_type=UIType.T2IAdapterModel,
)
weight: Union[float, list[float]] = InputField(
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"

View File

@ -17,8 +17,7 @@ from argparse import ArgumentParser
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
from omegaconf import DictConfig, DictKeyType, ListConfig, OmegaConf
from pydantic import BaseModel
from omegaconf import DictConfig, ListConfig, OmegaConf
from pydantic_settings import BaseSettings, SettingsConfigDict
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
@ -63,22 +62,6 @@ class InvokeAISettings(BaseSettings):
assert isinstance(category, str)
if category not in field_dict[type]:
field_dict[type][category] = {}
if isinstance(value, BaseModel):
dump = value.model_dump(exclude_defaults=True, exclude_unset=True, exclude_none=True)
field_dict[type][category][name] = dump
continue
if isinstance(value, list):
if not value or len(value) == 0:
continue
primitive = isinstance(value[0], get_args(DictKeyType))
if not primitive:
val_list: List[Dict[str, Any]] = []
for list_val in value:
if isinstance(list_val, BaseModel):
dump = list_val.model_dump(exclude_defaults=True, exclude_unset=True, exclude_none=True)
val_list.append(dump)
field_dict[type][category][name] = val_list
continue
# keep paths as strings to make it easier to read
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
conf = OmegaConf.create(field_dict)

View File

@ -170,17 +170,14 @@ two configs are kept in separate sections of the config file:
from __future__ import annotations
import os
import re
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Literal, Optional
from omegaconf import DictConfig, OmegaConf
from pydantic import BaseModel, Field, field_validator
from pydantic import Field
from pydantic.config import JsonDict
from pydantic_settings import SettingsConfigDict
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from .config_base import InvokeAISettings
INIT_FILE = Path("invokeai.yaml")
@ -199,87 +196,17 @@ class Categories(object):
Paths: JsonDict = {"category": "Paths"}
Logging: JsonDict = {"category": "Logging"}
Development: JsonDict = {"category": "Development"}
CLIArgs: JsonDict = {"category": "CLIArgs"}
ModelInstall: JsonDict = {"category": "Model Install"}
Other: JsonDict = {"category": "Other"}
ModelCache: JsonDict = {"category": "Model Cache"}
Device: JsonDict = {"category": "Device"}
Generation: JsonDict = {"category": "Generation"}
Queue: JsonDict = {"category": "Queue"}
Nodes: JsonDict = {"category": "Nodes"}
MemoryPerformance: JsonDict = {"category": "Memory/Performance"}
Deprecated: JsonDict = {"category": "Deprecated"}
class URLRegexToken(BaseModel):
url_regex: str = Field(description="Regular expression to match against the URL")
token: str = Field(description="Token to use when the URL matches the regex")
@field_validator("url_regex")
@classmethod
def validate_url_regex(cls, v: str) -> str:
"""Validate that the value is a valid regex."""
try:
re.compile(v)
except re.error as e:
raise ValueError(f"Invalid regex: {e}")
return v
class InvokeAIAppConfig(InvokeAISettings):
"""Invoke App Configuration
Attributes:
host: **Web Server**: IP address to bind to. Use `0.0.0.0` to serve to your local network.
port: **Web Server**: Port to bind to.
allow_origins: **Web Server**: Allowed CORS origins.
allow_credentials: **Web Server**: Allow CORS credentials.
allow_methods: **Web Server**: Methods allowed for CORS.
allow_headers: **Web Server**: Headers allowed for CORS.
ssl_certfile: **Web Server**: SSL certificate file for HTTPS.
ssl_keyfile: **Web Server**: SSL key file for HTTPS.
esrgan: **Features**: Enables or disables the upscaling code.
internet_available: **Features**: If true, attempt to download models on the fly; otherwise only use local models.
log_tokenization: **Features**: Enable logging of parsed prompt tokens.
patchmatch: **Features**: Enable patchmatch inpaint code.
ignore_missing_core_models: **Features**: Ignore missing core models on startup. If `True`, the app will attempt to download missing models on startup.
root: **Paths**: The InvokeAI runtime root directory.
autoimport_dir: **Paths**: Path to a directory of models files to be imported on startup.
models_dir: **Paths**: Path to the models directory.
convert_cache_dir: **Paths**: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
legacy_conf_dir: **Paths**: Path to directory of legacy checkpoint config files.
db_dir: **Paths**: Path to InvokeAI databases directory.
outdir: **Paths**: Path to directory for outputs.
custom_nodes_dir: **Paths**: Path to directory for custom nodes.
from_file: **Paths**: Take command input from the indicated file (command-line client only).
log_handlers: **Logging**: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
log_format: **Logging**: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.
log_level: **Logging**: Emit logging messages at this level or higher.
log_sql: **Logging**: Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.
use_memory_db: **Development**: Use in-memory database. Useful for development.
dev_reload: **Development**: Automatically reload when Python sources are changed. Does not reload node definitions.
profile_graphs: **Development**: Enable graph profiling using `cProfile`.
profile_prefix: **Development**: An optional prefix for profile output files.
profiles_dir: **Development**: Path to profiles output directory.
version: **CLIArgs**: CLI arg - show InvokeAI version and exit.
hashing_algorithm: **Model Install**: Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.
remote_api_tokens: **Model Install**: List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.
ram: **Model Cache**: Maximum memory amount used by memory model cache for rapid switching (GB).
vram: **Model Cache**: Amount of VRAM reserved for model storage (GB)
convert_cache: **Model Cache**: Maximum size of on-disk converted models cache (GB)
lazy_offload: **Model Cache**: Keep models in VRAM until their space is needed.
log_memory_usage: **Model Cache**: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: **Device**: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
precision: **Device**: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
sequential_guidance: **Generation**: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: **Generation**: Attention type.
attention_slice_size: **Generation**: Slice size, valid when attention_type=="sliced".
force_tiled_decode: **Generation**: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
png_compress_level: **Generation**: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: **Queue**: Maximum number of items in the session queue.
allow_nodes: **Nodes**: List of nodes to allow. Omit to allow all.
deny_nodes: **Nodes**: List of nodes to deny. Omit to deny none.
node_cache_size: **Nodes**: How many cached nodes to keep in memory.
"""
"""Configuration object for InvokeAI App."""
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
singleton_init: ClassVar[Optional[Dict[str, Any]]] = None
@ -288,98 +215,91 @@ class InvokeAIAppConfig(InvokeAISettings):
type: Literal["InvokeAI"] = "InvokeAI"
# WEB
host : str = Field(default="127.0.0.1", description="IP address to bind to. Use `0.0.0.0` to serve to your local network.", json_schema_extra=Categories.WebServer)
port : int = Field(default=9090, description="Port to bind to.", json_schema_extra=Categories.WebServer)
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins.", json_schema_extra=Categories.WebServer)
allow_credentials : bool = Field(default=True, description="Allow CORS credentials.", json_schema_extra=Categories.WebServer)
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS.", json_schema_extra=Categories.WebServer)
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS.", json_schema_extra=Categories.WebServer)
host : str = Field(default="127.0.0.1", description="IP address to bind to", json_schema_extra=Categories.WebServer)
port : int = Field(default=9090, description="Port to bind to", json_schema_extra=Categories.WebServer)
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", json_schema_extra=Categories.WebServer)
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", json_schema_extra=Categories.WebServer)
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", json_schema_extra=Categories.WebServer)
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", json_schema_extra=Categories.WebServer)
# SSL options correspond to https://www.uvicorn.org/settings/#https
ssl_certfile : Optional[Path] = Field(default=None, description="SSL certificate file for HTTPS.", json_schema_extra=Categories.WebServer)
ssl_keyfile : Optional[Path] = Field(default=None, description="SSL key file for HTTPS.", json_schema_extra=Categories.WebServer)
ssl_certfile : Optional[Path] = Field(default=None, description="SSL certificate file (for HTTPS)", json_schema_extra=Categories.WebServer)
ssl_keyfile : Optional[Path] = Field(default=None, description="SSL key file", json_schema_extra=Categories.WebServer)
# FEATURES
esrgan : bool = Field(default=True, description="Enables or disables the upscaling code.", json_schema_extra=Categories.Features)
# TODO(psyche): This is not used anywhere.
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models.", json_schema_extra=Categories.Features)
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", json_schema_extra=Categories.Features)
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", json_schema_extra=Categories.Features)
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", json_schema_extra=Categories.Features)
patchmatch : bool = Field(default=True, description="Enable patchmatch inpaint code.", json_schema_extra=Categories.Features)
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing core models on startup. If `True`, the app will attempt to download missing models on startup.', json_schema_extra=Categories.Features)
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", json_schema_extra=Categories.Features)
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', json_schema_extra=Categories.Features)
# PATHS
root : Optional[Path] = Field(default=None, description='The InvokeAI runtime root directory.', json_schema_extra=Categories.Paths)
root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths)
autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
models_dir : Path = Field(default=Path('models'), description='Path to the models directory.', json_schema_extra=Categories.Paths)
convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.', json_schema_extra=Categories.Paths)
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files.', json_schema_extra=Categories.Paths)
db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory.', json_schema_extra=Categories.Paths)
outdir : Path = Field(default=Path('outputs'), description='Path to directory for outputs.', json_schema_extra=Categories.Paths)
custom_nodes_dir : Path = Field(default=Path('nodes'), description='Path to directory for custom nodes.', json_schema_extra=Categories.Paths)
# TODO(psyche): This is not used anywhere.
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only).', json_schema_extra=Categories.Paths)
models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths)
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', json_schema_extra=Categories.Paths)
custom_nodes_dir : Path = Field(default=Path('nodes'), description='Path to directory for custom nodes', json_schema_extra=Categories.Paths)
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only)', json_schema_extra=Categories.Paths)
# LOGGING
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".', json_schema_extra=Categories.Logging)
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', json_schema_extra=Categories.Logging)
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.', json_schema_extra=Categories.Logging)
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher.", json_schema_extra=Categories.Logging)
log_sql : bool = Field(default=False, description="Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.", json_schema_extra=Categories.Logging)
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', json_schema_extra=Categories.Logging)
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", json_schema_extra=Categories.Logging)
log_sql : bool = Field(default=False, description="Log SQL queries", json_schema_extra=Categories.Logging)
# Development
use_memory_db : bool = Field(default=False, description='Use in-memory database. Useful for development.', json_schema_extra=Categories.Development)
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed. Does not reload node definitions.", json_schema_extra=Categories.Development)
profile_graphs : bool = Field(default=False, description="Enable graph profiling using `cProfile`.", json_schema_extra=Categories.Development)
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", json_schema_extra=Categories.Development)
profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
profiles_dir : Path = Field(default=Path('profiles'), description="Path to profiles output directory.", json_schema_extra=Categories.Development)
profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce startup time.", json_schema_extra=Categories.Development)
version : bool = Field(default=False, description="CLI arg - show InvokeAI version and exit.", json_schema_extra=Categories.CLIArgs)
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
# CACHE
ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).", json_schema_extra=Categories.ModelCache, )
vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB)", json_schema_extra=Categories.ModelCache, )
ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
convert_cache : float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache)
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed.", json_schema_extra=Categories.ModelCache, )
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache)
# DEVICE
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.", json_schema_extra=Categories.Device)
precision : Literal["auto", "float16", "bfloat16", "float32", "autocast"] = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.", json_schema_extra=Categories.Device)
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)
precision : Literal["auto", "float16", "bfloat16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", json_schema_extra=Categories.Device)
# GENERATION
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.", json_schema_extra=Categories.Generation)
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type.", json_schema_extra=Categories.Generation)
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced".', json_schema_extra=Categories.Generation)
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).", json_schema_extra=Categories.Generation)
png_compress_level : int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.", json_schema_extra=Categories.Generation)
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", json_schema_extra=Categories.Generation)
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", json_schema_extra=Categories.Generation)
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', json_schema_extra=Categories.Generation)
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Generation)
png_compress_level : int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", json_schema_extra=Categories.Generation)
# QUEUE
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.", json_schema_extra=Categories.Queue)
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", json_schema_extra=Categories.Queue)
# NODES
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", json_schema_extra=Categories.Nodes)
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", json_schema_extra=Categories.Nodes)
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory.", json_schema_extra=Categories.Nodes)
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
# MODEL INSTALL
hashing_algorithm : HASHING_ALGORITHMS = Field(default="blake3", description="Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.", json_schema_extra=Categories.ModelInstall)
remote_api_tokens : Optional[list[URLRegexToken]] = Field(
default=None,
description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.",
json_schema_extra=Categories.ModelInstall
)
# MODEL IMPORT
civitai_api_key : Optional[str] = Field(default=os.environ.get("CIVITAI_API_KEY"), description="API key for CivitAI", json_schema_extra=Categories.Other)
# TODO(psyche): Can we just remove these then?
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.Deprecated)
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.Deprecated)
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", json_schema_extra=Categories.Deprecated)
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", json_schema_extra=Categories.Deprecated)
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Deprecated)
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Deprecated)
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Deprecated)
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Deprecated)
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Deprecated)
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", json_schema_extra=Categories.MemoryPerformance)
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", json_schema_extra=Categories.MemoryPerformance)
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.MemoryPerformance)
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
# this is not referred to in the source code and can be removed entirely
#free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
@ -557,53 +477,6 @@ class InvokeAIAppConfig(InvokeAISettings):
"""Choose the runtime root directory when not specified on command line or init file."""
return _find_root()
@staticmethod
def generate_docstrings() -> str:
"""Helper function for mkdocs. Generates a docstring for the InvokeAIAppConfig class.
You shouldn't run this manually. Instead, run `scripts/update-config-docstring.py` to update the docstring.
A makefile target is also available: `make update-config-docstring`.
See that script for more information about why this is necessary.
"""
docstring = ' """Invoke App Configuration\n\n'
docstring += " Attributes:"
field_descriptions: dict[str, list[str]] = {}
for k, v in InvokeAIAppConfig.model_fields.items():
if not isinstance(v.json_schema_extra, dict):
# Should never happen
continue
category = v.json_schema_extra.get("category", None)
if not isinstance(category, str) or category == "Deprecated":
continue
if not field_descriptions.get(category):
field_descriptions[category] = []
field_descriptions[category].append(f" {k}: **{category}**: {v.description}")
for c in [
"Web Server",
"Features",
"Paths",
"Logging",
"Development",
"CLIArgs",
"Model Install",
"Model Cache",
"Device",
"Generation",
"Queue",
"Nodes",
]:
docstring += "\n"
docstring += "\n".join(field_descriptions[c])
docstring += '\n """'
return docstring
def get_invokeai_config(**kwargs: Any) -> InvokeAIAppConfig:
"""Legacy function which returns InvokeAIAppConfig.get_config()."""

View File

@ -12,7 +12,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
)
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager import AnyModelConfig
from invokeai.backend.model_manager.config import SubModelType
class EventServiceBase:
@ -81,7 +80,7 @@ class EventServiceBase:
"graph_execution_state_id": graph_execution_state_id,
"node_id": node_id,
"source_node_id": source_node_id,
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None,
"progress_image": progress_image.model_dump() if progress_image is not None else None,
"step": step,
"order": order,
"total_steps": total_steps,
@ -181,7 +180,6 @@ class EventServiceBase:
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is requested"""
self.__emit_queue_event(
@ -191,8 +189,7 @@ class EventServiceBase:
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
"model_config": model_config.model_dump(),
},
)
@ -203,7 +200,6 @@ class EventServiceBase:
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event(
@ -213,8 +209,7 @@ class EventServiceBase:
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
"model_config": model_config.model_dump(),
},
)
@ -259,8 +254,8 @@ class EventServiceBase:
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
},
"batch_status": batch_status.model_dump(mode="json"),
"queue_status": queue_status.model_dump(mode="json"),
"batch_status": batch_status.model_dump(),
"queue_status": queue_status.model_dump(),
},
)
@ -410,7 +405,7 @@ class EventServiceBase:
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
)
def emit_model_install_cancelled(self, source: str, id: int) -> None:
def emit_model_install_cancelled(self, source: str) -> None:
"""
Emit when an install job is cancelled.
@ -418,7 +413,7 @@ class EventServiceBase:
"""
self.__emit_model_event(
event_name="model_install_cancelled",
payload={"source": source, "id": id},
payload={"source": source},
)
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:

View File

@ -41,9 +41,8 @@ class InvocationCacheBase(ABC):
"""Clears the cache"""
pass
@staticmethod
@abstractmethod
def create_key(invocation: BaseInvocation) -> int:
def create_key(self, invocation: BaseInvocation) -> int:
"""Gets the key for the invocation's cache item"""
pass

View File

@ -61,7 +61,9 @@ class MemoryInvocationCache(InvocationCacheBase):
self._delete_oldest_access(number_to_delete)
self._cache[key] = CachedItem(
invocation_output,
invocation_output.model_dump_json(warnings=False, exclude_defaults=True, exclude_unset=True),
invocation_output.model_dump_json(
warnings=False, exclude_defaults=True, exclude_unset=True, include={"type"}
),
)
def _delete_oldest_access(self, number_to_delete: int) -> None:
@ -79,7 +81,7 @@ class MemoryInvocationCache(InvocationCacheBase):
with self._lock:
return self._delete(key)
def clear(self) -> None:
def clear(self, *args, **kwargs) -> None:
with self._lock:
if self._max_cache_size == 0:
return

View File

@ -25,7 +25,6 @@ if TYPE_CHECKING:
from .images.images_base import ImageServiceABC
from .invocation_cache.invocation_cache_base import InvocationCacheBase
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .model_images.model_images_base import ModelImageFileStorageBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
@ -50,7 +49,6 @@ class InvocationServices:
image_files: "ImageFileStorageBase",
image_records: "ImageRecordStorageBase",
logger: "Logger",
model_images: "ModelImageFileStorageBase",
model_manager: "ModelManagerServiceBase",
download_queue: "DownloadQueueServiceBase",
performance_statistics: "InvocationStatsServiceBase",
@ -74,7 +72,6 @@ class InvocationServices:
self.image_files = image_files
self.image_records = image_records
self.logger = logger
self.model_images = model_images
self.model_manager = model_manager
self.download_queue = download_queue
self.performance_statistics = performance_statistics

View File

@ -1,33 +0,0 @@
from abc import ABC, abstractmethod
from pathlib import Path
from PIL.Image import Image as PILImageType
class ModelImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, model_key: str) -> PILImageType:
"""Retrieves a model image as PIL Image."""
pass
@abstractmethod
def get_path(self, model_key: str) -> Path:
"""Gets the internal path to a model image."""
pass
@abstractmethod
def get_url(self, model_key: str) -> str | None:
"""Gets the URL to fetch a model image."""
pass
@abstractmethod
def save(self, image: PILImageType, model_key: str) -> None:
"""Saves a model image."""
pass
@abstractmethod
def delete(self, model_key: str) -> None:
"""Deletes a model image."""
pass

View File

@ -1,20 +0,0 @@
# TODO: Should these excpetions subclass existing python exceptions?
class ModelImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message="Model image file not found"):
super().__init__(message)
class ModelImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message="Model image file not saved"):
super().__init__(message)
class ModelImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message="Model image file not deleted"):
super().__init__(message)

View File

@ -1,85 +0,0 @@
from pathlib import Path
from PIL import Image
from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.services.invoker import Invoker
from invokeai.app.util.misc import uuid_string
from invokeai.app.util.thumbnails import make_thumbnail
from .model_images_base import ModelImageFileStorageBase
from .model_images_common import (
ModelImageFileDeleteException,
ModelImageFileNotFoundException,
ModelImageFileSaveException,
)
class ModelImageFileStorageDisk(ModelImageFileStorageBase):
"""Stores images on disk"""
def __init__(self, model_images_folder: Path):
self._model_images_folder = model_images_folder
self._validate_storage_folders()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def get(self, model_key: str) -> PILImageType:
try:
path = self.get_path(model_key)
if not self._validate_path(path):
raise ModelImageFileNotFoundException
return Image.open(path)
except FileNotFoundError as e:
raise ModelImageFileNotFoundException from e
def save(self, image: PILImageType, model_key: str) -> None:
try:
self._validate_storage_folders()
image_path = self._model_images_folder / (model_key + ".webp")
thumbnail = make_thumbnail(image, 256)
thumbnail.save(image_path, format="webp")
except Exception as e:
raise ModelImageFileSaveException from e
def get_path(self, model_key: str) -> Path:
path = self._model_images_folder / (model_key + ".webp")
return path
def get_url(self, model_key: str) -> str | None:
path = self.get_path(model_key)
if not self._validate_path(path):
return
url = self._invoker.services.urls.get_model_image_url(model_key)
# The image URL never changes, so we must add random query string to it to prevent caching
url += f"?{uuid_string()}"
return url
def delete(self, model_key: str) -> None:
try:
path = self.get_path(model_key)
if not self._validate_path(path):
raise ModelImageFileNotFoundException
send2trash(path)
except Exception as e:
raise ModelImageFileDeleteException from e
def _validate_path(self, path: Path) -> bool:
"""Validates the path given for an image."""
return path.exists()
def _validate_storage_folders(self) -> None:
"""Checks if the required folders exist and create them if they don't"""
self._model_images_folder.mkdir(parents=True, exist_ok=True)

View File

@ -1,6 +1,7 @@
"""Initialization file for model install service package."""
from .model_install_base import (
CivitaiModelSource,
HFModelSource,
InstallStatus,
LocalModelSource,
@ -22,4 +23,5 @@ __all__ = [
"LocalModelSource",
"HFModelSource",
"URLModelSource",
"CivitaiModelSource",
]

View File

@ -91,6 +91,21 @@ class LocalModelSource(StringLikeSource):
return Path(self.path).as_posix()
class CivitaiModelSource(StringLikeSource):
"""A Civitai version id, with optional variant and access token."""
version_id: int
variant: Optional[ModelRepoVariant] = None
access_token: Optional[str] = None
type: Literal["civitai"] = "civitai"
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = str(self.version_id)
base += f" ({self.variant})" if self.variant else ""
return base
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
@ -131,11 +146,14 @@ class URLModelSource(StringLikeSource):
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
ModelSource = Annotated[
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
CivitaiModelSource: ModelSourceType.CivitAI,
LocalModelSource: ModelSourceType.Path,
}

View File

@ -12,7 +12,6 @@ from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Set, Union
from huggingface_hub import HfFolder
from omegaconf import DictConfig, OmegaConf
from pydantic.networks import AnyHttpUrl
from requests import Session
@ -22,6 +21,7 @@ from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
@ -33,11 +33,12 @@ from invokeai.backend.model_manager.config import (
)
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
CivitaiMetadataFetch,
HuggingFaceMetadataFetch,
ModelMetadataWithFiles,
RemoteModelFile,
)
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import Chdir, InvokeAILogger
@ -45,6 +46,7 @@ from invokeai.backend.util.devices import choose_precision, choose_torch_device
from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP,
CivitaiModelSource,
HFModelSource,
InstallStatus,
LocalModelSource,
@ -115,7 +117,6 @@ class ModelInstallService(ModelInstallServiceBase):
raise Exception("Attempt to start the installer service twice")
self._start_installer_thread()
self._remove_dangling_install_dirs()
self._migrate_yaml()
self.sync_to_config()
def stop(self, invoker: Optional[Invoker] = None) -> None:
@ -133,14 +134,6 @@ class ModelInstallService(ModelInstallServiceBase):
self._download_cache.clear()
self._running = False
def _put_in_queue(self, job: ModelInstallJob) -> None:
print(f'DEBUG: in _put_in_queue(job={job.id})')
if self._stop_event.is_set():
self.cancel_job(job)
else:
print(f'DEBUG: putting {job.id} into the install queue')
self._install_queue.put(job)
def register_path(
self,
model_path: Union[Path, str],
@ -161,7 +154,10 @@ class ModelInstallService(ModelInstallServiceBase):
model_path = Path(model_path)
config = config or {}
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm)
if self._app_config.skip_model_hash:
config["hash"] = uuid_string()
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
if preferred_name := config.get("name"):
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
@ -203,16 +199,9 @@ class ModelInstallService(ModelInstallServiceBase):
access_token=access_token,
)
elif re.match(r"^https?://[^/]+", source):
# Pull the token from config if it exists and matches the URL
_token = access_token
if _token is None:
for pair in self.app_config.remote_api_tokens or []:
if re.search(pair.url_regex, source):
_token = pair.token
break
source_obj = URLModelSource(
url=AnyHttpUrl(source),
access_token=_token,
access_token=access_token,
)
else:
raise ValueError(f"Unsupported model source: '{source}'")
@ -226,7 +215,9 @@ class ModelInstallService(ModelInstallServiceBase):
if isinstance(source, LocalModelSource):
install_job = self._import_local_model(source, config)
self._put_in_queue(install_job) # synchronously install
self._install_queue.put(install_job) # synchronously install
elif isinstance(source, CivitaiModelSource):
install_job = self._import_from_civitai(source, config)
elif isinstance(source, HFModelSource):
install_job = self._import_from_hf(source, config)
elif isinstance(source, URLModelSource):
@ -293,52 +284,10 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"{len(installed)} new models registered")
self._logger.info("Model installer (re)initialized")
def _migrate_yaml(self) -> None:
db_models = self.record_store.all_models()
try:
yaml = self._get_yaml()
except OSError:
return
yaml_metadata = yaml.pop("__metadata__")
yaml_version = yaml_metadata.get("version")
if yaml_version != "3.0.0":
raise ValueError(
f"Attempted migration of unsupported `models.yaml` v{yaml_version}. Only v3.0.0 is supported. Exiting."
)
self._logger.info(
f"Starting one-time migration of {len(yaml.items())} models from `models.yaml` to database. This may take a few minutes."
)
if len(db_models) == 0 and len(yaml.items()) != 0:
for model_key, stanza in yaml.items():
_, _, model_name = str(model_key).split("/")
model_path = Path(stanza["path"])
if not model_path.is_absolute():
model_path = self._app_config.models_path / model_path
model_path = model_path.resolve()
config: dict[str, Any] = {}
config["name"] = model_name
config["description"] = stanza.get("description")
config["config_path"] = stanza.get("config")
try:
id = self.register_path(model_path=model_path, config=config)
self._logger.info(f"Migrated {model_name} with id {id}")
except Exception as e:
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
yaml_path = self._app_config.model_conf_path
yaml_path.rename(yaml_path.with_suffix(".yaml.bak"))
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback)
search = ModelSearch(on_model_found=callback, config=self._app_config)
self._models_installed.clear()
search.search(scan_dir)
return list(self._models_installed)
@ -350,7 +299,7 @@ class ModelInstallService(ModelInstallServiceBase):
"""Unregister the model. Delete its files only if they are within our models directory."""
model = self.record_store.get_model(key)
models_dir = self.app_config.models_path
model_path = models_dir / Path(model.path) # handle legacy relative model paths
model_path = models_dir / model.path
if model_path.is_relative_to(models_dir):
self.unconditionally_delete(key)
else:
@ -358,11 +307,11 @@ class ModelInstallService(ModelInstallServiceBase):
def unconditionally_delete(self, key: str) -> None: # noqa D102
model = self.record_store.get_model(key)
model_path = self.app_config.models_path / model.path
if model_path.is_dir():
rmtree(model_path)
path = self.app_config.models_path / model.path
if path.is_dir():
rmtree(path)
else:
model_path.unlink()
path.unlink()
self.unregister(key)
def download_and_cache(
@ -411,11 +360,10 @@ class ModelInstallService(ModelInstallServiceBase):
done = True
continue
try:
print(f'DEBUG: _install_next_item() checking for a job to install')
job = self._install_queue.get(timeout=1)
except Empty:
continue
print(f'DEBUG: _install_next_item() got job {job.id}, status={job.status}')
assert job.local_path is not None
try:
if job.cancelled:
@ -433,15 +381,16 @@ class ModelInstallService(ModelInstallServiceBase):
job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
job.config_in["source_api_response"] = job.source_metadata.api_response
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_phrases:
job.config_in["trigger_phrases"] = job.source_metadata.trigger_phrases
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
print(f'DEBUG: _install_next_item() signaling completion for job={job.id}, status={job.status}')
self._signal_job_completed(job)
except InvalidModelConfigException as excp:
@ -501,9 +450,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
for cur_base_model in BaseModelType:
for cur_model_type in ModelType:
models_dir = self._app_config.models_path / Path(cur_base_model.value, cur_model_type.value)
if not models_dir.exists():
continue
models_dir = Path(cur_base_model.value, cur_model_type.value)
installed.update(self.scan_directory(models_dir))
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
@ -522,20 +469,13 @@ class ModelInstallService(ModelInstallServiceBase):
old_path = Path(model.path)
models_dir = self.app_config.models_path
try:
old_path.relative_to(models_dir)
return model
except ValueError:
pass
new_path = models_dir / model.base.value / model.type.value / old_path.name
if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
if not old_path.is_relative_to(models_dir):
return model
new_path = models_dir / model.base.value / model.type.value / model.name
self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path)
model.path = new_path.as_posix()
model.path = new_path.relative_to(models_dir).as_posix()
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
return model
@ -593,16 +533,22 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str:
config = config or {}
info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm)
if self._app_config.skip_model_hash:
config["hash"] = uuid_string()
model_path = model_path.resolve()
info = info or ModelProbe.probe(model_path, config)
model_path = model_path.absolute()
if model_path.is_relative_to(self.app_config.models_path):
model_path = model_path.relative_to(self.app_config.models_path)
info.path = model_path.as_posix()
# add 'main' specific fields
if isinstance(info, CheckpointConfigBase):
# make config relative to our root
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve()
info.config_path = legacy_conf.as_posix()
info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
self.record_store.add_model(info)
return info.key
@ -612,16 +558,6 @@ class ModelInstallService(ModelInstallServiceBase):
self._next_job_id += 1
return id
# --------------------------------------------------------------------------------------------
# Internal functions that manage the old yaml config
# --------------------------------------------------------------------------------------------
def _get_yaml(self) -> DictConfig:
"""Fetch the models.yaml DictConfig for this installation."""
yaml_path = self._app_config.model_conf_path
omegaconf = OmegaConf.load(yaml_path)
assert isinstance(omegaconf, DictConfig)
return omegaconf
@staticmethod
def _guess_variant() -> Optional[ModelRepoVariant]:
"""Guess the best HuggingFace variant type to download."""
@ -637,6 +573,16 @@ class ModelInstallService(ModelInstallServiceBase):
inplace=source.inplace or False,
)
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
if not source.access_token:
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id(
str(source.version_id)
)
assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls(session=self._session)
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests
source.access_token = source.access_token or HfFolder.get_token()
@ -659,7 +605,7 @@ class ModelInstallService(ModelInstallServiceBase):
)
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
# URLs from HuggingFace will be handled specially
# URLs from Civitai or HuggingFace will be handled specially
metadata = None
fetcher = None
try:
@ -667,6 +613,8 @@ class ModelInstallService(ModelInstallServiceBase):
except ValueError:
pass
kwargs: dict[str, Any] = {"session": self._session}
if fetcher is CivitaiMetadataFetch:
kwargs["api_key"] = self._app_config.get_config().civitai_api_key
if fetcher is not None:
metadata = fetcher(**kwargs).from_url(source.url)
self._logger.debug(f"metadata={metadata}")
@ -683,7 +631,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_remote_model(
self,
source: HFModelSource | URLModelSource,
source: HFModelSource | CivitaiModelSource | URLModelSource,
remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]],
@ -791,16 +739,14 @@ class ModelInstallService(ModelInstallServiceBase):
def _download_complete_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"{download_job.source}: model download complete")
print(f'DEBUG: _download_complete_callback(download_job={download_job.source}')
with self._lock:
install_job = self._download_cache.pop(download_job.source, None)
print(f'DEBUG: download_job={download_job.source} / install_job={install_job}')
install_job = self._download_cache[download_job.source]
self._download_cache.pop(download_job.source, None)
# are there any more active jobs left in this task?
if install_job and install_job.downloading and all(x.complete for x in install_job.download_parts):
print(f'DEBUG: setting job {install_job.id} to DOWNLOADS_DONE')
if install_job.downloading and all(x.complete for x in install_job.download_parts):
install_job.status = InstallStatus.DOWNLOADS_DONE
print(f'DEBUG: putting {install_job.id} into the install queue')
self._put_in_queue(install_job)
self._install_queue.put(install_job)
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
@ -842,7 +788,7 @@ class ModelInstallService(ModelInstallServiceBase):
if all(x.in_terminal_state for x in install_job.download_parts):
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
self._put_in_queue(install_job)
self._install_queue.put(install_job)
# ------------------------------------------------------------------------------------------------
# Internal methods that put events on the event bus
@ -899,10 +845,12 @@ class ModelInstallService(ModelInstallServiceBase):
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
self._logger.info(f"{job.source}: model installation was cancelled")
if self._event_bus:
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
self._event_bus.emit_model_install_cancelled(str(job.source))
@staticmethod
def get_fetcher_from_url(url: str):
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
if re.match(r"^https?://civitai.com/", url.lower()):
return CivitaiMetadataFetch
elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'")

View File

@ -68,7 +68,6 @@ class ModelLoadService(ModelLoadServiceBase):
self._emit_load_event(
context_data=context_data,
model_config=model_config,
submodel_type=submodel_type,
)
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
@ -83,7 +82,6 @@ class ModelLoadService(ModelLoadServiceBase):
self._emit_load_event(
context_data=context_data,
model_config=model_config,
submodel_type=submodel_type,
loaded=True,
)
return loaded_model
@ -93,7 +91,6 @@ class ModelLoadService(ModelLoadServiceBase):
context_data: InvocationContextData,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
submodel_type: Optional[SubModelType] = None,
) -> None:
if not self._invoker:
return
@ -105,7 +102,6 @@ class ModelLoadService(ModelLoadServiceBase):
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)
else:
self._invoker.services.events.emit_model_load_completed(
@ -114,5 +110,4 @@ class ModelLoadService(ModelLoadServiceBase):
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)

View File

@ -1,11 +1,15 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
from abc import ABC, abstractmethod
from typing import Optional
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from ..config import InvokeAIAppConfig
from ..download import DownloadQueueServiceBase
@ -66,3 +70,32 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def stop(self, invoker: Invoker) -> None:
pass
@abstractmethod
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass
@abstractmethod
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass
@abstractmethod
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
pass

View File

@ -1,10 +1,14 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
from typing import Optional
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
@ -14,7 +18,7 @@ from ..download import DownloadQueueServiceBase
from ..events.events_base import EventServiceBase
from ..model_install import ModelInstallService, ModelInstallServiceBase
from ..model_load import ModelLoadService, ModelLoadServiceBase
from ..model_records import ModelRecordServiceBase
from ..model_records import ModelRecordServiceBase, UnknownModelException
from .model_manager_base import ModelManagerServiceBase
@ -60,6 +64,56 @@ class ModelManagerService(ModelManagerServiceBase):
if hasattr(service, "stop"):
service.stop(invoker)
def load_model_by_config(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
return self.load.load_model(model_config, submodel_type, context_data)
def load_model_by_key(
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
config = self.store.get_model(key)
return self.load.load_model(config, submodel_type, context_data)
def load_model_by_attr(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
"""
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
This is provided for API compatability with the get_model() method
in the original model manager. However, note that LoadedModel is
not the same as the original ModelInfo that ws returned.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context.
Exceptions: UnknownModelException -- model with this key not known
NotImplementedException -- a model loader was not provided at initialization time
ValueError -- more than one model matches this combination
"""
configs = self.store.search_by_attr(model_name, base_model, model_type)
if len(configs) == 0:
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
elif len(configs) > 1:
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
else:
return self.load.load_model(configs[0], submodel, context_data)
@classmethod
def build_model_manager(
cls,

View File

@ -18,12 +18,7 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import (
ControlAdapterDefaultSettings,
MainModelDefaultSettings,
ModelVariantType,
SchedulerPredictionType,
)
from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType
class DuplicateModelException(Exception):
@ -73,7 +68,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
description: Optional[str] = Field(description="Model description", default=None)
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
@ -84,7 +79,6 @@ class ModelRecordChanges(BaseModelExcludeNull):
description="The prediction type of the model.", default=None
)
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
config_path: Optional[str] = Field(description="Path to config file for model", default=None)
class ModelRecordServiceBase(ABC):
@ -135,17 +129,6 @@ class ModelRecordServiceBase(ABC):
"""
pass
@abstractmethod
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
"""
Retrieve the configuration for the indicated model.
:param hash: Hash of model config to be fetched.
Exceptions: UnknownModelException
"""
pass
@abstractmethod
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default

View File

@ -203,21 +203,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
@ -242,7 +227,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
@ -251,21 +235,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
:param model_format: Filter by model format (e.g. "diffusers") (optional)
:param order_by: Result order
If none of the optional filters are passed, will return all
models in the database.
"""
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
where_clause: list[str] = []
bindings: list[str] = []
if model_name:
@ -284,10 +257,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
SELECT config, strftime('%s',updated_at) FROM models
{where};
""",
tuple(bindings),
)
@ -333,7 +304,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Return a paginated summary listing of each model in the database."""
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Default: "type, base, format, name",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",

View File

@ -1,6 +1,35 @@
from abc import ABC, abstractmethod
from threading import Event
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
class SessionRunnerBase(ABC):
"""
Base class for session runner.
"""
@abstractmethod
def start(self, services: InvocationServices, cancel_event: Event) -> None:
"""Starts the session runner"""
pass
@abstractmethod
def run(self, queue_item: SessionQueueItem) -> None:
"""Runs the session"""
pass
@abstractmethod
def complete(self, queue_item: SessionQueueItem) -> None:
"""Completes the session"""
pass
@abstractmethod
def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None:
"""Runs an already prepared node on the session"""
pass
class SessionProcessorBase(ABC):

View File

@ -2,13 +2,14 @@ import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from typing import Optional
from typing import Callable, Optional, Union
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
@ -16,15 +17,164 @@ from invokeai.app.services.shared.invocation_context import InvocationContextDat
from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker
from .session_processor_base import SessionProcessorBase
from .session_processor_base import SessionProcessorBase, SessionRunnerBase
from .session_processor_common import SessionProcessorStatus
class DefaultSessionRunner(SessionRunnerBase):
"""Processes a single session's invocations"""
def __init__(
self,
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
):
self.on_before_run_node = on_before_run_node
self.on_after_run_node = on_after_run_node
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
"""Start the session runner"""
self.services = services
self.cancel_event = cancel_event
def run(self, queue_item: SessionQueueItem):
"""Run the graph"""
if not queue_item.session:
raise ValueError("Queue item has no session")
# Loop over invocations until the session is complete or canceled
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
# Prepare the next node
invocation = queue_item.session.next()
if invocation is None:
# If there are no more invocations, complete the graph
break
# Build invocation context (the node-facing API
self.run_node(invocation.id, queue_item)
self.complete(queue_item)
def complete(self, queue_item: SessionQueueItem):
"""Complete the graph"""
self.services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
)
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run before a node is executed"""
# Send starting event
self.services.events.emit_invocation_started(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session_id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
)
if self.on_before_run_node is not None:
self.on_before_run_node(invocation, queue_item)
def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run after a node is executed"""
if self.on_after_run_node is not None:
self.on_after_run_node(invocation, queue_item)
def run_node(self, node_id: str, queue_item: SessionQueueItem):
"""Run a single node in the graph"""
# If this error raises a NodeNotFoundError that's handled by the processor
invocation = queue_item.session.execution_graph.get_node(node_id)
try:
self._on_before_run_node(invocation, queue_item)
data = InvocationContextData(
invocation=invocation,
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
queue_item=queue_item,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
context = build_invocation_context(
data=data,
services=self.services,
cancel_event=self.cancel_event,
)
# Invoke the node
outputs = invocation.invoke_internal(context=context, services=self.services)
# Save outputs and history
queue_item.session.complete(invocation.id, outputs)
self._on_after_run_node(invocation, queue_item)
# Send complete event on successful runs
self.services.events.emit_invocation_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=data.source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
queue_item.session.set_node_error(invocation.id, error)
self.services.logger.error(
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
)
self.services.logger.error(error)
# Send error event
self.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
error_type=e.__class__.__name__,
error=error,
)
class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
"""Processes sessions from the session queue"""
def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None:
super().__init__()
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
def start(
self,
invoker: Invoker,
thread_limit: int = 1,
polling_interval: int = 1,
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
) -> None:
self._invoker: Invoker = invoker
self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None
self.on_before_run_session = on_before_run_session
self.on_after_run_session = on_after_run_session
self._resume_event = ThreadEvent()
self._stop_event = ThreadEvent()
@ -59,6 +209,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
"cancel_event": self._cancel_event,
},
)
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
self._thread.start()
def stop(self, *args, **kwargs) -> None:
@ -117,131 +268,34 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
# If we have a on_before_run_session callback, call it
if self.on_before_run_session is not None:
self.on_before_run_session(self._queue_item)
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=self._queue_item.session_id)
# Prepare invocations and take the first
self._invocation = self._queue_item.session.next()
# Run the graph
self.session_runner.run(queue_item=self._queue_item)
# Loop over invocations until the session is complete or canceled
while self._invocation is not None and not cancel_event.is_set():
# get the source node id to provide to clients (the prepared node id is not as useful)
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
# Send starting event
self._invoker.services.events.emit_invocation_started(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
with self._invoker.services.performance_statistics.collect_stats(
self._invocation, self._queue_item.session.id
):
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=self._invocation,
source_invocation_id=source_invocation_id,
queue_item=self._queue_item,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# Invoke the node
outputs = self._invocation.invoke_internal(
context=context, services=self._invoker.services
)
# Save outputs and history
self._queue_item.session.complete(self._invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
self._queue_item.session.set_node_error(self._invocation.id, error)
self._invoker.services.logger.error(
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
)
self._invoker.services.logger.error(error)
# Send error event
self._invoker.services.events.emit_invocation_error(
queue_batch_id=self._queue_item.session_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
error_type=e.__class__.__name__,
error=error,
)
pass
# The session is complete if the all invocations are complete or there was an error
if self._queue_item.session.is_complete() or cancel_event.is_set():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# Set the invocation to None to prepare for the next session
self._invocation = None
else:
# Prepare the next invocation
self._invocation = self._queue_item.session.next()
# If we have a on_after_run_session callback, call it
if self.on_after_run_session is not None:
self.on_after_run_session(self._queue_item)
# The session is complete, immediately poll for next session
self._queue_item = None
@ -275,3 +329,4 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.clear()
self._queue_item = None
self._thread_semaphore.release()
self._invoker.services.logger.debug("Session processor stopped")

View File

@ -1,7 +1,7 @@
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Optional
from PIL.Image import Image
from torch import Tensor
@ -13,16 +13,15 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
"""
@ -300,27 +299,22 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface):
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
def exists(self, key: str) -> bool:
"""Checks if a model exists.
Args:
identifier: The key or ModelField representing the model.
key: The key of the model.
Returns:
True if the model exists, False if not.
"""
if isinstance(identifier, str):
return self._services.model_manager.store.exists(identifier)
return self._services.model_manager.store.exists(key)
return self._services.model_manager.store.exists(identifier.key)
def load(
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""Loads a model.
Args:
identifier: The key or ModelField representing the model.
key: The key of the model.
submodel_type: The submodel of the model to get.
Returns:
@ -330,13 +324,9 @@ class ModelsInterface(InvocationContextInterface):
# The model manager emits events as it loads the model. It needs the context data to build
# the event payloads.
if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
else:
_submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
return self._services.model_manager.load_model_by_key(
key=key, submodel_type=submodel_type, context_data=self._data
)
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
@ -353,29 +343,35 @@ class ModelsInterface(InvocationContextInterface):
Returns:
An object representing the loaded model.
"""
return self._services.model_manager.load_model_by_attr(
model_name=name,
base_model=base,
model_type=type,
submodel=submodel_type,
context_data=self._data,
)
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
if len(configs) == 0:
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
def get_config(self, key: str) -> AnyModelConfig:
"""Gets a model's config.
Args:
identifier: The key or ModelField representing the model.
key: The key of the model.
Returns:
The model's config.
"""
if isinstance(identifier, str):
return self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.store.get_model(key=key)
return self._services.model_manager.store.get_model(identifier.key)
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""Gets a model's metadata, if it has any.
Args:
key: The key of the model.
Returns:
The model's metadata, if it has any.
"""
return self._services.model_manager.store.get_metadata(key=key)
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""Searches for models by path.

View File

@ -4,6 +4,8 @@ from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
class Migration3Callback:
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
@ -13,6 +15,7 @@ class Migration3Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._drop_model_manager_metadata(cursor)
self._recreate_model_config(cursor)
self._migrate_model_config_records(cursor)
def _drop_model_manager_metadata(self, cursor: sqlite3.Cursor) -> None:
"""Drops the `model_manager_metadata` table."""
@ -52,6 +55,12 @@ class Migration3Callback:
"""
)
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
"""After updating the model config table, we repopulate it."""
self._logger.info("Migrating model config records from models.yaml to database")
model_record_migrator = MigrateModelYamlToDb1(self._app_config, self._logger, cursor)
model_record_migrator.migrate()
def build_migration_3(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
"""

View File

@ -0,0 +1,163 @@
# Copyright (c) 2023 Lincoln D. Stein
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
import json
import sqlite3
from logging import Logger
from pathlib import Path
from typing import Optional
from omegaconf import DictConfig, OmegaConf
from pydantic import TypeAdapter
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import (
DuplicateModelException,
UnknownModelException,
)
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelConfigFactory,
ModelType,
)
from invokeai.backend.model_manager.hash import ModelHash
ModelsValidator = TypeAdapter(AnyModelConfig)
class MigrateModelYamlToDb1:
"""
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.5.0).
The class has one externally useful method, migrate(), which scans the
currently models.yaml file and imports all its entries into invokeai.db.
Use this way:
from invokeai.backend.model_manager/migrate_to_db import MigrateModelYamlToDb
MigrateModelYamlToDb().migrate()
"""
config: InvokeAIAppConfig
logger: Logger
cursor: sqlite3.Cursor
def __init__(self, config: InvokeAIAppConfig, logger: Logger, cursor: sqlite3.Cursor = None) -> None:
self.config = config
self.logger = logger
self.cursor = cursor
def get_yaml(self) -> DictConfig:
"""Fetch the models.yaml DictConfig for this installation."""
yaml_path = self.config.model_conf_path
omegaconf = OmegaConf.load(yaml_path)
assert isinstance(omegaconf, DictConfig)
return omegaconf
def migrate(self) -> None:
"""Do the migration from models.yaml to invokeai.db."""
try:
yaml = self.get_yaml()
except OSError:
return
for model_key, stanza in yaml.items():
if model_key == "__metadata__":
assert (
stanza["version"] == "3.0.0"
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
continue
base_type, model_type, model_name = str(model_key).split("/")
try:
hash = ModelHash().hash(self.config.models_path / stanza.path)
except OSError:
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
continue
stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type)
stanza["name"] = model_name
stanza["original_hash"] = hash
stanza["current_hash"] = hash
new_key = hash # deterministic key assignment
# special case for ip adapters, which need the new `image_encoder_model_id` field
if stanza["type"] == ModelType.IPAdapter:
try:
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
self.config.models_path / stanza.path
)
except OSError:
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
continue
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
try:
if original_record := self._search_by_path(stanza.path):
key = original_record.key
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
self._update_model(key, new_config)
else:
self.logger.info(f"Adding model {model_name} with key {new_key}")
self._add_model(new_key, new_config)
except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database")
except UnknownModelException:
self.logger.warning(f"Model at {stanza.path} could not be found in database")
def _search_by_path(self, path: Path) -> Optional[AnyModelConfig]:
self.cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self.cursor.fetchall()]
return results[0] if results else None
def _update_model(self, key: str, config: AnyModelConfig) -> None:
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = record.model_dump_json() # and turn it into a json string.
self.cursor.execute(
"""--sql
UPDATE model_config
SET
config=?
WHERE id=?;
""",
(json_serialized, key),
)
if self.cursor.rowcount == 0:
raise UnknownModelException("model not found")
def _add_model(self, key: str, config: AnyModelConfig) -> None:
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = record.model_dump_json() # and turn it into a json string.
try:
self.cursor.execute(
"""--sql
INSERT INTO model_config (
id,
original_hash,
config
)
VALUES (?,?,?);
""",
(
key,
record.hash,
json_serialized,
),
)
except sqlite3.IntegrityError as exc:
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
def _get_image_encoder_model_id(self, model_path: Path) -> str:
with open(model_path / "image_encoder.txt") as f:
encoder = f.read()
return encoder.strip()

View File

@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
See :class:`Migration` for an example.
"""
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
def __call__(self, cursor: sqlite3.Cursor) -> None:
...
class MigrationError(RuntimeError):

View File

@ -8,8 +8,3 @@ class UrlServiceBase(ABC):
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets the URL for an image or thumbnail."""
pass
@abstractmethod
def get_model_image_url(self, model_key: str) -> str:
"""Gets the URL for a model image"""
pass

View File

@ -4,9 +4,8 @@ from .urls_base import UrlServiceBase
class LocalUrlService(UrlServiceBase):
def __init__(self, base_url: str = "api/v1", base_url_v2: str = "api/v2"):
def __init__(self, base_url: str = "api/v1"):
self._base_url = base_url
self._base_url_v2 = base_url_v2
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
image_basename = os.path.basename(image_name)
@ -16,6 +15,3 @@ class LocalUrlService(UrlServiceBase):
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
return f"{self._base_url}/images/i/{image_basename}/full"
def get_model_image_url(self, model_key: str) -> str:
return f"{self._base_url_v2}/models/i/{model_key}/image"

View File

@ -22,7 +22,7 @@ def generate_ti_list(
for trigger in extract_ti_triggers_from_prompt(prompt):
name_or_key = trigger[1:-1]
try:
loaded_model = context.models.load(name_or_key)
loaded_model = context.models.load(key=name_or_key)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)
assert loaded_model.config.base == base

View File

@ -13,11 +13,9 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.util import download_with_progress_bar
config = InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger(config=config)
DEPTH_ANYTHING_MODELS = {
"large": {
@ -56,9 +54,8 @@ class DepthAnythingDetector:
def __init__(self) -> None:
self.model = None
self.model_size: Union[Literal["large", "base", "small"], None] = None
self.device = choose_torch_device()
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
def load_model(self, model_size=Literal["large", "base", "small"]):
DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"])
if not DEPTH_ANYTHING_MODEL_PATH.exists():
download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH)
@ -74,6 +71,8 @@ class DepthAnythingDetector:
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
case "large":
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
case _:
raise TypeError("Not a supported model")
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
self.model.eval()
@ -81,20 +80,20 @@ class DepthAnythingDetector:
self.model.to(choose_torch_device())
return self.model
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
if not self.model:
logger.warn("DepthAnything model was not loaded. Returning original image")
return image
def to(self, device):
self.model.to(device)
return self
np_image = np.array(image, dtype=np.uint8)
np_image = np_image[:, :, ::-1] / 255.0
def __call__(self, image, resolution=512, offload=False):
image = np.array(image, dtype=np.uint8)
image = image[:, :, ::-1] / 255.0
image_height, image_width = np_image.shape[:2]
np_image = transform({"image": np_image})["image"]
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device())
image_height, image_width = image.shape[:2]
image = transform({"image": image})["image"]
image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
with torch.no_grad():
depth = self.model(tensor_image)
depth = self.model(image)
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
@ -104,4 +103,7 @@ class DepthAnythingDetector:
new_height = int(image_height * (resolution / image_width))
depth_map = depth_map.resize((resolution, new_height))
if offload:
del self.model
return depth_map

View File

@ -11,6 +11,17 @@ def check_invokeai_root(config: InvokeAIAppConfig):
try:
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
assert config.models_path.exists(), f"{config.models_path} not found"
if not config.ignore_missing_core_models:
for model in [
"CLIP-ViT-bigG-14-laion2B-39B-b160k",
"bert-base-uncased",
"clip-vit-large-patch14",
"sd-vae-ft-mse",
"stable-diffusion-2-clip",
"stable-diffusion-safety-checker",
]:
path = config.models_path / f"core/convert/{model}"
assert path.exists(), f"{path} is missing"
except Exception as e:
print()
print(f"An exception has occurred: {str(e)}")
@ -21,5 +32,10 @@ def check_invokeai_root(config: InvokeAIAppConfig):
print(
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
)
print(
'** (To skip this check completely, add "--ignore_missing_core_models" to your CLI args. Not installing '
"these core models will prevent the loading of some or all .safetensors and .ckpt files. However, you can "
"always come back and install these core models in the future.)"
)
input("Press any key to continue...")
sys.exit(0)

View File

@ -19,6 +19,7 @@ from invokeai.app.services.model_install import (
ModelInstallService,
ModelInstallServiceBase,
)
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import (
@ -38,7 +39,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService
logger = InvokeAILogger.get_logger(config=app_config)
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
db = init_db(config=app_config, logger=logger, image_files=image_files)
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
return obj

View File

@ -17,7 +17,7 @@ import warnings
from argparse import Namespace
from enum import Enum
from pathlib import Path
from shutil import copy, get_terminal_size, move
from shutil import get_terminal_size
from typing import Any, Optional, Set, Tuple, Type, get_args, get_type_hints
from urllib import request
@ -25,20 +25,20 @@ import npyscreen
import psutil
import torch
import transformers
from diffusers import ModelMixin
from diffusers import AutoencoderKL, ModelMixin
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder
from huggingface_hub import login as hf_hub_login
from omegaconf import DictConfig, OmegaConf
from pydantic.error_wrappers import ValidationError
from tqdm import tqdm
from transformers import AutoFeatureExtractor
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
from invokeai.backend.model_manager import ModelType
from invokeai.backend.model_manager import BaseModelType, ModelType
from invokeai.backend.util import choose_precision, choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm
@ -210,15 +210,51 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
print(traceback.format_exc(), file=sys.stderr)
def download_safety_checker():
def download_conversion_models():
target_dir = config.models_path / "core/convert"
kwargs = {} # for future use
try:
logger.info("Downloading core tokenizers and text encoders")
# bert
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
bert.save_pretrained(target_dir / "bert-base-uncased", safe_serialization=True)
# sd-1
repo_id = "openai/clip-vit-large-patch14"
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / "clip-vit-large-patch14")
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / "clip-vit-large-patch14")
# sd-2
repo_id = "stabilityai/stable-diffusion-2"
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "tokenizer", safe_serialization=True)
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "text_encoder", safe_serialization=True)
# sd-xl - tokenizer_2
repo_id = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
_, model_name = repo_id.split("/")
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
pipeline = CLIPTextConfig.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
# VAE
logger.info("Downloading stable diffusion VAE")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", **kwargs)
vae.save_pretrained(target_dir / "sd-vae-ft-mse", safe_serialization=True)
# safety checking
logger.info("Downloading safety checker")
repo_id = "CompVis/stable-diffusion-safety-checker"
pipeline = AutoFeatureExtractor.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
except KeyboardInterrupt:
@ -271,7 +307,7 @@ def download_lama():
def download_support_models() -> None:
download_realesrgan()
download_lama()
download_safety_checker()
download_conversion_models()
# -------------------------------------
@ -708,7 +744,12 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
dest = root / "models"
dest.mkdir(parents=True, exist_ok=True)
for model_base in BaseModelType:
for model_type in ModelType:
path = dest / model_base.value / model_type.value
path.mkdir(parents=True, exist_ok=True)
path = dest / "core"
path.mkdir(parents=True, exist_ok=True)
# -------------------------------------
@ -888,10 +929,6 @@ def main() -> None:
errors = set()
FORCE_FULL_PRECISION = opt.full_precision # FIXME global
new_init_file = config.root_path / "invokeai.yaml"
backup_init_file = new_init_file.with_suffix(".bak")
if new_init_file.exists():
copy(new_init_file, backup_init_file)
try:
# if we do a root migration/upgrade, then we are keeping previous
@ -906,6 +943,7 @@ def main() -> None:
install_helper = InstallHelper(config, logger)
models_to_download = default_user_selections(opt, install_helper)
new_init_file = config.root_path / "invokeai.yaml"
if opt.yes_to_all:
write_default_options(opt, new_init_file)
@ -937,17 +975,8 @@ def main() -> None:
input("Press any key to continue...")
except WindowTooSmallException as e:
logger.error(str(e))
if backup_init_file.exists():
move(backup_init_file, new_init_file)
except KeyboardInterrupt:
print("\nGoodbye! Come back soon.")
if backup_init_file.exists():
move(backup_init_file, new_init_file)
except Exception:
print("An error occurred during installation.")
if backup_init_file.exists():
move(backup_init_file, new_init_file)
print(traceback.format_exc(), file=sys.stderr)
# -------------------------------------

View File

@ -22,7 +22,7 @@ Validation errors will raise an InvalidModelConfigException error.
import time
from enum import Enum
from typing import Literal, Optional, Type, TypeAlias, Union
from typing import Literal, Optional, Type, Union
import torch
from diffusers.models.modeling_utils import ModelMixin
@ -129,27 +129,16 @@ class ModelSourceType(str, Enum):
Path = "path"
Url = "url"
HFRepoID = "hf_repo_id"
CivitAI = "civitai"
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
class MainModelDefaultSettings(BaseModel):
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
scheduler: SCHEDULER_NAME_VALUES | None = Field(default=None, description="Default scheduler for this model")
steps: int | None = Field(default=None, gt=0, description="Default number of steps for this model")
cfg_scale: float | None = Field(default=None, ge=1, description="Default CFG Scale for this model")
cfg_rescale_multiplier: float | None = Field(
default=None, ge=0, lt=1, description="Default CFG Rescale Multiplier for this model"
)
width: int | None = Field(default=None, multiple_of=8, ge=64, description="Default width for this model")
height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model")
class ControlAdapterDefaultSettings(BaseModel):
# This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
preprocessor: str | None
class ModelDefaultSettings(BaseModel):
vae: str | None
vae_precision: str | None
scheduler: SCHEDULER_NAME_VALUES | None
steps: int | None
cfg_scale: float | None
cfg_rescale_multiplier: float | None
class ModelConfigBase(BaseModel):
@ -168,7 +157,10 @@ class ModelConfigBase(BaseModel):
source_api_response: Optional[str] = Field(
description="The original API response from the source, as stringified JSON.", default=None
)
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
@ -194,14 +186,10 @@ class DiffusersConfigBase(ModelConfigBase):
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
class LoRAConfigBase(ModelConfigBase):
type: Literal[ModelType.LoRA] = ModelType.LoRA
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
class LoRALyCORISConfig(LoRAConfigBase):
class LoRALyCORISConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
type: Literal[ModelType.LoRA] = ModelType.LoRA
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@staticmethod
@ -209,9 +197,10 @@ class LoRALyCORISConfig(LoRAConfigBase):
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
class LoRADiffusersConfig(LoRAConfigBase):
class LoRADiffusersConfig(ModelConfigBase):
"""Model config for LoRA/Diffusers models."""
type: Literal[ModelType.LoRA] = ModelType.LoRA
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
@ -241,13 +230,7 @@ class VAEDiffusersConfig(ModelConfigBase):
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
class ControlAdapterConfigBase(BaseModel):
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None
)
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase):
class ControlNetDiffusersConfig(DiffusersConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
@ -258,7 +241,7 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase):
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase):
class ControlNetCheckpointConfig(CheckpointConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
@ -291,17 +274,10 @@ class TextualInversionFolderConfig(ModelConfigBase):
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
class MainConfigBase(ModelConfigBase):
type: Literal[ModelType.Main] = ModelType.Main
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
class MainCheckpointConfig(CheckpointConfigBase):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
@ -311,9 +287,11 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
class MainDiffusersConfig(DiffusersConfigBase):
"""Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
@ -332,7 +310,7 @@ class IPAdapterConfig(ModelConfigBase):
class CLIPVisionDiffusersConfig(ModelConfigBase):
"""Model config for CLIPVision."""
"""Model config for ClipVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers]
@ -342,7 +320,7 @@ class CLIPVisionDiffusersConfig(ModelConfigBase):
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
class T2IAdapterConfig(ModelConfigBase, ControlAdapterConfigBase):
class T2IAdapterConfig(ModelConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
@ -394,7 +372,6 @@ AnyModelConfig = Annotated[
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, ControlAdapterDefaultSettings]
class ModelConfigFactory(object):

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,12 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Fast hashing of diffusers and checkpoint-style models.
Usage:
from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f'
"""
import hashlib
import os
@ -7,9 +15,9 @@ from typing import Callable, Literal, Optional, Union
from blake3 import blake3
from invokeai.app.util.misc import uuid_string
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
HASHING_ALGORITHMS = Literal[
ALGORITHM = Literal[
"md5",
"sha1",
"sha224",
@ -25,15 +33,12 @@ HASHING_ALGORITHMS = Literal[
"shake_128",
"shake_256",
"blake3",
"blake3_single",
"random",
]
MODEL_FILE_EXTENSIONS = (".ckpt", ".safetensors", ".bin", ".pt", ".pth")
class ModelHash:
"""
Creates a hash of a model using a specified algorithm. The hash is prefixed by the algorithm used.
Creates a hash of a model using a specified algorithm.
Args:
algorithm: Hashing algorithm to use. Defaults to BLAKE3.
@ -48,29 +53,20 @@ class ModelHash:
The final hash is computed by hashing the hashes of all model files in the directory using BLAKE3, ensuring
that directory hashes are never weaker than the file hashes.
A convenience algorithm choice of "random" is also available, which returns a random string. This is not a hash.
Usage:
```py
# BLAKE3 hash
ModelHash().hash("path/to/some/model.safetensors") # "blake3:ce3f0c5f3c05d119f4a5dcaf209b50d3149046a0d3a9adee9fed4c83cad6b4d0"
ModelHash().hash("path/to/some/model.safetensors")
# MD5
ModelHash("md5").hash("path/to/model/dir/") # "md5:a0cd925fc063f98dbf029eee315060c3"
ModelHash("md5").hash("path/to/model/dir/")
```
"""
def __init__(
self, algorithm: HASHING_ALGORITHMS = "blake3", file_filter: Optional[Callable[[str], bool]] = None
) -> None:
self.algorithm: HASHING_ALGORITHMS = algorithm
def __init__(self, algorithm: ALGORITHM = "blake3", file_filter: Optional[Callable[[str], bool]] = None) -> None:
if algorithm == "blake3":
self._hash_file = self._blake3
elif algorithm == "blake3_single":
self._hash_file = self._blake3_single
elif algorithm in hashlib.algorithms_available:
self._hash_file = self._get_hashlib(algorithm)
elif algorithm == "random":
self._hash_file = self._random
else:
raise ValueError(f"Algorithm {algorithm} not available")
@ -91,12 +87,10 @@ class ModelHash:
"""
model_path = Path(model_path)
# blake3_single is a single-threaded version of blake3, prefix should still be "blake3:"
prefix = self._get_prefix(self.algorithm)
if model_path.is_file():
return prefix + self._hash_file(model_path)
return self._hash_file(model_path)
elif model_path.is_dir():
return prefix + self._hash_dir(model_path)
return self._hash_dir(model_path)
else:
raise OSError(f"Not a valid file or directory: {model_path}")
@ -120,7 +114,6 @@ class ModelHash:
composite_hasher = blake3()
for h in component_hashes:
composite_hasher.update(h.encode("utf-8"))
return composite_hasher.hexdigest()
@staticmethod
@ -144,7 +137,7 @@ class ModelHash:
@staticmethod
def _blake3(file_path: Path) -> str:
"""Hashes a file using BLAKE3, using parallelized and memory-mapped I/O to avoid reading the entire file into memory.
"""Hashes a file using BLAKE3
Args:
file_path: Path to the file to hash
@ -157,21 +150,7 @@ class ModelHash:
return file_hasher.hexdigest()
@staticmethod
def _blake3_single(file_path: Path) -> str:
"""Hashes a file using BLAKE3, without parallelism. Suitable for spinning hard drives.
Args:
file_path: Path to the file to hash
Returns:
Hexdigest of the hash of the file
"""
file_hasher = blake3()
file_hasher.update_mmap(file_path)
return file_hasher.hexdigest()
@staticmethod
def _get_hashlib(algorithm: HASHING_ALGORITHMS) -> Callable[[Path], str]:
def _get_hashlib(algorithm: ALGORITHM) -> Callable[[Path], str]:
"""Factory function that returns a function to hash a file with the given algorithm.
Args:
@ -193,13 +172,6 @@ class ModelHash:
return hashlib_hasher
@staticmethod
def _random(_file_path: Path) -> str:
"""Returns a random string. This is not a hash.
The string is a UUID, hashed with BLAKE3 to ensure that it is unique."""
return blake3(uuid_string().encode()).hexdigest()
@staticmethod
def _default_file_filter(file_path: str) -> bool:
"""A default file filter that only includes files with the following extensions: .ckpt, .safetensors, .bin, .pt, .pth
@ -211,9 +183,3 @@ class ModelHash:
True if the file matches the given extensions, otherwise False
"""
return file_path.endswith(MODEL_FILE_EXTENSIONS)
@staticmethod
def _get_prefix(algorithm: HASHING_ALGORITHMS) -> str:
"""Return the prefix for the given algorithm, e.g. \"blake3:\" or \"md5:\"."""
# blake3_single is a single-threaded version of blake3, prefix should still be "blake3:"
return "blake3:" if algorithm == "blake3_single" else f"{algorithm}:"

View File

@ -19,6 +19,7 @@ context. Use like this:
"""
import gc
import logging
import math
import sys
import time
@ -91,7 +92,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._execution_device: torch.device = execution_device
self._storage_device: torch.device = storage_device
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage
self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG
# used for stats collection
self._stats: Optional[CacheStats] = None
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}

View File

@ -60,7 +60,7 @@ class ModelLoaderRegistryBase(ABC):
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
class ModelLoaderRegistry(ModelLoaderRegistryBase):
class ModelLoaderRegistry:
"""
This class allows model loaders to register their type, base and format.
"""

View File

@ -3,6 +3,9 @@
from pathlib import Path
import torch
from safetensors.torch import load_file as safetensors_load_file
from invokeai.backend.model_manager import (
AnyModelConfig,
BaseModelType,
@ -34,25 +37,27 @@ class ControlNetLoader(GenericDiffusersLoader):
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
assert isinstance(config, CheckpointConfigBase)
config_file = config.config_path
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"ControlNet conversion not supported for model type: {config.base}")
else:
assert isinstance(config, CheckpointConfigBase)
config_file = config.config_path
image_size = (
512
if config.base == BaseModelType.StableDiffusion1
else 768
if config.base == BaseModelType.StableDiffusion2
else 1024
if model_path.suffix == ".safetensors":
checkpoint = safetensors_load_file(model_path, device="cpu")
else:
checkpoint = torch.load(model_path, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
convert_controlnet_to_diffusers(
model_path,
output_path,
original_config_file=self._app_config.root_path / config_file,
image_size=512,
scan_needed=True,
from_safetensors=model_path.suffix == ".safetensors",
)
self._logger.info(f"Converting {model_path} to diffusers format")
with open(self._app_config.root_path / config_file, "r") as config_stream:
convert_controlnet_to_diffusers(
model_path,
output_path,
original_config_file=config_stream,
image_size=image_size,
precision=self._torch_dtype,
from_safetensors=model_path.suffix == ".safetensors",
)
return output_path

View File

@ -24,7 +24,7 @@ from .. import ModelLoader, ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
class LoRALoader(ModelLoader):
class LoraLoader(ModelLoader):
"""Class to load LoRA models."""
# We cheat a little bit to get access to the model base

View File

@ -4,6 +4,9 @@
from pathlib import Path
from typing import Optional
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@ -11,7 +14,7 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelRepoVariant,
ModelType,
SchedulerPredictionType,
ModelVariantType,
SubModelType,
)
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig
@ -65,31 +68,27 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
assert isinstance(config, MainCheckpointConfig)
variant = config.variant
base = config.base
pipeline_class = (
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
)
config_file = config.config_path
prediction_type = config.prediction_type.value
upcast_attention = config.upcast_attention
image_size = (
1024
if base == BaseModelType.StableDiffusionXL
else 768
if config.prediction_type == SchedulerPredictionType.VPrediction and base == BaseModelType.StableDiffusion2
else 512
)
self._logger.info(f"Converting {model_path} to diffusers format")
convert_ckpt_to_diffusers(
model_path,
output_path,
model_type=self.model_base_to_model_type[base],
model_version=base,
model_variant=variant,
original_config_file=self._app_config.root_path / config_file,
extract_ema=True,
scan_needed=True,
pipeline_class=pipeline_class,
from_safetensors=model_path.suffix == ".safetensors",
precision=self._torch_dtype,
prediction_type=prediction_type,
image_size=image_size,
upcast_attention=upcast_attention,
load_safety_checker=False,
)
return output_path

View File

@ -23,7 +23,7 @@ from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class VAELoader(GenericDiffusersLoader):
class VaeLoader(GenericDiffusersLoader):
"""Class to load VAE models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
@ -57,12 +57,12 @@ class VAELoader(GenericDiffusersLoader):
ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
assert isinstance(ckpt_config, DictConfig)
self._logger.info(f"Converting {model_path} to diffusers format")
vae_model = convert_ldm_vae_to_diffusers(
checkpoint=checkpoint,
vae_config=ckpt_config,
image_size=512,
precision=self._torch_dtype,
)
vae_model.to(self._torch_dtype) # set precision appropriately
vae_model.save_pretrained(output_path, safe_serialization=True)
return output_path

View File

@ -8,19 +8,23 @@ from invokeai.backend.model_manager.metadata import(
CommercialUsage,
LicenseRestrictions,
HuggingFaceMetadata,
CivitaiMetadata,
)
from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
data = HuggingFaceMetadataFetch().from_id("<REPO_ID>")
assert isinstance(data, HuggingFaceMetadata)
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
assert isinstance(data, CivitaiMetadata)
if data.allow_commercial_use:
print("Commercial use of this model is allowed")
"""
from .fetch import HuggingFaceMetadataFetch, ModelMetadataFetchBase
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch, ModelMetadataFetchBase
from .metadata_base import (
AnyModelRepoMetadata,
AnyModelRepoMetadataValidator,
BaseMetadata,
CivitaiMetadata,
HuggingFaceMetadata,
ModelMetadataWithFiles,
RemoteModelFile,
@ -30,6 +34,8 @@ from .metadata_base import (
__all__ = [
"AnyModelRepoMetadata",
"AnyModelRepoMetadataValidator",
"CivitaiMetadata",
"CivitaiMetadataFetch",
"HuggingFaceMetadata",
"HuggingFaceMetadataFetch",
"ModelMetadataFetchBase",

View File

@ -3,14 +3,19 @@ Initialization file for invokeai.backend.model_manager.metadata.fetch
Usage:
from invokeai.backend.model_manager.metadata.fetch import (
CivitaiMetadataFetch,
HuggingFaceMetadataFetch,
)
from invokeai.backend.model_manager.metadata import CivitaiMetadata
data = HuggingFaceMetadataFetch().from_id("<repo_id>")
assert isinstance(data, HuggingFaceMetadata)
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
assert isinstance(data, CivitaiMetadata)
if data.allow_commercial_use:
print("Commercial use of this model is allowed")
"""
from .civitai import CivitaiMetadataFetch
from .fetch_base import ModelMetadataFetchBase
from .huggingface import HuggingFaceMetadataFetch
__all__ = ["ModelMetadataFetchBase", "HuggingFaceMetadataFetch"]
__all__ = ["ModelMetadataFetchBase", "CivitaiMetadataFetch", "HuggingFaceMetadataFetch"]

View File

@ -0,0 +1,188 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
This module fetches model metadata objects from the Civitai model repository.
In addition to the `from_url()` and `from_id()` methods inherited from the
`ModelMetadataFetchBase` base class.
Civitai has two separate ID spaces: a model ID and a version ID. The
version ID corresponds to a specific model, and is the ID accepted by
`from_id()`. The model ID corresponds to a family of related models,
such as different training checkpoints or 16 vs 32-bit versions. The
`from_civitai_modelid()` method will accept a model ID and return the
metadata from the default version within this model set. The default
version is the same as what the user sees when they click on a model's
thumbnail.
Usage:
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
fetcher = CivitaiMetadataFetch()
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
print(metadata.trained_words)
"""
import json
import re
from pathlib import Path
from typing import Any, Optional
import requests
from pydantic import TypeAdapter, ValidationError
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from invokeai.backend.model_manager.config import ModelRepoVariant
from ..metadata_base import (
AnyModelRepoMetadata,
CivitaiMetadata,
RemoteModelFile,
UnknownMetadataException,
)
from .fetch_base import ModelMetadataFetchBase
CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)"
CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)"
CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
StringSetAdapter = TypeAdapter(set[str])
class CivitaiMetadataFetch(ModelMetadataFetchBase):
"""Fetch model metadata from Civitai."""
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
"""
Initialize the fetcher with an optional requests.sessions.Session object.
By providing a configurable Session object, we can support unit tests on
this module without an internet connection.
"""
self._requests = session or requests.Session()
self._api_key = api_key
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
"""
Given a URL to a CivitAI model or version page, return a ModelMetadata object.
In the event that the URL points to a model page without the particular version
indicated, the default model version is returned. Otherwise, the requested version
is returned.
"""
if match := re.match(CIVITAI_VERSION_PAGE_RE, str(url), re.IGNORECASE):
model_id = match.group(1)
version_id = match.group(2)
return self.from_civitai_versionid(int(version_id), int(model_id))
elif match := re.match(CIVITAI_MODEL_PAGE_RE, str(url), re.IGNORECASE):
model_id = match.group(1)
return self.from_civitai_modelid(int(model_id))
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url), re.IGNORECASE):
version_id = match.group(1)
return self.from_civitai_versionid(int(version_id))
raise UnknownMetadataException("The url '{url}' does not match any known Civitai URL patterns")
def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata:
"""
Given a Civitai model version ID, return a ModelRepoMetadata object.
:param id: An ID.
:param variant: A model variant from the ModelRepoVariant enum (currently ignored)
May raise an `UnknownMetadataException`.
"""
return self.from_civitai_versionid(int(id))
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
"""
Return metadata from the default version of the indicated model.
May raise an `UnknownMetadataException`.
"""
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
return self._from_api_response(model_json)
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
try:
version_id = version_id or api_response["modelVersions"][0]["id"]
except TypeError as excp:
raise UnknownMetadataException from excp
# loop till we find the section containing the version requested
version_sections = [x for x in api_response["modelVersions"] if x["id"] == version_id]
if not version_sections:
raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
version_json = version_sections[0]
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
primary = [x for x in version_json["files"] if x.get("primary")]
assert len(primary) == 1
primary_file = primary[0]
url = primary_file["downloadUrl"]
if "?" not in url: # work around apparent bug in civitai api
metadata_string = ""
for key, value in primary_file["metadata"].items():
if not value:
continue
metadata_string += f"&{key}={value}"
url = url + f"?type={primary_file['type']}{metadata_string}"
model_files = [
RemoteModelFile(
url=self._get_url_with_api_key(url),
path=Path(primary_file["name"]),
size=int(primary_file["sizeKB"] * 1024),
sha256=primary_file["hashes"]["SHA256"],
)
]
try:
trigger_phrases = StringSetAdapter.validate_python(version_json.get("trainedWords"))
except ValidationError:
trigger_phrases: set[str] = set()
return CivitaiMetadata(
name=version_json["name"],
files=model_files,
trigger_phrases=trigger_phrases,
api_response=json.dumps(version_json),
)
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
"""
Return a CivitaiMetadata object given a model version id.
May raise an `UnknownMetadataException`.
"""
if model_id is None:
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
version = self._requests.get(self._get_url_with_api_key(version_url)).json()
if error := version.get("error"):
raise UnknownMetadataException(error)
model_id = version["modelId"]
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
return self._from_api_response(model_json, version_id)
@classmethod
def from_json(cls, json: str) -> CivitaiMetadata:
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = CivitaiMetadata.model_validate_json(json)
return metadata
def _get_url_with_api_key(self, url: str) -> str:
if not self._api_key:
return url
if "?" in url:
return f"{url}&token={self._api_key}"
return f"{url}?token={self._api_key}"

View File

@ -5,10 +5,11 @@ This module is the base class for subclasses that fetch metadata from model repo
Usage:
from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.fetch import CivitAIMetadataFetch
data = HuggingFaceMetadataFetch().from_id("<REPO_ID>")
assert isinstance(data, HuggingFaceMetadata)
fetcher = CivitaiMetadataFetch()
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
print(metadata.trained_words)
"""
from abc import ABC, abstractmethod

View File

@ -90,35 +90,8 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
)
)
# diffusers models have a `model_index.json` or `config.json` file
is_diffusers = any(str(f.url).endswith(("model_index.json", "config.json")) for f in files)
# These URLs will be exposed to the user - I think these are the only file types we fully support
ckpt_urls = (
None
if is_diffusers
else [
f.url
for f in files
if str(f.url).endswith(
(
".safetensors",
".bin",
".pth",
".pt",
".ckpt",
)
)
]
)
return HuggingFaceMetadata(
id=model_info.id,
name=name,
files=files,
api_response=json.dumps(model_info.__dict__, default=str),
is_diffusers=is_diffusers,
ckpt_urls=ckpt_urls,
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
)
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:

View File

@ -78,16 +78,20 @@ class ModelMetadataWithFiles(ModelMetadataBase):
return self.files
class CivitaiMetadata(ModelMetadataWithFiles):
"""Extended metadata fields provided by Civitai."""
type: Literal["civitai"] = "civitai"
trigger_phrases: set[str] = Field(description="Trigger phrases extracted from the API response")
api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None)
class HuggingFaceMetadata(ModelMetadataWithFiles):
"""Extended metadata fields provided by HuggingFace."""
type: Literal["huggingface"] = "huggingface"
id: str = Field(description="The HF model id")
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model", default=False)
ckpt_urls: Optional[List[AnyHttpUrl]] = Field(
description="URLs for all checkpoint format models in the metadata", default=None
)
def download_urls(
self,
@ -126,5 +130,5 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
return [x for x in self.files if x.path in paths]
AnyModelRepoMetadata = Annotated[Union[BaseMetadata, HuggingFaceMetadata], Field(discriminator="type")]
AnyModelRepoMetadata = Annotated[Union[BaseMetadata, HuggingFaceMetadata, CivitaiMetadata], Field(discriminator="type")]
AnyModelRepoMetadataValidator = TypeAdapter(AnyModelRepoMetadata)

View File

@ -9,15 +9,12 @@ from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.util.util import SilenceWarnings
from .config import (
AnyModelConfig,
BaseModelType,
ControlAdapterDefaultSettings,
InvalidModelConfigException,
MainModelDefaultSettings,
ModelConfigFactory,
ModelFormat,
ModelRepoVariant,
@ -26,6 +23,7 @@ from .config import (
ModelVariantType,
SchedulerPredictionType,
)
from .hash import ModelHash
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
CkptType = Dict[str, Any]
@ -114,7 +112,9 @@ class ModelProbe(object):
@classmethod
def probe(
cls, model_path: Path, fields: Optional[Dict[str, Any]] = None, hash_algo: HASHING_ALGORITHMS = "blake3"
cls,
model_path: Path,
fields: Optional[Dict[str, Any]] = None,
) -> AnyModelConfig:
"""
Probe the model at model_path and return its configuration record.
@ -128,16 +128,13 @@ class ModelProbe(object):
if fields is None:
fields = {}
model_path = model_path.resolve()
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None
model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
if not model_type:
if format_type is ModelFormat.Diffusers:
model_type = cls.get_model_type_from_folder(model_path)
else:
model_type = cls.get_model_type_from_checkpoint(model_path)
model_type = None
if format_type is ModelFormat.Diffusers:
model_type = cls.get_model_type_from_folder(model_path)
else:
model_type = cls.get_model_type_from_checkpoint(model_path)
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type)
@ -157,18 +154,10 @@ class ModelProbe(object):
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
fields["description"] = (
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
fields["default_settings"] = fields.get("default_settings")
if not fields["default_settings"]:
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter}:
fields["default_settings"] = get_default_settings_controlnet_t2i_adapter(fields["name"])
elif fields["type"] is ModelType.Main:
fields["default_settings"] = get_default_settings_main(fields["base"])
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
@ -320,7 +309,7 @@ class ModelProbe(object):
@classmethod
def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType:
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model_path.name, model_path)
model = torch.load(model_path)
assert isinstance(model, dict)
@ -340,43 +329,6 @@ class ModelProbe(object):
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
# Probing utilities
MODEL_NAME_TO_PREPROCESSOR = {
"canny": "canny_image_processor",
"mlsd": "mlsd_image_processor",
"depth": "depth_anything_image_processor",
"bae": "normalbae_image_processor",
"normal": "normalbae_image_processor",
"sketch": "pidi_image_processor",
"scribble": "lineart_image_processor",
"lineart": "lineart_image_processor",
"lineart_anime": "lineart_anime_image_processor",
"softedge": "hed_image_processor",
"shuffle": "content_shuffle_image_processor",
"pose": "dw_openpose_image_processor",
"mediapipe": "mediapipe_face_processor",
"pidi": "pidi_image_processor",
"zoe": "zoe_depth_image_processor",
"color": "color_map_image_processor",
}
def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
if k in model_name:
return ControlAdapterDefaultSettings(preprocessor=v)
return None
def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
return MainModelDefaultSettings(width=512, height=512)
elif model_base is BaseModelType.StableDiffusionXL:
return MainModelDefaultSettings(width=1024, height=1024)
# We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models.
return None
# ##################################################3
# Checkpoint probing
# ##################################################3

View File

@ -4,75 +4,121 @@ Abstract base class and implementation for recursive directory search for models
Example usage:
```
from invokeai.backend.model_manager import ModelSearch, ModelProbe
from invokeai.backend.model_manager import ModelSearch, ModelProbe
def find_main_models(model: Path) -> bool:
info = ModelProbe.probe(model)
if info.model_type == 'main' and info.base_type == 'sd-1':
return True
else:
return False
def find_main_models(model: Path) -> bool:
info = ModelProbe.probe(model)
if info.model_type == 'main' and info.base_type == 'sd-1':
return True
else:
return False
search = ModelSearch(on_model_found=report_it)
found = search.search('/tmp/models')
print(found) # list of matching model paths
print(search.stats) # search stats
search = ModelSearch(on_model_found=report_it)
found = search.search('/tmp/models')
print(found) # list of matching model paths
print(search.stats) # search stats
```
"""
import os
from dataclasses import dataclass
from abc import ABC, abstractmethod
from logging import Logger
from pathlib import Path
from typing import Callable, Optional
from typing import Callable, Optional, Set, Union
from pydantic import BaseModel, Field
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
default_logger: Logger = InvokeAILogger.get_logger()
@dataclass
class SearchStats:
"""Statistics about the search.
Attributes:
items_scanned: number of items scanned
models_found: number of models found
models_filtered: number of models that passed the filter
class SearchStats(BaseModel):
items_scanned: int = 0
models_found: int = 0
models_filtered: int = 0
class ModelSearchBase(ABC, BaseModel):
"""
items_scanned = 0
models_found = 0
models_filtered = 0
class ModelSearch:
"""Searches a directory tree for models, using a callback to filter the results.
Abstract directory traversal model search class
Usage:
search = ModelSearch()
search.model_found = lambda path : 'anime' in path.as_posix()
found = search.list_models(['/tmp/models1','/tmp/models2'])
# returns all models that have 'anime' in the path
search = ModelSearchBase(
on_search_started = search_started_callback,
on_search_completed = search_completed_callback,
on_model_found = model_found_callback,
)
models_found = search.search('/path/to/directory')
"""
def __init__(
self,
on_search_started: Optional[Callable[[Path], None]] = None,
on_model_found: Optional[Callable[[Path], bool]] = None,
on_search_completed: Optional[Callable[[set[Path]], None]] = None,
) -> None:
"""Create a new ModelSearch object.
# fmt: off
on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221
# fmt: on
Args:
on_search_started: callback to be invoked when the search starts
on_model_found: callback to be invoked when a model is found. The callback should return True if the model
should be included in the results.
on_search_completed: callback to be invoked when the search is completed
class Config:
arbitrary_types_allowed = True
@abstractmethod
def search_started(self) -> None:
"""
self.stats = SearchStats()
self.logger = InvokeAILogger.get_logger()
self.on_search_started = on_search_started
self.on_model_found = on_model_found
self.on_search_completed = on_search_completed
self.models_found: set[Path] = set()
Called before the scan starts.
Passes the root search directory to the Callable `on_search_started`.
"""
pass
@abstractmethod
def model_found(self, model: Path) -> None:
"""
Called when a model is found during search.
:param model: Model to process - could be a directory or checkpoint.
Passes the model's Path to the Callable `on_model_found`.
This Callable receives the path to the model and returns a boolean
to indicate whether the model should be returned in the search
results.
"""
pass
@abstractmethod
def search_completed(self) -> None:
"""
Called before the scan starts.
Passes the Set of found model Paths to the Callable `on_search_completed`.
"""
pass
@abstractmethod
def search(self, directory: Union[Path, str]) -> Set[Path]:
"""
Recursively search for models in `directory` and return a set of model paths.
If provided, the `on_search_started`, `on_model_found` and `on_search_completed`
Callables will be invoked during the search.
"""
pass
class ModelSearch(ModelSearchBase):
"""
Implementation of ModelSearch with callbacks.
Usage:
search = ModelSearch()
search.model_found = lambda path : 'anime' in path.as_posix()
found = search.list_models(['/tmp/models1','/tmp/models2'])
# returns all models that have 'anime' in the path
"""
models_found: Set[Path] = Field(default_factory=set)
config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
def search_started(self) -> None:
self.models_found = set()
@ -89,17 +135,17 @@ class ModelSearch:
if self.on_search_completed is not None:
self.on_search_completed(self.models_found)
def search(self, directory: Path) -> set[Path]:
def search(self, directory: Union[Path, str]) -> Set[Path]:
self._directory = Path(directory)
self._directory = self._directory.resolve()
if not self._directory.is_absolute():
self._directory = self.config.models_path / self._directory
self.stats = SearchStats() # zero out
self.search_started() # This will initialize _models_found to empty
self._walk_directory(self._directory)
self.search_completed()
return self.models_found
def _walk_directory(self, path: Path, max_depth: int = 20) -> None:
"""Recursively walk the directory tree, looking for models."""
def _walk_directory(self, path: Union[Path, str], max_depth: int = 20) -> None:
absolute_path = Path(path)
if (
len(absolute_path.parts) - len(self._directory.parts) > max_depth

View File

@ -455,6 +455,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
)
latents = step_output.prev_sample
latents = self.invokeai_diffuser.do_latent_postprocessing(
postprocessing_settings=conditioning_data.postprocessing_settings,
latents=latents,
sigma=batched_t,
step_index=i,
total_step_count=len(timesteps),
)
predicted_original = getattr(step_output, "pred_original_sample", None)
if callback is not None:

View File

@ -44,6 +44,14 @@ class SDXLConditioningInfo(BasicConditioningInfo):
return super().to(device=device, dtype=dtype)
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
warmup: float
h_symmetry_time_pct: Optional[float]
v_symmetry_time_pct: Optional[float]
@dataclass
class IPAdapterConditioningInfo:
cond_image_prompt_embeds: torch.Tensor
@ -72,6 +80,10 @@ class ConditioningData:
"""
guidance_rescale_multiplier: float = 0
scheduler_args: dict[str, Any] = field(default_factory=dict)
"""
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
"""
postprocessing_settings: Optional[PostprocessingSettings] = None
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None

View File

@ -12,6 +12,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData,
ExtraConditioningInfo,
PostprocessingSettings,
SDXLConditioningInfo,
)
@ -243,6 +244,19 @@ class InvokeAIDiffuserComponent:
return unconditioned_next_x, conditioned_next_x
def do_latent_postprocessing(
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
sigma,
step_index,
total_step_count,
) -> torch.Tensor:
if postprocessing_settings is not None:
percent_through = step_index / total_step_count
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
return latents
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones(
@ -492,3 +506,64 @@ class InvokeAIDiffuserComponent:
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
combined_next_x = unconditioned_next_x + scaled_delta
return combined_next_x
def apply_symmetry(
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
percent_through: float,
) -> torch.Tensor:
# Reset our last percent through if this is our first step.
if percent_through == 0.0:
self.last_percent_through = 0.0
if postprocessing_settings is None:
return latents
# Check for out of bounds
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
if h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0):
h_symmetry_time_pct = None
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
if v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0):
v_symmetry_time_pct = None
dev = latents.device.type
latents.to(device="cpu")
if (
h_symmetry_time_pct is not None
and self.last_percent_through < h_symmetry_time_pct
and percent_through >= h_symmetry_time_pct
):
# Horizontal symmetry occurs on the 3rd dimension of the latent
width = latents.shape[3]
x_flipped = torch.flip(latents, dims=[3])
latents = torch.cat(
[
latents[:, :, :, 0 : int(width / 2)],
x_flipped[:, :, :, int(width / 2) : int(width)],
],
dim=3,
)
if (
v_symmetry_time_pct is not None
and self.last_percent_through < v_symmetry_time_pct
and percent_through >= v_symmetry_time_pct
):
# Vertical symmetry occurs on the 2nd dimension of the latent
height = latents.shape[2]
y_flipped = torch.flip(latents, dims=[2])
latents = torch.cat(
[
latents[:, :, 0 : int(height / 2)],
y_flipped[:, :, int(height / 2) : int(height)],
],
dim=2,
)
self.last_percent_through = percent_through
return latents.to(device=dev)

View File

@ -5,7 +5,6 @@ from typing import Callable, List, Union
import torch.nn as nn
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
@ -27,7 +26,7 @@ def _conv_forward_asymmetric(self, input, weight, bias):
@contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
try:

View File

@ -858,9 +858,9 @@ def do_textual_inversion_training(
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
orig_embeds_params[index_no_updates]
)
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:

View File

@ -42,10 +42,9 @@ def install_and_load_model(
# If the requested model is already installed, return its LoadedModel
with contextlib.suppress(UnknownModelException):
# TODO: Replace with wrapper call
configs = model_manager.store.search_by_attr(
loaded_model: LoadedModel = model_manager.load_model_by_attr(
model_name=model_name, base_model=base_model, model_type=model_type
)
loaded_model: LoadedModel = model_manager.load.load_model(configs[0])
return loaded_model
# Install the requested model.
@ -54,7 +53,7 @@ def install_and_load_model(
assert job.complete
try:
loaded_model = model_manager.load.load_model(job.config_out)
loaded_model = model_manager.load_model_by_config(job.config_out)
return loaded_model
except UnknownModelException as e:
raise Exception(

View File

@ -62,72 +62,40 @@ sd-1/main/trinart_stable_diffusion_v2:
recommended: False
sd-1/controlnet/qrcode_monster:
source: monster-labs/control_v1p_sd15_qrcode_monster
description: Controlnet model that generates scannable creative QR codes
subfolder: v2
sd-1/controlnet/canny:
description: Controlnet weights trained on sd-1.5 with canny conditioning.
source: lllyasviel/control_v11p_sd15_canny
recommended: True
sd-1/controlnet/inpaint:
source: lllyasviel/control_v11p_sd15_inpaint
description: Controlnet weights trained on sd-1.5 with canny conditioning, inpaint version
sd-1/controlnet/mlsd:
description: Controlnet weights trained on sd-1.5 with canny conditioning, MLSD version
source: lllyasviel/control_v11p_sd15_mlsd
sd-1/controlnet/depth:
description: Controlnet weights trained on sd-1.5 with depth conditioning
source: lllyasviel/control_v11f1p_sd15_depth
recommended: True
sd-1/controlnet/normal_bae:
description: Controlnet weights trained on sd-1.5 with normalbae image conditioning
source: lllyasviel/control_v11p_sd15_normalbae
sd-1/controlnet/seg:
description: Controlnet weights trained on sd-1.5 with seg image conditioning
source: lllyasviel/control_v11p_sd15_seg
sd-1/controlnet/lineart:
description: Controlnet weights trained on sd-1.5 with lineart image conditioning
source: lllyasviel/control_v11p_sd15_lineart
recommended: True
sd-1/controlnet/lineart_anime:
description: Controlnet weights trained on sd-1.5 with anime image conditioning
source: lllyasviel/control_v11p_sd15s2_lineart_anime
sd-1/controlnet/openpose:
description: Controlnet weights trained on sd-1.5 with openpose image conditioning
source: lllyasviel/control_v11p_sd15_openpose
recommended: True
sd-1/controlnet/scribble:
source: lllyasviel/control_v11p_sd15_scribble
description: Controlnet weights trained on sd-1.5 with scribble image conditioning
recommended: False
sd-1/controlnet/softedge:
source: lllyasviel/control_v11p_sd15_softedge
description: Controlnet weights trained on sd-1.5 with soft edge conditioning
sd-1/controlnet/shuffle:
source: lllyasviel/control_v11e_sd15_shuffle
description: Controlnet weights trained on sd-1.5 with shuffle image conditioning
sd-1/controlnet/tile:
source: lllyasviel/control_v11f1e_sd15_tile
description: Controlnet weights trained on sd-1.5 with tiled image conditioning
sd-1/controlnet/ip2p:
source: lllyasviel/control_v11e_sd15_ip2p
description: Controlnet weights trained on sd-1.5 with ip2p conditioning.
sdxl/controlnet/canny-sdxl:
description: Controlnet weights trained on sdxl-1.0 with canny conditioning.
source: diffusers/controlnet-canny-sdxl-1.0
recommended: True
sdxl/controlnet/depth-sdxl:
description: Controlnet weights trained on sdxl-1.0 with depth conditioning.
source: diffusers/controlnet-depth-sdxl-1.0
recommended: True
sdxl/controlnet/softedge-dexined-sdxl:
description: Controlnet weights trained on sdxl-1.0 with dexined soft edge preprocessing.
source: SargeZT/controlnet-sd-xl-1.0-softedge-dexined
sdxl/controlnet/depth-16bit-zoe-sdxl:
description: Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).
source: SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe
sdxl/controlnet/depth-zoe-sdxl:
description: Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).
source: diffusers/controlnet-zoe-depth-sdxl-1.0
sd-1/t2i_adapter/canny-sd15:
source: TencentARC/t2iadapter_canny_sd15v2
sd-1/t2i_adapter/sketch-sd15:

View File

@ -608,9 +608,8 @@ def main() -> None:
config.parse_args(invoke_args)
logger = InvokeAILogger().get_logger(config=config)
if not config.models_path.exists():
if not config.model_conf_path.exists():
logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.")
sys.argv = ["invokeai_configure", "--yes", "--skip-sd-weights"]
from invokeai.frontend.install.invokeai_configure import invokeai_configure
invokeai_configure()

View File

@ -20,6 +20,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager import (
@ -412,7 +413,7 @@ def get_config_store() -> ModelRecordServiceSQL:
assert output_path is not None
image_files = DiskImageFileStorage(output_path / "images")
db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files)
return ModelRecordServiceSQL(db)
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger:

View File

@ -10,7 +10,7 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
const dispatch = useAppDispatch();
useGlobalModifiersInit();
useEffect(() => {
dispatch(modelChanged({ key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' }));
dispatch(modelChanged({ key: 'test_model', base: 'sd-1' }));
}, []);
return props.children;

View File

@ -1,3 +1,150 @@
# Invoke UI
<https://invoke-ai.github.io/InvokeAI/contributing/frontend/OVERVIEW/>
<!-- @import "[TOC]" {cmd="toc" depthFrom=2 depthTo=3 orderedList=false} -->
<!-- code_chunk_output -->
- [Dev environment](#dev-environment)
- [Setup](#setup)
- [Package scripts](#package-scripts)
- [Type generation](#type-generation)
- [Localization](#localization)
- [VSCode](#vscode)
- [Contributing](#contributing)
- [Check in before investing your time](#check-in-before-investing-your-time)
- [Commit format](#commit-format)
- [Submitting a PR](#submitting-a-pr)
- [Other docs](#other-docs)
<!-- /code_chunk_output -->
Invoke's UI is made possible by many contributors and open-source libraries. Thank you!
## Dev environment
### Setup
1. Install [node] and [pnpm].
1. Run `pnpm i` to install all packages.
#### Run in dev mode
1. From `invokeai/frontend/web/`, run `pnpm dev`.
1. From repo root, run `python scripts/invokeai-web.py`.
1. Point your browser to the dev server address, e.g. <http://localhost:5173/>
### Package scripts
- `dev`: run the frontend in dev mode, enabling hot reloading
- `build`: run all checks (madge, eslint, prettier, tsc) and then build the frontend
- `typegen`: generate types from the OpenAPI schema (see [Type generation])
- `lint:madge`: check frontend for circular dependencies
- `lint:eslint`: check frontend for code quality
- `lint:prettier`: check frontend for code formatting
- `lint:tsc`: check frontend for type issues
- `lint`: run all checks concurrently
- `fix`: run `eslint` and `prettier`, fixing fixable issues
### Type generation
We use [openapi-typescript] to generate types from the app's OpenAPI schema.
The generated types are committed to the repo in [schema.ts].
```sh
# from the repo root, start the server
python scripts/invokeai-web.py
# from invokeai/frontend/web/, run the script
pnpm typegen
```
### Localization
We use [i18next] for localization, but translation to languages other than English happens on our [Weblate] project.
Only the English source strings should be changed on this repo.
### VSCode
#### Example debugger config
```jsonc
{
"version": "0.2.0",
"configurations": [
{
"type": "chrome",
"request": "launch",
"name": "Invoke UI",
"url": "http://localhost:5173",
"webRoot": "${workspaceFolder}/invokeai/frontend/web",
},
],
}
```
#### Remote dev
We've noticed an intermittent timeout issue with the VSCode remote dev port forwarding.
We suggest disabling the editor's port forwarding feature and doing it manually via SSH:
```sh
ssh -L 9090:localhost:9090 -L 5173:localhost:5173 user@host
```
## Contributing Guidelines
Thanks for your interest in contributing to the Invoke Web UI!
Please follow these guidelines when contributing.
### Check in before investing your time
Please check in before you invest your time on anything besides a trivial fix, in case it conflicts with ongoing work or isn't aligned with the vision for the app.
If a feature request or issue doesn't already exist for the thing you want to work on, please create one.
Ping `@psychedelicious` on [discord] in the `#frontend-dev` channel or in the feature request / issue you want to work on - we're happy chat.
### Code conventions
- This is a fairly complex app with a deep component tree. Please use memoization (`useCallback`, `useMemo`, `memo`) with enthusiasm.
- If you need to add some global, ephemeral state, please use [nanostores] if possible.
- Be careful with your redux selectors. If they need to be parameterized, consider creating them inside a `useMemo`.
- Feel free to use `lodash` (via `lodash-es`) to make the intent of your code clear.
- Please add comments describing the "why", not the "how" (unless it is really arcane).
### Commit format
Please use the [conventional commits] spec for the web UI, with a scope of "ui":
- `chore(ui): bump deps`
- `chore(ui): lint`
- `feat(ui): add some cool new feature`
- `fix(ui): fix some bug`
### Submitting a PR
- Ensure your branch is tidy. Use an interactive rebase to clean up the commit history and reword the commit messages if they are not descriptive.
- Run `pnpm lint`. Some issues are auto-fixable with `pnpm fix`.
- Fill out the PR form when creating the PR.
- It doesn't need to be super detailed, but a screenshot or video is nice if you changed something visually.
- If a section isn't relevant, delete it. There are no UI tests at this time.
## Other docs
- [Workflows - Design and Implementation]
- [State Management]
[node]: https://nodejs.org/en/download/
[pnpm]: https://github.com/pnpm/pnpm
[discord]: https://discord.gg/ZmtBAhwWhy
[i18next]: https://github.com/i18next/react-i18next
[Weblate]: https://hosted.weblate.org/engage/invokeai/
[openapi-typescript]: https://github.com/drwpow/openapi-typescript
[Type generation]: #type-generation
[schema.ts]: ../src/services/api/schema.ts
[conventional commits]: https://www.conventionalcommits.org/en/v1.0.0/
[Workflows - Design and Implementation]: ./docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md
[State Management]: ./docs/STATE_MGMT.md

View File

@ -1,5 +1,40 @@
# Workflows - Design and Implementation
<!-- @import "[TOC]" {cmd="toc" depthFrom=1 depthTo=6 orderedList=false} -->
<!-- code_chunk_output -->
- [Workflows - Design and Implementation](#workflows---design-and-implementation)
- [Design](#design)
- [Linear UI](#linear-ui)
- [Workflow Editor](#workflow-editor)
- [Workflows](#workflows)
- [Workflow -> reactflow state -> InvokeAI graph](#workflow---reactflow-state---invokeai-graph)
- [Nodes vs Invocations](#nodes-vs-invocations)
- [Workflow Linear View](#workflow-linear-view)
- [OpenAPI Schema](#openapi-schema)
- [Field Instances and Templates](#field-instances-and-templates)
- [Stateful vs Stateless Fields](#stateful-vs-stateless-fields)
- [Collection and Polymorphic Fields](#collection-and-polymorphic-fields)
- [Implementation](#implementation)
- [zod Schemas and Types](#zod-schemas-and-types)
- [OpenAPI Schema Parsing](#openapi-schema-parsing)
- [Parsing Field Types](#parsing-field-types)
- [Primitive Types](#primitive-types)
- [Complex Types](#complex-types)
- [Collection Types](#collection-types)
- [Collection or Scalar Types](#collection-or-scalar-types)
- [Optional Fields](#optional-fields)
- [Building Field Input Templates](#building-field-input-templates)
- [Building Field Output Templates](#building-field-output-templates)
- [Managing reactflow State](#managing-reactflow-state)
- [Building Nodes and Edges](#building-nodes-and-edges)
- [Building a Workflow](#building-a-workflow)
- [Loading a Workflow](#loading-a-workflow)
- [Workflow Migrations](#workflow-migrations)
<!-- /code_chunk_output -->
> This document describes, at a high level, the design and implementation of workflows in the InvokeAI frontend. There are a substantial number of implementation details not included, but which are hopefully clear from the code.
InvokeAI's backend uses graphs, composed of **nodes** and **edges**, to process data and generate images.
@ -117,13 +152,13 @@ Stateless fields do not store their value in the node, so their field instances
"Custom" fields will always be treated as stateless fields.
##### Collection and Scalar Fields
##### Collection and Polymorphic Fields
Field types have a name and two flags which may identify it as a **collection** or **collection or scalar** field.
Field types have a name and two flags which may identify it as a **collection** or **polymorphic** field.
If a field is annotated in python as a list, its field type is parsed and flagged as a **collection** type (e.g. `list[int]`).
If a field is annotated in python as a list, its field type is parsed and flagged as a collection type (e.g. `list[int]`).
If it is annotated as a union of a type and list, the type will be flagged as a **collection or scalar** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
If it is annotated as a union of a type and list, the type will be flagged as a polymorphic type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
## Implementation
@ -303,13 +338,13 @@ Migration logic is in [migrations.ts].
[reactflow]: https://github.com/xyflow/xyflow 'reactflow'
[reactflow-concepts]: https://reactflow.dev/learn/concepts/terms-and-definitions
[reactflow-events]: https://reactflow.dev/api-reference/react-flow#event-handlers
[buildWorkflow.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts
[nodesSlice.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
[buildLinearTextToImageGraph.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts
[buildNodesGraph.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts
[buildInvocationNode.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts
[validateWorkflow.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts
[migrations.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts
[parseSchema.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts
[buildFieldInputTemplate.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts
[buildFieldOutputTemplate.ts]: https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldOutputTemplate.ts
[buildWorkflow.ts]: ../src/features/nodes/util/workflow/buildWorkflow.ts
[nodesSlice.ts]: ../src/features/nodes/store/nodesSlice.ts
[buildLinearTextToImageGraph.ts]: ../src/features/nodes/util/graph/buildLinearTextToImageGraph.ts
[buildNodesGraph.ts]: ../src/features/nodes/util/graph/buildNodesGraph.ts
[buildInvocationNode.ts]: ../src/features/nodes/util/node/buildInvocationNode.ts
[validateWorkflow.ts]: ../src/features/nodes/util/workflow/validateWorkflow.ts
[migrations.ts]: ../src/features/nodes/util/workflow/migrations.ts
[parseSchema.ts]: ../src/features/nodes/util/schema/parseSchema.ts
[buildFieldInputTemplate.ts]: ../src/features/nodes/util/schema/buildFieldInputTemplate.ts
[buildFieldOutputTemplate.ts]: ../src/features/nodes/util/schema/buildFieldOutputTemplate.ts

File diff suppressed because it is too large Load Diff

View File

@ -115,8 +115,7 @@
"safetensors": "Safetensors",
"ai": "ia",
"file": "File",
"toResolve": "Da risolvere",
"add": "Aggiungi"
"toResolve": "Da risolvere"
},
"gallery": {
"generations": "Generazioni",
@ -154,12 +153,7 @@
"starImage": "Immagine preferita",
"dropToUpload": "$t(gallery.drop) per aggiornare",
"problemDeletingImagesDesc": "Impossibile eliminare una o più immagini",
"problemDeletingImages": "Problema durante l'eliminazione delle immagini",
"bulkDownloadRequested": "Preparazione del download",
"bulkDownloadRequestedDesc": "La tua richiesta di download è in preparazione. L'operazione potrebbe richiedere alcuni istanti.",
"bulkDownloadRequestFailed": "Problema durante la preparazione del download",
"bulkDownloadStarting": "Avvio scaricamento",
"bulkDownloadFailed": "Scaricamento fallito"
"problemDeletingImages": "Problema durante l'eliminazione delle immagini"
},
"hotkeys": {
"keyboardShortcuts": "Tasti di scelta rapida",
@ -511,12 +505,12 @@
"modelSyncFailed": "Sincronizzazione modello non riuscita",
"settings": "Impostazioni",
"syncModels": "Sincronizza Modelli",
"syncModelsDesc": "Se i tuoi modelli non sono sincronizzati con il back-end, puoi aggiornarli utilizzando questa opzione. Questo è generalmente utile nei casi in cui aggiungi modelli alla cartella principale di InvokeAI dopo l'avvio dell'applicazione.",
"syncModelsDesc": "Se i tuoi modelli non sono sincronizzati con il back-end, puoi aggiornarli utilizzando questa opzione. Questo è generalmente utile nei casi in cui aggiorni manualmente il tuo file models.yaml o aggiungi modelli alla cartella principale di InvokeAI dopo l'avvio dell'applicazione.",
"loraModels": "LoRA",
"oliveModels": "Olive",
"onnxModels": "ONNX",
"noModels": "Nessun modello trovato",
"predictionType": "Tipo di previsione",
"predictionType": "Tipo di previsione (per modelli Stable Diffusion 2.x ed alcuni modelli Stable Diffusion 1.x)",
"quickAdd": "Aggiunta rapida",
"simpleModelDesc": "Fornire un percorso a un modello diffusori locale, un modello checkpoint/safetensor locale, un ID repository HuggingFace o un URL del modello checkpoint/diffusori.",
"advanced": "Avanzate",
@ -527,34 +521,7 @@
"vaePrecision": "Precisione VAE",
"noModelSelected": "Nessun modello selezionato",
"conversionNotSupported": "Conversione non supportata",
"configFile": "File di configurazione",
"modelName": "Nome del modello",
"modelSettings": "Impostazioni del modello",
"advancedImportInfo": "La scheda opzioni avanzate consente la configurazione manuale delle impostazioni del modello principale. Utilizza questa scheda solo se sei sicuro di conoscere il tipo di modello e la configurazione corretti per il modello selezionato.",
"addAll": "Aggiungi tutto",
"addModels": "Aggiungi modelli",
"cancel": "Annulla",
"edit": "Modifica",
"imageEncoderModelId": "ID modello codificatore di immagini",
"importQueue": "Coda di importazione",
"modelMetadata": "Metadati del modello",
"path": "Percorso",
"prune": "Elimina",
"pruneTooltip": "Elimina dalla coda le importazioni completate",
"removeFromQueue": "Rimuovi dalla coda",
"repoVariant": "Variante del repository",
"scan": "Scansiona",
"scanFolder": "Scansione cartella",
"scanResults": "Risultati della scansione",
"source": "Sorgente",
"upcastAttention": "Eleva l'attenzione",
"ztsnrTraining": "Addestramento ZTSNR",
"typePhraseHere": "Digita la frase qui",
"defaultSettingsSaved": "Impostazioni predefinite salvate",
"defaultSettings": "Impostazioni predefinite",
"metadata": "Metadati",
"useDefaultSettings": "Usa le impostazioni predefinite",
"triggerPhrases": "Frasi trigger"
"configFile": "File di configurazione"
},
"parameters": {
"images": "Immagini",
@ -636,8 +603,8 @@
"clipSkip": "CLIP Skip",
"aspectRatio": "Proporzioni",
"maskAdjustmentsHeader": "Regolazioni della maschera",
"maskBlur": "Sfocatura maschera",
"maskBlurMethod": "Metodo sfocatura maschera",
"maskBlur": "Sfocatura",
"maskBlurMethod": "Metodo di sfocatura",
"seamLowThreshold": "Basso",
"seamHighThreshold": "Alto",
"coherencePassHeader": "Passaggio di coerenza",
@ -694,8 +661,7 @@
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (potrebbe essere troppo grande)",
"boxBlur": "Box",
"gaussianBlur": "Gaussian",
"remixImage": "Remixa l'immagine",
"coherenceEdgeSize": "Dimensione bordo"
"remixImage": "Remixa l'immagine"
},
"settings": {
"models": "Modelli",
@ -778,8 +744,8 @@
"canceled": "Elaborazione annullata",
"problemCopyingImageLink": "Impossibile copiare il collegamento dell'immagine",
"uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG",
"parameterSet": "{{parameter}} impostato",
"parameterNotSet": "{{parameter}} non impostato",
"parameterSet": "Parametro impostato",
"parameterNotSet": "Parametro non impostato",
"nodesLoadedFailed": "Impossibile caricare i nodi",
"nodesSaved": "Nodi salvati",
"nodesLoaded": "Nodi caricati",
@ -832,10 +798,7 @@
"problemRetrievingWorkflow": "Problema nel recupero del flusso di lavoro",
"resetInitialImage": "Reimposta l'immagine iniziale",
"uploadInitialImage": "Carica l'immagine iniziale",
"problemDownloadingImage": "Impossibile scaricare l'immagine",
"prunedQueue": "Coda ripulita",
"modelImportCanceled": "Importazione del modello annullata",
"modelImportRemoved": "Importazione del modello rimossa"
"problemDownloadingImage": "Impossibile scaricare l'immagine"
},
"tooltip": {
"feature": {
@ -913,10 +876,7 @@
"antialiasing": "Anti aliasing",
"showResultsOn": "Mostra i risultati (attivato)",
"showResultsOff": "Mostra i risultati (disattivato)",
"saveMask": "Salva $t(unifiedCanvas.mask)",
"coherenceModeGaussianBlur": "Sfocatura Gaussiana",
"coherenceModeBoxBlur": "Sfocatura Box",
"coherenceModeStaged": "Maschera espansa"
"saveMask": "Salva $t(unifiedCanvas.mask)"
},
"accessibility": {
"modelSelect": "Seleziona modello",
@ -1385,8 +1345,7 @@
"allLoRAsAdded": "Tutti i LoRA aggiunti",
"defaultVAE": "VAE predefinito",
"incompatibleBaseModel": "Modello base incompatibile",
"loraAlreadyAdded": "LoRA già aggiunto",
"concepts": "Concetti"
"loraAlreadyAdded": "LoRA già aggiunto"
},
"invocationCache": {
"disable": "Disabilita",
@ -1739,25 +1698,6 @@
"paragraphs": [
"Valuta le generazioni in modo che siano più simili alle immagini con un punteggio estetico elevato, in base ai dati di addestramento."
]
},
"compositingCoherenceMinDenoise": {
"heading": "Livello minimo di riduzione del rumore",
"paragraphs": [
"Intensità minima di riduzione rumore per la modalità di Coerenza",
"L'intensità minima di riduzione del rumore per la regione di coerenza durante l'inpainting o l'outpainting"
]
},
"compositingMaskBlur": {
"paragraphs": [
"Il raggio di sfocatura della maschera."
],
"heading": "Sfocatura maschera"
},
"compositingCoherenceEdgeSize": {
"heading": "Dimensione del bordo",
"paragraphs": [
"La dimensione del bordo del passaggio di coerenza."
]
}
},
"sdxl": {
@ -1806,12 +1746,7 @@
"scheduler": "Campionatore",
"recallParameters": "Richiama i parametri",
"noRecallParameters": "Nessun parametro da richiamare trovato",
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
"allPrompts": "Tutti i prompt",
"imageDimensions": "Dimensioni dell'immagine",
"parameterSet": "Parametro {{parameter}} impostato",
"parsingFailed": "Analisi non riuscita",
"recallParameter": "Richiama {{label}}"
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
},
"hrf": {
"enableHrf": "Abilita Correzione Alta Risoluzione",
@ -1883,11 +1818,5 @@
"image": {
"title": "Immagine"
}
},
"prompt": {
"compatibleEmbeddings": "Incorporamenti compatibili",
"addPromptTrigger": "Aggiungi parola chiave nel prompt",
"noPromptTriggers": "Nessuna parola chiave disponibile",
"noMatchingTriggers": "Nessuna parola chiave corrispondente"
}
}

View File

@ -52,7 +52,7 @@
"accept": "Принять",
"postprocessing": "Постобработка",
"txt2img": "Текст в изображение (txt2img)",
"linear": "Линейный вид",
"linear": "Линейная обработка",
"dontAskMeAgain": "Больше не спрашивать",
"areYouSure": "Вы уверены?",
"random": "Случайное",
@ -117,8 +117,7 @@
"toResolve": "Чтоб решить",
"copy": "Копировать",
"localSystem": "Локальная система",
"aboutDesc": "Используя Invoke для работы? Проверьте это:",
"add": "Добавить"
"aboutDesc": "Используя Invoke для работы? Проверьте это:"
},
"gallery": {
"generations": "Генерации",
@ -156,12 +155,7 @@
"noImageSelected": "Изображение не выбрано",
"setCurrentImage": "Установить как текущее изображение",
"starImage": "Добавить в избранное",
"dropToUpload": "$t(gallery.drop) чтоб загрузить",
"bulkDownloadFailed": "Загрузка не удалась",
"bulkDownloadStarting": "Начало загрузки",
"bulkDownloadRequested": "Подготовка к скачиванию",
"bulkDownloadRequestedDesc": "Ваш запрос на скачивание готовится. Это может занять несколько минут.",
"bulkDownloadRequestFailed": "Возникла проблема при подготовке скачивания"
"dropToUpload": "$t(gallery.drop) чтоб загрузить"
},
"hotkeys": {
"keyboardShortcuts": "Горячие клавиши",
@ -510,7 +504,7 @@
"settings": "Настройки",
"selectModel": "Выберите модель",
"syncModels": "Синхронизация моделей",
"syncModelsDesc": "Если ваши модели не синхронизированы с серверной частью, вы можете обновить их с помощью этой опции. Обычно это удобно в тех случаях, когда вы добавляете модели в корневую папку InvokeAI или каталог автоимпорта после загрузки приложения.",
"syncModelsDesc": "Если ваши модели не синхронизированы с серверной частью, вы можете обновить их, используя эту опцию. Обычно это удобно в тех случаях, когда вы вручную обновляете свой файл \"models.yaml\" или добавляете модели в корневую папку InvokeAI после загрузки приложения.",
"modelUpdateFailed": "Не удалось обновить модель",
"modelConversionFailed": "Не удалось сконвертировать модель",
"modelsMergeFailed": "Не удалось выполнить слияние моделей",
@ -519,7 +513,7 @@
"oliveModels": "Модели Olives",
"conversionNotSupported": "Преобразование не поддерживается",
"noModels": "Нет моделей",
"predictionType": "Тип прогноза",
"predictionType": "Тип прогноза (для моделей Stable Diffusion 2.x и периодических моделей Stable Diffusion 1.x)",
"quickAdd": "Быстрое добавление",
"simpleModelDesc": "Укажите путь к локальной модели Diffusers , локальной модели checkpoint / safetensors, идентификатор репозитория HuggingFace или URL-адрес модели контрольной checkpoint / diffusers.",
"advanced": "Продвинутый",
@ -530,33 +524,7 @@
"customConfigFileLocation": "Расположение пользовательского файла конфигурации",
"vaePrecision": "Точность VAE",
"noModelSelected": "Модель не выбрана",
"configFile": "Файл конфигурации",
"addAll": "Добавить всё",
"addModels": "Добавить модели",
"cancel": "Отмена",
"defaultSettings": "Стандартные настройки",
"importQueue": "Импортировать очередь",
"metadata": "Метаданные",
"imageEncoderModelId": "ID модели-энкодера изображений",
"typePhraseHere": "Введите фразы здесь",
"advancedImportInfo": "Вкладка «Дополнительно» позволяет вручную настроить основные параметры модели. Используйте эту вкладку только в том случае, если вы уверены, что знаете правильный тип модели и конфигурацию выбранной модели.",
"defaultSettingsSaved": "Стандартные настройки сохранены",
"edit": "Редактировать",
"path": "Путь",
"prune": "Удалить",
"pruneTooltip": "Удалить готовые импорты из очереди",
"removeFromQueue": "Удалить из очереди",
"repoVariant": "Вариант репозитория",
"scan": "Сканировать",
"scanFolder": "Сканировать папку",
"scanResults": "Результаты сканирования",
"source": "Источник",
"triggerPhrases": "Триггерные фразы",
"useDefaultSettings": "Использовать стандартные настройки",
"modelMetadata": "Метаданные модели",
"modelName": "Название модели",
"modelSettings": "Настройки модели",
"upcastAttention": "Внимание"
"configFile": "Файл конфигурации"
},
"parameters": {
"images": "Изображения",
@ -623,7 +591,7 @@
"hSymmetryStep": "Шаг гор. симметрии",
"hidePreview": "Скрыть предпросмотр",
"imageToImage": "Изображение в изображение",
"denoisingStrength": "Сила зашумления",
"denoisingStrength": "Сила шумоподавления",
"copyImage": "Скопировать изображение",
"showPreview": "Показать предпросмотр",
"noiseSettings": "Шум",
@ -638,8 +606,8 @@
"clipSkip": "CLIP Пропуск",
"aspectRatio": "Соотношение",
"maskAdjustmentsHeader": "Настройка маски",
"maskBlur": "Размытие маски",
"maskBlurMethod": "Метод размытия маски",
"maskBlur": "Размытие",
"maskBlurMethod": "Метод размытия",
"seamLowThreshold": "Низкий",
"seamHighThreshold": "Высокий",
"coherencePassHeader": "Порог Coherence",
@ -698,9 +666,7 @@
"lockAspectRatio": "Заблокировать соотношение",
"boxBlur": "Размытие прямоугольника",
"gaussianBlur": "Размытие по Гауссу",
"remixImage": "Ремикс изображения",
"coherenceMinDenoise": "Мин. шумоподавление",
"coherenceEdgeSize": "Размер края"
"remixImage": "Ремикс изображения"
},
"settings": {
"models": "Модели",
@ -783,8 +749,8 @@
"canceled": "Обработка отменена",
"problemCopyingImageLink": "Не удалось скопировать ссылку на изображение",
"uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG",
"parameterNotSet": "Параметр {{parameter}} не задан",
"parameterSet": "Параметр {{parameter}} задан",
"parameterNotSet": "Параметр не задан",
"parameterSet": "Параметр задан",
"nodesLoaded": "Узлы загружены",
"problemCopyingImage": "Не удается скопировать изображение",
"nodesLoadedFailed": "Не удалось загрузить Узлы",
@ -837,10 +803,7 @@
"problemImportingMask": "Проблема с импортом маски",
"problemDownloadingImage": "Не удается скачать изображение",
"uploadInitialImage": "Загрузить начальное изображение",
"resetInitialImage": "Сбросить начальное изображение",
"prunedQueue": "Урезанная очередь",
"modelImportCanceled": "Импорт модели отменен",
"modelImportRemoved": "Импорт модели удален"
"resetInitialImage": "Сбросить начальное изображение"
},
"tooltip": {
"feature": {
@ -1182,11 +1145,7 @@
"reorderLinearView": "Изменить порядок линейного просмотра",
"viewMode": "Использовать в линейном представлении",
"editMode": "Открыть в редакторе узлов",
"resetToDefaultValue": "Сбросить к стандартному значкнию",
"latentsField": "Латенты",
"latentsCollectionDescription": "Латенты могут передаваться между узлами.",
"latentsPolymorphicDescription": "Латенты могут передаваться между узлами.",
"latentsFieldDescription": "Латенты могут передаваться между узлами."
"resetToDefaultValue": "Сбросить к стандартному значкнию"
},
"controlnet": {
"amult": "a_mult",
@ -1335,8 +1294,7 @@
},
"paramScheduler": {
"paragraphs": [
"Планировщик, используемый в процессе генерации.",
"Каждый планировщик определяет, как итеративно добавлять шум к изображению или как обновлять образец на основе выходных данных модели."
"Планировщик определяет, как итеративно добавлять шум к изображению или как обновлять образец на основе выходных данных модели."
],
"heading": "Планировщик"
},
@ -1389,7 +1347,7 @@
"compositingCoherenceMode": {
"heading": "Режим",
"paragraphs": [
"Метод, используемый для создания связного изображения с вновь созданной замаскированной областью."
"Режим прохождения когерентности."
]
},
"paramSeed": {
@ -1407,7 +1365,7 @@
},
"controlNetBeginEnd": {
"paragraphs": [
"Часть процесса шумоподавления, к которой будет применен адаптер контроля.",
"На каких этапах процесса шумоподавления будет применена ControlNet.",
"ControlNet, применяемые в начале процесса, направляют композицию, а ControlNet, применяемые в конце, направляют детали."
],
"heading": "Процент начала/конца шага"
@ -1423,8 +1381,8 @@
},
"clipSkip": {
"paragraphs": [
"Сколько слоев модели CLIP пропустить.",
"Некоторые модели лучше подходят для использования с CLIP Skip."
"Выберите, сколько слоев модели CLIP нужно пропустить.",
"Некоторые модели работают лучше с определенными настройками пропуска CLIP."
],
"heading": "CLIP пропуск"
},
@ -1521,25 +1479,6 @@
"paragraphs": [
"Более высокий вес LoRA приведет к большему влиянию на конечное изображение."
]
},
"compositingMaskBlur": {
"heading": "Размытие маски",
"paragraphs": [
"Радиус размытия маски."
]
},
"compositingCoherenceMinDenoise": {
"heading": "Минимальное шумоподавление",
"paragraphs": [
"Минимальный уровень шумоподавления для режима Coherence",
"Минимальный уровень шумоподавления для области когерентности при перерисовывании или дорисовке"
]
},
"compositingCoherenceEdgeSize": {
"heading": "Размер края",
"paragraphs": [
"Размер края прохода когерентности."
]
}
},
"metadata": {
@ -1570,12 +1509,7 @@
"steps": "Шаги",
"scheduler": "Планировщик",
"noRecallParameters": "Параметры для вызова не найдены",
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
"parameterSet": "Параметр {{parameter}} установлен",
"parsingFailed": "Не удалось выполнить синтаксический анализ",
"recallParameter": "Отозвать {{label}}",
"allPrompts": "Все запросы",
"imageDimensions": "Размеры изображения"
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
},
"queue": {
"status": "Статус",
@ -1654,11 +1588,10 @@
"denoisingStrength": "Шумоподавление",
"refinermodel": "Модель перерисовщик",
"posAestheticScore": "Положительная эстетическая оценка",
"concatPromptStyle": "Связывание запроса и стиля",
"concatPromptStyle": "Объединение запроса и стиля",
"loading": "Загрузка...",
"steps": "Шаги",
"posStylePrompt": "Запрос стиля",
"freePromptStyle": "Ручной запрос стиля"
"posStylePrompt": "Запрос стиля"
},
"invocationCache": {
"useCache": "Использовать кэш",
@ -1745,8 +1678,7 @@
"allLoRAsAdded": "Все LoRA добавлены",
"defaultVAE": "Стандартное VAE",
"incompatibleBaseModel": "Несовместимая базовая модель",
"loraAlreadyAdded": "LoRA уже добавлена",
"concepts": "Концепты"
"loraAlreadyAdded": "LoRA уже добавлена"
},
"app": {
"storeNotInitialized": "Магазин не инициализирован"
@ -1764,7 +1696,7 @@
},
"generation": {
"title": "Генерация",
"conceptsTab": "LoRA",
"conceptsTab": "Концепты",
"modelTab": "Модель"
},
"advanced": {

View File

@ -1,88 +0,0 @@
# Cleans translations by removing unused keys
# Usage: python clean_translations.py
# Note: Must be run from invokeai/frontend/web/scripts directory
#
# After running the script, open `en.json` and check for empty objects (`{}`) and remove them manually.
import json
import os
import re
from typing import TypeAlias, Union
from tqdm import tqdm
RecursiveDict: TypeAlias = dict[str, Union["RecursiveDict", str]]
class TranslationCleaner:
file_cache: dict[str, str] = {}
def _get_keys(self, obj: RecursiveDict, current_path: str = "", keys: list[str] | None = None):
if keys is None:
keys = []
for key in obj:
new_path = f"{current_path}.{key}" if current_path else key
next_ = obj[key]
if isinstance(next_, dict):
self._get_keys(next_, new_path, keys)
elif "_" in key:
# This typically means its a pluralized key
continue
else:
keys.append(new_path)
return keys
def _search_codebase(self, key: str):
for root, _dirs, files in os.walk("../src"):
for file in files:
if file.endswith(".ts") or file.endswith(".tsx"):
full_path = os.path.join(root, file)
if full_path in self.file_cache:
content = self.file_cache[full_path]
else:
with open(full_path, "r") as f:
content = f.read()
self.file_cache[full_path] = content
# match the whole key, surrounding by quotes
if re.search(r"['\"`]" + re.escape(key) + r"['\"`]", self.file_cache[full_path]):
return True
# math the stem of the key, with quotes at the end
if re.search(re.escape(key.split(".")[-1]) + r"['\"`]", self.file_cache[full_path]):
return True
return False
def _remove_key(self, obj: RecursiveDict, key: str):
path = key.split(".")
last_key = path[-1]
for k in path[:-1]:
obj = obj[k]
del obj[last_key]
def clean(self, obj: RecursiveDict) -> RecursiveDict:
keys = self._get_keys(obj)
pbar = tqdm(keys, desc="Checking keys")
for key in pbar:
if not self._search_codebase(key):
self._remove_key(obj, key)
return obj
def main():
try:
with open("../public/locales/en.json", "r") as f:
data = json.load(f)
except FileNotFoundError as e:
raise FileNotFoundError(
"Unable to find en.json file - must be run from invokeai/frontend/web/scripts directory"
) from e
cleaner = TranslationCleaner()
cleaned_data = cleaner.clean(data)
with open("../public/locales/en.json", "w") as f:
json.dump(cleaned_data, f, indent=4)
if __name__ == "__main__":
main()

View File

@ -5,55 +5,18 @@ import openapiTS from 'openapi-typescript';
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
const OUTPUT_FILE = 'src/services/api/schema.ts';
async function generateTypes(schema) {
process.stdout.write(`Generating types ${OUTPUT_FILE}...`);
const types = await openapiTS(schema, {
async function main() {
process.stdout.write(`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...`);
const types = await openapiTS(OPENAPI_URL, {
exportType: true,
transform: (schemaObject) => {
if ('format' in schemaObject && schemaObject.format === 'binary') {
return schemaObject.nullable ? 'Blob | null' : 'Blob';
}
if (schemaObject.title === 'MetadataField') {
// This is `Record<string, never>` by default, but it actually accepts any a dict of any valid JSON value.
return 'Record<string, unknown>';
}
},
});
fs.writeFileSync(OUTPUT_FILE, types);
process.stdout.write(`\nOK!\r\n`);
}
async function main() {
const encoding = 'utf-8';
if (process.stdin.isTTY) {
// Handle generating types with an arg (e.g. URL or path to file)
if (process.argv.length > 3) {
console.error('Usage: typegen.js <openapi.json>');
process.exit(1);
}
if (process.argv[2]) {
const schema = new Buffer.from(process.argv[2], encoding);
generateTypes(schema);
} else {
generateTypes(OPENAPI_URL);
}
} else {
// Handle generating types from stdin
let schema = '';
process.stdin.setEncoding(encoding);
process.stdin.on('readable', function () {
const chunk = process.stdin.read();
if (chunk !== null) {
schema += chunk;
}
});
process.stdin.on('end', function () {
generateTypes(JSON.parse(schema));
});
}
}
main();

View File

@ -38,7 +38,7 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi
type: 'image/png',
}),
image_category: 'control',
is_intermediate: true,
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: false,
postUploadAction: {

View File

@ -48,7 +48,7 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis
type: 'image/png',
}),
image_category: 'mask',
is_intermediate: true,
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: false,
postUploadAction: {

View File

@ -101,7 +101,7 @@ export const addEnqueueRequestedCanvasListener = (startAppListening: AppStartLis
).unwrap();
}
const graph = await buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
const graph = buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
log.debug({ graph: parseify(graph) }, `Canvas graph built`);

View File

@ -20,15 +20,15 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
if (model && model.base === 'sdxl') {
if (action.payload.tabName === 'txt2img') {
graph = await buildLinearSDXLTextToImageGraph(state);
graph = buildLinearSDXLTextToImageGraph(state);
} else {
graph = await buildLinearSDXLImageToImageGraph(state);
graph = buildLinearSDXLImageToImageGraph(state);
}
} else {
if (action.payload.tabName === 'txt2img') {
graph = await buildLinearTextToImageGraph(state);
graph = buildLinearTextToImageGraph(state);
} else {
graph = await buildLinearImageToImageGraph(state);
graph = buildLinearImageToImageGraph(state);
}
}

View File

@ -1,10 +1,10 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import type { JSONObject } from 'common/types';
import {
controlAdapterModelCleared,
selectControlAdapterAll,
selectAllControlNets,
selectAllIPAdapters,
selectAllT2IAdapters,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
@ -12,161 +12,212 @@ import { heightChanged, modelChanged, vaeSelected, widthChanged } from 'features
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import { forEach } from 'lodash-es';
import type { Logger } from 'roarr';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import { isNonRefinerMainModelConfig, isRefinerMainModelModelConfig, isVAEModelConfig } from 'services/api/types';
import { forEach, some } from 'lodash-es';
import { mainModelsAdapterSelectors, modelsApi, vaeModelsAdapterSelectors } from 'services/api/endpoints/models';
import type { TypeGuardFor } from 'services/api/types';
export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
startAppListening({
predicate: modelsApi.endpoints.getModelConfigs.matchFulfilled,
predicate: (action): action is TypeGuardFor<typeof modelsApi.endpoints.getMainModels.matchFulfilled> =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models');
log.info({ models: action.payload.entities }, `Models loaded (${action.payload.ids.length})`);
log.info({ models: action.payload.entities }, `Main models loaded (${action.payload.ids.length})`);
const state = getState();
const models = modelConfigsAdapterSelectors.selectAll(action.payload);
const currentModel = state.generation.model;
const models = mainModelsAdapterSelectors.selectAll(action.payload);
handleMainModels(models, state, dispatch, log);
handleRefinerModels(models, state, dispatch, log);
handleVAEModels(models, state, dispatch, log);
handleLoRAModels(models, state, dispatch, log);
handleControlAdapterModels(models, state, dispatch, log);
if (models.length === 0) {
// No models loaded at all
dispatch(modelChanged(null));
return;
}
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
if (isCurrentModelAvailable) {
return;
}
const defaultModel = state.config.sd.defaultModel;
const defaultModelInList = defaultModel ? models.find((m) => m.key === defaultModel) : false;
if (defaultModelInList) {
const result = zParameterModel.safeParse(defaultModelInList);
if (result.success) {
dispatch(modelChanged(defaultModelInList, currentModel));
const optimalDimension = getOptimalDimension(defaultModelInList);
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
return;
}
const { width, height } = calculateNewSize(
state.generation.aspectRatio.value,
optimalDimension * optimalDimension
);
dispatch(widthChanged(width));
dispatch(heightChanged(height));
return;
}
}
const result = zParameterModel.safeParse(models[0]);
if (!result.success) {
log.error({ error: result.error.format() }, 'Failed to parse main model');
return;
}
dispatch(modelChanged(result.data, currentModel));
},
});
startAppListening({
predicate: (action): action is TypeGuardFor<typeof modelsApi.endpoints.getMainModels.matchFulfilled> =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) && action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models');
log.info({ models: action.payload.entities }, `SDXL Refiner models loaded (${action.payload.ids.length})`);
const currentModel = getState().sdxl.refinerModel;
const models = mainModelsAdapterSelectors.selectAll(action.payload);
if (models.length === 0) {
// No models loaded at all
dispatch(refinerModelChanged(null));
return;
}
const isCurrentModelAvailable = currentModel ? models.some((m) => m.key === currentModel.key) : false;
if (!isCurrentModelAvailable) {
dispatch(refinerModelChanged(null));
return;
}
},
});
startAppListening({
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// VAEs loaded, need to reset the VAE is it's no longer available
const log = logger('models');
log.info({ models: action.payload.entities }, `VAEs loaded (${action.payload.ids.length})`);
const currentVae = getState().generation.vae;
if (currentVae === null) {
// null is a valid VAE! it means "use the default with the main model"
return;
}
const isCurrentVAEAvailable = some(action.payload.entities, (m) => m?.key === currentVae?.key);
if (isCurrentVAEAvailable) {
return;
}
const firstModel = vaeModelsAdapterSelectors.selectAll(action.payload)[0];
if (!firstModel) {
// No custom VAEs loaded at all; use the default
dispatch(vaeSelected(null));
return;
}
const result = zParameterVAEModel.safeParse(firstModel);
if (!result.success) {
log.error({ error: result.error.format() }, 'Failed to parse VAE model');
return;
}
dispatch(vaeSelected(result.data));
},
});
startAppListening({
matcher: modelsApi.endpoints.getLoRAModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// LoRA models loaded - need to remove missing LoRAs from state
const log = logger('models');
log.info({ models: action.payload.entities }, `LoRAs loaded (${action.payload.ids.length})`);
const loras = getState().lora.loras;
forEach(loras, (lora, id) => {
const isLoRAAvailable = some(action.payload.entities, (m) => m?.key === lora?.model.key);
if (isLoRAAvailable) {
return;
}
dispatch(loraRemoved(id));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info({ models: action.payload.entities }, `ControlNet models loaded (${action.payload.ids.length})`);
selectAllControlNets(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getT2IAdapterModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info({ models: action.payload.entities }, `T2I Adapter models loaded (${action.payload.ids.length})`);
selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getIPAdapterModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info({ models: action.payload.entities }, `IP Adapter models loaded (${action.payload.ids.length})`);
selectAllIPAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(action.payload.entities, (m) => m?.key === ca?.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getTextualInversionModels.matchFulfilled,
effect: async (action) => {
const log = logger('models');
log.info({ models: action.payload.entities }, `Embeddings loaded (${action.payload.ids.length})`);
},
});
};
type ModelHandler = (
models: AnyModelConfig[],
state: RootState,
dispatch: AppDispatch,
log: Logger<JSONObject>
) => undefined;
const handleMainModels: ModelHandler = (models, state, dispatch, log) => {
const currentModel = state.generation.model;
const mainModels = models.filter(isNonRefinerMainModelConfig);
if (mainModels.length === 0) {
// No models loaded at all
dispatch(modelChanged(null));
return;
}
const isCurrentMainModelAvailable = currentModel ? mainModels.some((m) => m.key === currentModel.key) : false;
if (isCurrentMainModelAvailable) {
return;
}
const defaultModel = state.config.sd.defaultModel;
const defaultModelInList = defaultModel ? mainModels.find((m) => m.key === defaultModel) : false;
if (defaultModelInList) {
const result = zParameterModel.safeParse(defaultModelInList);
if (result.success) {
dispatch(modelChanged(defaultModelInList, currentModel));
const optimalDimension = getOptimalDimension(defaultModelInList);
if (getIsSizeOptimal(state.generation.width, state.generation.height, optimalDimension)) {
return;
}
const { width, height } = calculateNewSize(
state.generation.aspectRatio.value,
optimalDimension * optimalDimension
);
dispatch(widthChanged(width));
dispatch(heightChanged(height));
return;
}
}
const result = zParameterModel.safeParse(mainModels[0]);
if (!result.success) {
log.error({ error: result.error.format() }, 'Failed to parse main model');
return;
}
dispatch(modelChanged(result.data, currentModel));
};
const handleRefinerModels: ModelHandler = (models, state, dispatch, _log) => {
const currentRefinerModel = state.sdxl.refinerModel;
const refinerModels = models.filter(isRefinerMainModelModelConfig);
if (models.length === 0) {
// No models loaded at all
dispatch(refinerModelChanged(null));
return;
}
const isCurrentRefinerModelAvailable = currentRefinerModel
? refinerModels.some((m) => m.key === currentRefinerModel.key)
: false;
if (!isCurrentRefinerModelAvailable) {
dispatch(refinerModelChanged(null));
return;
}
};
const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
const currentVae = state.generation.vae;
if (currentVae === null) {
// null is a valid VAE! it means "use the default with the main model"
return;
}
const vaeModels = models.filter(isVAEModelConfig);
const isCurrentVAEAvailable = vaeModels.some((m) => m.key === currentVae.key);
if (isCurrentVAEAvailable) {
return;
}
const firstModel = vaeModels[0];
if (!firstModel) {
// No custom VAEs loaded at all; use the default
dispatch(vaeSelected(null));
return;
}
const result = zParameterVAEModel.safeParse(firstModel);
if (!result.success) {
log.error({ error: result.error.format() }, 'Failed to parse VAE model');
return;
}
dispatch(vaeSelected(result.data));
};
const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
const loras = state.lora.loras;
forEach(loras, (lora, id) => {
const isLoRAAvailable = models.some((m) => m.key === lora.model.key);
if (isLoRAAvailable) {
return;
}
dispatch(loraRemoved(id));
});
};
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
const isModelAvailable = models.some((m) => m.key === ca.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
};

View File

@ -1,30 +1,26 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { setDefaultSettings } from 'features/parameters/store/actions';
import {
heightChanged,
setCfgRescaleMultiplier,
setCfgScale,
setScheduler,
setSteps,
vaePrecisionChanged,
vaeSelected,
widthChanged,
} from 'features/parameters/store/generationSlice';
import {
isParameterCFGRescaleMultiplier,
isParameterCFGScale,
isParameterHeight,
isParameterPrecision,
isParameterScheduler,
isParameterSteps,
isParameterWidth,
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { map } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
startAppListening({
@ -38,80 +34,63 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
return;
}
const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate());
const data = await request.unwrap();
request.unsubscribe();
const models = modelConfigsAdapterSelectors.selectAll(data);
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
const modelConfig = models.find((model) => model.key === currentModel.key);
if (!modelConfig) {
if (!modelConfig || !modelConfig.default_settings) {
return;
}
if (isNonRefinerMainModelConfig(modelConfig) && modelConfig.default_settings) {
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler, width, height } =
modelConfig.default_settings;
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = modelConfig.default_settings;
if (vae) {
// we store this as "default" within default settings
// to distinguish it from no default set
if (vae === 'default') {
dispatch(vaeSelected(null));
} else {
const vaeModel = models.find((model) => model.key === vae);
const result = zParameterVAEModel.safeParse(vaeModel);
if (!result.success) {
return;
}
dispatch(vaeSelected(result.data));
if (vae) {
// we store this as "default" within default settings
// to distinguish it from no default set
if (vae === 'default') {
dispatch(vaeSelected(null));
} else {
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
const vaeArray = map(data?.entities);
const validVae = vaeArray.find((model) => model.key === vae);
const result = zParameterVAEModel.safeParse(validVae);
if (!result.success) {
return;
}
dispatch(vaeSelected(result.data));
}
if (vae_precision) {
if (isParameterPrecision(vae_precision)) {
dispatch(vaePrecisionChanged(vae_precision));
}
}
if (cfg_scale) {
if (isParameterCFGScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale));
}
}
if (cfg_rescale_multiplier) {
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
}
}
if (steps) {
if (isParameterSteps(steps)) {
dispatch(setSteps(steps));
}
}
if (scheduler) {
if (isParameterScheduler(scheduler)) {
dispatch(setScheduler(scheduler));
}
}
if (width) {
if (isParameterWidth(width)) {
dispatch(widthChanged(width));
}
}
if (height) {
if (isParameterHeight(height)) {
dispatch(heightChanged(height));
}
}
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
}
if (vae_precision) {
if (isParameterPrecision(vae_precision)) {
dispatch(vaePrecisionChanged(vae_precision));
}
}
if (cfg_scale) {
if (isParameterCFGScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale));
}
}
if (cfg_rescale_multiplier) {
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
}
}
if (steps) {
if (isParameterSteps(steps)) {
dispatch(setSteps(steps));
}
}
if (scheduler) {
if (isParameterScheduler(scheduler)) {
dispatch(setScheduler(scheduler));
}
}
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
},
});
};

View File

@ -4,7 +4,6 @@ import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { isEqual } from 'lodash-es';
import { atom } from 'nanostores';
import { api } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
import { queueApi, selectQueueStatus } from 'services/api/endpoints/queue';
import { socketConnected } from 'services/events/actions';
@ -30,11 +29,6 @@ export const addSocketConnectedEventListener = (startAppListening: AppStartListe
// Bail on the recovery logic if this is the first connection - we don't need to recover anything
if ($isFirstConnection.get()) {
// Populate the model configs on first connection. This query cache has a 24hr timeout, so we can immediately
// unsubscribe.
const request = dispatch(modelsApi.endpoints.getModelConfigs.initiate());
request.unsubscribe();
$isFirstConnection.set(false);
return;
}

View File

@ -2,7 +2,6 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { api } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
import {
socketModelInstallCancelled,
socketModelInstallCompleted,
socketModelInstallDownloading,
socketModelInstallError,
@ -64,21 +63,4 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
);
},
});
startAppListening({
actionCreator: socketModelInstallCancelled,
effect: (action, { dispatch }) => {
const { id } = action.payload.data;
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'cancelled';
}
return draft;
})
);
},
});
};

View File

@ -8,16 +8,14 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
startAppListening({
actionCreator: socketModelLoadStarted,
effect: (action) => {
const { model_config, submodel_type } = action.payload.data;
const { name, base, type } = model_config;
const { base_model, model_name, model_type, submodel } = action.payload.data;
const extras: string[] = [base, type];
if (submodel_type) {
extras.push(submodel_type);
let message = `Model load started: ${base_model}/${model_type}/${model_name}`;
if (submodel) {
message = message.concat(`/${submodel}`);
}
const message = `Model load started: ${name} (${extras.join(', ')})`;
log.debug(action.payload, message);
},
});
@ -25,16 +23,14 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
startAppListening({
actionCreator: socketModelLoadCompleted,
effect: (action) => {
const { model_config, submodel_type } = action.payload.data;
const { name, base, type } = model_config;
const { base_model, model_name, model_type, submodel } = action.payload.data;
const extras: string[] = [base, type];
if (submodel_type) {
extras.push(submodel_type);
let message = `Model load complete: ${base_model}/${model_type}/${model_name}`;
if (submodel) {
message = message.concat(`/${submodel}`);
}
const message = `Model load complete: ${name} (${extras.join(', ')})`;
log.debug(action.payload, message);
},
});

View File

@ -20,7 +20,7 @@ const sx: ChakraProps['sx'] = {
'.react-colorful__hue-pointer': colorPickerPointerStyles,
'.react-colorful__saturation-pointer': colorPickerPointerStyles,
'.react-colorful__alpha-pointer': colorPickerPointerStyles,
gap: 5,
gap: 2,
flexDir: 'column',
};
@ -39,8 +39,8 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
<Flex sx={sx}>
<RgbaColorPicker color={color} onChange={onChange} style={colorPickerStyles} {...rest} />
{withNumberInput && (
<Flex gap={5}>
<FormControl gap={0}>
<Flex>
<FormControl>
<FormLabel>{t('common.red')}</FormLabel>
<CompositeNumberInput
value={color.r}
@ -52,7 +52,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
defaultValue={90}
/>
</FormControl>
<FormControl gap={0}>
<FormControl>
<FormLabel>{t('common.green')}</FormLabel>
<CompositeNumberInput
value={color.g}
@ -64,7 +64,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
defaultValue={90}
/>
</FormControl>
<FormControl gap={0}>
<FormControl>
<FormLabel>{t('common.blue')}</FormLabel>
<CompositeNumberInput
value={color.b}
@ -76,7 +76,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
defaultValue={255}
/>
</FormControl>
<FormControl gap={0}>
<FormControl>
<FormLabel>{t('common.alpha')}</FormLabel>
<CompositeNumberInput
value={color.a}

View File

@ -1,15 +1,16 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { groupBy, reduce } from 'lodash-es';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { groupBy, map, reduce } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
modelConfigs: T[];
selectedModel?: ModelIdentifierField | null;
modelEntities: EntityState<T, string> | undefined;
selectedModel?: ModelIdentifierWithBase | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
isLoading?: boolean;
@ -28,12 +29,13 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
): UseGroupedModelComboboxReturn => {
const { t } = useTranslation();
const base_model = useAppSelector((s) => s.generation.model?.base ?? 'sdxl');
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading } = arg;
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg;
const options = useMemo<GroupBase<ComboboxOption>[]>(() => {
if (!modelConfigs) {
if (!modelEntities) {
return [];
}
const groupedModels = groupBy(modelConfigs, 'base');
const modelEntitiesArray = map(modelEntities.entities);
const groupedModels = groupBy(modelEntitiesArray, 'base');
const _options = reduce(
groupedModels,
(acc, val, label) => {
@ -51,7 +53,7 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
);
_options.sort((a) => (a.label === base_model ? -1 : 1));
return _options;
}, [getIsDisabled, modelConfigs, base_model]);
}, [getIsDisabled, modelEntities, base_model]);
const value = useMemo(
() =>
@ -65,14 +67,14 @@ export const useGroupedModelCombobox = <T extends AnyModelConfig>(
onChange(null);
return;
}
const model = modelConfigs.find((m) => m.key === v.value);
const model = modelEntities?.entities[v.value];
if (!model) {
onChange(null);
return;
}
onChange(model);
},
[modelConfigs, onChange]
[modelEntities?.entities, onChange]
);
const placeholder = useMemo(() => {

View File

@ -1,12 +1,14 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import type { EntityState } from '@reduxjs/toolkit';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
type UseModelComboboxArg<T extends AnyModelConfig> = {
modelConfigs: T[];
selectedModel?: ModelIdentifierField | null;
modelEntities: EntityState<T, string> | undefined;
selectedModel?: ModelIdentifierWithBase | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
optionsFilter?: (model: T) => boolean;
@ -23,14 +25,19 @@ type UseModelComboboxReturn = {
export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelComboboxArg<T>): UseModelComboboxReturn => {
const { t } = useTranslation();
const { modelConfigs, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading, optionsFilter = () => true } = arg;
const options = useMemo<ComboboxOption[]>(() => {
return modelConfigs.filter(optionsFilter).map((model) => ({
label: model.name,
value: model.key,
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
}));
}, [optionsFilter, getIsDisabled, modelConfigs]);
if (!modelEntities) {
return [];
}
return map(modelEntities.entities)
.filter(optionsFilter)
.map((model) => ({
label: model.name,
value: model.key,
isDisabled: getIsDisabled ? getIsDisabled(model) : false,
}));
}, [optionsFilter, getIsDisabled, modelEntities]);
const value = useMemo(
() => options.find((m) => (selectedModel ? m.value === selectedModel.key : false)),
@ -43,14 +50,14 @@ export const useModelCombobox = <T extends AnyModelConfig>(arg: UseModelCombobox
onChange(null);
return;
}
const model = modelConfigs.find((m) => m.key === v.value);
const model = modelEntities?.entities[v.value];
if (!model) {
onChange(null);
return;
}
onChange(model);
},
[modelConfigs, onChange]
[modelEntities?.entities, onChange]
);
const placeholder = useMemo(() => {

Some files were not shown because too many files have changed in this diff Show More