mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
5 Commits
remove-abi
...
separate-g
Author | SHA1 | Date | |
---|---|---|---|
cd3f5f30dc | |||
71ee28ac12 | |||
46c904d08a | |||
7d5a88b69d | |||
afa4df1991 |
29
Makefile
29
Makefile
@ -6,18 +6,16 @@ default: help
|
|||||||
help:
|
help:
|
||||||
@echo Developer commands:
|
@echo Developer commands:
|
||||||
@echo
|
@echo
|
||||||
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
|
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
|
||||||
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
|
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
|
||||||
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
|
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
|
||||||
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
|
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
|
||||||
@echo "test Run the unit tests."
|
@echo "test" Run the unit tests.
|
||||||
@echo "update-config-docstring Update the app's config docstring so mkdocs can autogenerate it correctly."
|
@echo "frontend-install" Install the pnpm modules needed for the front end
|
||||||
@echo "frontend-install Install the pnpm modules needed for the front end"
|
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
@echo "installer-zip Build the installer .zip file for the current version"
|
||||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||||
@echo "installer-zip Build the installer .zip file for the current version"
|
|
||||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
|
||||||
|
|
||||||
# Runs ruff, fixing any safely-fixable errors and formatting
|
# Runs ruff, fixing any safely-fixable errors and formatting
|
||||||
ruff:
|
ruff:
|
||||||
@ -42,10 +40,6 @@ mypy-all:
|
|||||||
test:
|
test:
|
||||||
pytest ./tests
|
pytest ./tests
|
||||||
|
|
||||||
# Update config docstring
|
|
||||||
update-config-docstring:
|
|
||||||
python scripts/update_config_docstring.py
|
|
||||||
|
|
||||||
# Install the pnpm modules needed for the front end
|
# Install the pnpm modules needed for the front end
|
||||||
frontend-install:
|
frontend-install:
|
||||||
rm -rf invokeai/frontend/web/node_modules
|
rm -rf invokeai/frontend/web/node_modules
|
||||||
@ -59,9 +53,6 @@ frontend-build:
|
|||||||
frontend-dev:
|
frontend-dev:
|
||||||
cd invokeai/frontend/web && pnpm dev
|
cd invokeai/frontend/web && pnpm dev
|
||||||
|
|
||||||
frontend-typegen:
|
|
||||||
cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen
|
|
||||||
|
|
||||||
# Installer zip file
|
# Installer zip file
|
||||||
installer-zip:
|
installer-zip:
|
||||||
cd installer && ./create_installer.sh
|
cd installer && ./create_installer.sh
|
||||||
|
@ -16,6 +16,11 @@ model. These are the:
|
|||||||
information. It is also responsible for managing the InvokeAI
|
information. It is also responsible for managing the InvokeAI
|
||||||
`models` directory and its contents.
|
`models` directory and its contents.
|
||||||
|
|
||||||
|
* _ModelMetadataStore_ and _ModelMetaDataFetch_ Backend modules that
|
||||||
|
are able to retrieve metadata from online model repositories,
|
||||||
|
transform them into Pydantic models, and cache them to the InvokeAI
|
||||||
|
SQL database.
|
||||||
|
|
||||||
* _DownloadQueueServiceBase_
|
* _DownloadQueueServiceBase_
|
||||||
A multithreaded downloader responsible
|
A multithreaded downloader responsible
|
||||||
for downloading models from a remote source to disk. The download
|
for downloading models from a remote source to disk. The download
|
||||||
@ -377,14 +382,17 @@ functionality:
|
|||||||
|
|
||||||
* Downloading a model from an arbitrary URL and installing it in
|
* Downloading a model from an arbitrary URL and installing it in
|
||||||
`models_dir`.
|
`models_dir`.
|
||||||
|
|
||||||
|
* Special handling for Civitai model URLs which allow the user to
|
||||||
|
paste in a model page's URL or download link
|
||||||
|
|
||||||
* Special handling for HuggingFace repo_ids to recursively download
|
* Special handling for HuggingFace repo_ids to recursively download
|
||||||
the contents of the repository, paying attention to alternative
|
the contents of the repository, paying attention to alternative
|
||||||
variants such as fp16.
|
variants such as fp16.
|
||||||
|
|
||||||
* Saving tags and other metadata about the model into the invokeai database
|
* Saving tags and other metadata about the model into the invokeai database
|
||||||
when fetching from a repo that provides that type of information,
|
when fetching from a repo that provides that type of information,
|
||||||
(currently only HuggingFace).
|
(currently only Civitai and HuggingFace).
|
||||||
|
|
||||||
### Initializing the installer
|
### Initializing the installer
|
||||||
|
|
||||||
@ -428,6 +436,7 @@ required parameters:
|
|||||||
| `app_config` | InvokeAIAppConfig | InvokeAI app configuration object |
|
| `app_config` | InvokeAIAppConfig | InvokeAI app configuration object |
|
||||||
| `record_store` | ModelRecordServiceBase | Config record storage database |
|
| `record_store` | ModelRecordServiceBase | Config record storage database |
|
||||||
| `download_queue` | DownloadQueueServiceBase | Download queue object |
|
| `download_queue` | DownloadQueueServiceBase | Download queue object |
|
||||||
|
| `metadata_store` | Optional[ModelMetadataStore] | Metadata storage object |
|
||||||
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
|
|`session` | Optional[requests.Session] | Swap in a different Session object (usually for debugging) |
|
||||||
|
|
||||||
Once initialized, the installer will provide the following methods:
|
Once initialized, the installer will provide the following methods:
|
||||||
@ -571,7 +580,33 @@ The `AnyHttpUrl` class can be imported from `pydantic.networks`.
|
|||||||
|
|
||||||
Ordinarily, no metadata is retrieved from these sources. However,
|
Ordinarily, no metadata is retrieved from these sources. However,
|
||||||
there is special-case code in the installer that looks for HuggingFace
|
there is special-case code in the installer that looks for HuggingFace
|
||||||
and fetches the corresponding model metadata from the corresponding repo.
|
and Civitai URLs and fetches the corresponding model metadata from
|
||||||
|
the corresponding repo.
|
||||||
|
|
||||||
|
#### CivitaiModelSource
|
||||||
|
|
||||||
|
This is used for a model that is hosted by the Civitai web site.
|
||||||
|
|
||||||
|
| **Argument** | **Type** | **Default** | **Description** |
|
||||||
|
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||||
|
| `version_id` | int | None | The ID of the particular version of the desired model. |
|
||||||
|
| `access_token` | str | None | An access token needed to gain access to a subscriber's-only model. |
|
||||||
|
|
||||||
|
Civitai has two model IDs, both of which are integers. The `model_id`
|
||||||
|
corresponds to a collection of model versions that may different in
|
||||||
|
arbitrary ways, such as derivation from different checkpoint training
|
||||||
|
steps, SFW vs NSFW generation, pruned vs non-pruned, etc. The
|
||||||
|
`version_id` points to a specific version. Please use the latter.
|
||||||
|
|
||||||
|
Some Civitai models require an access token to download. These can be
|
||||||
|
generated from the Civitai profile page of a logged-in
|
||||||
|
account. Somewhat annoyingly, if you fail to provide the access token
|
||||||
|
when downloading a model that needs it, Civitai generates a redirect
|
||||||
|
to a login page rather than a 403 Forbidden error. The installer
|
||||||
|
attempts to catch this event and issue an informative error
|
||||||
|
message. Otherwise you will get an "unrecognized model suffix" error
|
||||||
|
when the model prober tries to identify the type of the HTML login
|
||||||
|
page.
|
||||||
|
|
||||||
#### HFModelSource
|
#### HFModelSource
|
||||||
|
|
||||||
@ -1218,9 +1253,9 @@ queue and have not yet reached a terminal state.
|
|||||||
|
|
||||||
The modules found under `invokeai.backend.model_manager.metadata`
|
The modules found under `invokeai.backend.model_manager.metadata`
|
||||||
provide a straightforward API for fetching model metadatda from online
|
provide a straightforward API for fetching model metadatda from online
|
||||||
repositories. Currently only HuggingFace is supported. However, the
|
repositories. Currently two repositories are supported: HuggingFace
|
||||||
modules are easily extended for additional repos, provided that they
|
and Civitai. However, the modules are easily extended for additional
|
||||||
have defined APIs for metadata access.
|
repos, provided that they have defined APIs for metadata access.
|
||||||
|
|
||||||
Metadata comprises any descriptive information that is not essential
|
Metadata comprises any descriptive information that is not essential
|
||||||
for getting the model to run. For example "author" is metadata, while
|
for getting the model to run. For example "author" is metadata, while
|
||||||
@ -1232,16 +1267,37 @@ model's config, as defined in `invokeai.backend.model_manager.config`.
|
|||||||
```
|
```
|
||||||
from invokeai.backend.model_manager.metadata import (
|
from invokeai.backend.model_manager.metadata import (
|
||||||
AnyModelRepoMetadata,
|
AnyModelRepoMetadata,
|
||||||
|
CivitaiMetadataFetch,
|
||||||
|
CivitaiMetadata
|
||||||
|
ModelMetadataStore,
|
||||||
)
|
)
|
||||||
# to access the initialized sql database
|
# to access the initialized sql database
|
||||||
from invokeai.app.api.dependencies import ApiDependencies
|
from invokeai.app.api.dependencies import ApiDependencies
|
||||||
|
|
||||||
hf = HuggingFaceMetadataFetch()
|
civitai = CivitaiMetadataFetch()
|
||||||
|
|
||||||
# fetch the metadata
|
# fetch the metadata
|
||||||
model_metadata = hf.from_id("<repo_id>")
|
model_metadata = civitai.from_url("https://civitai.com/models/215796")
|
||||||
|
|
||||||
assert isinstance(model_metadata, HuggingFaceMetadata)
|
# get some common metadata fields
|
||||||
|
author = model_metadata.author
|
||||||
|
tags = model_metadata.tags
|
||||||
|
|
||||||
|
# get some Civitai-specific fields
|
||||||
|
assert isinstance(model_metadata, CivitaiMetadata)
|
||||||
|
|
||||||
|
trained_words = model_metadata.trained_words
|
||||||
|
base_model = model_metadata.base_model_trained_on
|
||||||
|
thumbnail = model_metadata.thumbnail_url
|
||||||
|
|
||||||
|
# cache the metadata to the database using the key corresponding to
|
||||||
|
# an existing model config record in the `model_config` table
|
||||||
|
sql_cache = ModelMetadataStore(ApiDependencies.invoker.services.db)
|
||||||
|
sql_cache.add_metadata('fb237ace520b6716adc98bcb16e8462c', model_metadata)
|
||||||
|
|
||||||
|
# now we can search the database by tag, author or model name
|
||||||
|
# matches will contain a list of model keys that match the search
|
||||||
|
matches = sql_cache.search_by_tag({"tool", "turbo"})
|
||||||
```
|
```
|
||||||
|
|
||||||
### Structure of the Metadata objects
|
### Structure of the Metadata objects
|
||||||
@ -1278,14 +1334,52 @@ This descends from `ModelMetadataBase` and adds the following fields:
|
|||||||
| `last_modified`| datetime | Date of last commit of this model to the repo |
|
| `last_modified`| datetime | Date of last commit of this model to the repo |
|
||||||
| `files` | List[Path] | List of the files in the model repo |
|
| `files` | List[Path] | List of the files in the model repo |
|
||||||
|
|
||||||
|
#### `CivitaiMetadata`
|
||||||
|
|
||||||
|
This descends from `ModelMetadataBase` and adds the following fields:
|
||||||
|
|
||||||
|
| **Field Name** | **Type** | **Description** |
|
||||||
|
|----------------|-----------------|------------------|
|
||||||
|
| `type` | Literal["civitai"] | Used for the discriminated union of metadata classes|
|
||||||
|
| `id` | int | Civitai model id |
|
||||||
|
| `version_name` | str | Name of this version of the model (distinct from model name) |
|
||||||
|
| `version_id` | int | Civitai model version id (distinct from model id) |
|
||||||
|
| `created` | datetime | Date this version of the model was created |
|
||||||
|
| `updated` | datetime | Date this version of the model was last updated |
|
||||||
|
| `published` | datetime | Date this version of the model was published to Civitai |
|
||||||
|
| `description` | str | Model description. Quite verbose and contains HTML tags |
|
||||||
|
| `version_description` | str | Model version description, usually describes changes to the model |
|
||||||
|
| `nsfw` | bool | Whether the model tends to generate NSFW content |
|
||||||
|
| `restrictions` | LicenseRestrictions | An object that describes what is and isn't allowed with this model |
|
||||||
|
| `trained_words`| Set[str] | Trigger words for this model, if any |
|
||||||
|
| `download_url` | AnyHttpUrl | URL for downloading this version of the model |
|
||||||
|
| `base_model_trained_on` | str | Name of the model that this version was trained on |
|
||||||
|
| `thumbnail_url` | AnyHttpUrl | URL to access a representative thumbnail image of the model's output |
|
||||||
|
| `weight_min` | int | For LoRA sliders, the minimum suggested weight to apply |
|
||||||
|
| `weight_max` | int | For LoRA sliders, the maximum suggested weight to apply |
|
||||||
|
|
||||||
|
Note that `weight_min` and `weight_max` are not currently populated
|
||||||
|
and take the default values of (-1.0, +2.0). The issue is that these
|
||||||
|
values aren't part of the structured data but appear in the text
|
||||||
|
description. Some regular expression or LLM coding may be able to
|
||||||
|
extract these values.
|
||||||
|
|
||||||
|
Also be aware that `base_model_trained_on` is free text and doesn't
|
||||||
|
correspond to our `ModelType` enum.
|
||||||
|
|
||||||
|
`CivitaiMetadata` also defines some convenience properties relating to
|
||||||
|
licensing restrictions: `credit_required`, `allow_commercial_use`,
|
||||||
|
`allow_derivatives` and `allow_different_license`.
|
||||||
|
|
||||||
#### `AnyModelRepoMetadata`
|
#### `AnyModelRepoMetadata`
|
||||||
|
|
||||||
This is a discriminated Union of `HuggingFaceMetadata`.
|
This is a discriminated Union of `CivitaiMetadata` and
|
||||||
|
`HuggingFaceMetadata`.
|
||||||
|
|
||||||
### Fetching Metadata from Online Repos
|
### Fetching Metadata from Online Repos
|
||||||
|
|
||||||
The `HuggingFaceMetadataFetch` class will
|
The `HuggingFaceMetadataFetch` and `CivitaiMetadataFetch` classes will
|
||||||
retrieve metadata from its corresponding repository and return
|
retrieve metadata from their corresponding repositories and return
|
||||||
`AnyModelRepoMetadata` objects. Their base class
|
`AnyModelRepoMetadata` objects. Their base class
|
||||||
`ModelMetadataFetchBase` is an abstract class that defines two
|
`ModelMetadataFetchBase` is an abstract class that defines two
|
||||||
methods: `from_url()` and `from_id()`. The former accepts the type of
|
methods: `from_url()` and `from_id()`. The former accepts the type of
|
||||||
@ -1303,17 +1397,96 @@ provide a `requests.Session` argument. This allows you to customize
|
|||||||
the low-level HTTP fetch requests and is used, for instance, in the
|
the low-level HTTP fetch requests and is used, for instance, in the
|
||||||
testing suite to avoid hitting the internet.
|
testing suite to avoid hitting the internet.
|
||||||
|
|
||||||
The HuggingFace fetcher subclass add additional repo-specific fetching methods:
|
The HuggingFace and Civitai fetcher subclasses add additional
|
||||||
|
repo-specific fetching methods:
|
||||||
|
|
||||||
#### HuggingFaceMetadataFetch
|
#### HuggingFaceMetadataFetch
|
||||||
|
|
||||||
This overrides its base class `from_json()` method to return a
|
This overrides its base class `from_json()` method to return a
|
||||||
`HuggingFaceMetadata` object directly.
|
`HuggingFaceMetadata` object directly.
|
||||||
|
|
||||||
|
#### CivitaiMetadataFetch
|
||||||
|
|
||||||
|
This adds the following methods:
|
||||||
|
|
||||||
|
`from_civitai_modelid()` This takes the ID of a model, finds the
|
||||||
|
default version of the model, and then retrieves the metadata for
|
||||||
|
that version, returning a `CivitaiMetadata` object directly.
|
||||||
|
|
||||||
|
`from_civitai_versionid()` This takes the ID of a model version and
|
||||||
|
retrieves its metadata. Functionally equivalent to `from_id()`, the
|
||||||
|
only difference is that it returna a `CivitaiMetadata` object rather
|
||||||
|
than an `AnyModelRepoMetadata`.
|
||||||
|
|
||||||
### Metadata Storage
|
### Metadata Storage
|
||||||
|
|
||||||
The `ModelConfigBase` stores this response in the `source_api_response` field
|
The `ModelMetadataStore` provides a simple facility to store model
|
||||||
as a JSON blob.
|
metadata in the `invokeai.db` database. The data is stored as a JSON
|
||||||
|
blob, with a few common fields (`name`, `author`, `tags`) broken out
|
||||||
|
to be searchable.
|
||||||
|
|
||||||
|
When a metadata object is saved to the database, it is identified
|
||||||
|
using the model key, _and this key must correspond to an existing
|
||||||
|
model key in the model_config table_. There is a foreign key integrity
|
||||||
|
constraint between the `model_config.id` field and the
|
||||||
|
`model_metadata.id` field such that if you attempt to save metadata
|
||||||
|
under an unknown key, the attempt will result in an
|
||||||
|
`UnknownModelException`. Likewise, when a model is deleted from
|
||||||
|
`model_config`, the deletion of the corresponding metadata record will
|
||||||
|
be triggered.
|
||||||
|
|
||||||
|
Tags are stored in a normalized fashion in the tables `model_tags` and
|
||||||
|
`tags`. Triggers keep the tag table in sync with the `model_metadata`
|
||||||
|
table.
|
||||||
|
|
||||||
|
To create the storage object, initialize it with the InvokeAI
|
||||||
|
`SqliteDatabase` object. This is often done this way:
|
||||||
|
|
||||||
|
```
|
||||||
|
from invokeai.app.api.dependencies import ApiDependencies
|
||||||
|
metadata_store = ModelMetadataStore(ApiDependencies.invoker.services.db)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can then access the storage with the following methods:
|
||||||
|
|
||||||
|
#### `add_metadata(key, metadata)`
|
||||||
|
|
||||||
|
Add the metadata using a previously-defined model key.
|
||||||
|
|
||||||
|
There is currently no `delete_metadata()` method. The metadata will
|
||||||
|
persist until the matching config is deleted from the `model_config`
|
||||||
|
table.
|
||||||
|
|
||||||
|
#### `get_metadata(key) -> AnyModelRepoMetadata`
|
||||||
|
|
||||||
|
Retrieve the metadata corresponding to the model key.
|
||||||
|
|
||||||
|
#### `update_metadata(key, new_metadata)`
|
||||||
|
|
||||||
|
Update an existing metadata record with new metadata.
|
||||||
|
|
||||||
|
#### `search_by_tag(tags: Set[str]) -> Set[str]`
|
||||||
|
|
||||||
|
Given a set of tags, find models that are tagged with them. If
|
||||||
|
multiple tags are provided then a matching model must be tagged with
|
||||||
|
*all* the tags in the set. This method returns a set of model keys and
|
||||||
|
is intended to be used in conjunction with the `ModelRecordService`:
|
||||||
|
|
||||||
|
```
|
||||||
|
model_config_store = ApiDependencies.invoker.services.model_records
|
||||||
|
matches = metadata_store.search_by_tag({'license:other'})
|
||||||
|
models = [model_config_store.get(x) for x in matches]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `search_by_name(name: str) -> Set[str]
|
||||||
|
|
||||||
|
Find all model metadata records that have the given name and return a
|
||||||
|
set of keys to the corresponding model config objects.
|
||||||
|
|
||||||
|
#### `search_by_author(author: str) -> Set[str]
|
||||||
|
|
||||||
|
Find all model metadata records that have the given author and return
|
||||||
|
a set of keys to the corresponding model config objects.
|
||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
|
@ -31,18 +31,18 @@ be referred to as ROOT.
|
|||||||
To find its root directory, InvokeAI uses the following recipe:
|
To find its root directory, InvokeAI uses the following recipe:
|
||||||
|
|
||||||
1. It first looks for the argument `--root <path>` on the command line
|
1. It first looks for the argument `--root <path>` on the command line
|
||||||
it was launched from, and uses the indicated path if present.
|
it was launched from, and uses the indicated path if present.
|
||||||
|
|
||||||
2. Next it looks for the environment variable INVOKEAI_ROOT, and uses
|
2. Next it looks for the environment variable INVOKEAI_ROOT, and uses
|
||||||
the directory path found there if present.
|
the directory path found there if present.
|
||||||
|
|
||||||
3. If neither of these are present, then InvokeAI looks for the
|
3. If neither of these are present, then InvokeAI looks for the
|
||||||
folder containing the `.venv` Python virtual environment directory for
|
folder containing the `.venv` Python virtual environment directory for
|
||||||
the currently active environment. This directory is checked for files
|
the currently active environment. This directory is checked for files
|
||||||
expected inside the InvokeAI root before it is used.
|
expected inside the InvokeAI root before it is used.
|
||||||
|
|
||||||
4. Finally, InvokeAI looks for a directory in the current user's home
|
4. Finally, InvokeAI looks for a directory in the current user's home
|
||||||
directory named `invokeai`.
|
directory named `invokeai`.
|
||||||
|
|
||||||
#### Reading the InvokeAI Configuration File
|
#### Reading the InvokeAI Configuration File
|
||||||
|
|
||||||
@ -149,65 +149,104 @@ usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS
|
|||||||
|
|
||||||
## The Configuration Settings
|
## The Configuration Settings
|
||||||
|
|
||||||
The config is managed by the `InvokeAIAppConfig` class, which is a pydantic model. The below docs are autogenerated from the class.
|
The configuration settings are divided into several distinct
|
||||||
|
groups in `invokeia.yaml`:
|
||||||
|
|
||||||
When editing your `invokeai.yaml` file, you'll need to put settings under their appropriate group. The group for each setting is denoted in the table below.
|
### Web Server
|
||||||
|
|
||||||
Following the table are additional explanations for certain settings.
|
| Setting | Default Value | Description |
|
||||||
|
|---------------------|---------------|----------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
|
||||||
|
| `port` | `9090` | Network port number that the web server will listen on |
|
||||||
|
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
|
||||||
|
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
|
||||||
|
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
|
||||||
|
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
|
||||||
|
| `ssl_certfile` | null | Path to an SSL certificate file, used to enable HTTPS. |
|
||||||
|
| `ssl_keyfile` | null | Path to an SSL keyfile, if the key is not included in the certificate file. |
|
||||||
|
|
||||||
<!-- prettier-ignore-start -->
|
The documentation for InvokeAI's API can be accessed by browsing to the following URL: [http://localhost:9090/docs].
|
||||||
::: invokeai.app.services.config.config_default.InvokeAIAppConfig
|
|
||||||
options:
|
|
||||||
heading_level: 3
|
|
||||||
members: false
|
|
||||||
<!-- prettier-ignore-end -->
|
|
||||||
|
|
||||||
### Model Marketplace API Keys
|
### Features
|
||||||
|
|
||||||
Some model marketplaces require an API key to download models. You can provide a URL pattern and appropriate token in your `invokeai.yaml` file to provide that API key.
|
These configuration settings allow you to enable and disable various InvokeAI features:
|
||||||
|
|
||||||
The pattern can be any valid regex (you may need to surround the pattern with quotes):
|
| Setting | Default Value | Description |
|
||||||
|
|----------|----------------|--------------|
|
||||||
|
| `esrgan` | `true` | Activate the ESRGAN upscaling options|
|
||||||
|
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
|
||||||
|
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
|
||||||
|
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
|
||||||
|
|
||||||
```yaml
|
### Generation
|
||||||
InvokeAI:
|
|
||||||
Model Install:
|
|
||||||
remote_api_tokens:
|
|
||||||
# Any URL containing `models.com` will automatically use `your_models_com_token`
|
|
||||||
- url_regex: models.com
|
|
||||||
token: your_models_com_token
|
|
||||||
# Any URL matching this contrived regex will use `some_other_token`
|
|
||||||
- url_regex: '^[a-z]{3}whatever.*\.com$'
|
|
||||||
token: some_other_token
|
|
||||||
```
|
|
||||||
|
|
||||||
The provided token will be added as a `Bearer` token to the network requests to download the model files. As far as we know, this works for all model marketplaces that require authorization.
|
These options tune InvokeAI's memory and performance characteristics.
|
||||||
|
|
||||||
### Model Hashing
|
| Setting | Default Value | Description |
|
||||||
|
|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss |
|
||||||
|
| `attention_type` | `auto` | Select the type of attention to use. One of `auto`,`normal`,`xformers`,`sliced`, or `torch-sdp` |
|
||||||
|
| `attention_slice_size` | `auto` | When "sliced" attention is selected, set the slice size. One of `auto`, `balanced`, `max` or the integers 1-8|
|
||||||
|
| `force_tiled_decode` | `false` | Force the VAE step to decode in tiles, reducing memory consumption at the cost of performance |
|
||||||
|
|
||||||
Models are hashed during installation with the `BLAKE3` algorithm, providing a stable identifier for models across all platforms.
|
### Device
|
||||||
|
|
||||||
Model hashing is a one-time operation, but it may take a couple minutes to hash a large model collection. You may opt out of model hashing and instead have a random UUID assigned instead:
|
These options configure the generation execution device.
|
||||||
|
|
||||||
|
| Setting | Default Value | Description |
|
||||||
|
|-----------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| `device` | `auto` | Preferred execution device. One of `auto`, `cpu`, `cuda`, `cuda:1`, `mps`. `auto` will choose the device depending on the hardware platform and the installed torch capabilities. |
|
||||||
|
| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system |
|
||||||
|
|
||||||
```yaml
|
|
||||||
InvokeAI:
|
|
||||||
Model Install:
|
|
||||||
skip_model_hash: true
|
|
||||||
```
|
|
||||||
|
|
||||||
### Paths
|
### Paths
|
||||||
|
|
||||||
These options set the paths of various directories and files used by
|
These options set the paths of various directories and files used by
|
||||||
InvokeAI. Relative paths are interpreted relative to the root directory, so
|
InvokeAI. Relative paths are interpreted relative to INVOKEAI_ROOT, so
|
||||||
if root is `/home/fred/invokeai` and the path is
|
if INVOKEAI_ROOT is `/home/fred/invokeai` and the path is
|
||||||
`autoimport/main`, then the corresponding directory will be located at
|
`autoimport/main`, then the corresponding directory will be located at
|
||||||
`/home/fred/invokeai/autoimport/main`.
|
`/home/fred/invokeai/autoimport/main`.
|
||||||
|
|
||||||
Note that the autoimport directory will be searched recursively,
|
| Setting | Default Value | Description |
|
||||||
|
|----------|----------------|--------------|
|
||||||
|
| `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory |
|
||||||
|
| `lora_dir` | `autoimport/lora` | At startup time, read and import any LoRA/LyCORIS models found in this directory |
|
||||||
|
| `embedding_dir` | `autoimport/embedding` | At startup time, read and import any textual inversion (embedding) models found in this directory |
|
||||||
|
| `controlnet_dir` | `autoimport/controlnet` | At startup time, read and import any ControlNet models found in this directory |
|
||||||
|
| `conf_path` | `configs/models.yaml` | Location of the `models.yaml` model configuration file |
|
||||||
|
| `models_dir` | `models` | Location of the directory containing models installed by InvokeAI's model manager |
|
||||||
|
| `legacy_conf_dir` | `configs/stable-diffusion` | Location of the directory containing the .yaml configuration files for legacy checkpoint models |
|
||||||
|
| `db_dir` | `databases` | Location of the directory containing InvokeAI's image, schema and session database |
|
||||||
|
| `outdir` | `outputs` | Location of the directory in which the gallery of generated and uploaded images will be stored |
|
||||||
|
| `use_memory_db` | `false` | Keep database information in memory rather than on disk; this will not preserve image gallery information across restarts |
|
||||||
|
|
||||||
|
Note that the autoimport directories will be searched recursively,
|
||||||
allowing you to organize the models into folders and subfolders in any
|
allowing you to organize the models into folders and subfolders in any
|
||||||
way you wish.
|
way you wish. In addition, while we have split up autoimport
|
||||||
|
directories by the type of model they contain, this isn't
|
||||||
|
necessary. You can combine different model types in the same folder
|
||||||
|
and InvokeAI will figure out what they are. So you can easily use just
|
||||||
|
one autoimport directory by commenting out the unneeded paths:
|
||||||
|
|
||||||
|
```
|
||||||
|
Paths:
|
||||||
|
autoimport_dir: autoimport
|
||||||
|
# lora_dir: null
|
||||||
|
# embedding_dir: null
|
||||||
|
# controlnet_dir: null
|
||||||
|
```
|
||||||
|
|
||||||
### Logging
|
### Logging
|
||||||
|
|
||||||
|
These settings control the information, warning, and debugging
|
||||||
|
messages printed to the console log while InvokeAI is running:
|
||||||
|
|
||||||
|
| Setting | Default Value | Description |
|
||||||
|
|----------|----------------|--------------|
|
||||||
|
| `log_handlers` | `console` | This controls where log messages are sent, and can be a list of one or more destinations. Values include `console`, `file`, `syslog` and `http`. These are described in more detail below |
|
||||||
|
| `log_format` | `color` | This controls the formatting of the log messages. Values are `plain`, `color`, `legacy` and `syslog` |
|
||||||
|
| `log_level` | `debug` | This filters messages according to the level of severity and can be one of `debug`, `info`, `warning`, `error` and `critical`. For example, setting to `warning` will display all messages at the warning level or higher, but won't display "debug" or "info" messages |
|
||||||
|
|
||||||
Several different log handler destinations are available, and multiple destinations are supported by providing a list:
|
Several different log handler destinations are available, and multiple destinations are supported by providing a list:
|
||||||
|
|
||||||
```
|
```
|
||||||
@ -217,9 +256,9 @@ Several different log handler destinations are available, and multiple destinati
|
|||||||
- file=/var/log/invokeai.log
|
- file=/var/log/invokeai.log
|
||||||
```
|
```
|
||||||
|
|
||||||
- `console` is the default. It prints log messages to the command-line window from which InvokeAI was launched.
|
* `console` is the default. It prints log messages to the command-line window from which InvokeAI was launched.
|
||||||
|
|
||||||
- `syslog` is only available on Linux and Macintosh systems. It uses
|
* `syslog` is only available on Linux and Macintosh systems. It uses
|
||||||
the operating system's "syslog" facility to write log file entries
|
the operating system's "syslog" facility to write log file entries
|
||||||
locally or to a remote logging machine. `syslog` offers a variety
|
locally or to a remote logging machine. `syslog` offers a variety
|
||||||
of configuration options:
|
of configuration options:
|
||||||
@ -232,7 +271,7 @@ Several different log handler destinations are available, and multiple destinati
|
|||||||
- Log to LAN-connected server "fredserver" using the facility LOG_USER and datagram packets.
|
- Log to LAN-connected server "fredserver" using the facility LOG_USER and datagram packets.
|
||||||
```
|
```
|
||||||
|
|
||||||
- `http` can be used to log to a remote web server. The server must be
|
* `http` can be used to log to a remote web server. The server must be
|
||||||
properly configured to receive and act on log messages. The option
|
properly configured to receive and act on log messages. The option
|
||||||
accepts the URL to the web server, and a `method` argument
|
accepts the URL to the web server, and a `method` argument
|
||||||
indicating whether the message should be submitted using the GET or
|
indicating whether the message should be submitted using the GET or
|
||||||
@ -244,7 +283,7 @@ Several different log handler destinations are available, and multiple destinati
|
|||||||
|
|
||||||
The `log_format` option provides several alternative formats:
|
The `log_format` option provides several alternative formats:
|
||||||
|
|
||||||
- `color` - default format providing time, date and a message, using text colors to distinguish different log severities
|
* `color` - default format providing time, date and a message, using text colors to distinguish different log severities
|
||||||
- `plain` - same as above, but monochrome text only
|
* `plain` - same as above, but monochrome text only
|
||||||
- `syslog` - the log level and error message only, allowing the syslog system to attach the time and date
|
* `syslog` - the log level and error message only, allowing the syslog system to attach the time and date
|
||||||
- `legacy` - a format similar to the one used by the legacy 2.3 InvokeAI releases.
|
* `legacy` - a format similar to the one used by the legacy 2.3 InvokeAI releases.
|
||||||
|
@ -1,35 +0,0 @@
|
|||||||
---
|
|
||||||
title: Database
|
|
||||||
---
|
|
||||||
|
|
||||||
# Invoke's SQLite Database
|
|
||||||
|
|
||||||
Invoke uses a SQLite database to store image, workflow, model, and execution data.
|
|
||||||
|
|
||||||
We take great care to ensure your data is safe, by utilizing transactions and a database migration system.
|
|
||||||
|
|
||||||
Even so, when testing an prerelease version of the app, we strongly suggest either backing up your database or using an in-memory database. This ensures any prelease hiccups or databases schema changes will not cause problems for your data.
|
|
||||||
|
|
||||||
## Database Backup
|
|
||||||
|
|
||||||
Backing up your database is very simple. Invoke's data is stored in an `$INVOKEAI_ROOT` directory - where your `invoke.sh`/`invoke.bat` and `invokeai.yaml` files live.
|
|
||||||
|
|
||||||
To back up your database, copy the `invokeai.db` file from `$INVOKEAI_ROOT/databases/invokeai.db` to somewhere safe.
|
|
||||||
|
|
||||||
If anything comes up during prelease testing, you can simply copy your backup back into `$INVOKEAI_ROOT/databases/`.
|
|
||||||
|
|
||||||
## In-Memory Database
|
|
||||||
|
|
||||||
SQLite can run on an in-memory database. Your existing database is untouched when this mode is enabled, but your existing data won't be accessible.
|
|
||||||
|
|
||||||
This is very useful for testing, as there is no chance of a database change modifying your "physical" database.
|
|
||||||
|
|
||||||
To run Invoke with a memory database, edit your `invokeai.yaml` file, and add `use_memory_db: true` to the `Paths:` stanza:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
InvokeAI:
|
|
||||||
Development:
|
|
||||||
use_memory_db: true
|
|
||||||
```
|
|
||||||
|
|
||||||
Delete this line (or set it to `false`) to use your main database.
|
|
@ -25,7 +25,6 @@ from ..services.invocation_cache.invocation_cache_memory import MemoryInvocation
|
|||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
from ..services.model_images.model_images_default import ModelImageFileStorageDisk
|
|
||||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||||
from ..services.model_records import ModelRecordServiceSQL
|
from ..services.model_records import ModelRecordServiceSQL
|
||||||
from ..services.names.names_default import SimpleNameService
|
from ..services.names.names_default import SimpleNameService
|
||||||
@ -72,8 +71,6 @@ class ApiDependencies:
|
|||||||
|
|
||||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
|
|
||||||
model_images_folder = config.models_path
|
|
||||||
|
|
||||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||||
|
|
||||||
configuration = config
|
configuration = config
|
||||||
@ -95,7 +92,6 @@ class ApiDependencies:
|
|||||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||||
)
|
)
|
||||||
download_queue_service = DownloadQueueService(event_bus=events)
|
download_queue_service = DownloadQueueService(event_bus=events)
|
||||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
|
||||||
model_manager = ModelManagerService.build_model_manager(
|
model_manager = ModelManagerService.build_model_manager(
|
||||||
app_config=configuration,
|
app_config=configuration,
|
||||||
model_record_service=ModelRecordServiceSQL(db=db),
|
model_record_service=ModelRecordServiceSQL(db=db),
|
||||||
@ -122,7 +118,6 @@ class ApiDependencies:
|
|||||||
images=images,
|
images=images,
|
||||||
invocation_cache=invocation_cache,
|
invocation_cache=invocation_cache,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
model_images=model_images_service,
|
|
||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
download_queue=download_queue_service,
|
download_queue=download_queue_service,
|
||||||
names=names,
|
names=names,
|
||||||
|
@ -1,16 +1,12 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein
|
# Copyright (c) 2023 Lincoln D. Stein
|
||||||
"""FastAPI route for model configuration records."""
|
"""FastAPI route for model configuration records."""
|
||||||
|
|
||||||
import io
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import shutil
|
import shutil
|
||||||
import traceback
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response, UploadFile
|
from fastapi import Body, Path, Query, Response
|
||||||
from fastapi.responses import FileResponse
|
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from PIL import Image
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
@ -35,9 +31,6 @@ from ..dependencies import ApiDependencies
|
|||||||
|
|
||||||
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
|
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
|
||||||
|
|
||||||
# images are immutable; set a high max-age
|
|
||||||
IMAGE_MAX_AGE = 31536000
|
|
||||||
|
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
"""Return list of configs."""
|
"""Return list of configs."""
|
||||||
@ -112,9 +105,6 @@ async def list_model_records(
|
|||||||
found_models.extend(
|
found_models.extend(
|
||||||
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||||
)
|
)
|
||||||
for model in found_models:
|
|
||||||
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key)
|
|
||||||
model.cover_image = cover_image
|
|
||||||
return ModelsList(models=found_models)
|
return ModelsList(models=found_models)
|
||||||
|
|
||||||
|
|
||||||
@ -158,8 +148,6 @@ async def get_model_record(
|
|||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||||
try:
|
try:
|
||||||
config: AnyModelConfig = record_store.get_model(key)
|
config: AnyModelConfig = record_store.get_model(key)
|
||||||
cover_image = ApiDependencies.invoker.services.model_images.get_url(key)
|
|
||||||
config.cover_image = cover_image
|
|
||||||
return config
|
return config
|
||||||
except UnknownModelException as e:
|
except UnknownModelException as e:
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
@ -278,75 +266,6 @@ async def update_model_record(
|
|||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.get(
|
|
||||||
"/i/{key}/image",
|
|
||||||
operation_id="get_model_image",
|
|
||||||
responses={
|
|
||||||
200: {
|
|
||||||
"description": "The model image was fetched successfully",
|
|
||||||
},
|
|
||||||
400: {"description": "Bad request"},
|
|
||||||
404: {"description": "The model image could not be found"},
|
|
||||||
},
|
|
||||||
status_code=200,
|
|
||||||
)
|
|
||||||
async def get_model_image(
|
|
||||||
key: str = Path(description="The name of model image file to get"),
|
|
||||||
) -> FileResponse:
|
|
||||||
"""Gets an image file that previews the model"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
path = ApiDependencies.invoker.services.model_images.get_path(key)
|
|
||||||
|
|
||||||
response = FileResponse(
|
|
||||||
path,
|
|
||||||
media_type="image/png",
|
|
||||||
filename=key + ".png",
|
|
||||||
content_disposition_type="inline",
|
|
||||||
)
|
|
||||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
|
||||||
return response
|
|
||||||
except Exception:
|
|
||||||
raise HTTPException(status_code=404)
|
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.patch(
|
|
||||||
"/i/{key}/image",
|
|
||||||
operation_id="update_model_image",
|
|
||||||
responses={
|
|
||||||
200: {
|
|
||||||
"description": "The model image was updated successfully",
|
|
||||||
},
|
|
||||||
400: {"description": "Bad request"},
|
|
||||||
},
|
|
||||||
status_code=200,
|
|
||||||
)
|
|
||||||
async def update_model_image(
|
|
||||||
key: Annotated[str, Path(description="Unique key of model")],
|
|
||||||
image: UploadFile,
|
|
||||||
) -> None:
|
|
||||||
if not image.content_type or not image.content_type.startswith("image"):
|
|
||||||
raise HTTPException(status_code=415, detail="Not an image")
|
|
||||||
|
|
||||||
contents = await image.read()
|
|
||||||
try:
|
|
||||||
pil_image = Image.open(io.BytesIO(contents))
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
|
||||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
|
||||||
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
|
||||||
model_images = ApiDependencies.invoker.services.model_images
|
|
||||||
try:
|
|
||||||
model_images.save(pil_image, key)
|
|
||||||
logger.info(f"Updated image for model: {key}")
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.delete(
|
@model_manager_router.delete(
|
||||||
"/i/{key}",
|
"/i/{key}",
|
||||||
operation_id="delete_model",
|
operation_id="delete_model",
|
||||||
@ -377,29 +296,6 @@ async def delete_model(
|
|||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.delete(
|
|
||||||
"/i/{key}/image",
|
|
||||||
operation_id="delete_model_image",
|
|
||||||
responses={
|
|
||||||
204: {"description": "Model image deleted successfully"},
|
|
||||||
404: {"description": "Model image not found"},
|
|
||||||
},
|
|
||||||
status_code=204,
|
|
||||||
)
|
|
||||||
async def delete_model_image(
|
|
||||||
key: str = Path(description="Unique key of model image to remove from model_images directory."),
|
|
||||||
) -> None:
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
|
||||||
model_images = ApiDependencies.invoker.services.model_images
|
|
||||||
try:
|
|
||||||
model_images.delete(key)
|
|
||||||
logger.info(f"Deleted model image: {key}")
|
|
||||||
return
|
|
||||||
except UnknownModelException as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
# @model_manager_router.post(
|
# @model_manager_router.post(
|
||||||
# "/i/",
|
# "/i/",
|
||||||
# operation_id="add_model_record",
|
# operation_id="add_model_record",
|
||||||
@ -643,7 +539,7 @@ async def convert_model(
|
|||||||
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||||
|
|
||||||
# loading the model will convert it into a cached diffusers file
|
# loading the model will convert it into a cached diffusers file
|
||||||
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
|
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
|
||||||
|
|
||||||
# Get the path of the converted model from the loader
|
# Get the path of the converted model from the loader
|
||||||
cache_path = loader.convert_cache.cache_path(key)
|
cache_path = loader.convert_cache.cache_path(key)
|
||||||
|
@ -2,11 +2,12 @@
|
|||||||
# which are imported/used before parse_args() is called will get the default config values instead of the
|
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||||
# values from the command line or config file.
|
# values from the command line or config file.
|
||||||
import sys
|
import sys
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
|
||||||
from invokeai.version.invokeai_version import __version__
|
from invokeai.version.invokeai_version import __version__
|
||||||
|
|
||||||
|
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||||
from .services.config import InvokeAIAppConfig
|
from .services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
@ -19,7 +20,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
import asyncio
|
import asyncio
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import socket
|
import socket
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -40,7 +40,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
import invokeai.frontend.web as web_dir
|
import invokeai.frontend.web as web_dir
|
||||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
|
||||||
|
|
||||||
from ..backend.util.logging import InvokeAILogger
|
from ..backend.util.logging import InvokeAILogger
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
@ -60,7 +59,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
UIConfigBase,
|
UIConfigBase,
|
||||||
)
|
)
|
||||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
|
||||||
|
|
||||||
if is_mps_available():
|
if is_mps_available():
|
||||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||||
@ -158,19 +156,17 @@ def custom_openapi() -> dict[str, Any]:
|
|||||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||||
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
|
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
|
||||||
|
|
||||||
# Some models don't end up in the schemas as standalone definitions
|
# Add Node Editor UI helper schemas
|
||||||
additional_schemas = models_json_schema(
|
ui_config_schemas = models_json_schema(
|
||||||
[
|
[
|
||||||
(UIConfigBase, "serialization"),
|
(UIConfigBase, "serialization"),
|
||||||
(InputFieldJSONSchemaExtra, "serialization"),
|
(InputFieldJSONSchemaExtra, "serialization"),
|
||||||
(OutputFieldJSONSchemaExtra, "serialization"),
|
(OutputFieldJSONSchemaExtra, "serialization"),
|
||||||
(ModelIdentifierField, "serialization"),
|
|
||||||
(ProgressImage, "serialization"),
|
|
||||||
],
|
],
|
||||||
ref_template="#/components/schemas/{model}",
|
ref_template="#/components/schemas/{model}",
|
||||||
)
|
)
|
||||||
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
|
for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items():
|
||||||
openapi_schema["components"]["schemas"][schema_key] = schema_json
|
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
|
||||||
|
|
||||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||||
for invoker in all_invocations:
|
for invoker in all_invocations:
|
||||||
|
@ -20,7 +20,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
from invokeai.backend.util.devices import torch_dtype
|
from invokeai.backend.util.devices import torch_dtype
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
from .model import CLIPField
|
from .model import ClipField
|
||||||
|
|
||||||
# unconditioned: Optional[torch.Tensor]
|
# unconditioned: Optional[torch.Tensor]
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
description=FieldDescriptions.compel_prompt,
|
description=FieldDescriptions.compel_prompt,
|
||||||
ui_component=UIComponent.Textarea,
|
ui_component=UIComponent.Textarea,
|
||||||
)
|
)
|
||||||
clip: CLIPField = InputField(
|
clip: ClipField = InputField(
|
||||||
title="CLIP",
|
title="CLIP",
|
||||||
description=FieldDescriptions.clip,
|
description=FieldDescriptions.clip,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
@ -54,16 +54,16 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
tokenizer_info = context.models.load(self.clip.tokenizer)
|
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||||
tokenizer_model = tokenizer_info.model
|
tokenizer_model = tokenizer_info.model
|
||||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||||
text_encoder_info = context.models.load(self.clip.text_encoder)
|
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||||
text_encoder_model = text_encoder_info.model
|
text_encoder_model = text_encoder_info.model
|
||||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in self.clip.loras:
|
for lora in self.clip.loras:
|
||||||
lora_info = context.models.load(lora.lora)
|
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||||
yield (lora_info.model, lora.weight)
|
yield (lora_info.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
@ -127,16 +127,16 @@ class SDXLPromptInvocationBase:
|
|||||||
def run_clip_compel(
|
def run_clip_compel(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
clip_field: CLIPField,
|
clip_field: ClipField,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
get_pooled: bool,
|
get_pooled: bool,
|
||||||
lora_prefix: str,
|
lora_prefix: str,
|
||||||
zero_on_empty: bool,
|
zero_on_empty: bool,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
||||||
tokenizer_model = tokenizer_info.model
|
tokenizer_model = tokenizer_info.model
|
||||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||||
text_encoder_info = context.models.load(clip_field.text_encoder)
|
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||||
text_encoder_model = text_encoder_info.model
|
text_encoder_model = text_encoder_info.model
|
||||||
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
|
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||||
|
|
||||||
@ -163,7 +163,7 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
lora_info = context.models.load(lora.lora)
|
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||||
lora_model = lora_info.model
|
lora_model = lora_info.model
|
||||||
assert isinstance(lora_model, LoRAModelRaw)
|
assert isinstance(lora_model, LoRAModelRaw)
|
||||||
yield (lora_model, lora.weight)
|
yield (lora_model, lora.weight)
|
||||||
@ -253,8 +253,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
crop_left: int = InputField(default=0, description="")
|
crop_left: int = InputField(default=0, description="")
|
||||||
target_width: int = InputField(default=1024, description="")
|
target_width: int = InputField(default=1024, description="")
|
||||||
target_height: int = InputField(default=1024, description="")
|
target_height: int = InputField(default=1024, description="")
|
||||||
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||||
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
@ -340,7 +340,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
crop_top: int = InputField(default=0, description="")
|
crop_top: int = InputField(default=0, description="")
|
||||||
crop_left: int = InputField(default=0, description="")
|
crop_left: int = InputField(default=0, description="")
|
||||||
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
|
aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic)
|
||||||
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
@ -370,10 +370,10 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
|
|
||||||
|
|
||||||
@invocation_output("clip_skip_output")
|
@invocation_output("clip_skip_output")
|
||||||
class CLIPSkipInvocationOutput(BaseInvocationOutput):
|
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||||
"""CLIP skip node output"""
|
"""Clip skip node output"""
|
||||||
|
|
||||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -383,15 +383,15 @@ class CLIPSkipInvocationOutput(BaseInvocationOutput):
|
|||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class CLIPSkipInvocation(BaseInvocation):
|
class ClipSkipInvocation(BaseInvocation):
|
||||||
"""Skip layers in clip text_encoder model."""
|
"""Skip layers in clip text_encoder model."""
|
||||||
|
|
||||||
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||||
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
|
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> CLIPSkipInvocationOutput:
|
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||||
self.clip.skipped_layers += self.skipped_layers
|
self.clip.skipped_layers += self.skipped_layers
|
||||||
return CLIPSkipInvocationOutput(
|
return ClipSkipInvocationOutput(
|
||||||
clip=self.clip,
|
clip=self.clip,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -31,11 +31,9 @@ from invokeai.app.invocations.fields import (
|
|||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
OutputField,
|
OutputField,
|
||||||
UIType,
|
|
||||||
WithBoard,
|
WithBoard,
|
||||||
WithMetadata,
|
WithMetadata,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
@ -53,9 +51,15 @@ CONTROLNET_RESIZE_VALUES = Literal[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetModelField(BaseModel):
|
||||||
|
"""ControlNet model field"""
|
||||||
|
|
||||||
|
key: str = Field(description="Model config record key for the ControlNet model")
|
||||||
|
|
||||||
|
|
||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
image: ImageField = Field(description="The control image")
|
image: ImageField = Field(description="The control image")
|
||||||
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
|
||||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||||
begin_step_percent: float = Field(
|
begin_step_percent: float = Field(
|
||||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||||
@ -91,9 +95,7 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
"""Collects ControlNet info to pass to other nodes"""
|
"""Collects ControlNet info to pass to other nodes"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The control image")
|
image: ImageField = InputField(description="The control image")
|
||||||
control_model: ModelIdentifierField = InputField(
|
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||||
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel
|
|
||||||
)
|
|
||||||
control_weight: Union[float, List[float]] = InputField(
|
control_weight: Union[float, List[float]] = InputField(
|
||||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||||
)
|
)
|
||||||
|
@ -39,15 +39,13 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# region Model Field Types
|
# region Model Field Types
|
||||||
MainModel = "MainModelField"
|
|
||||||
SDXLMainModel = "SDXLMainModelField"
|
SDXLMainModel = "SDXLMainModelField"
|
||||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||||
ONNXModel = "ONNXModelField"
|
ONNXModel = "ONNXModelField"
|
||||||
VAEModel = "VAEModelField"
|
VaeModel = "VAEModelField"
|
||||||
LoRAModel = "LoRAModelField"
|
LoRAModel = "LoRAModelField"
|
||||||
ControlNetModel = "ControlNetModelField"
|
ControlNetModel = "ControlNetModelField"
|
||||||
IPAdapterModel = "IPAdapterModelField"
|
IPAdapterModel = "IPAdapterModelField"
|
||||||
T2IAdapterModel = "T2IAdapterModelField"
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Misc Field Types
|
# region Misc Field Types
|
||||||
@ -88,6 +86,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
|||||||
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
|
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
|
||||||
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
|
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
|
||||||
StringPolymorphic = "DEPRECATED_StringPolymorphic"
|
StringPolymorphic = "DEPRECATED_StringPolymorphic"
|
||||||
|
MainModel = "DEPRECATED_MainModel"
|
||||||
UNet = "DEPRECATED_UNet"
|
UNet = "DEPRECATED_UNet"
|
||||||
Vae = "DEPRECATED_Vae"
|
Vae = "DEPRECATED_Vae"
|
||||||
CLIP = "DEPRECATED_CLIP"
|
CLIP = "DEPRECATED_CLIP"
|
||||||
@ -229,7 +228,7 @@ class ConditioningField(BaseModel):
|
|||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
class MetadataField(RootModel[dict[str, Any]]):
|
class MetadataField(RootModel):
|
||||||
"""
|
"""
|
||||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||||
Metadata is stored without a strict schema.
|
Metadata is stored without a strict schema.
|
||||||
|
@ -10,18 +10,26 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig, ModelType
|
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||||
|
|
||||||
|
|
||||||
|
# LS: Consider moving these two classes into model.py
|
||||||
|
class IPAdapterModelField(BaseModel):
|
||||||
|
key: str = Field(description="Key to the IP-Adapter model")
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPVisionModelField(BaseModel):
|
||||||
|
key: str = Field(description="Key to the CLIP Vision image encoder model")
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterField(BaseModel):
|
class IPAdapterField(BaseModel):
|
||||||
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
|
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
|
||||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
|
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
|
||||||
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
|
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
|
||||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||||
begin_step_percent: float = Field(
|
begin_step_percent: float = Field(
|
||||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||||
@ -54,12 +62,8 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
|
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
|
||||||
ip_adapter_model: ModelIdentifierField = InputField(
|
ip_adapter_model: IPAdapterModelField = InputField(
|
||||||
description="The IP-Adapter model.",
|
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
|
||||||
title="IP-Adapter Model",
|
|
||||||
input=Input.Direct,
|
|
||||||
ui_order=-1,
|
|
||||||
ui_type=UIType.IPAdapterModel,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
weight: Union[float, List[float]] = InputField(
|
weight: Union[float, List[float]] = InputField(
|
||||||
@ -86,18 +90,18 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||||
assert isinstance(ip_adapter_info, IPAdapterConfig)
|
|
||||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||||
image_encoder_models = context.models.search_by_attrs(
|
image_encoder_models = context.models.search_by_attrs(
|
||||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||||
)
|
)
|
||||||
assert len(image_encoder_models) == 1
|
assert len(image_encoder_models) == 1
|
||||||
|
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
|
||||||
return IPAdapterOutput(
|
return IPAdapterOutput(
|
||||||
ip_adapter=IPAdapterField(
|
ip_adapter=IPAdapterField(
|
||||||
image=self.image,
|
image=self.image,
|
||||||
ip_adapter_model=self.ip_adapter_model,
|
ip_adapter_model=self.ip_adapter_model,
|
||||||
image_encoder_model=ModelIdentifierField.from_config(image_encoder_models[0]),
|
image_encoder_model=image_encoder_model,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
begin_step_percent=self.begin_step_percent,
|
begin_step_percent=self.begin_step_percent,
|
||||||
end_step_percent=self.end_step_percent,
|
end_step_percent=self.end_step_percent,
|
||||||
|
@ -26,7 +26,6 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
|
|||||||
from PIL import Image, ImageFilter
|
from PIL import Image, ImageFilter
|
||||||
from pydantic import field_validator
|
from pydantic import field_validator
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
from transformers import CLIPVisionModelWithProjection
|
|
||||||
|
|
||||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
@ -66,6 +65,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
T2IAdapterData,
|
T2IAdapterData,
|
||||||
image_resized_to_grid_as_tensor,
|
image_resized_to_grid_as_tensor,
|
||||||
)
|
)
|
||||||
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
@ -75,7 +75,7 @@ from .baseinvocation import (
|
|||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .model import ModelIdentifierField, UNetField, VAEField
|
from .model import ModelInfo, UNetField, VaeField
|
||||||
|
|
||||||
if choose_torch_device() == torch.device("mps"):
|
if choose_torch_device() == torch.device("mps"):
|
||||||
from torch import mps
|
from torch import mps
|
||||||
@ -118,7 +118,7 @@ class SchedulerInvocation(BaseInvocation):
|
|||||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||||
"""Creates mask for denoising model run."""
|
"""Creates mask for denoising model run."""
|
||||||
|
|
||||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
||||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||||
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
||||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
||||||
@ -153,7 +153,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if image_tensor is not None:
|
if image_tensor is not None:
|
||||||
vae_info = context.models.load(self.vae.vae)
|
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||||
|
|
||||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||||
@ -244,12 +244,12 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def get_scheduler(
|
def get_scheduler(
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
scheduler_info: ModelIdentifierField,
|
scheduler_info: ModelInfo,
|
||||||
scheduler_name: str,
|
scheduler_name: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||||
orig_scheduler_info = context.models.load(scheduler_info)
|
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
|
||||||
with orig_scheduler_info as orig_scheduler:
|
with orig_scheduler_info as orig_scheduler:
|
||||||
scheduler_config = orig_scheduler.config
|
scheduler_config = orig_scheduler.config
|
||||||
|
|
||||||
@ -383,6 +383,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
text_embeddings=c,
|
text_embeddings=c,
|
||||||
guidance_scale=self.cfg_scale,
|
guidance_scale=self.cfg_scale,
|
||||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
|
postprocessing_settings=PostprocessingSettings(
|
||||||
|
threshold=0.0, # threshold,
|
||||||
|
warmup=0.2, # warmup,
|
||||||
|
h_symmetry_time_pct=None, # h_symmetry_time_pct,
|
||||||
|
v_symmetry_time_pct=None, # v_symmetry_time_pct,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
|
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
|
||||||
@ -455,7 +461,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
# and if weight is None, populate with default 1.0?
|
# and if weight is None, populate with default 1.0?
|
||||||
controlnet_data = []
|
controlnet_data = []
|
||||||
for control_info in control_list:
|
for control_info in control_list:
|
||||||
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
|
||||||
|
|
||||||
# control_models.append(control_model)
|
# control_models.append(control_model)
|
||||||
control_image_field = control_info.image
|
control_image_field = control_info.image
|
||||||
@ -517,10 +523,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
conditioning_data.ip_adapter_conditioning = []
|
conditioning_data.ip_adapter_conditioning = []
|
||||||
for single_ip_adapter in ip_adapter:
|
for single_ip_adapter in ip_adapter:
|
||||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||||
context.models.load(single_ip_adapter.ip_adapter_model)
|
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
|
||||||
)
|
)
|
||||||
|
|
||||||
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
|
||||||
|
|
||||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||||
single_ipa_image_fields = single_ip_adapter.image
|
single_ipa_image_fields = single_ip_adapter.image
|
||||||
if not isinstance(single_ipa_image_fields, list):
|
if not isinstance(single_ipa_image_fields, list):
|
||||||
@ -531,7 +538,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||||
with image_encoder_model_info as image_encoder_model:
|
with image_encoder_model_info as image_encoder_model:
|
||||||
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
|
||||||
# Get image embeddings from CLIP and ImageProjModel.
|
# Get image embeddings from CLIP and ImageProjModel.
|
||||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||||
single_ipa_images, image_encoder_model
|
single_ipa_images, image_encoder_model
|
||||||
@ -571,8 +577,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
t2i_adapter_data = []
|
t2i_adapter_data = []
|
||||||
for t2i_adapter_field in t2i_adapter:
|
for t2i_adapter_field in t2i_adapter:
|
||||||
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
|
t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key)
|
||||||
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
|
t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
|
||||||
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||||
|
|
||||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||||
@ -677,7 +683,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if self.denoise_mask.masked_latents_name is not None:
|
if self.denoise_mask.masked_latents_name is not None:
|
||||||
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
|
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
|
||||||
else:
|
else:
|
||||||
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
masked_latents = None
|
||||||
|
|
||||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||||
|
|
||||||
@ -725,13 +731,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.models.load(lora.lora)
|
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
|
||||||
yield (lora_info.model, lora.weight)
|
yield (lora_info.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.models.load(self.unet.unet)
|
unet_info = context.models.load(**self.unet.unet.model_dump())
|
||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
@ -825,7 +830,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
description=FieldDescriptions.latents,
|
description=FieldDescriptions.latents,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
)
|
)
|
||||||
vae: VAEField = InputField(
|
vae: VaeField = InputField(
|
||||||
description=FieldDescriptions.vae,
|
description=FieldDescriptions.vae,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
)
|
)
|
||||||
@ -836,8 +841,8 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
latents = context.tensors.load(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
vae_info = context.models.load(self.vae.vae)
|
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||||
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL))
|
|
||||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||||
assert isinstance(vae, torch.nn.Module)
|
assert isinstance(vae, torch.nn.Module)
|
||||||
latents = latents.to(vae.device)
|
latents = latents.to(vae.device)
|
||||||
@ -1003,7 +1008,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
image: ImageField = InputField(
|
image: ImageField = InputField(
|
||||||
description="The image to encode",
|
description="The image to encode",
|
||||||
)
|
)
|
||||||
vae: VAEField = InputField(
|
vae: VaeField = InputField(
|
||||||
description=FieldDescriptions.vae,
|
description=FieldDescriptions.vae,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
)
|
)
|
||||||
@ -1059,7 +1064,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
image = context.images.get_pil(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
vae_info = context.models.load(self.vae.vae)
|
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||||
|
|
||||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
if image_tensor.dim() == 3:
|
if image_tensor.dim() == 3:
|
||||||
|
@ -8,10 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.controlnet_image_processors import (
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
CONTROLNET_MODE_VALUES,
|
|
||||||
CONTROLNET_RESIZE_VALUES,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
FieldDescriptions,
|
FieldDescriptions,
|
||||||
ImageField,
|
ImageField,
|
||||||
@ -20,8 +17,10 @@ from invokeai.app.invocations.fields import (
|
|||||||
OutputField,
|
OutputField,
|
||||||
UIType,
|
UIType,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
||||||
|
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||||
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
|
||||||
|
|
||||||
from ...version import __version__
|
from ...version import __version__
|
||||||
|
|
||||||
@ -31,20 +30,10 @@ class MetadataItemField(BaseModel):
|
|||||||
value: Any = Field(description=FieldDescriptions.metadata_item_value)
|
value: Any = Field(description=FieldDescriptions.metadata_item_value)
|
||||||
|
|
||||||
|
|
||||||
class ModelMetadataField(BaseModel):
|
|
||||||
"""Model Metadata Field"""
|
|
||||||
|
|
||||||
key: str
|
|
||||||
hash: str
|
|
||||||
name: str
|
|
||||||
base: BaseModelType
|
|
||||||
type: ModelType
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAMetadataField(BaseModel):
|
class LoRAMetadataField(BaseModel):
|
||||||
"""LoRA Metadata Field"""
|
"""LoRA Metadata Field"""
|
||||||
|
|
||||||
model: ModelMetadataField = Field(description=FieldDescriptions.lora_model)
|
model: LoRAModelField = Field(description=FieldDescriptions.lora_model)
|
||||||
weight: float = Field(description=FieldDescriptions.lora_weight)
|
weight: float = Field(description=FieldDescriptions.lora_weight)
|
||||||
|
|
||||||
|
|
||||||
@ -52,7 +41,7 @@ class IPAdapterMetadataField(BaseModel):
|
|||||||
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
|
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
|
||||||
|
|
||||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||||
ip_adapter_model: ModelMetadataField = Field(
|
ip_adapter_model: IPAdapterModelField = Field(
|
||||||
description="The IP-Adapter model.",
|
description="The IP-Adapter model.",
|
||||||
)
|
)
|
||||||
weight: Union[float, list[float]] = Field(
|
weight: Union[float, list[float]] = Field(
|
||||||
@ -62,33 +51,6 @@ class IPAdapterMetadataField(BaseModel):
|
|||||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterMetadataField(BaseModel):
|
|
||||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
|
||||||
t2i_adapter_model: ModelMetadataField = Field(description="The T2I-Adapter model to use.")
|
|
||||||
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
|
||||||
begin_step_percent: float = Field(
|
|
||||||
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
|
||||||
)
|
|
||||||
end_step_percent: float = Field(
|
|
||||||
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
|
|
||||||
)
|
|
||||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetMetadataField(BaseModel):
|
|
||||||
image: ImageField = Field(description="The control image")
|
|
||||||
control_model: ModelMetadataField = Field(description="The ControlNet model to use")
|
|
||||||
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
|
|
||||||
begin_step_percent: float = Field(
|
|
||||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
|
||||||
)
|
|
||||||
end_step_percent: float = Field(
|
|
||||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
|
||||||
)
|
|
||||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
|
||||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("metadata_item_output")
|
@invocation_output("metadata_item_output")
|
||||||
class MetadataItemOutput(BaseInvocationOutput):
|
class MetadataItemOutput(BaseInvocationOutput):
|
||||||
"""Metadata Item Output"""
|
"""Metadata Item Output"""
|
||||||
@ -178,14 +140,14 @@ class CoreMetadataInvocation(BaseInvocation):
|
|||||||
default=None,
|
default=None,
|
||||||
description="The number of skipped CLIP layers",
|
description="The number of skipped CLIP layers",
|
||||||
)
|
)
|
||||||
model: Optional[ModelMetadataField] = InputField(default=None, description="The main model used for inference")
|
model: Optional[MainModelField] = InputField(default=None, description="The main model used for inference")
|
||||||
controlnets: Optional[list[ControlNetMetadataField]] = InputField(
|
controlnets: Optional[list[ControlField]] = InputField(
|
||||||
default=None, description="The ControlNets used for inference"
|
default=None, description="The ControlNets used for inference"
|
||||||
)
|
)
|
||||||
ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField(
|
ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField(
|
||||||
default=None, description="The IP Adapters used for inference"
|
default=None, description="The IP Adapters used for inference"
|
||||||
)
|
)
|
||||||
t2iAdapters: Optional[list[T2IAdapterMetadataField]] = InputField(
|
t2iAdapters: Optional[list[T2IAdapterField]] = InputField(
|
||||||
default=None, description="The IP Adapters used for inference"
|
default=None, description="The IP Adapters used for inference"
|
||||||
)
|
)
|
||||||
loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference")
|
loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference")
|
||||||
@ -197,7 +159,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
|||||||
default=None,
|
default=None,
|
||||||
description="The name of the initial image",
|
description="The name of the initial image",
|
||||||
)
|
)
|
||||||
vae: Optional[ModelMetadataField] = InputField(
|
vae: Optional[VAEModelField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="The VAE used for decoding, if the main model's default was not used",
|
description="The VAE used for decoding, if the main model's default was not used",
|
||||||
)
|
)
|
||||||
@ -228,7 +190,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# SDXL Refiner
|
# SDXL Refiner
|
||||||
refiner_model: Optional[ModelMetadataField] = InputField(
|
refiner_model: Optional[MainModelField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="The SDXL Refiner model used",
|
description="The SDXL Refiner model used",
|
||||||
)
|
)
|
||||||
@ -260,9 +222,10 @@ class CoreMetadataInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
||||||
"""Collects and outputs a CoreMetadata object"""
|
"""Collects and outputs a CoreMetadata object"""
|
||||||
|
|
||||||
as_dict = self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
return MetadataOutput(
|
||||||
as_dict["app_version"] = __version__
|
metadata=MetadataField.model_validate(
|
||||||
|
self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
||||||
return MetadataOutput(metadata=MetadataField.model_validate(as_dict))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
@ -3,11 +3,11 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
|
||||||
|
|
||||||
|
from ...backend.model_manager import SubModelType
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
@ -16,52 +16,33 @@ from .baseinvocation import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelIdentifierField(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
key: str = Field(description="The model's unique key")
|
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
|
||||||
hash: str = Field(description="The model's BLAKE3 hash")
|
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||||
name: str = Field(description="The model's name")
|
|
||||||
base: BaseModelType = Field(description="The model's base model type")
|
|
||||||
type: ModelType = Field(description="The model's type")
|
|
||||||
submodel_type: Optional[SubModelType] = Field(
|
|
||||||
description="The submodel to load, if this is a main model", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls, config: "AnyModelConfig", submodel_type: Optional[SubModelType] = None
|
|
||||||
) -> "ModelIdentifierField":
|
|
||||||
return cls(
|
|
||||||
key=config.key,
|
|
||||||
hash=config.hash,
|
|
||||||
name=config.name,
|
|
||||||
base=config.base,
|
|
||||||
type=config.type,
|
|
||||||
submodel_type=submodel_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAField(BaseModel):
|
class LoraInfo(ModelInfo):
|
||||||
lora: ModelIdentifierField = Field(description="Info to load lora model")
|
weight: float = Field(description="Lora's weight which to use when apply to model")
|
||||||
weight: float = Field(description="Weight to apply to lora model")
|
|
||||||
|
|
||||||
|
|
||||||
class UNetField(BaseModel):
|
class UNetField(BaseModel):
|
||||||
unet: ModelIdentifierField = Field(description="Info to load unet submodel")
|
unet: ModelInfo = Field(description="Info to load unet submodel")
|
||||||
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
|
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
||||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||||
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
||||||
|
|
||||||
|
|
||||||
class CLIPField(BaseModel):
|
class ClipField(BaseModel):
|
||||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
|
||||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
|
||||||
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
||||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||||
|
|
||||||
|
|
||||||
class VAEField(BaseModel):
|
class VaeField(BaseModel):
|
||||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
# TODO: better naming?
|
||||||
|
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||||
|
|
||||||
|
|
||||||
@ -76,14 +57,14 @@ class UNetOutput(BaseInvocationOutput):
|
|||||||
class VAEOutput(BaseInvocationOutput):
|
class VAEOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output a VAE field"""
|
"""Base class for invocations that output a VAE field"""
|
||||||
|
|
||||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("clip_output")
|
@invocation_output("clip_output")
|
||||||
class CLIPOutput(BaseInvocationOutput):
|
class CLIPOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output a CLIP field"""
|
"""Base class for invocations that output a CLIP field"""
|
||||||
|
|
||||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("model_loader_output")
|
@invocation_output("model_loader_output")
|
||||||
@ -93,6 +74,18 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MainModelField(BaseModel):
|
||||||
|
"""Main model field"""
|
||||||
|
|
||||||
|
key: str = Field(description="Model key")
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAModelField(BaseModel):
|
||||||
|
"""LoRA model field"""
|
||||||
|
|
||||||
|
key: str = Field(description="LoRA model key")
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"main_model_loader",
|
"main_model_loader",
|
||||||
title="Main Model",
|
title="Main Model",
|
||||||
@ -103,44 +96,62 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
|||||||
class MainModelLoaderInvocation(BaseInvocation):
|
class MainModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
model: ModelIdentifierField = InputField(
|
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||||
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
|
|
||||||
)
|
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||||
# TODO: not found exceptions
|
key = self.model.key
|
||||||
if not context.models.exists(self.model.key):
|
|
||||||
raise Exception(f"Unknown model {self.model.key}")
|
|
||||||
|
|
||||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
# TODO: not found exceptions
|
||||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
if not context.models.exists(key):
|
||||||
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
raise Exception(f"Unknown model {key}")
|
||||||
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
|
||||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
|
||||||
|
|
||||||
return ModelLoaderOutput(
|
return ModelLoaderOutput(
|
||||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
unet=UNetField(
|
||||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
unet=ModelInfo(
|
||||||
vae=VAEField(vae=vae),
|
key=key,
|
||||||
|
submodel_type=SubModelType.UNet,
|
||||||
|
),
|
||||||
|
scheduler=ModelInfo(
|
||||||
|
key=key,
|
||||||
|
submodel_type=SubModelType.Scheduler,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
),
|
||||||
|
clip=ClipField(
|
||||||
|
tokenizer=ModelInfo(
|
||||||
|
key=key,
|
||||||
|
submodel_type=SubModelType.Tokenizer,
|
||||||
|
),
|
||||||
|
text_encoder=ModelInfo(
|
||||||
|
key=key,
|
||||||
|
submodel_type=SubModelType.TextEncoder,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
skipped_layers=0,
|
||||||
|
),
|
||||||
|
vae=VaeField(
|
||||||
|
vae=ModelInfo(
|
||||||
|
key=key,
|
||||||
|
submodel_type=SubModelType.VAE,
|
||||||
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("lora_loader_output")
|
@invocation_output("lora_loader_output")
|
||||||
class LoRALoaderOutput(BaseInvocationOutput):
|
class LoraLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""Model loader output"""
|
||||||
|
|
||||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
|
||||||
|
|
||||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
|
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
|
||||||
class LoRALoaderInvocation(BaseInvocation):
|
class LoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
lora: ModelIdentifierField = InputField(
|
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||||
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
|
|
||||||
)
|
|
||||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
unet: Optional[UNetField] = InputField(
|
unet: Optional[UNetField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
@ -148,41 +159,46 @@ class LoRALoaderInvocation(BaseInvocation):
|
|||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
title="UNet",
|
title="UNet",
|
||||||
)
|
)
|
||||||
clip: Optional[CLIPField] = InputField(
|
clip: Optional[ClipField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description=FieldDescriptions.clip,
|
description=FieldDescriptions.clip,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
title="CLIP",
|
title="CLIP",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
|
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||||
|
if self.lora is None:
|
||||||
|
raise Exception("No LoRA provided")
|
||||||
|
|
||||||
lora_key = self.lora.key
|
lora_key = self.lora.key
|
||||||
|
|
||||||
if not context.models.exists(lora_key):
|
if not context.models.exists(lora_key):
|
||||||
raise Exception(f"Unkown lora: {lora_key}!")
|
raise Exception(f"Unkown lora: {lora_key}!")
|
||||||
|
|
||||||
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
|
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||||
raise Exception(f'LoRA "{lora_key}" already applied to unet')
|
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||||
|
|
||||||
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||||
raise Exception(f'LoRA "{lora_key}" already applied to clip')
|
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||||
|
|
||||||
output = LoRALoaderOutput()
|
output = LoraLoaderOutput()
|
||||||
|
|
||||||
if self.unet is not None:
|
if self.unet is not None:
|
||||||
output.unet = self.unet.model_copy(deep=True)
|
output.unet = copy.deepcopy(self.unet)
|
||||||
output.unet.loras.append(
|
output.unet.loras.append(
|
||||||
LoRAField(
|
LoraInfo(
|
||||||
lora=self.lora,
|
key=lora_key,
|
||||||
|
submodel_type=None,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.clip is not None:
|
if self.clip is not None:
|
||||||
output.clip = self.clip.model_copy(deep=True)
|
output.clip = copy.deepcopy(self.clip)
|
||||||
output.clip.loras.append(
|
output.clip.loras.append(
|
||||||
LoRAField(
|
LoraInfo(
|
||||||
lora=self.lora,
|
key=lora_key,
|
||||||
|
submodel_type=None,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -191,12 +207,12 @@ class LoRALoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation_output("sdxl_lora_loader_output")
|
@invocation_output("sdxl_lora_loader_output")
|
||||||
class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||||
"""SDXL LoRA Loader Output"""
|
"""SDXL LoRA Loader Output"""
|
||||||
|
|
||||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1")
|
||||||
clip2: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -206,12 +222,10 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
|||||||
category="model",
|
category="model",
|
||||||
version="1.0.1",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class SDXLLoRALoaderInvocation(BaseInvocation):
|
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
lora: ModelIdentifierField = InputField(
|
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||||
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
|
|
||||||
)
|
|
||||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
unet: Optional[UNetField] = InputField(
|
unet: Optional[UNetField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
@ -219,59 +233,65 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
|
|||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
title="UNet",
|
title="UNet",
|
||||||
)
|
)
|
||||||
clip: Optional[CLIPField] = InputField(
|
clip: Optional[ClipField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description=FieldDescriptions.clip,
|
description=FieldDescriptions.clip,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
title="CLIP 1",
|
title="CLIP 1",
|
||||||
)
|
)
|
||||||
clip2: Optional[CLIPField] = InputField(
|
clip2: Optional[ClipField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description=FieldDescriptions.clip,
|
description=FieldDescriptions.clip,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
title="CLIP 2",
|
title="CLIP 2",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
|
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||||
|
if self.lora is None:
|
||||||
|
raise Exception("No LoRA provided")
|
||||||
|
|
||||||
lora_key = self.lora.key
|
lora_key = self.lora.key
|
||||||
|
|
||||||
if not context.models.exists(lora_key):
|
if not context.models.exists(lora_key):
|
||||||
raise Exception(f"Unknown lora: {lora_key}!")
|
raise Exception(f"Unknown lora: {lora_key}!")
|
||||||
|
|
||||||
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
|
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||||
raise Exception(f'LoRA "{lora_key}" already applied to unet')
|
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||||
|
|
||||||
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||||
raise Exception(f'LoRA "{lora_key}" already applied to clip')
|
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||||
|
|
||||||
if self.clip2 is not None and any(lora.lora.key == lora_key for lora in self.clip2.loras):
|
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
|
||||||
raise Exception(f'LoRA "{lora_key}" already applied to clip2')
|
raise Exception(f'Lora "{lora_key}" already applied to clip2')
|
||||||
|
|
||||||
output = SDXLLoRALoaderOutput()
|
output = SDXLLoraLoaderOutput()
|
||||||
|
|
||||||
if self.unet is not None:
|
if self.unet is not None:
|
||||||
output.unet = self.unet.model_copy(deep=True)
|
output.unet = copy.deepcopy(self.unet)
|
||||||
output.unet.loras.append(
|
output.unet.loras.append(
|
||||||
LoRAField(
|
LoraInfo(
|
||||||
lora=self.lora,
|
key=lora_key,
|
||||||
|
submodel_type=None,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.clip is not None:
|
if self.clip is not None:
|
||||||
output.clip = self.clip.model_copy(deep=True)
|
output.clip = copy.deepcopy(self.clip)
|
||||||
output.clip.loras.append(
|
output.clip.loras.append(
|
||||||
LoRAField(
|
LoraInfo(
|
||||||
lora=self.lora,
|
key=lora_key,
|
||||||
|
submodel_type=None,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.clip2 is not None:
|
if self.clip2 is not None:
|
||||||
output.clip2 = self.clip2.model_copy(deep=True)
|
output.clip2 = copy.deepcopy(self.clip2)
|
||||||
output.clip2.loras.append(
|
output.clip2.loras.append(
|
||||||
LoRAField(
|
LoraInfo(
|
||||||
lora=self.lora,
|
key=lora_key,
|
||||||
|
submodel_type=None,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -279,12 +299,20 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class VAEModelField(BaseModel):
|
||||||
|
"""Vae model field"""
|
||||||
|
|
||||||
|
key: str = Field(description="Model's key")
|
||||||
|
|
||||||
|
|
||||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
|
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
|
||||||
class VAELoaderInvocation(BaseInvocation):
|
class VaeLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||||
|
|
||||||
vae_model: ModelIdentifierField = InputField(
|
vae_model: VAEModelField = InputField(
|
||||||
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel
|
description=FieldDescriptions.vae_model,
|
||||||
|
input=Input.Direct,
|
||||||
|
title="VAE",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||||
@ -293,7 +321,7 @@ class VAELoaderInvocation(BaseInvocation):
|
|||||||
if not context.models.exists(key):
|
if not context.models.exists(key):
|
||||||
raise Exception(f"Unkown vae: {key}!")
|
raise Exception(f"Unkown vae: {key}!")
|
||||||
|
|
||||||
return VAEOutput(vae=VAEField(vae=self.vae_model))
|
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("seamless_output")
|
@invocation_output("seamless_output")
|
||||||
@ -301,7 +329,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
|
|||||||
"""Modified Seamless Model output"""
|
"""Modified Seamless Model output"""
|
||||||
|
|
||||||
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||||
vae: Optional[VAEField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
|
vae: Optional[VaeField] = OutputField(default=None, description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -320,7 +348,7 @@ class SeamlessModeInvocation(BaseInvocation):
|
|||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
title="UNet",
|
title="UNet",
|
||||||
)
|
)
|
||||||
vae: Optional[VAEField] = InputField(
|
vae: Optional[VaeField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description=FieldDescriptions.vae_model,
|
description=FieldDescriptions.vae_model,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
|
@ -8,7 +8,7 @@ from .baseinvocation import (
|
|||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from .model import CLIPField, ModelIdentifierField, UNetField, VAEField
|
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("sdxl_model_loader_output")
|
@invocation_output("sdxl_model_loader_output")
|
||||||
@ -16,9 +16,9 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
|
|||||||
"""SDXL base model loader output"""
|
"""SDXL base model loader output"""
|
||||||
|
|
||||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
|
||||||
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("sdxl_refiner_model_loader_output")
|
@invocation_output("sdxl_refiner_model_loader_output")
|
||||||
@ -26,15 +26,15 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
|||||||
"""SDXL refiner model loader output"""
|
"""SDXL refiner model loader output"""
|
||||||
|
|
||||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||||
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
|
||||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1")
|
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1")
|
||||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl base model, outputting its submodels."""
|
"""Loads an sdxl base model, outputting its submodels."""
|
||||||
|
|
||||||
model: ModelIdentifierField = InputField(
|
model: MainModelField = InputField(
|
||||||
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
||||||
)
|
)
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
@ -46,19 +46,48 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
if not context.models.exists(model_key):
|
if not context.models.exists(model_key):
|
||||||
raise Exception(f"Unknown model: {model_key}")
|
raise Exception(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
|
||||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
|
||||||
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
|
||||||
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
|
||||||
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
|
||||||
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
|
||||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
|
||||||
|
|
||||||
return SDXLModelLoaderOutput(
|
return SDXLModelLoaderOutput(
|
||||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
unet=UNetField(
|
||||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
unet=ModelInfo(
|
||||||
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
key=model_key,
|
||||||
vae=VAEField(vae=vae),
|
submodel_type=SubModelType.UNet,
|
||||||
|
),
|
||||||
|
scheduler=ModelInfo(
|
||||||
|
key=model_key,
|
||||||
|
submodel_type=SubModelType.Scheduler,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
),
|
||||||
|
clip=ClipField(
|
||||||
|
tokenizer=ModelInfo(
|
||||||
|
key=model_key,
|
||||||
|
submodel_type=SubModelType.Tokenizer,
|
||||||
|
),
|
||||||
|
text_encoder=ModelInfo(
|
||||||
|
key=model_key,
|
||||||
|
submodel_type=SubModelType.TextEncoder,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
skipped_layers=0,
|
||||||
|
),
|
||||||
|
clip2=ClipField(
|
||||||
|
tokenizer=ModelInfo(
|
||||||
|
key=model_key,
|
||||||
|
submodel_type=SubModelType.Tokenizer2,
|
||||||
|
),
|
||||||
|
text_encoder=ModelInfo(
|
||||||
|
key=model_key,
|
||||||
|
submodel_type=SubModelType.TextEncoder2,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
skipped_layers=0,
|
||||||
|
),
|
||||||
|
vae=VaeField(
|
||||||
|
vae=ModelInfo(
|
||||||
|
key=model_key,
|
||||||
|
submodel_type=SubModelType.VAE,
|
||||||
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -72,8 +101,10 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||||
|
|
||||||
model: ModelIdentifierField = InputField(
|
model: MainModelField = InputField(
|
||||||
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
|
description=FieldDescriptions.sdxl_refiner_model,
|
||||||
|
input=Input.Direct,
|
||||||
|
ui_type=UIType.SDXLRefinerModel,
|
||||||
)
|
)
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
@ -84,14 +115,34 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
|||||||
if not context.models.exists(model_key):
|
if not context.models.exists(model_key):
|
||||||
raise Exception(f"Unknown model: {model_key}")
|
raise Exception(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
|
||||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
|
||||||
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
|
||||||
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
|
||||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
|
||||||
|
|
||||||
return SDXLRefinerModelLoaderOutput(
|
return SDXLRefinerModelLoaderOutput(
|
||||||
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
unet=UNetField(
|
||||||
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
unet=ModelInfo(
|
||||||
vae=VAEField(vae=vae),
|
key=model_key,
|
||||||
|
submodel_type=SubModelType.UNet,
|
||||||
|
),
|
||||||
|
scheduler=ModelInfo(
|
||||||
|
key=model_key,
|
||||||
|
submodel_type=SubModelType.Scheduler,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
),
|
||||||
|
clip2=ClipField(
|
||||||
|
tokenizer=ModelInfo(
|
||||||
|
key=model_key,
|
||||||
|
submodel_type=SubModelType.Tokenizer2,
|
||||||
|
),
|
||||||
|
text_encoder=ModelInfo(
|
||||||
|
key=model_key,
|
||||||
|
submodel_type=SubModelType.TextEncoder2,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
skipped_layers=0,
|
||||||
|
),
|
||||||
|
vae=VaeField(
|
||||||
|
vae=ModelInfo(
|
||||||
|
key=model_key,
|
||||||
|
submodel_type=SubModelType.VAE,
|
||||||
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
@ -9,15 +9,18 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
|
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
|
|
||||||
|
class T2IAdapterModelField(BaseModel):
|
||||||
|
key: str = Field(description="Model record key for the T2I-Adapter model")
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterField(BaseModel):
|
class T2IAdapterField(BaseModel):
|
||||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||||
t2i_adapter_model: ModelIdentifierField = Field(description="The T2I-Adapter model to use.")
|
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
|
||||||
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||||
begin_step_percent: float = Field(
|
begin_step_percent: float = Field(
|
||||||
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||||
@ -52,12 +55,11 @@ class T2IAdapterInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||||
t2i_adapter_model: ModelIdentifierField = InputField(
|
t2i_adapter_model: T2IAdapterModelField = InputField(
|
||||||
description="The T2I-Adapter model.",
|
description="The T2I-Adapter model.",
|
||||||
title="T2I-Adapter Model",
|
title="T2I-Adapter Model",
|
||||||
input=Input.Direct,
|
input=Input.Direct,
|
||||||
ui_order=-1,
|
ui_order=-1,
|
||||||
ui_type=UIType.T2IAdapterModel,
|
|
||||||
)
|
)
|
||||||
weight: Union[float, list[float]] = InputField(
|
weight: Union[float, list[float]] = InputField(
|
||||||
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"
|
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"
|
||||||
|
@ -17,8 +17,7 @@ from argparse import ArgumentParser
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
from omegaconf import DictConfig, DictKeyType, ListConfig, OmegaConf
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||||
from pydantic import BaseModel
|
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
|
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
|
||||||
@ -63,22 +62,6 @@ class InvokeAISettings(BaseSettings):
|
|||||||
assert isinstance(category, str)
|
assert isinstance(category, str)
|
||||||
if category not in field_dict[type]:
|
if category not in field_dict[type]:
|
||||||
field_dict[type][category] = {}
|
field_dict[type][category] = {}
|
||||||
if isinstance(value, BaseModel):
|
|
||||||
dump = value.model_dump(exclude_defaults=True, exclude_unset=True, exclude_none=True)
|
|
||||||
field_dict[type][category][name] = dump
|
|
||||||
continue
|
|
||||||
if isinstance(value, list):
|
|
||||||
if not value or len(value) == 0:
|
|
||||||
continue
|
|
||||||
primitive = isinstance(value[0], get_args(DictKeyType))
|
|
||||||
if not primitive:
|
|
||||||
val_list: List[Dict[str, Any]] = []
|
|
||||||
for list_val in value:
|
|
||||||
if isinstance(list_val, BaseModel):
|
|
||||||
dump = list_val.model_dump(exclude_defaults=True, exclude_unset=True, exclude_none=True)
|
|
||||||
val_list.append(dump)
|
|
||||||
field_dict[type][category][name] = val_list
|
|
||||||
continue
|
|
||||||
# keep paths as strings to make it easier to read
|
# keep paths as strings to make it easier to read
|
||||||
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
||||||
conf = OmegaConf.create(field_dict)
|
conf = OmegaConf.create(field_dict)
|
||||||
@ -152,7 +135,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _excluded(cls) -> List[str]:
|
def _excluded(cls) -> List[str]:
|
||||||
# internal fields that shouldn't be exposed as command line options
|
# internal fields that shouldn't be exposed as command line options
|
||||||
return ["type", "initconf", "remote_api_tokens"]
|
return ["type", "initconf"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _excluded_from_yaml(cls) -> List[str]:
|
def _excluded_from_yaml(cls) -> List[str]:
|
||||||
|
@ -170,12 +170,11 @@ two configs are kept in separate sections of the config file:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, ClassVar, Dict, List, Literal, Optional
|
from typing import Any, ClassVar, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import Field
|
||||||
from pydantic.config import JsonDict
|
from pydantic.config import JsonDict
|
||||||
from pydantic_settings import SettingsConfigDict
|
from pydantic_settings import SettingsConfigDict
|
||||||
|
|
||||||
@ -197,87 +196,17 @@ class Categories(object):
|
|||||||
Paths: JsonDict = {"category": "Paths"}
|
Paths: JsonDict = {"category": "Paths"}
|
||||||
Logging: JsonDict = {"category": "Logging"}
|
Logging: JsonDict = {"category": "Logging"}
|
||||||
Development: JsonDict = {"category": "Development"}
|
Development: JsonDict = {"category": "Development"}
|
||||||
CLIArgs: JsonDict = {"category": "CLIArgs"}
|
Other: JsonDict = {"category": "Other"}
|
||||||
ModelInstall: JsonDict = {"category": "Model Install"}
|
|
||||||
ModelCache: JsonDict = {"category": "Model Cache"}
|
ModelCache: JsonDict = {"category": "Model Cache"}
|
||||||
Device: JsonDict = {"category": "Device"}
|
Device: JsonDict = {"category": "Device"}
|
||||||
Generation: JsonDict = {"category": "Generation"}
|
Generation: JsonDict = {"category": "Generation"}
|
||||||
Queue: JsonDict = {"category": "Queue"}
|
Queue: JsonDict = {"category": "Queue"}
|
||||||
Nodes: JsonDict = {"category": "Nodes"}
|
Nodes: JsonDict = {"category": "Nodes"}
|
||||||
MemoryPerformance: JsonDict = {"category": "Memory/Performance"}
|
MemoryPerformance: JsonDict = {"category": "Memory/Performance"}
|
||||||
Deprecated: JsonDict = {"category": "Deprecated"}
|
|
||||||
|
|
||||||
|
|
||||||
class URLRegexToken(BaseModel):
|
|
||||||
url_regex: str = Field(description="Regular expression to match against the URL")
|
|
||||||
token: str = Field(description="Token to use when the URL matches the regex")
|
|
||||||
|
|
||||||
@field_validator("url_regex")
|
|
||||||
@classmethod
|
|
||||||
def validate_url_regex(cls, v: str) -> str:
|
|
||||||
"""Validate that the value is a valid regex."""
|
|
||||||
try:
|
|
||||||
re.compile(v)
|
|
||||||
except re.error as e:
|
|
||||||
raise ValueError(f"Invalid regex: {e}")
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIAppConfig(InvokeAISettings):
|
class InvokeAIAppConfig(InvokeAISettings):
|
||||||
"""Invoke App Configuration
|
"""Configuration object for InvokeAI App."""
|
||||||
|
|
||||||
Attributes:
|
|
||||||
host: **Web Server**: IP address to bind to. Use `0.0.0.0` to serve to your local network.
|
|
||||||
port: **Web Server**: Port to bind to.
|
|
||||||
allow_origins: **Web Server**: Allowed CORS origins.
|
|
||||||
allow_credentials: **Web Server**: Allow CORS credentials.
|
|
||||||
allow_methods: **Web Server**: Methods allowed for CORS.
|
|
||||||
allow_headers: **Web Server**: Headers allowed for CORS.
|
|
||||||
ssl_certfile: **Web Server**: SSL certificate file for HTTPS.
|
|
||||||
ssl_keyfile: **Web Server**: SSL key file for HTTPS.
|
|
||||||
esrgan: **Features**: Enables or disables the upscaling code.
|
|
||||||
internet_available: **Features**: If true, attempt to download models on the fly; otherwise only use local models.
|
|
||||||
log_tokenization: **Features**: Enable logging of parsed prompt tokens.
|
|
||||||
patchmatch: **Features**: Enable patchmatch inpaint code.
|
|
||||||
ignore_missing_core_models: **Features**: Ignore missing core models on startup. If `True`, the app will attempt to download missing models on startup.
|
|
||||||
root: **Paths**: The InvokeAI runtime root directory.
|
|
||||||
autoimport_dir: **Paths**: Path to a directory of models files to be imported on startup.
|
|
||||||
models_dir: **Paths**: Path to the models directory.
|
|
||||||
convert_cache_dir: **Paths**: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
|
|
||||||
legacy_conf_dir: **Paths**: Path to directory of legacy checkpoint config files.
|
|
||||||
db_dir: **Paths**: Path to InvokeAI databases directory.
|
|
||||||
outdir: **Paths**: Path to directory for outputs.
|
|
||||||
custom_nodes_dir: **Paths**: Path to directory for custom nodes.
|
|
||||||
from_file: **Paths**: Take command input from the indicated file (command-line client only).
|
|
||||||
log_handlers: **Logging**: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
|
|
||||||
log_format: **Logging**: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.
|
|
||||||
log_level: **Logging**: Emit logging messages at this level or higher.
|
|
||||||
log_sql: **Logging**: Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.
|
|
||||||
use_memory_db: **Development**: Use in-memory database. Useful for development.
|
|
||||||
dev_reload: **Development**: Automatically reload when Python sources are changed. Does not reload node definitions.
|
|
||||||
profile_graphs: **Development**: Enable graph profiling using `cProfile`.
|
|
||||||
profile_prefix: **Development**: An optional prefix for profile output files.
|
|
||||||
profiles_dir: **Development**: Path to profiles output directory.
|
|
||||||
version: **CLIArgs**: CLI arg - show InvokeAI version and exit.
|
|
||||||
skip_model_hash: **Model Install**: Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models.
|
|
||||||
remote_api_tokens: **Model Install**: List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.
|
|
||||||
ram: **Model Cache**: Maximum memory amount used by memory model cache for rapid switching (GB).
|
|
||||||
vram: **Model Cache**: Amount of VRAM reserved for model storage (GB)
|
|
||||||
convert_cache: **Model Cache**: Maximum size of on-disk converted models cache (GB)
|
|
||||||
lazy_offload: **Model Cache**: Keep models in VRAM until their space is needed.
|
|
||||||
log_memory_usage: **Model Cache**: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
|
||||||
device: **Device**: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
|
|
||||||
precision: **Device**: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
|
|
||||||
sequential_guidance: **Generation**: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
|
||||||
attention_type: **Generation**: Attention type.
|
|
||||||
attention_slice_size: **Generation**: Slice size, valid when attention_type=="sliced".
|
|
||||||
force_tiled_decode: **Generation**: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
|
|
||||||
png_compress_level: **Generation**: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
|
|
||||||
max_queue_size: **Queue**: Maximum number of items in the session queue.
|
|
||||||
allow_nodes: **Nodes**: List of nodes to allow. Omit to allow all.
|
|
||||||
deny_nodes: **Nodes**: List of nodes to deny. Omit to deny none.
|
|
||||||
node_cache_size: **Nodes**: How many cached nodes to keep in memory.
|
|
||||||
"""
|
|
||||||
|
|
||||||
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
|
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
|
||||||
singleton_init: ClassVar[Optional[Dict[str, Any]]] = None
|
singleton_init: ClassVar[Optional[Dict[str, Any]]] = None
|
||||||
@ -286,98 +215,91 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
type: Literal["InvokeAI"] = "InvokeAI"
|
type: Literal["InvokeAI"] = "InvokeAI"
|
||||||
|
|
||||||
# WEB
|
# WEB
|
||||||
host : str = Field(default="127.0.0.1", description="IP address to bind to. Use `0.0.0.0` to serve to your local network.", json_schema_extra=Categories.WebServer)
|
host : str = Field(default="127.0.0.1", description="IP address to bind to", json_schema_extra=Categories.WebServer)
|
||||||
port : int = Field(default=9090, description="Port to bind to.", json_schema_extra=Categories.WebServer)
|
port : int = Field(default=9090, description="Port to bind to", json_schema_extra=Categories.WebServer)
|
||||||
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins.", json_schema_extra=Categories.WebServer)
|
allow_origins : List[str] = Field(default=[], description="Allowed CORS origins", json_schema_extra=Categories.WebServer)
|
||||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials.", json_schema_extra=Categories.WebServer)
|
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", json_schema_extra=Categories.WebServer)
|
||||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS.", json_schema_extra=Categories.WebServer)
|
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS.", json_schema_extra=Categories.WebServer)
|
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||||
# SSL options correspond to https://www.uvicorn.org/settings/#https
|
# SSL options correspond to https://www.uvicorn.org/settings/#https
|
||||||
ssl_certfile : Optional[Path] = Field(default=None, description="SSL certificate file for HTTPS.", json_schema_extra=Categories.WebServer)
|
ssl_certfile : Optional[Path] = Field(default=None, description="SSL certificate file (for HTTPS)", json_schema_extra=Categories.WebServer)
|
||||||
ssl_keyfile : Optional[Path] = Field(default=None, description="SSL key file for HTTPS.", json_schema_extra=Categories.WebServer)
|
ssl_keyfile : Optional[Path] = Field(default=None, description="SSL key file", json_schema_extra=Categories.WebServer)
|
||||||
|
|
||||||
# FEATURES
|
# FEATURES
|
||||||
esrgan : bool = Field(default=True, description="Enables or disables the upscaling code.", json_schema_extra=Categories.Features)
|
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", json_schema_extra=Categories.Features)
|
||||||
# TODO(psyche): This is not used anywhere.
|
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", json_schema_extra=Categories.Features)
|
||||||
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models.", json_schema_extra=Categories.Features)
|
|
||||||
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", json_schema_extra=Categories.Features)
|
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", json_schema_extra=Categories.Features)
|
||||||
patchmatch : bool = Field(default=True, description="Enable patchmatch inpaint code.", json_schema_extra=Categories.Features)
|
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", json_schema_extra=Categories.Features)
|
||||||
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing core models on startup. If `True`, the app will attempt to download missing models on startup.', json_schema_extra=Categories.Features)
|
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', json_schema_extra=Categories.Features)
|
||||||
|
|
||||||
# PATHS
|
# PATHS
|
||||||
root : Optional[Path] = Field(default=None, description='The InvokeAI runtime root directory.', json_schema_extra=Categories.Paths)
|
root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths)
|
||||||
autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
|
autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||||
models_dir : Path = Field(default=Path('models'), description='Path to the models directory.', json_schema_extra=Categories.Paths)
|
models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
|
||||||
convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.', json_schema_extra=Categories.Paths)
|
convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths)
|
||||||
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files.', json_schema_extra=Categories.Paths)
|
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
|
||||||
db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory.', json_schema_extra=Categories.Paths)
|
db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
|
||||||
outdir : Path = Field(default=Path('outputs'), description='Path to directory for outputs.', json_schema_extra=Categories.Paths)
|
outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
|
||||||
custom_nodes_dir : Path = Field(default=Path('nodes'), description='Path to directory for custom nodes.', json_schema_extra=Categories.Paths)
|
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', json_schema_extra=Categories.Paths)
|
||||||
# TODO(psyche): This is not used anywhere.
|
custom_nodes_dir : Path = Field(default=Path('nodes'), description='Path to directory for custom nodes', json_schema_extra=Categories.Paths)
|
||||||
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only).', json_schema_extra=Categories.Paths)
|
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only)', json_schema_extra=Categories.Paths)
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".', json_schema_extra=Categories.Logging)
|
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', json_schema_extra=Categories.Logging)
|
||||||
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||||
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.', json_schema_extra=Categories.Logging)
|
log_format : Literal['plain', 'color', 'syslog', 'legacy'] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', json_schema_extra=Categories.Logging)
|
||||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher.", json_schema_extra=Categories.Logging)
|
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", json_schema_extra=Categories.Logging)
|
||||||
log_sql : bool = Field(default=False, description="Log SQL queries. `log_level` must be `debug` for this to do anything. Extremely verbose.", json_schema_extra=Categories.Logging)
|
log_sql : bool = Field(default=False, description="Log SQL queries", json_schema_extra=Categories.Logging)
|
||||||
|
|
||||||
# Development
|
# Development
|
||||||
use_memory_db : bool = Field(default=False, description='Use in-memory database. Useful for development.', json_schema_extra=Categories.Development)
|
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", json_schema_extra=Categories.Development)
|
||||||
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed. Does not reload node definitions.", json_schema_extra=Categories.Development)
|
profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
|
||||||
profile_graphs : bool = Field(default=False, description="Enable graph profiling using `cProfile`.", json_schema_extra=Categories.Development)
|
|
||||||
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
|
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
|
||||||
profiles_dir : Path = Field(default=Path('profiles'), description="Path to profiles output directory.", json_schema_extra=Categories.Development)
|
profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
|
||||||
|
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce startup time.", json_schema_extra=Categories.Development)
|
||||||
|
|
||||||
version : bool = Field(default=False, description="CLI arg - show InvokeAI version and exit.", json_schema_extra=Categories.CLIArgs)
|
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
|
||||||
|
|
||||||
# CACHE
|
# CACHE
|
||||||
ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).", json_schema_extra=Categories.ModelCache, )
|
ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||||
vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB)", json_schema_extra=Categories.ModelCache, )
|
vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||||
convert_cache : float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache)
|
convert_cache : float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache)
|
||||||
|
|
||||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed.", json_schema_extra=Categories.ModelCache, )
|
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
|
||||||
log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache)
|
log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache)
|
||||||
|
|
||||||
# DEVICE
|
# DEVICE
|
||||||
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.", json_schema_extra=Categories.Device)
|
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)
|
||||||
precision : Literal["auto", "float16", "bfloat16", "float32", "autocast"] = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.", json_schema_extra=Categories.Device)
|
precision : Literal["auto", "float16", "bfloat16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", json_schema_extra=Categories.Device)
|
||||||
|
|
||||||
# GENERATION
|
# GENERATION
|
||||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.", json_schema_extra=Categories.Generation)
|
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", json_schema_extra=Categories.Generation)
|
||||||
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type.", json_schema_extra=Categories.Generation)
|
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", json_schema_extra=Categories.Generation)
|
||||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced".', json_schema_extra=Categories.Generation)
|
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', json_schema_extra=Categories.Generation)
|
||||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).", json_schema_extra=Categories.Generation)
|
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Generation)
|
||||||
png_compress_level : int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.", json_schema_extra=Categories.Generation)
|
png_compress_level : int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", json_schema_extra=Categories.Generation)
|
||||||
|
|
||||||
# QUEUE
|
# QUEUE
|
||||||
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.", json_schema_extra=Categories.Queue)
|
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", json_schema_extra=Categories.Queue)
|
||||||
|
|
||||||
# NODES
|
# NODES
|
||||||
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", json_schema_extra=Categories.Nodes)
|
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", json_schema_extra=Categories.Nodes)
|
||||||
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", json_schema_extra=Categories.Nodes)
|
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", json_schema_extra=Categories.Nodes)
|
||||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory.", json_schema_extra=Categories.Nodes)
|
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
|
||||||
|
|
||||||
# MODEL INSTALL
|
# MODEL IMPORT
|
||||||
skip_model_hash : bool = Field(default=False, description="Skip model hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models.", json_schema_extra=Categories.ModelInstall)
|
civitai_api_key : Optional[str] = Field(default=os.environ.get("CIVITAI_API_KEY"), description="API key for CivitAI", json_schema_extra=Categories.Other)
|
||||||
remote_api_tokens : Optional[list[URLRegexToken]] = Field(
|
|
||||||
default=None,
|
|
||||||
description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.",
|
|
||||||
json_schema_extra=Categories.ModelInstall
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO(psyche): Can we just remove these then?
|
|
||||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.Deprecated)
|
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
|
||||||
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.Deprecated)
|
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)
|
||||||
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", json_schema_extra=Categories.Deprecated)
|
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", json_schema_extra=Categories.MemoryPerformance)
|
||||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", json_schema_extra=Categories.Deprecated)
|
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", json_schema_extra=Categories.MemoryPerformance)
|
||||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Deprecated)
|
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.MemoryPerformance)
|
||||||
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Deprecated)
|
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||||
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Deprecated)
|
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||||
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Deprecated)
|
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||||
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Deprecated)
|
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
|
||||||
|
|
||||||
# this is not referred to in the source code and can be removed entirely
|
# this is not referred to in the source code and can be removed entirely
|
||||||
#free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
|
#free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
|
||||||
@ -555,53 +477,6 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
"""Choose the runtime root directory when not specified on command line or init file."""
|
"""Choose the runtime root directory when not specified on command line or init file."""
|
||||||
return _find_root()
|
return _find_root()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generate_docstrings() -> str:
|
|
||||||
"""Helper function for mkdocs. Generates a docstring for the InvokeAIAppConfig class.
|
|
||||||
|
|
||||||
You shouldn't run this manually. Instead, run `scripts/update-config-docstring.py` to update the docstring.
|
|
||||||
A makefile target is also available: `make update-config-docstring`.
|
|
||||||
|
|
||||||
See that script for more information about why this is necessary.
|
|
||||||
"""
|
|
||||||
docstring = ' """Invoke App Configuration\n\n'
|
|
||||||
docstring += " Attributes:"
|
|
||||||
|
|
||||||
field_descriptions: dict[str, list[str]] = {}
|
|
||||||
|
|
||||||
for k, v in InvokeAIAppConfig.model_fields.items():
|
|
||||||
if not isinstance(v.json_schema_extra, dict):
|
|
||||||
# Should never happen
|
|
||||||
continue
|
|
||||||
|
|
||||||
category = v.json_schema_extra.get("category", None)
|
|
||||||
if not isinstance(category, str) or category == "Deprecated":
|
|
||||||
continue
|
|
||||||
if not field_descriptions.get(category):
|
|
||||||
field_descriptions[category] = []
|
|
||||||
field_descriptions[category].append(f" {k}: **{category}**: {v.description}")
|
|
||||||
|
|
||||||
for c in [
|
|
||||||
"Web Server",
|
|
||||||
"Features",
|
|
||||||
"Paths",
|
|
||||||
"Logging",
|
|
||||||
"Development",
|
|
||||||
"CLIArgs",
|
|
||||||
"Model Install",
|
|
||||||
"Model Cache",
|
|
||||||
"Device",
|
|
||||||
"Generation",
|
|
||||||
"Queue",
|
|
||||||
"Nodes",
|
|
||||||
]:
|
|
||||||
docstring += "\n"
|
|
||||||
docstring += "\n".join(field_descriptions[c])
|
|
||||||
|
|
||||||
docstring += '\n """'
|
|
||||||
|
|
||||||
return docstring
|
|
||||||
|
|
||||||
|
|
||||||
def get_invokeai_config(**kwargs: Any) -> InvokeAIAppConfig:
|
def get_invokeai_config(**kwargs: Any) -> InvokeAIAppConfig:
|
||||||
"""Legacy function which returns InvokeAIAppConfig.get_config()."""
|
"""Legacy function which returns InvokeAIAppConfig.get_config()."""
|
||||||
|
@ -41,9 +41,8 @@ class InvocationCacheBase(ABC):
|
|||||||
"""Clears the cache"""
|
"""Clears the cache"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_key(invocation: BaseInvocation) -> int:
|
def create_key(self, invocation: BaseInvocation) -> int:
|
||||||
"""Gets the key for the invocation's cache item"""
|
"""Gets the key for the invocation's cache item"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -61,7 +61,9 @@ class MemoryInvocationCache(InvocationCacheBase):
|
|||||||
self._delete_oldest_access(number_to_delete)
|
self._delete_oldest_access(number_to_delete)
|
||||||
self._cache[key] = CachedItem(
|
self._cache[key] = CachedItem(
|
||||||
invocation_output,
|
invocation_output,
|
||||||
invocation_output.model_dump_json(warnings=False, exclude_defaults=True, exclude_unset=True),
|
invocation_output.model_dump_json(
|
||||||
|
warnings=False, exclude_defaults=True, exclude_unset=True, include={"type"}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
||||||
@ -79,7 +81,7 @@ class MemoryInvocationCache(InvocationCacheBase):
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
return self._delete(key)
|
return self._delete(key)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self, *args, **kwargs) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._max_cache_size == 0:
|
if self._max_cache_size == 0:
|
||||||
return
|
return
|
||||||
|
@ -25,7 +25,6 @@ if TYPE_CHECKING:
|
|||||||
from .images.images_base import ImageServiceABC
|
from .images.images_base import ImageServiceABC
|
||||||
from .invocation_cache.invocation_cache_base import InvocationCacheBase
|
from .invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||||
from .model_images.model_images_base import ModelImageFileStorageBase
|
|
||||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||||
from .names.names_base import NameServiceBase
|
from .names.names_base import NameServiceBase
|
||||||
from .session_processor.session_processor_base import SessionProcessorBase
|
from .session_processor.session_processor_base import SessionProcessorBase
|
||||||
@ -50,7 +49,6 @@ class InvocationServices:
|
|||||||
image_files: "ImageFileStorageBase",
|
image_files: "ImageFileStorageBase",
|
||||||
image_records: "ImageRecordStorageBase",
|
image_records: "ImageRecordStorageBase",
|
||||||
logger: "Logger",
|
logger: "Logger",
|
||||||
model_images: "ModelImageFileStorageBase",
|
|
||||||
model_manager: "ModelManagerServiceBase",
|
model_manager: "ModelManagerServiceBase",
|
||||||
download_queue: "DownloadQueueServiceBase",
|
download_queue: "DownloadQueueServiceBase",
|
||||||
performance_statistics: "InvocationStatsServiceBase",
|
performance_statistics: "InvocationStatsServiceBase",
|
||||||
@ -74,7 +72,6 @@ class InvocationServices:
|
|||||||
self.image_files = image_files
|
self.image_files = image_files
|
||||||
self.image_records = image_records
|
self.image_records = image_records
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.model_images = model_images
|
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.download_queue = download_queue
|
self.download_queue = download_queue
|
||||||
self.performance_statistics = performance_statistics
|
self.performance_statistics = performance_statistics
|
||||||
|
@ -1,33 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
|
||||||
|
|
||||||
|
|
||||||
class ModelImageFileStorageBase(ABC):
|
|
||||||
"""Low-level service responsible for storing and retrieving image files."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, model_key: str) -> PILImageType:
|
|
||||||
"""Retrieves a model image as PIL Image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_path(self, model_key: str) -> Path:
|
|
||||||
"""Gets the internal path to a model image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_url(self, model_key: str) -> str | None:
|
|
||||||
"""Gets the URL to fetch a model image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(self, image: PILImageType, model_key: str) -> None:
|
|
||||||
"""Saves a model image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, model_key: str) -> None:
|
|
||||||
"""Deletes a model image."""
|
|
||||||
pass
|
|
@ -1,20 +0,0 @@
|
|||||||
# TODO: Should these excpetions subclass existing python exceptions?
|
|
||||||
class ModelImageFileNotFoundException(Exception):
|
|
||||||
"""Raised when an image file is not found in storage."""
|
|
||||||
|
|
||||||
def __init__(self, message="Model image file not found"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelImageFileSaveException(Exception):
|
|
||||||
"""Raised when an image cannot be saved."""
|
|
||||||
|
|
||||||
def __init__(self, message="Model image file not saved"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelImageFileDeleteException(Exception):
|
|
||||||
"""Raised when an image cannot be deleted."""
|
|
||||||
|
|
||||||
def __init__(self, message="Model image file not deleted"):
|
|
||||||
super().__init__(message)
|
|
@ -1,85 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
from PIL.Image import Image as PILImageType
|
|
||||||
from send2trash import send2trash
|
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
|
||||||
from invokeai.app.util.misc import uuid_string
|
|
||||||
from invokeai.app.util.thumbnails import make_thumbnail
|
|
||||||
|
|
||||||
from .model_images_base import ModelImageFileStorageBase
|
|
||||||
from .model_images_common import (
|
|
||||||
ModelImageFileDeleteException,
|
|
||||||
ModelImageFileNotFoundException,
|
|
||||||
ModelImageFileSaveException,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelImageFileStorageDisk(ModelImageFileStorageBase):
|
|
||||||
"""Stores images on disk"""
|
|
||||||
|
|
||||||
def __init__(self, model_images_folder: Path):
|
|
||||||
self._model_images_folder = model_images_folder
|
|
||||||
self._validate_storage_folders()
|
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
|
||||||
self._invoker = invoker
|
|
||||||
|
|
||||||
def get(self, model_key: str) -> PILImageType:
|
|
||||||
try:
|
|
||||||
path = self.get_path(model_key)
|
|
||||||
|
|
||||||
if not self._validate_path(path):
|
|
||||||
raise ModelImageFileNotFoundException
|
|
||||||
|
|
||||||
return Image.open(path)
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
raise ModelImageFileNotFoundException from e
|
|
||||||
|
|
||||||
def save(self, image: PILImageType, model_key: str) -> None:
|
|
||||||
try:
|
|
||||||
self._validate_storage_folders()
|
|
||||||
image_path = self._model_images_folder / (model_key + ".webp")
|
|
||||||
thumbnail = make_thumbnail(image, 256)
|
|
||||||
thumbnail.save(image_path, format="webp")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise ModelImageFileSaveException from e
|
|
||||||
|
|
||||||
def get_path(self, model_key: str) -> Path:
|
|
||||||
path = self._model_images_folder / (model_key + ".webp")
|
|
||||||
|
|
||||||
return path
|
|
||||||
|
|
||||||
def get_url(self, model_key: str) -> str | None:
|
|
||||||
path = self.get_path(model_key)
|
|
||||||
if not self._validate_path(path):
|
|
||||||
return
|
|
||||||
|
|
||||||
url = self._invoker.services.urls.get_model_image_url(model_key)
|
|
||||||
|
|
||||||
# The image URL never changes, so we must add random query string to it to prevent caching
|
|
||||||
url += f"?{uuid_string()}"
|
|
||||||
|
|
||||||
return url
|
|
||||||
|
|
||||||
def delete(self, model_key: str) -> None:
|
|
||||||
try:
|
|
||||||
path = self.get_path(model_key)
|
|
||||||
|
|
||||||
if not self._validate_path(path):
|
|
||||||
raise ModelImageFileNotFoundException
|
|
||||||
|
|
||||||
send2trash(path)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise ModelImageFileDeleteException from e
|
|
||||||
|
|
||||||
def _validate_path(self, path: Path) -> bool:
|
|
||||||
"""Validates the path given for an image."""
|
|
||||||
return path.exists()
|
|
||||||
|
|
||||||
def _validate_storage_folders(self) -> None:
|
|
||||||
"""Checks if the required folders exist and create them if they don't"""
|
|
||||||
self._model_images_folder.mkdir(parents=True, exist_ok=True)
|
|
@ -1,6 +1,7 @@
|
|||||||
"""Initialization file for model install service package."""
|
"""Initialization file for model install service package."""
|
||||||
|
|
||||||
from .model_install_base import (
|
from .model_install_base import (
|
||||||
|
CivitaiModelSource,
|
||||||
HFModelSource,
|
HFModelSource,
|
||||||
InstallStatus,
|
InstallStatus,
|
||||||
LocalModelSource,
|
LocalModelSource,
|
||||||
@ -22,4 +23,5 @@ __all__ = [
|
|||||||
"LocalModelSource",
|
"LocalModelSource",
|
||||||
"HFModelSource",
|
"HFModelSource",
|
||||||
"URLModelSource",
|
"URLModelSource",
|
||||||
|
"CivitaiModelSource",
|
||||||
]
|
]
|
||||||
|
@ -91,6 +91,21 @@ class LocalModelSource(StringLikeSource):
|
|||||||
return Path(self.path).as_posix()
|
return Path(self.path).as_posix()
|
||||||
|
|
||||||
|
|
||||||
|
class CivitaiModelSource(StringLikeSource):
|
||||||
|
"""A Civitai version id, with optional variant and access token."""
|
||||||
|
|
||||||
|
version_id: int
|
||||||
|
variant: Optional[ModelRepoVariant] = None
|
||||||
|
access_token: Optional[str] = None
|
||||||
|
type: Literal["civitai"] = "civitai"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Return string version of repoid when string rep needed."""
|
||||||
|
base: str = str(self.version_id)
|
||||||
|
base += f" ({self.variant})" if self.variant else ""
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
class HFModelSource(StringLikeSource):
|
class HFModelSource(StringLikeSource):
|
||||||
"""
|
"""
|
||||||
A HuggingFace repo_id with optional variant, sub-folder and access token.
|
A HuggingFace repo_id with optional variant, sub-folder and access token.
|
||||||
@ -131,11 +146,14 @@ class URLModelSource(StringLikeSource):
|
|||||||
return str(self.url)
|
return str(self.url)
|
||||||
|
|
||||||
|
|
||||||
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
|
ModelSource = Annotated[
|
||||||
|
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
|
||||||
|
]
|
||||||
|
|
||||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||||
URLModelSource: ModelSourceType.Url,
|
URLModelSource: ModelSourceType.Url,
|
||||||
HFModelSource: ModelSourceType.HFRepoID,
|
HFModelSource: ModelSourceType.HFRepoID,
|
||||||
|
CivitaiModelSource: ModelSourceType.CivitAI,
|
||||||
LocalModelSource: ModelSourceType.Path,
|
LocalModelSource: ModelSourceType.Path,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,7 +12,6 @@ from tempfile import mkdtemp
|
|||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
from huggingface_hub import HfFolder
|
from huggingface_hub import HfFolder
|
||||||
from omegaconf import DictConfig, OmegaConf
|
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from requests import Session
|
from requests import Session
|
||||||
|
|
||||||
@ -34,11 +33,12 @@ from invokeai.backend.model_manager.config import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata import (
|
from invokeai.backend.model_manager.metadata import (
|
||||||
AnyModelRepoMetadata,
|
AnyModelRepoMetadata,
|
||||||
|
CivitaiMetadataFetch,
|
||||||
HuggingFaceMetadataFetch,
|
HuggingFaceMetadataFetch,
|
||||||
ModelMetadataWithFiles,
|
ModelMetadataWithFiles,
|
||||||
RemoteModelFile,
|
RemoteModelFile,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
|
from invokeai.backend.model_manager.metadata.metadata_base import CivitaiMetadata, HuggingFaceMetadata
|
||||||
from invokeai.backend.model_manager.probe import ModelProbe
|
from invokeai.backend.model_manager.probe import ModelProbe
|
||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||||
@ -46,6 +46,7 @@ from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
|||||||
|
|
||||||
from .model_install_base import (
|
from .model_install_base import (
|
||||||
MODEL_SOURCE_TO_TYPE_MAP,
|
MODEL_SOURCE_TO_TYPE_MAP,
|
||||||
|
CivitaiModelSource,
|
||||||
HFModelSource,
|
HFModelSource,
|
||||||
InstallStatus,
|
InstallStatus,
|
||||||
LocalModelSource,
|
LocalModelSource,
|
||||||
@ -116,7 +117,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
raise Exception("Attempt to start the installer service twice")
|
raise Exception("Attempt to start the installer service twice")
|
||||||
self._start_installer_thread()
|
self._start_installer_thread()
|
||||||
self._remove_dangling_install_dirs()
|
self._remove_dangling_install_dirs()
|
||||||
self._migrate_yaml()
|
|
||||||
self.sync_to_config()
|
self.sync_to_config()
|
||||||
|
|
||||||
def stop(self, invoker: Optional[Invoker] = None) -> None:
|
def stop(self, invoker: Optional[Invoker] = None) -> None:
|
||||||
@ -199,16 +199,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
elif re.match(r"^https?://[^/]+", source):
|
elif re.match(r"^https?://[^/]+", source):
|
||||||
# Pull the token from config if it exists and matches the URL
|
|
||||||
_token = access_token
|
|
||||||
if _token is None:
|
|
||||||
for pair in self.app_config.remote_api_tokens or []:
|
|
||||||
if re.search(pair.url_regex, source):
|
|
||||||
_token = pair.token
|
|
||||||
break
|
|
||||||
source_obj = URLModelSource(
|
source_obj = URLModelSource(
|
||||||
url=AnyHttpUrl(source),
|
url=AnyHttpUrl(source),
|
||||||
access_token=_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported model source: '{source}'")
|
raise ValueError(f"Unsupported model source: '{source}'")
|
||||||
@ -223,6 +216,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
if isinstance(source, LocalModelSource):
|
if isinstance(source, LocalModelSource):
|
||||||
install_job = self._import_local_model(source, config)
|
install_job = self._import_local_model(source, config)
|
||||||
self._install_queue.put(install_job) # synchronously install
|
self._install_queue.put(install_job) # synchronously install
|
||||||
|
elif isinstance(source, CivitaiModelSource):
|
||||||
|
install_job = self._import_from_civitai(source, config)
|
||||||
elif isinstance(source, HFModelSource):
|
elif isinstance(source, HFModelSource):
|
||||||
install_job = self._import_from_hf(source, config)
|
install_job = self._import_from_hf(source, config)
|
||||||
elif isinstance(source, URLModelSource):
|
elif isinstance(source, URLModelSource):
|
||||||
@ -289,52 +284,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._logger.info(f"{len(installed)} new models registered")
|
self._logger.info(f"{len(installed)} new models registered")
|
||||||
self._logger.info("Model installer (re)initialized")
|
self._logger.info("Model installer (re)initialized")
|
||||||
|
|
||||||
def _migrate_yaml(self) -> None:
|
|
||||||
db_models = self.record_store.all_models()
|
|
||||||
try:
|
|
||||||
yaml = self._get_yaml()
|
|
||||||
except OSError:
|
|
||||||
return
|
|
||||||
|
|
||||||
yaml_metadata = yaml.pop("__metadata__")
|
|
||||||
yaml_version = yaml_metadata.get("version")
|
|
||||||
|
|
||||||
if yaml_version != "3.0.0":
|
|
||||||
raise ValueError(
|
|
||||||
f"Attempted migration of unsupported `models.yaml` v{yaml_version}. Only v3.0.0 is supported. Exiting."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._logger.info(
|
|
||||||
f"Starting one-time migration of {len(yaml.items())} models from `models.yaml` to database. This may take a few minutes."
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(db_models) == 0 and len(yaml.items()) != 0:
|
|
||||||
for model_key, stanza in yaml.items():
|
|
||||||
_, _, model_name = str(model_key).split("/")
|
|
||||||
model_path = Path(stanza["path"])
|
|
||||||
if not model_path.is_absolute():
|
|
||||||
model_path = self._app_config.models_path / model_path
|
|
||||||
model_path = model_path.resolve()
|
|
||||||
|
|
||||||
config: dict[str, Any] = {}
|
|
||||||
config["name"] = model_name
|
|
||||||
config["description"] = stanza.get("description")
|
|
||||||
config["config_path"] = stanza.get("config")
|
|
||||||
|
|
||||||
try:
|
|
||||||
id = self.register_path(model_path=model_path, config=config)
|
|
||||||
self._logger.info(f"Migrated {model_name} with id {id}")
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
|
|
||||||
|
|
||||||
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
|
|
||||||
yaml_path = self._app_config.model_conf_path
|
|
||||||
yaml_path.rename(yaml_path.with_suffix(".yaml.bak"))
|
|
||||||
|
|
||||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||||
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
|
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
|
||||||
callback = self._scan_install if install else self._scan_register
|
callback = self._scan_install if install else self._scan_register
|
||||||
search = ModelSearch(on_model_found=callback)
|
search = ModelSearch(on_model_found=callback, config=self._app_config)
|
||||||
self._models_installed.clear()
|
self._models_installed.clear()
|
||||||
search.search(scan_dir)
|
search.search(scan_dir)
|
||||||
return list(self._models_installed)
|
return list(self._models_installed)
|
||||||
@ -346,7 +299,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||||
model = self.record_store.get_model(key)
|
model = self.record_store.get_model(key)
|
||||||
models_dir = self.app_config.models_path
|
models_dir = self.app_config.models_path
|
||||||
model_path = Path(model.path)
|
model_path = models_dir / model.path
|
||||||
if model_path.is_relative_to(models_dir):
|
if model_path.is_relative_to(models_dir):
|
||||||
self.unconditionally_delete(key)
|
self.unconditionally_delete(key)
|
||||||
else:
|
else:
|
||||||
@ -354,11 +307,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
||||||
model = self.record_store.get_model(key)
|
model = self.record_store.get_model(key)
|
||||||
model_path = Path(model.path)
|
path = self.app_config.models_path / model.path
|
||||||
if model_path.is_dir():
|
if path.is_dir():
|
||||||
rmtree(model_path)
|
rmtree(path)
|
||||||
else:
|
else:
|
||||||
model_path.unlink()
|
path.unlink()
|
||||||
self.unregister(key)
|
self.unregister(key)
|
||||||
|
|
||||||
def download_and_cache(
|
def download_and_cache(
|
||||||
@ -428,8 +381,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
job.config_in["source"] = str(job.source)
|
job.config_in["source"] = str(job.source)
|
||||||
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||||
# enter the metadata, if there is any
|
# enter the metadata, if there is any
|
||||||
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
|
if isinstance(job.source_metadata, (CivitaiMetadata, HuggingFaceMetadata)):
|
||||||
job.config_in["source_api_response"] = job.source_metadata.api_response
|
job.config_in["source_api_response"] = job.source_metadata.api_response
|
||||||
|
if isinstance(job.source_metadata, CivitaiMetadata) and job.source_metadata.trigger_phrases:
|
||||||
|
job.config_in["trigger_phrases"] = job.source_metadata.trigger_phrases
|
||||||
|
|
||||||
if job.inplace:
|
if job.inplace:
|
||||||
key = self.register_path(job.local_path, job.config_in)
|
key = self.register_path(job.local_path, job.config_in)
|
||||||
@ -495,7 +450,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
|
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
|
||||||
for cur_base_model in BaseModelType:
|
for cur_base_model in BaseModelType:
|
||||||
for cur_model_type in ModelType:
|
for cur_model_type in ModelType:
|
||||||
models_dir = self._app_config.models_path / Path(cur_base_model.value, cur_model_type.value)
|
models_dir = Path(cur_base_model.value, cur_model_type.value)
|
||||||
installed.update(self.scan_directory(models_dir))
|
installed.update(self.scan_directory(models_dir))
|
||||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||||
|
|
||||||
@ -514,20 +469,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
old_path = Path(model.path)
|
old_path = Path(model.path)
|
||||||
models_dir = self.app_config.models_path
|
models_dir = self.app_config.models_path
|
||||||
|
|
||||||
try:
|
if not old_path.is_relative_to(models_dir):
|
||||||
old_path.relative_to(models_dir)
|
|
||||||
return model
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
new_path = models_dir / model.base.value / model.type.value / old_path.name
|
|
||||||
|
|
||||||
if old_path == new_path:
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
new_path = models_dir / model.base.value / model.type.value / model.name
|
||||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||||
new_path = self._move_model(old_path, new_path)
|
new_path = self._move_model(old_path, new_path)
|
||||||
model.path = new_path.as_posix()
|
model.path = new_path.relative_to(models_dir).as_posix()
|
||||||
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
|
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -590,14 +538,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
info = info or ModelProbe.probe(model_path, config)
|
info = info or ModelProbe.probe(model_path, config)
|
||||||
|
|
||||||
model_path = model_path.resolve()
|
model_path = model_path.absolute()
|
||||||
|
if model_path.is_relative_to(self.app_config.models_path):
|
||||||
|
model_path = model_path.relative_to(self.app_config.models_path)
|
||||||
|
|
||||||
info.path = model_path.as_posix()
|
info.path = model_path.as_posix()
|
||||||
|
|
||||||
# add 'main' specific fields
|
# add 'main' specific fields
|
||||||
if isinstance(info, CheckpointConfigBase):
|
if isinstance(info, CheckpointConfigBase):
|
||||||
|
# make config relative to our root
|
||||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve()
|
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config_path).resolve()
|
||||||
info.config_path = legacy_conf.as_posix()
|
info.config_path = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||||
self.record_store.add_model(info)
|
self.record_store.add_model(info)
|
||||||
return info.key
|
return info.key
|
||||||
|
|
||||||
@ -607,16 +558,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._next_job_id += 1
|
self._next_job_id += 1
|
||||||
return id
|
return id
|
||||||
|
|
||||||
# --------------------------------------------------------------------------------------------
|
|
||||||
# Internal functions that manage the old yaml config
|
|
||||||
# --------------------------------------------------------------------------------------------
|
|
||||||
def _get_yaml(self) -> DictConfig:
|
|
||||||
"""Fetch the models.yaml DictConfig for this installation."""
|
|
||||||
yaml_path = self._app_config.model_conf_path
|
|
||||||
omegaconf = OmegaConf.load(yaml_path)
|
|
||||||
assert isinstance(omegaconf, DictConfig)
|
|
||||||
return omegaconf
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _guess_variant() -> Optional[ModelRepoVariant]:
|
def _guess_variant() -> Optional[ModelRepoVariant]:
|
||||||
"""Guess the best HuggingFace variant type to download."""
|
"""Guess the best HuggingFace variant type to download."""
|
||||||
@ -632,6 +573,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
inplace=source.inplace or False,
|
inplace=source.inplace or False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||||
|
if not source.access_token:
|
||||||
|
self._logger.info("No Civitai access token provided; some models may not be downloadable.")
|
||||||
|
metadata = CivitaiMetadataFetch(self._session, self.app_config.get_config().civitai_api_key).from_id(
|
||||||
|
str(source.version_id)
|
||||||
|
)
|
||||||
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
|
remote_files = metadata.download_urls(session=self._session)
|
||||||
|
return self._import_remote_model(source=source, config=config, metadata=metadata, remote_files=remote_files)
|
||||||
|
|
||||||
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||||
# Add user's cached access token to HuggingFace requests
|
# Add user's cached access token to HuggingFace requests
|
||||||
source.access_token = source.access_token or HfFolder.get_token()
|
source.access_token = source.access_token or HfFolder.get_token()
|
||||||
@ -654,7 +605,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||||
# URLs from HuggingFace will be handled specially
|
# URLs from Civitai or HuggingFace will be handled specially
|
||||||
metadata = None
|
metadata = None
|
||||||
fetcher = None
|
fetcher = None
|
||||||
try:
|
try:
|
||||||
@ -662,6 +613,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
kwargs: dict[str, Any] = {"session": self._session}
|
kwargs: dict[str, Any] = {"session": self._session}
|
||||||
|
if fetcher is CivitaiMetadataFetch:
|
||||||
|
kwargs["api_key"] = self._app_config.get_config().civitai_api_key
|
||||||
if fetcher is not None:
|
if fetcher is not None:
|
||||||
metadata = fetcher(**kwargs).from_url(source.url)
|
metadata = fetcher(**kwargs).from_url(source.url)
|
||||||
self._logger.debug(f"metadata={metadata}")
|
self._logger.debug(f"metadata={metadata}")
|
||||||
@ -678,7 +631,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def _import_remote_model(
|
def _import_remote_model(
|
||||||
self,
|
self,
|
||||||
source: HFModelSource | URLModelSource,
|
source: HFModelSource | CivitaiModelSource | URLModelSource,
|
||||||
remote_files: List[RemoteModelFile],
|
remote_files: List[RemoteModelFile],
|
||||||
metadata: Optional[AnyModelRepoMetadata],
|
metadata: Optional[AnyModelRepoMetadata],
|
||||||
config: Optional[Dict[str, Any]],
|
config: Optional[Dict[str, Any]],
|
||||||
@ -896,6 +849,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_fetcher_from_url(url: str):
|
def get_fetcher_from_url(url: str):
|
||||||
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
if re.match(r"^https?://civitai.com/", url.lower()):
|
||||||
|
return CivitaiMetadataFetch
|
||||||
|
elif re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||||
return HuggingFaceMetadataFetch
|
return HuggingFaceMetadataFetch
|
||||||
raise ValueError(f"Unsupported model source: '{url}'")
|
raise ValueError(f"Unsupported model source: '{url}'")
|
||||||
|
@ -1,11 +1,15 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||||
|
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||||
|
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||||
|
|
||||||
from ..config import InvokeAIAppConfig
|
from ..config import InvokeAIAppConfig
|
||||||
from ..download import DownloadQueueServiceBase
|
from ..download import DownloadQueueServiceBase
|
||||||
@ -66,3 +70,32 @@ class ModelManagerServiceBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def stop(self, invoker: Invoker) -> None:
|
def stop(self, invoker: Invoker) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_model_by_config(
|
||||||
|
self,
|
||||||
|
model_config: AnyModelConfig,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
context_data: Optional[InvocationContextData] = None,
|
||||||
|
) -> LoadedModel:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_model_by_key(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
context_data: Optional[InvocationContextData] = None,
|
||||||
|
) -> LoadedModel:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_model_by_attr(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
submodel: Optional[SubModelType] = None,
|
||||||
|
context_data: Optional[InvocationContextData] = None,
|
||||||
|
) -> LoadedModel:
|
||||||
|
pass
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
"""Implementation of ModelManagerServiceBase."""
|
"""Implementation of ModelManagerServiceBase."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||||
|
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
|
||||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
@ -14,7 +18,7 @@ from ..download import DownloadQueueServiceBase
|
|||||||
from ..events.events_base import EventServiceBase
|
from ..events.events_base import EventServiceBase
|
||||||
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
||||||
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
||||||
from ..model_records import ModelRecordServiceBase
|
from ..model_records import ModelRecordServiceBase, UnknownModelException
|
||||||
from .model_manager_base import ModelManagerServiceBase
|
from .model_manager_base import ModelManagerServiceBase
|
||||||
|
|
||||||
|
|
||||||
@ -60,6 +64,56 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
if hasattr(service, "stop"):
|
if hasattr(service, "stop"):
|
||||||
service.stop(invoker)
|
service.stop(invoker)
|
||||||
|
|
||||||
|
def load_model_by_config(
|
||||||
|
self,
|
||||||
|
model_config: AnyModelConfig,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
context_data: Optional[InvocationContextData] = None,
|
||||||
|
) -> LoadedModel:
|
||||||
|
return self.load.load_model(model_config, submodel_type, context_data)
|
||||||
|
|
||||||
|
def load_model_by_key(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
context_data: Optional[InvocationContextData] = None,
|
||||||
|
) -> LoadedModel:
|
||||||
|
config = self.store.get_model(key)
|
||||||
|
return self.load.load_model(config, submodel_type, context_data)
|
||||||
|
|
||||||
|
def load_model_by_attr(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
submodel: Optional[SubModelType] = None,
|
||||||
|
context_data: Optional[InvocationContextData] = None,
|
||||||
|
) -> LoadedModel:
|
||||||
|
"""
|
||||||
|
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||||
|
|
||||||
|
This is provided for API compatability with the get_model() method
|
||||||
|
in the original model manager. However, note that LoadedModel is
|
||||||
|
not the same as the original ModelInfo that ws returned.
|
||||||
|
|
||||||
|
:param model_name: Name of to be fetched.
|
||||||
|
:param base_model: Base model
|
||||||
|
:param model_type: Type of the model
|
||||||
|
:param submodel: For main (pipeline models), the submodel to fetch
|
||||||
|
:param context: The invocation context.
|
||||||
|
|
||||||
|
Exceptions: UnknownModelException -- model with this key not known
|
||||||
|
NotImplementedException -- a model loader was not provided at initialization time
|
||||||
|
ValueError -- more than one model matches this combination
|
||||||
|
"""
|
||||||
|
configs = self.store.search_by_attr(model_name, base_model, model_type)
|
||||||
|
if len(configs) == 0:
|
||||||
|
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
|
||||||
|
elif len(configs) > 1:
|
||||||
|
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
|
||||||
|
else:
|
||||||
|
return self.load.load_model(configs[0], submodel, context_data)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_model_manager(
|
def build_model_manager(
|
||||||
cls,
|
cls,
|
||||||
|
@ -18,12 +18,7 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType
|
||||||
ControlAdapterDefaultSettings,
|
|
||||||
MainModelDefaultSettings,
|
|
||||||
ModelVariantType,
|
|
||||||
SchedulerPredictionType,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DuplicateModelException(Exception):
|
class DuplicateModelException(Exception):
|
||||||
@ -73,7 +68,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
|||||||
description: Optional[str] = Field(description="Model description", default=None)
|
description: Optional[str] = Field(description="Model description", default=None)
|
||||||
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
||||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||||
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
|
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||||
description="Default settings for this model", default=None
|
description="Default settings for this model", default=None
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -84,7 +79,6 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
|||||||
description="The prediction type of the model.", default=None
|
description="The prediction type of the model.", default=None
|
||||||
)
|
)
|
||||||
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
|
upcast_attention: Optional[bool] = Field(description="Whether to upcast attention.", default=None)
|
||||||
config_path: Optional[str] = Field(description="Path to config file for model", default=None)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelRecordServiceBase(ABC):
|
class ModelRecordServiceBase(ABC):
|
||||||
@ -135,17 +129,6 @@ class ModelRecordServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
|
||||||
"""
|
|
||||||
Retrieve the configuration for the indicated model.
|
|
||||||
|
|
||||||
:param hash: Hash of model config to be fetched.
|
|
||||||
|
|
||||||
Exceptions: UnknownModelException
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_models(
|
def list_models(
|
||||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||||
|
@ -203,21 +203,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
|
||||||
with self._db.lock:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT config, strftime('%s',updated_at) FROM models
|
|
||||||
WHERE hash=?;
|
|
||||||
""",
|
|
||||||
(hash,),
|
|
||||||
)
|
|
||||||
rows = self._cursor.fetchone()
|
|
||||||
if not rows:
|
|
||||||
raise UnknownModelException("model not found")
|
|
||||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
|
||||||
return model
|
|
||||||
|
|
||||||
def exists(self, key: str) -> bool:
|
def exists(self, key: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Return True if a model with the indicated key exists in the databse.
|
Return True if a model with the indicated key exists in the databse.
|
||||||
@ -242,7 +227,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
base_model: Optional[BaseModelType] = None,
|
base_model: Optional[BaseModelType] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
model_format: Optional[ModelFormat] = None,
|
model_format: Optional[ModelFormat] = None,
|
||||||
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
|
|
||||||
) -> List[AnyModelConfig]:
|
) -> List[AnyModelConfig]:
|
||||||
"""
|
"""
|
||||||
Return models matching name, base and/or type.
|
Return models matching name, base and/or type.
|
||||||
@ -251,21 +235,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
:param base_model: Filter by base model (optional)
|
:param base_model: Filter by base model (optional)
|
||||||
:param model_type: Filter by type of model (optional)
|
:param model_type: Filter by type of model (optional)
|
||||||
:param model_format: Filter by model format (e.g. "diffusers") (optional)
|
:param model_format: Filter by model format (e.g. "diffusers") (optional)
|
||||||
:param order_by: Result order
|
|
||||||
|
|
||||||
If none of the optional filters are passed, will return all
|
If none of the optional filters are passed, will return all
|
||||||
models in the database.
|
models in the database.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert isinstance(order_by, ModelRecordOrderBy)
|
|
||||||
ordering = {
|
|
||||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
|
||||||
ModelRecordOrderBy.Type: "type",
|
|
||||||
ModelRecordOrderBy.Base: "base",
|
|
||||||
ModelRecordOrderBy.Name: "name",
|
|
||||||
ModelRecordOrderBy.Format: "format",
|
|
||||||
}
|
|
||||||
|
|
||||||
where_clause: list[str] = []
|
where_clause: list[str] = []
|
||||||
bindings: list[str] = []
|
bindings: list[str] = []
|
||||||
if model_name:
|
if model_name:
|
||||||
@ -284,10 +257,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
f"""--sql
|
||||||
SELECT config, strftime('%s',updated_at)
|
SELECT config, strftime('%s',updated_at) FROM models
|
||||||
FROM models
|
{where};
|
||||||
{where}
|
|
||||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
|
|
||||||
""",
|
""",
|
||||||
tuple(bindings),
|
tuple(bindings),
|
||||||
)
|
)
|
||||||
@ -333,7 +304,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
"""Return a paginated summary listing of each model in the database."""
|
"""Return a paginated summary listing of each model in the database."""
|
||||||
assert isinstance(order_by, ModelRecordOrderBy)
|
assert isinstance(order_by, ModelRecordOrderBy)
|
||||||
ordering = {
|
ordering = {
|
||||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
ModelRecordOrderBy.Default: "type, base, format, name",
|
||||||
ModelRecordOrderBy.Type: "type",
|
ModelRecordOrderBy.Type: "type",
|
||||||
ModelRecordOrderBy.Base: "base",
|
ModelRecordOrderBy.Base: "base",
|
||||||
ModelRecordOrderBy.Name: "name",
|
ModelRecordOrderBy.Name: "name",
|
||||||
|
@ -1,6 +1,35 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from threading import Event
|
||||||
|
|
||||||
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||||
|
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||||
|
|
||||||
|
|
||||||
|
class SessionRunnerBase(ABC):
|
||||||
|
"""
|
||||||
|
Base class for session runner.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def start(self, services: InvocationServices, cancel_event: Event) -> None:
|
||||||
|
"""Starts the session runner"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run(self, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Runs the session"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def complete(self, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Completes the session"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Runs an already prepared node on the session"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SessionProcessorBase(ABC):
|
class SessionProcessorBase(ABC):
|
||||||
|
@ -2,13 +2,14 @@ import traceback
|
|||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from threading import BoundedSemaphore, Thread
|
from threading import BoundedSemaphore, Thread
|
||||||
from threading import Event as ThreadEvent
|
from threading import Event as ThreadEvent
|
||||||
from typing import Optional
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.typing import Event as FastAPIEvent
|
from fastapi_events.typing import Event as FastAPIEvent
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||||
@ -16,15 +17,164 @@ from invokeai.app.services.shared.invocation_context import InvocationContextDat
|
|||||||
from invokeai.app.util.profiler import Profiler
|
from invokeai.app.util.profiler import Profiler
|
||||||
|
|
||||||
from ..invoker import Invoker
|
from ..invoker import Invoker
|
||||||
from .session_processor_base import SessionProcessorBase
|
from .session_processor_base import SessionProcessorBase, SessionRunnerBase
|
||||||
from .session_processor_common import SessionProcessorStatus
|
from .session_processor_common import SessionProcessorStatus
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultSessionRunner(SessionRunnerBase):
|
||||||
|
"""Processes a single session's invocations"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||||
|
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||||
|
):
|
||||||
|
self.on_before_run_node = on_before_run_node
|
||||||
|
self.on_after_run_node = on_after_run_node
|
||||||
|
|
||||||
|
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
|
||||||
|
"""Start the session runner"""
|
||||||
|
self.services = services
|
||||||
|
self.cancel_event = cancel_event
|
||||||
|
|
||||||
|
def run(self, queue_item: SessionQueueItem):
|
||||||
|
"""Run the graph"""
|
||||||
|
if not queue_item.session:
|
||||||
|
raise ValueError("Queue item has no session")
|
||||||
|
# Loop over invocations until the session is complete or canceled
|
||||||
|
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
|
||||||
|
# Prepare the next node
|
||||||
|
invocation = queue_item.session.next()
|
||||||
|
if invocation is None:
|
||||||
|
# If there are no more invocations, complete the graph
|
||||||
|
break
|
||||||
|
# Build invocation context (the node-facing API
|
||||||
|
self.run_node(invocation.id, queue_item)
|
||||||
|
self.complete(queue_item)
|
||||||
|
|
||||||
|
def complete(self, queue_item: SessionQueueItem):
|
||||||
|
"""Complete the graph"""
|
||||||
|
self.services.events.emit_graph_execution_complete(
|
||||||
|
queue_batch_id=queue_item.batch_id,
|
||||||
|
queue_item_id=queue_item.item_id,
|
||||||
|
queue_id=queue_item.queue_id,
|
||||||
|
graph_execution_state_id=queue_item.session.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||||
|
"""Run before a node is executed"""
|
||||||
|
# Send starting event
|
||||||
|
self.services.events.emit_invocation_started(
|
||||||
|
queue_batch_id=queue_item.batch_id,
|
||||||
|
queue_item_id=queue_item.item_id,
|
||||||
|
queue_id=queue_item.queue_id,
|
||||||
|
graph_execution_state_id=queue_item.session_id,
|
||||||
|
node=invocation.model_dump(),
|
||||||
|
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
|
)
|
||||||
|
if self.on_before_run_node is not None:
|
||||||
|
self.on_before_run_node(invocation, queue_item)
|
||||||
|
|
||||||
|
def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||||
|
"""Run after a node is executed"""
|
||||||
|
if self.on_after_run_node is not None:
|
||||||
|
self.on_after_run_node(invocation, queue_item)
|
||||||
|
|
||||||
|
def run_node(self, node_id: str, queue_item: SessionQueueItem):
|
||||||
|
"""Run a single node in the graph"""
|
||||||
|
# If this error raises a NodeNotFoundError that's handled by the processor
|
||||||
|
invocation = queue_item.session.execution_graph.get_node(node_id)
|
||||||
|
try:
|
||||||
|
self._on_before_run_node(invocation, queue_item)
|
||||||
|
data = InvocationContextData(
|
||||||
|
invocation=invocation,
|
||||||
|
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
|
queue_item=queue_item,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||||
|
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||||
|
context = build_invocation_context(
|
||||||
|
data=data,
|
||||||
|
services=self.services,
|
||||||
|
cancel_event=self.cancel_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invoke the node
|
||||||
|
outputs = invocation.invoke_internal(context=context, services=self.services)
|
||||||
|
|
||||||
|
# Save outputs and history
|
||||||
|
queue_item.session.complete(invocation.id, outputs)
|
||||||
|
|
||||||
|
self._on_after_run_node(invocation, queue_item)
|
||||||
|
# Send complete event on successful runs
|
||||||
|
self.services.events.emit_invocation_complete(
|
||||||
|
queue_batch_id=queue_item.batch_id,
|
||||||
|
queue_item_id=queue_item.item_id,
|
||||||
|
queue_id=queue_item.queue_id,
|
||||||
|
graph_execution_state_id=queue_item.session.id,
|
||||||
|
node=invocation.model_dump(),
|
||||||
|
source_node_id=data.source_invocation_id,
|
||||||
|
result=outputs.model_dump(),
|
||||||
|
)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
# TODO(MM2): Create an event for this
|
||||||
|
pass
|
||||||
|
except CanceledException:
|
||||||
|
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||||
|
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||||
|
# be able to cancel them mid-execution.
|
||||||
|
#
|
||||||
|
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||||
|
# is executed after each step. This step callback checks if the canceled event is set,
|
||||||
|
# then raises a CanceledException to stop execution immediately.
|
||||||
|
#
|
||||||
|
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||||
|
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
error = traceback.format_exc()
|
||||||
|
|
||||||
|
# Save error
|
||||||
|
queue_item.session.set_node_error(invocation.id, error)
|
||||||
|
self.services.logger.error(
|
||||||
|
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
|
||||||
|
)
|
||||||
|
self.services.logger.error(error)
|
||||||
|
|
||||||
|
# Send error event
|
||||||
|
self.services.events.emit_invocation_error(
|
||||||
|
queue_batch_id=queue_item.session_id,
|
||||||
|
queue_item_id=queue_item.item_id,
|
||||||
|
queue_id=queue_item.queue_id,
|
||||||
|
graph_execution_state_id=queue_item.session.id,
|
||||||
|
node=invocation.model_dump(),
|
||||||
|
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
|
error_type=e.__class__.__name__,
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DefaultSessionProcessor(SessionProcessorBase):
|
class DefaultSessionProcessor(SessionProcessorBase):
|
||||||
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
|
"""Processes sessions from the session queue"""
|
||||||
|
|
||||||
|
def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
|
||||||
|
|
||||||
|
def start(
|
||||||
|
self,
|
||||||
|
invoker: Invoker,
|
||||||
|
thread_limit: int = 1,
|
||||||
|
polling_interval: int = 1,
|
||||||
|
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||||
|
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||||
|
) -> None:
|
||||||
self._invoker: Invoker = invoker
|
self._invoker: Invoker = invoker
|
||||||
self._queue_item: Optional[SessionQueueItem] = None
|
self._queue_item: Optional[SessionQueueItem] = None
|
||||||
self._invocation: Optional[BaseInvocation] = None
|
self._invocation: Optional[BaseInvocation] = None
|
||||||
|
self.on_before_run_session = on_before_run_session
|
||||||
|
self.on_after_run_session = on_after_run_session
|
||||||
|
|
||||||
self._resume_event = ThreadEvent()
|
self._resume_event = ThreadEvent()
|
||||||
self._stop_event = ThreadEvent()
|
self._stop_event = ThreadEvent()
|
||||||
@ -59,6 +209,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
"cancel_event": self._cancel_event,
|
"cancel_event": self._cancel_event,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
|
|
||||||
def stop(self, *args, **kwargs) -> None:
|
def stop(self, *args, **kwargs) -> None:
|
||||||
@ -117,131 +268,34 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||||
cancel_event.clear()
|
cancel_event.clear()
|
||||||
|
|
||||||
|
# If we have a on_before_run_session callback, call it
|
||||||
|
if self.on_before_run_session is not None:
|
||||||
|
self.on_before_run_session(self._queue_item)
|
||||||
|
|
||||||
# If profiling is enabled, start the profiler
|
# If profiling is enabled, start the profiler
|
||||||
if self._profiler is not None:
|
if self._profiler is not None:
|
||||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||||
|
|
||||||
# Prepare invocations and take the first
|
# Run the graph
|
||||||
self._invocation = self._queue_item.session.next()
|
self.session_runner.run(queue_item=self._queue_item)
|
||||||
|
|
||||||
# Loop over invocations until the session is complete or canceled
|
# If we are profiling, stop the profiler and dump the profile & stats
|
||||||
while self._invocation is not None and not cancel_event.is_set():
|
if self._profiler:
|
||||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
profile_path = self._profiler.stop()
|
||||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
stats_path = profile_path.with_suffix(".json")
|
||||||
|
self._invoker.services.performance_statistics.dump_stats(
|
||||||
# Send starting event
|
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||||
self._invoker.services.events.emit_invocation_started(
|
|
||||||
queue_batch_id=self._queue_item.batch_id,
|
|
||||||
queue_item_id=self._queue_item.item_id,
|
|
||||||
queue_id=self._queue_item.queue_id,
|
|
||||||
graph_execution_state_id=self._queue_item.session_id,
|
|
||||||
node=self._invocation.model_dump(),
|
|
||||||
source_node_id=source_invocation_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||||
try:
|
# we don't care about that - suppress the error.
|
||||||
with self._invoker.services.performance_statistics.collect_stats(
|
with suppress(GESStatsNotFoundError):
|
||||||
self._invocation, self._queue_item.session.id
|
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||||
):
|
self._invoker.services.performance_statistics.reset_stats()
|
||||||
# Build invocation context (the node-facing API)
|
|
||||||
data = InvocationContextData(
|
|
||||||
invocation=self._invocation,
|
|
||||||
source_invocation_id=source_invocation_id,
|
|
||||||
queue_item=self._queue_item,
|
|
||||||
)
|
|
||||||
context = build_invocation_context(
|
|
||||||
data=data,
|
|
||||||
services=self._invoker.services,
|
|
||||||
cancel_event=self._cancel_event,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Invoke the node
|
# If we have a on_after_run_session callback, call it
|
||||||
outputs = self._invocation.invoke_internal(
|
if self.on_after_run_session is not None:
|
||||||
context=context, services=self._invoker.services
|
self.on_after_run_session(self._queue_item)
|
||||||
)
|
|
||||||
|
|
||||||
# Save outputs and history
|
|
||||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
|
||||||
|
|
||||||
# Send complete event
|
|
||||||
self._invoker.services.events.emit_invocation_complete(
|
|
||||||
queue_batch_id=self._queue_item.batch_id,
|
|
||||||
queue_item_id=self._queue_item.item_id,
|
|
||||||
queue_id=self._queue_item.queue_id,
|
|
||||||
graph_execution_state_id=self._queue_item.session.id,
|
|
||||||
node=self._invocation.model_dump(),
|
|
||||||
source_node_id=source_invocation_id,
|
|
||||||
result=outputs.model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
# TODO(MM2): Create an event for this
|
|
||||||
pass
|
|
||||||
|
|
||||||
except CanceledException:
|
|
||||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
|
||||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
|
||||||
# be able to cancel them mid-execution.
|
|
||||||
#
|
|
||||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
|
||||||
# is executed after each step. This step callback checks if the canceled event is set,
|
|
||||||
# then raises a CanceledException to stop execution immediately.
|
|
||||||
#
|
|
||||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
|
||||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
|
||||||
pass
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error = traceback.format_exc()
|
|
||||||
|
|
||||||
# Save error
|
|
||||||
self._queue_item.session.set_node_error(self._invocation.id, error)
|
|
||||||
self._invoker.services.logger.error(
|
|
||||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
|
||||||
)
|
|
||||||
self._invoker.services.logger.error(error)
|
|
||||||
|
|
||||||
# Send error event
|
|
||||||
self._invoker.services.events.emit_invocation_error(
|
|
||||||
queue_batch_id=self._queue_item.session_id,
|
|
||||||
queue_item_id=self._queue_item.item_id,
|
|
||||||
queue_id=self._queue_item.queue_id,
|
|
||||||
graph_execution_state_id=self._queue_item.session.id,
|
|
||||||
node=self._invocation.model_dump(),
|
|
||||||
source_node_id=source_invocation_id,
|
|
||||||
error_type=e.__class__.__name__,
|
|
||||||
error=error,
|
|
||||||
)
|
|
||||||
pass
|
|
||||||
|
|
||||||
# The session is complete if the all invocations are complete or there was an error
|
|
||||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
|
||||||
# Send complete event
|
|
||||||
self._invoker.services.events.emit_graph_execution_complete(
|
|
||||||
queue_batch_id=self._queue_item.batch_id,
|
|
||||||
queue_item_id=self._queue_item.item_id,
|
|
||||||
queue_id=self._queue_item.queue_id,
|
|
||||||
graph_execution_state_id=self._queue_item.session.id,
|
|
||||||
)
|
|
||||||
# If we are profiling, stop the profiler and dump the profile & stats
|
|
||||||
if self._profiler:
|
|
||||||
profile_path = self._profiler.stop()
|
|
||||||
stats_path = profile_path.with_suffix(".json")
|
|
||||||
self._invoker.services.performance_statistics.dump_stats(
|
|
||||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
|
||||||
)
|
|
||||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
|
||||||
# we don't care about that - suppress the error.
|
|
||||||
with suppress(GESStatsNotFoundError):
|
|
||||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
|
||||||
self._invoker.services.performance_statistics.reset_stats()
|
|
||||||
|
|
||||||
# Set the invocation to None to prepare for the next session
|
|
||||||
self._invocation = None
|
|
||||||
else:
|
|
||||||
# Prepare the next invocation
|
|
||||||
self._invocation = self._queue_item.session.next()
|
|
||||||
|
|
||||||
# The session is complete, immediately poll for next session
|
# The session is complete, immediately poll for next session
|
||||||
self._queue_item = None
|
self._queue_item = None
|
||||||
@ -275,3 +329,4 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
poll_now_event.clear()
|
poll_now_event.clear()
|
||||||
self._queue_item = None
|
self._queue_item = None
|
||||||
self._thread_semaphore.release()
|
self._thread_semaphore.release()
|
||||||
|
self._invoker.services.logger.debug("Session processor stopped")
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import threading
|
import threading
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -13,16 +13,15 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
|||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.services.images.images_common import ImageDTO
|
from invokeai.app.services.images.images_common import ImageDTO
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||||
|
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
|
||||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -300,27 +299,22 @@ class ConditioningInterface(InvocationContextInterface):
|
|||||||
|
|
||||||
|
|
||||||
class ModelsInterface(InvocationContextInterface):
|
class ModelsInterface(InvocationContextInterface):
|
||||||
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
|
def exists(self, key: str) -> bool:
|
||||||
"""Checks if a model exists.
|
"""Checks if a model exists.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
identifier: The key or ModelField representing the model.
|
key: The key of the model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the model exists, False if not.
|
True if the model exists, False if not.
|
||||||
"""
|
"""
|
||||||
if isinstance(identifier, str):
|
return self._services.model_manager.store.exists(key)
|
||||||
return self._services.model_manager.store.exists(identifier)
|
|
||||||
|
|
||||||
return self._services.model_manager.store.exists(identifier.key)
|
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
|
|
||||||
def load(
|
|
||||||
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
|
|
||||||
) -> LoadedModel:
|
|
||||||
"""Loads a model.
|
"""Loads a model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
identifier: The key or ModelField representing the model.
|
key: The key of the model.
|
||||||
submodel_type: The submodel of the model to get.
|
submodel_type: The submodel of the model to get.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -330,13 +324,9 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
# The model manager emits events as it loads the model. It needs the context data to build
|
# The model manager emits events as it loads the model. It needs the context data to build
|
||||||
# the event payloads.
|
# the event payloads.
|
||||||
|
|
||||||
if isinstance(identifier, str):
|
return self._services.model_manager.load_model_by_key(
|
||||||
model = self._services.model_manager.store.get_model(identifier)
|
key=key, submodel_type=submodel_type, context_data=self._data
|
||||||
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
|
)
|
||||||
else:
|
|
||||||
_submodel_type = submodel_type or identifier.submodel_type
|
|
||||||
model = self._services.model_manager.store.get_model(identifier.key)
|
|
||||||
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
|
|
||||||
|
|
||||||
def load_by_attrs(
|
def load_by_attrs(
|
||||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||||
@ -353,29 +343,35 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
Returns:
|
Returns:
|
||||||
An object representing the loaded model.
|
An object representing the loaded model.
|
||||||
"""
|
"""
|
||||||
|
return self._services.model_manager.load_model_by_attr(
|
||||||
|
model_name=name,
|
||||||
|
base_model=base,
|
||||||
|
model_type=type,
|
||||||
|
submodel=submodel_type,
|
||||||
|
context_data=self._data,
|
||||||
|
)
|
||||||
|
|
||||||
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
|
def get_config(self, key: str) -> AnyModelConfig:
|
||||||
if len(configs) == 0:
|
|
||||||
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
|
|
||||||
|
|
||||||
if len(configs) > 1:
|
|
||||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
|
||||||
|
|
||||||
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
|
|
||||||
|
|
||||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
|
||||||
"""Gets a model's config.
|
"""Gets a model's config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
identifier: The key or ModelField representing the model.
|
key: The key of the model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The model's config.
|
The model's config.
|
||||||
"""
|
"""
|
||||||
if isinstance(identifier, str):
|
return self._services.model_manager.store.get_model(key=key)
|
||||||
return self._services.model_manager.store.get_model(identifier)
|
|
||||||
|
|
||||||
return self._services.model_manager.store.get_model(identifier.key)
|
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||||
|
"""Gets a model's metadata, if it has any.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key of the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The model's metadata, if it has any.
|
||||||
|
"""
|
||||||
|
return self._services.model_manager.store.get_metadata(key=key)
|
||||||
|
|
||||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||||
"""Searches for models by path.
|
"""Searches for models by path.
|
||||||
|
@ -4,6 +4,8 @@ from logging import Logger
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||||
|
|
||||||
|
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
|
||||||
|
|
||||||
|
|
||||||
class Migration3Callback:
|
class Migration3Callback:
|
||||||
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
|
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
|
||||||
@ -13,6 +15,7 @@ class Migration3Callback:
|
|||||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||||
self._drop_model_manager_metadata(cursor)
|
self._drop_model_manager_metadata(cursor)
|
||||||
self._recreate_model_config(cursor)
|
self._recreate_model_config(cursor)
|
||||||
|
self._migrate_model_config_records(cursor)
|
||||||
|
|
||||||
def _drop_model_manager_metadata(self, cursor: sqlite3.Cursor) -> None:
|
def _drop_model_manager_metadata(self, cursor: sqlite3.Cursor) -> None:
|
||||||
"""Drops the `model_manager_metadata` table."""
|
"""Drops the `model_manager_metadata` table."""
|
||||||
@ -52,6 +55,12 @@ class Migration3Callback:
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""After updating the model config table, we repopulate it."""
|
||||||
|
self._logger.info("Migrating model config records from models.yaml to database")
|
||||||
|
model_record_migrator = MigrateModelYamlToDb1(self._app_config, self._logger, cursor)
|
||||||
|
model_record_migrator.migrate()
|
||||||
|
|
||||||
|
|
||||||
def build_migration_3(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
def build_migration_3(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
||||||
"""
|
"""
|
||||||
|
@ -0,0 +1,163 @@
|
|||||||
|
# Copyright (c) 2023 Lincoln D. Stein
|
||||||
|
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sqlite3
|
||||||
|
from logging import Logger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.model_records import (
|
||||||
|
DuplicateModelException,
|
||||||
|
UnknownModelException,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.config import (
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
ModelConfigFactory,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.hash import ModelHash
|
||||||
|
|
||||||
|
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class MigrateModelYamlToDb1:
|
||||||
|
"""
|
||||||
|
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.5.0).
|
||||||
|
|
||||||
|
The class has one externally useful method, migrate(), which scans the
|
||||||
|
currently models.yaml file and imports all its entries into invokeai.db.
|
||||||
|
|
||||||
|
Use this way:
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager/migrate_to_db import MigrateModelYamlToDb
|
||||||
|
MigrateModelYamlToDb().migrate()
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
config: InvokeAIAppConfig
|
||||||
|
logger: Logger
|
||||||
|
cursor: sqlite3.Cursor
|
||||||
|
|
||||||
|
def __init__(self, config: InvokeAIAppConfig, logger: Logger, cursor: sqlite3.Cursor = None) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.logger = logger
|
||||||
|
self.cursor = cursor
|
||||||
|
|
||||||
|
def get_yaml(self) -> DictConfig:
|
||||||
|
"""Fetch the models.yaml DictConfig for this installation."""
|
||||||
|
yaml_path = self.config.model_conf_path
|
||||||
|
omegaconf = OmegaConf.load(yaml_path)
|
||||||
|
assert isinstance(omegaconf, DictConfig)
|
||||||
|
return omegaconf
|
||||||
|
|
||||||
|
def migrate(self) -> None:
|
||||||
|
"""Do the migration from models.yaml to invokeai.db."""
|
||||||
|
try:
|
||||||
|
yaml = self.get_yaml()
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
|
||||||
|
for model_key, stanza in yaml.items():
|
||||||
|
if model_key == "__metadata__":
|
||||||
|
assert (
|
||||||
|
stanza["version"] == "3.0.0"
|
||||||
|
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
|
||||||
|
continue
|
||||||
|
|
||||||
|
base_type, model_type, model_name = str(model_key).split("/")
|
||||||
|
try:
|
||||||
|
hash = ModelHash().hash(self.config.models_path / stanza.path)
|
||||||
|
except OSError:
|
||||||
|
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
stanza["base"] = BaseModelType(base_type)
|
||||||
|
stanza["type"] = ModelType(model_type)
|
||||||
|
stanza["name"] = model_name
|
||||||
|
stanza["original_hash"] = hash
|
||||||
|
stanza["current_hash"] = hash
|
||||||
|
new_key = hash # deterministic key assignment
|
||||||
|
|
||||||
|
# special case for ip adapters, which need the new `image_encoder_model_id` field
|
||||||
|
if stanza["type"] == ModelType.IPAdapter:
|
||||||
|
try:
|
||||||
|
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
|
||||||
|
self.config.models_path / stanza.path
|
||||||
|
)
|
||||||
|
except OSError:
|
||||||
|
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||||
|
|
||||||
|
try:
|
||||||
|
if original_record := self._search_by_path(stanza.path):
|
||||||
|
key = original_record.key
|
||||||
|
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||||
|
self._update_model(key, new_config)
|
||||||
|
else:
|
||||||
|
self.logger.info(f"Adding model {model_name} with key {new_key}")
|
||||||
|
self._add_model(new_key, new_config)
|
||||||
|
except DuplicateModelException:
|
||||||
|
self.logger.warning(f"Model {model_name} is already in the database")
|
||||||
|
except UnknownModelException:
|
||||||
|
self.logger.warning(f"Model at {stanza.path} could not be found in database")
|
||||||
|
|
||||||
|
def _search_by_path(self, path: Path) -> Optional[AnyModelConfig]:
|
||||||
|
self.cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT config FROM model_config
|
||||||
|
WHERE path=?;
|
||||||
|
""",
|
||||||
|
(str(path),),
|
||||||
|
)
|
||||||
|
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self.cursor.fetchall()]
|
||||||
|
return results[0] if results else None
|
||||||
|
|
||||||
|
def _update_model(self, key: str, config: AnyModelConfig) -> None:
|
||||||
|
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||||
|
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||||
|
self.cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
UPDATE model_config
|
||||||
|
SET
|
||||||
|
config=?
|
||||||
|
WHERE id=?;
|
||||||
|
""",
|
||||||
|
(json_serialized, key),
|
||||||
|
)
|
||||||
|
if self.cursor.rowcount == 0:
|
||||||
|
raise UnknownModelException("model not found")
|
||||||
|
|
||||||
|
def _add_model(self, key: str, config: AnyModelConfig) -> None:
|
||||||
|
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||||
|
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||||
|
try:
|
||||||
|
self.cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT INTO model_config (
|
||||||
|
id,
|
||||||
|
original_hash,
|
||||||
|
config
|
||||||
|
)
|
||||||
|
VALUES (?,?,?);
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
key,
|
||||||
|
record.hash,
|
||||||
|
json_serialized,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except sqlite3.IntegrityError as exc:
|
||||||
|
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
|
||||||
|
|
||||||
|
def _get_image_encoder_model_id(self, model_path: Path) -> str:
|
||||||
|
with open(model_path / "image_encoder.txt") as f:
|
||||||
|
encoder = f.read()
|
||||||
|
return encoder.strip()
|
@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
|
|||||||
See :class:`Migration` for an example.
|
See :class:`Migration` for an example.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
|
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class MigrationError(RuntimeError):
|
class MigrationError(RuntimeError):
|
||||||
|
@ -8,8 +8,3 @@ class UrlServiceBase(ABC):
|
|||||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
"""Gets the URL for an image or thumbnail."""
|
"""Gets the URL for an image or thumbnail."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_model_image_url(self, model_key: str) -> str:
|
|
||||||
"""Gets the URL for a model image"""
|
|
||||||
pass
|
|
||||||
|
@ -4,9 +4,8 @@ from .urls_base import UrlServiceBase
|
|||||||
|
|
||||||
|
|
||||||
class LocalUrlService(UrlServiceBase):
|
class LocalUrlService(UrlServiceBase):
|
||||||
def __init__(self, base_url: str = "api/v1", base_url_v2: str = "api/v2"):
|
def __init__(self, base_url: str = "api/v1"):
|
||||||
self._base_url = base_url
|
self._base_url = base_url
|
||||||
self._base_url_v2 = base_url_v2
|
|
||||||
|
|
||||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
image_basename = os.path.basename(image_name)
|
image_basename = os.path.basename(image_name)
|
||||||
@ -16,6 +15,3 @@ class LocalUrlService(UrlServiceBase):
|
|||||||
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
|
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
|
||||||
|
|
||||||
return f"{self._base_url}/images/i/{image_basename}/full"
|
return f"{self._base_url}/images/i/{image_basename}/full"
|
||||||
|
|
||||||
def get_model_image_url(self, model_key: str) -> str:
|
|
||||||
return f"{self._base_url_v2}/models/i/{model_key}/image"
|
|
||||||
|
@ -22,7 +22,7 @@ def generate_ti_list(
|
|||||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||||
name_or_key = trigger[1:-1]
|
name_or_key = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
loaded_model = context.models.load(name_or_key)
|
loaded_model = context.models.load(key=name_or_key)
|
||||||
model = loaded_model.model
|
model = loaded_model.model
|
||||||
assert isinstance(model, TextualInversionModelRaw)
|
assert isinstance(model, TextualInversionModelRaw)
|
||||||
assert loaded_model.config.base == base
|
assert loaded_model.config.base == base
|
||||||
|
@ -19,6 +19,7 @@ from invokeai.app.services.model_install import (
|
|||||||
ModelInstallService,
|
ModelInstallService,
|
||||||
ModelInstallServiceBase,
|
ModelInstallServiceBase,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
@ -38,7 +39,7 @@ def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordService
|
|||||||
logger = InvokeAILogger.get_logger(config=app_config)
|
logger = InvokeAILogger.get_logger(config=app_config)
|
||||||
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
|
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
|
||||||
db = init_db(config=app_config, logger=logger, image_files=image_files)
|
db = init_db(config=app_config, logger=logger, image_files=image_files)
|
||||||
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ import warnings
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copy, get_terminal_size, move
|
from shutil import get_terminal_size
|
||||||
from typing import Any, Optional, Set, Tuple, Type, get_args, get_type_hints
|
from typing import Any, Optional, Set, Tuple, Type, get_args, get_type_hints
|
||||||
from urllib import request
|
from urllib import request
|
||||||
|
|
||||||
@ -929,10 +929,6 @@ def main() -> None:
|
|||||||
|
|
||||||
errors = set()
|
errors = set()
|
||||||
FORCE_FULL_PRECISION = opt.full_precision # FIXME global
|
FORCE_FULL_PRECISION = opt.full_precision # FIXME global
|
||||||
new_init_file = config.root_path / "invokeai.yaml"
|
|
||||||
backup_init_file = new_init_file.with_suffix(".bak")
|
|
||||||
if new_init_file.exists():
|
|
||||||
copy(new_init_file, backup_init_file)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# if we do a root migration/upgrade, then we are keeping previous
|
# if we do a root migration/upgrade, then we are keeping previous
|
||||||
@ -947,6 +943,7 @@ def main() -> None:
|
|||||||
install_helper = InstallHelper(config, logger)
|
install_helper = InstallHelper(config, logger)
|
||||||
|
|
||||||
models_to_download = default_user_selections(opt, install_helper)
|
models_to_download = default_user_selections(opt, install_helper)
|
||||||
|
new_init_file = config.root_path / "invokeai.yaml"
|
||||||
|
|
||||||
if opt.yes_to_all:
|
if opt.yes_to_all:
|
||||||
write_default_options(opt, new_init_file)
|
write_default_options(opt, new_init_file)
|
||||||
@ -978,17 +975,8 @@ def main() -> None:
|
|||||||
input("Press any key to continue...")
|
input("Press any key to continue...")
|
||||||
except WindowTooSmallException as e:
|
except WindowTooSmallException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
if backup_init_file.exists():
|
|
||||||
move(backup_init_file, new_init_file)
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\nGoodbye! Come back soon.")
|
print("\nGoodbye! Come back soon.")
|
||||||
if backup_init_file.exists():
|
|
||||||
move(backup_init_file, new_init_file)
|
|
||||||
except Exception:
|
|
||||||
print("An error occurred during installation.")
|
|
||||||
if backup_init_file.exists():
|
|
||||||
move(backup_init_file, new_init_file)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
|
@ -22,7 +22,7 @@ Validation errors will raise an InvalidModelConfigException error.
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal, Optional, Type, TypeAlias, Union
|
from typing import Literal, Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.modeling_utils import ModelMixin
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
@ -129,9 +129,10 @@ class ModelSourceType(str, Enum):
|
|||||||
Path = "path"
|
Path = "path"
|
||||||
Url = "url"
|
Url = "url"
|
||||||
HFRepoID = "hf_repo_id"
|
HFRepoID = "hf_repo_id"
|
||||||
|
CivitAI = "civitai"
|
||||||
|
|
||||||
|
|
||||||
class MainModelDefaultSettings(BaseModel):
|
class ModelDefaultSettings(BaseModel):
|
||||||
vae: str | None
|
vae: str | None
|
||||||
vae_precision: str | None
|
vae_precision: str | None
|
||||||
scheduler: SCHEDULER_NAME_VALUES | None
|
scheduler: SCHEDULER_NAME_VALUES | None
|
||||||
@ -140,11 +141,6 @@ class MainModelDefaultSettings(BaseModel):
|
|||||||
cfg_rescale_multiplier: float | None
|
cfg_rescale_multiplier: float | None
|
||||||
|
|
||||||
|
|
||||||
class ControlAdapterDefaultSettings(BaseModel):
|
|
||||||
# This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
|
|
||||||
preprocessor: str | None
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigBase(BaseModel):
|
class ModelConfigBase(BaseModel):
|
||||||
"""Base class for model configuration information."""
|
"""Base class for model configuration information."""
|
||||||
|
|
||||||
@ -161,7 +157,10 @@ class ModelConfigBase(BaseModel):
|
|||||||
source_api_response: Optional[str] = Field(
|
source_api_response: Optional[str] = Field(
|
||||||
description="The original API response from the source, as stringified JSON.", default=None
|
description="The original API response from the source, as stringified JSON.", default=None
|
||||||
)
|
)
|
||||||
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||||
|
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||||
|
description="Default settings for this model", default=None
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
@ -187,14 +186,10 @@ class DiffusersConfigBase(ModelConfigBase):
|
|||||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
|
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
|
||||||
|
|
||||||
|
|
||||||
class LoRAConfigBase(ModelConfigBase):
|
class LoRALyCORISConfig(ModelConfigBase):
|
||||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
|
||||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALyCORISConfig(LoRAConfigBase):
|
|
||||||
"""Model config for LoRA/Lycoris models."""
|
"""Model config for LoRA/Lycoris models."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -202,9 +197,10 @@ class LoRALyCORISConfig(LoRAConfigBase):
|
|||||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
||||||
|
|
||||||
|
|
||||||
class LoRADiffusersConfig(LoRAConfigBase):
|
class LoRADiffusersConfig(ModelConfigBase):
|
||||||
"""Model config for LoRA/Diffusers models."""
|
"""Model config for LoRA/Diffusers models."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -234,13 +230,7 @@ class VAEDiffusersConfig(ModelConfigBase):
|
|||||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
|
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
|
||||||
|
|
||||||
|
|
||||||
class ControlAdapterConfigBase(BaseModel):
|
class ControlNetDiffusersConfig(DiffusersConfigBase):
|
||||||
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
|
|
||||||
description="Default settings for this model", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
|
||||||
"""Model config for ControlNet models (diffusers version)."""
|
"""Model config for ControlNet models (diffusers version)."""
|
||||||
|
|
||||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||||
@ -251,7 +241,7 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
|||||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
|
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
|
||||||
|
|
||||||
|
|
||||||
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase):
|
class ControlNetCheckpointConfig(CheckpointConfigBase):
|
||||||
"""Model config for ControlNet models (diffusers version)."""
|
"""Model config for ControlNet models (diffusers version)."""
|
||||||
|
|
||||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||||
@ -284,17 +274,10 @@ class TextualInversionFolderConfig(ModelConfigBase):
|
|||||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
|
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
|
||||||
|
|
||||||
|
|
||||||
class MainConfigBase(ModelConfigBase):
|
class MainCheckpointConfig(CheckpointConfigBase):
|
||||||
type: Literal[ModelType.Main] = ModelType.Main
|
|
||||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
|
||||||
default_settings: Optional[MainModelDefaultSettings] = Field(
|
|
||||||
description="Default settings for this model", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
|
||||||
"""Model config for main checkpoint models."""
|
"""Model config for main checkpoint models."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.Main] = ModelType.Main
|
||||||
variant: ModelVariantType = ModelVariantType.Normal
|
variant: ModelVariantType = ModelVariantType.Normal
|
||||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||||
upcast_attention: bool = False
|
upcast_attention: bool = False
|
||||||
@ -304,9 +287,11 @@ class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
|||||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
||||||
|
|
||||||
|
|
||||||
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
class MainDiffusersConfig(DiffusersConfigBase):
|
||||||
"""Model config for main diffusers models."""
|
"""Model config for main diffusers models."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.Main] = ModelType.Main
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tag() -> Tag:
|
def get_tag() -> Tag:
|
||||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
||||||
@ -325,7 +310,7 @@ class IPAdapterConfig(ModelConfigBase):
|
|||||||
|
|
||||||
|
|
||||||
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||||
"""Model config for CLIPVision."""
|
"""Model config for ClipVision."""
|
||||||
|
|
||||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||||
format: Literal[ModelFormat.Diffusers]
|
format: Literal[ModelFormat.Diffusers]
|
||||||
@ -335,7 +320,7 @@ class CLIPVisionDiffusersConfig(ModelConfigBase):
|
|||||||
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
|
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterConfig(ModelConfigBase, ControlAdapterConfigBase):
|
class T2IAdapterConfig(ModelConfigBase):
|
||||||
"""Model config for T2I."""
|
"""Model config for T2I."""
|
||||||
|
|
||||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||||
@ -387,7 +372,6 @@ AnyModelConfig = Annotated[
|
|||||||
]
|
]
|
||||||
|
|
||||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||||
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, ControlAdapterDefaultSettings]
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigFactory(object):
|
class ModelConfigFactory(object):
|
||||||
|
@ -60,7 +60,7 @@ class ModelLoaderRegistryBase(ABC):
|
|||||||
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
|
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
|
||||||
|
|
||||||
|
|
||||||
class ModelLoaderRegistry(ModelLoaderRegistryBase):
|
class ModelLoaderRegistry:
|
||||||
"""
|
"""
|
||||||
This class allows model loaders to register their type, base and format.
|
This class allows model loaders to register their type, base and format.
|
||||||
"""
|
"""
|
||||||
|
@ -24,7 +24,7 @@ from .. import ModelLoader, ModelLoaderRegistry
|
|||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||||
class LoRALoader(ModelLoader):
|
class LoraLoader(ModelLoader):
|
||||||
"""Class to load LoRA models."""
|
"""Class to load LoRA models."""
|
||||||
|
|
||||||
# We cheat a little bit to get access to the model base
|
# We cheat a little bit to get access to the model base
|
||||||
|
@ -23,7 +23,7 @@ from .generic_diffusers import GenericDiffusersLoader
|
|||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||||
class VAELoader(GenericDiffusersLoader):
|
class VaeLoader(GenericDiffusersLoader):
|
||||||
"""Class to load VAE models."""
|
"""Class to load VAE models."""
|
||||||
|
|
||||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||||
|
@ -8,19 +8,23 @@ from invokeai.backend.model_manager.metadata import(
|
|||||||
CommercialUsage,
|
CommercialUsage,
|
||||||
LicenseRestrictions,
|
LicenseRestrictions,
|
||||||
HuggingFaceMetadata,
|
HuggingFaceMetadata,
|
||||||
|
CivitaiMetadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch
|
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
|
||||||
|
|
||||||
data = HuggingFaceMetadataFetch().from_id("<REPO_ID>")
|
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
|
||||||
assert isinstance(data, HuggingFaceMetadata)
|
assert isinstance(data, CivitaiMetadata)
|
||||||
|
if data.allow_commercial_use:
|
||||||
|
print("Commercial use of this model is allowed")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .fetch import HuggingFaceMetadataFetch, ModelMetadataFetchBase
|
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch, ModelMetadataFetchBase
|
||||||
from .metadata_base import (
|
from .metadata_base import (
|
||||||
AnyModelRepoMetadata,
|
AnyModelRepoMetadata,
|
||||||
AnyModelRepoMetadataValidator,
|
AnyModelRepoMetadataValidator,
|
||||||
BaseMetadata,
|
BaseMetadata,
|
||||||
|
CivitaiMetadata,
|
||||||
HuggingFaceMetadata,
|
HuggingFaceMetadata,
|
||||||
ModelMetadataWithFiles,
|
ModelMetadataWithFiles,
|
||||||
RemoteModelFile,
|
RemoteModelFile,
|
||||||
@ -30,6 +34,8 @@ from .metadata_base import (
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"AnyModelRepoMetadata",
|
"AnyModelRepoMetadata",
|
||||||
"AnyModelRepoMetadataValidator",
|
"AnyModelRepoMetadataValidator",
|
||||||
|
"CivitaiMetadata",
|
||||||
|
"CivitaiMetadataFetch",
|
||||||
"HuggingFaceMetadata",
|
"HuggingFaceMetadata",
|
||||||
"HuggingFaceMetadataFetch",
|
"HuggingFaceMetadataFetch",
|
||||||
"ModelMetadataFetchBase",
|
"ModelMetadataFetchBase",
|
||||||
|
@ -3,14 +3,19 @@ Initialization file for invokeai.backend.model_manager.metadata.fetch
|
|||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
from invokeai.backend.model_manager.metadata.fetch import (
|
from invokeai.backend.model_manager.metadata.fetch import (
|
||||||
|
CivitaiMetadataFetch,
|
||||||
HuggingFaceMetadataFetch,
|
HuggingFaceMetadataFetch,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_manager.metadata import CivitaiMetadata
|
||||||
|
|
||||||
data = HuggingFaceMetadataFetch().from_id("<repo_id>")
|
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
|
||||||
assert isinstance(data, HuggingFaceMetadata)
|
assert isinstance(data, CivitaiMetadata)
|
||||||
|
if data.allow_commercial_use:
|
||||||
|
print("Commercial use of this model is allowed")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .civitai import CivitaiMetadataFetch
|
||||||
from .fetch_base import ModelMetadataFetchBase
|
from .fetch_base import ModelMetadataFetchBase
|
||||||
from .huggingface import HuggingFaceMetadataFetch
|
from .huggingface import HuggingFaceMetadataFetch
|
||||||
|
|
||||||
__all__ = ["ModelMetadataFetchBase", "HuggingFaceMetadataFetch"]
|
__all__ = ["ModelMetadataFetchBase", "CivitaiMetadataFetch", "HuggingFaceMetadataFetch"]
|
||||||
|
188
invokeai/backend/model_manager/metadata/fetch/civitai.py
Normal file
188
invokeai/backend/model_manager/metadata/fetch/civitai.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
|
||||||
|
"""
|
||||||
|
This module fetches model metadata objects from the Civitai model repository.
|
||||||
|
In addition to the `from_url()` and `from_id()` methods inherited from the
|
||||||
|
`ModelMetadataFetchBase` base class.
|
||||||
|
|
||||||
|
Civitai has two separate ID spaces: a model ID and a version ID. The
|
||||||
|
version ID corresponds to a specific model, and is the ID accepted by
|
||||||
|
`from_id()`. The model ID corresponds to a family of related models,
|
||||||
|
such as different training checkpoints or 16 vs 32-bit versions. The
|
||||||
|
`from_civitai_modelid()` method will accept a model ID and return the
|
||||||
|
metadata from the default version within this model set. The default
|
||||||
|
version is the same as what the user sees when they click on a model's
|
||||||
|
thumbnail.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
|
||||||
|
|
||||||
|
fetcher = CivitaiMetadataFetch()
|
||||||
|
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||||
|
print(metadata.trained_words)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from pydantic import TypeAdapter, ValidationError
|
||||||
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
from requests.sessions import Session
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||||
|
|
||||||
|
from ..metadata_base import (
|
||||||
|
AnyModelRepoMetadata,
|
||||||
|
CivitaiMetadata,
|
||||||
|
RemoteModelFile,
|
||||||
|
UnknownMetadataException,
|
||||||
|
)
|
||||||
|
from .fetch_base import ModelMetadataFetchBase
|
||||||
|
|
||||||
|
CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)"
|
||||||
|
CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
|
||||||
|
CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)"
|
||||||
|
|
||||||
|
CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||||
|
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||||
|
|
||||||
|
|
||||||
|
StringSetAdapter = TypeAdapter(set[str])
|
||||||
|
|
||||||
|
|
||||||
|
class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||||
|
"""Fetch model metadata from Civitai."""
|
||||||
|
|
||||||
|
def __init__(self, session: Optional[Session] = None, api_key: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||||
|
|
||||||
|
By providing a configurable Session object, we can support unit tests on
|
||||||
|
this module without an internet connection.
|
||||||
|
"""
|
||||||
|
self._requests = session or requests.Session()
|
||||||
|
self._api_key = api_key
|
||||||
|
|
||||||
|
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||||
|
"""
|
||||||
|
Given a URL to a CivitAI model or version page, return a ModelMetadata object.
|
||||||
|
|
||||||
|
In the event that the URL points to a model page without the particular version
|
||||||
|
indicated, the default model version is returned. Otherwise, the requested version
|
||||||
|
is returned.
|
||||||
|
"""
|
||||||
|
if match := re.match(CIVITAI_VERSION_PAGE_RE, str(url), re.IGNORECASE):
|
||||||
|
model_id = match.group(1)
|
||||||
|
version_id = match.group(2)
|
||||||
|
return self.from_civitai_versionid(int(version_id), int(model_id))
|
||||||
|
elif match := re.match(CIVITAI_MODEL_PAGE_RE, str(url), re.IGNORECASE):
|
||||||
|
model_id = match.group(1)
|
||||||
|
return self.from_civitai_modelid(int(model_id))
|
||||||
|
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url), re.IGNORECASE):
|
||||||
|
version_id = match.group(1)
|
||||||
|
return self.from_civitai_versionid(int(version_id))
|
||||||
|
raise UnknownMetadataException("The url '{url}' does not match any known Civitai URL patterns")
|
||||||
|
|
||||||
|
def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata:
|
||||||
|
"""
|
||||||
|
Given a Civitai model version ID, return a ModelRepoMetadata object.
|
||||||
|
|
||||||
|
:param id: An ID.
|
||||||
|
:param variant: A model variant from the ModelRepoVariant enum (currently ignored)
|
||||||
|
|
||||||
|
May raise an `UnknownMetadataException`.
|
||||||
|
"""
|
||||||
|
return self.from_civitai_versionid(int(id))
|
||||||
|
|
||||||
|
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
|
||||||
|
"""
|
||||||
|
Return metadata from the default version of the indicated model.
|
||||||
|
|
||||||
|
May raise an `UnknownMetadataException`.
|
||||||
|
"""
|
||||||
|
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||||
|
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||||
|
return self._from_api_response(model_json)
|
||||||
|
|
||||||
|
def _from_api_response(self, api_response: dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||||
|
try:
|
||||||
|
version_id = version_id or api_response["modelVersions"][0]["id"]
|
||||||
|
except TypeError as excp:
|
||||||
|
raise UnknownMetadataException from excp
|
||||||
|
|
||||||
|
# loop till we find the section containing the version requested
|
||||||
|
version_sections = [x for x in api_response["modelVersions"] if x["id"] == version_id]
|
||||||
|
if not version_sections:
|
||||||
|
raise UnknownMetadataException(f"Version {version_id} not found in model metadata")
|
||||||
|
|
||||||
|
version_json = version_sections[0]
|
||||||
|
|
||||||
|
# Civitai has one "primary" file plus others such as VAEs. We only fetch the primary.
|
||||||
|
primary = [x for x in version_json["files"] if x.get("primary")]
|
||||||
|
assert len(primary) == 1
|
||||||
|
primary_file = primary[0]
|
||||||
|
|
||||||
|
url = primary_file["downloadUrl"]
|
||||||
|
if "?" not in url: # work around apparent bug in civitai api
|
||||||
|
metadata_string = ""
|
||||||
|
for key, value in primary_file["metadata"].items():
|
||||||
|
if not value:
|
||||||
|
continue
|
||||||
|
metadata_string += f"&{key}={value}"
|
||||||
|
url = url + f"?type={primary_file['type']}{metadata_string}"
|
||||||
|
model_files = [
|
||||||
|
RemoteModelFile(
|
||||||
|
url=self._get_url_with_api_key(url),
|
||||||
|
path=Path(primary_file["name"]),
|
||||||
|
size=int(primary_file["sizeKB"] * 1024),
|
||||||
|
sha256=primary_file["hashes"]["SHA256"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
trigger_phrases = StringSetAdapter.validate_python(version_json.get("trainedWords"))
|
||||||
|
except ValidationError:
|
||||||
|
trigger_phrases: set[str] = set()
|
||||||
|
|
||||||
|
return CivitaiMetadata(
|
||||||
|
name=version_json["name"],
|
||||||
|
files=model_files,
|
||||||
|
trigger_phrases=trigger_phrases,
|
||||||
|
api_response=json.dumps(version_json),
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_civitai_versionid(self, version_id: int, model_id: Optional[int] = None) -> CivitaiMetadata:
|
||||||
|
"""
|
||||||
|
Return a CivitaiMetadata object given a model version id.
|
||||||
|
|
||||||
|
May raise an `UnknownMetadataException`.
|
||||||
|
"""
|
||||||
|
if model_id is None:
|
||||||
|
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||||
|
version = self._requests.get(self._get_url_with_api_key(version_url)).json()
|
||||||
|
if error := version.get("error"):
|
||||||
|
raise UnknownMetadataException(error)
|
||||||
|
model_id = version["modelId"]
|
||||||
|
|
||||||
|
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||||
|
model_json = self._requests.get(self._get_url_with_api_key(model_url)).json()
|
||||||
|
return self._from_api_response(model_json, version_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, json: str) -> CivitaiMetadata:
|
||||||
|
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||||
|
metadata = CivitaiMetadata.model_validate_json(json)
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def _get_url_with_api_key(self, url: str) -> str:
|
||||||
|
if not self._api_key:
|
||||||
|
return url
|
||||||
|
|
||||||
|
if "?" in url:
|
||||||
|
return f"{url}&token={self._api_key}"
|
||||||
|
|
||||||
|
return f"{url}?token={self._api_key}"
|
@ -5,10 +5,11 @@ This module is the base class for subclasses that fetch metadata from model repo
|
|||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch
|
from invokeai.backend.model_manager.metadata.fetch import CivitAIMetadataFetch
|
||||||
|
|
||||||
data = HuggingFaceMetadataFetch().from_id("<REPO_ID>")
|
fetcher = CivitaiMetadataFetch()
|
||||||
assert isinstance(data, HuggingFaceMetadata)
|
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||||
|
print(metadata.trained_words)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
@ -78,6 +78,14 @@ class ModelMetadataWithFiles(ModelMetadataBase):
|
|||||||
return self.files
|
return self.files
|
||||||
|
|
||||||
|
|
||||||
|
class CivitaiMetadata(ModelMetadataWithFiles):
|
||||||
|
"""Extended metadata fields provided by Civitai."""
|
||||||
|
|
||||||
|
type: Literal["civitai"] = "civitai"
|
||||||
|
trigger_phrases: set[str] = Field(description="Trigger phrases extracted from the API response")
|
||||||
|
api_response: Optional[str] = Field(description="Response from the Civitai API as stringified JSON", default=None)
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceMetadata(ModelMetadataWithFiles):
|
class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||||
"""Extended metadata fields provided by HuggingFace."""
|
"""Extended metadata fields provided by HuggingFace."""
|
||||||
|
|
||||||
@ -122,5 +130,5 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
|
|||||||
return [x for x in self.files if x.path in paths]
|
return [x for x in self.files if x.path in paths]
|
||||||
|
|
||||||
|
|
||||||
AnyModelRepoMetadata = Annotated[Union[BaseMetadata, HuggingFaceMetadata], Field(discriminator="type")]
|
AnyModelRepoMetadata = Annotated[Union[BaseMetadata, HuggingFaceMetadata, CivitaiMetadata], Field(discriminator="type")]
|
||||||
AnyModelRepoMetadataValidator = TypeAdapter(AnyModelRepoMetadata)
|
AnyModelRepoMetadataValidator = TypeAdapter(AnyModelRepoMetadata)
|
||||||
|
@ -14,7 +14,6 @@ from invokeai.backend.util.util import SilenceWarnings
|
|||||||
from .config import (
|
from .config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
ControlAdapterDefaultSettings,
|
|
||||||
InvalidModelConfigException,
|
InvalidModelConfigException,
|
||||||
ModelConfigFactory,
|
ModelConfigFactory,
|
||||||
ModelFormat,
|
ModelFormat,
|
||||||
@ -129,8 +128,6 @@ class ModelProbe(object):
|
|||||||
if fields is None:
|
if fields is None:
|
||||||
fields = {}
|
fields = {}
|
||||||
|
|
||||||
model_path = model_path.resolve()
|
|
||||||
|
|
||||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||||
model_info = None
|
model_info = None
|
||||||
model_type = None
|
model_type = None
|
||||||
@ -162,12 +159,6 @@ class ModelProbe(object):
|
|||||||
fields["format"] = fields.get("format") or probe.get_format()
|
fields["format"] = fields.get("format") or probe.get_format()
|
||||||
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
||||||
|
|
||||||
fields["default_settings"] = (
|
|
||||||
fields.get("default_settings") or probe.get_default_settings(fields["name"])
|
|
||||||
if isinstance(probe, ControlAdapterProbe)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
||||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||||
|
|
||||||
@ -338,38 +329,6 @@ class ModelProbe(object):
|
|||||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||||
|
|
||||||
|
|
||||||
class ControlAdapterProbe(ProbeBase):
|
|
||||||
"""Adds `get_default_settings` for ControlNet and T2IAdapter probes"""
|
|
||||||
|
|
||||||
# TODO(psyche): It would be nice to get these from the invocations, but that creates circular dependencies.
|
|
||||||
# "canny": CannyImageProcessorInvocation.get_type()
|
|
||||||
MODEL_NAME_TO_PREPROCESSOR = {
|
|
||||||
"canny": "canny_image_processor",
|
|
||||||
"mlsd": "mlsd_image_processor",
|
|
||||||
"depth": "depth_anything_image_processor",
|
|
||||||
"bae": "normalbae_image_processor",
|
|
||||||
"normal": "normalbae_image_processor",
|
|
||||||
"sketch": "pidi_image_processor",
|
|
||||||
"scribble": "lineart_image_processor",
|
|
||||||
"lineart": "lineart_image_processor",
|
|
||||||
"lineart_anime": "lineart_anime_image_processor",
|
|
||||||
"softedge": "hed_image_processor",
|
|
||||||
"shuffle": "content_shuffle_image_processor",
|
|
||||||
"pose": "dw_openpose_image_processor",
|
|
||||||
"mediapipe": "mediapipe_face_processor",
|
|
||||||
"pidi": "pidi_image_processor",
|
|
||||||
"zoe": "zoe_depth_image_processor",
|
|
||||||
"color": "color_map_image_processor",
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_default_settings(cls, model_name: str) -> Optional[ControlAdapterDefaultSettings]:
|
|
||||||
for k, v in cls.MODEL_NAME_TO_PREPROCESSOR.items():
|
|
||||||
if k in model_name:
|
|
||||||
return ControlAdapterDefaultSettings(preprocessor=v)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# ##################################################3
|
# ##################################################3
|
||||||
# Checkpoint probing
|
# Checkpoint probing
|
||||||
# ##################################################3
|
# ##################################################3
|
||||||
@ -493,7 +452,7 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
|||||||
raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type")
|
raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type")
|
||||||
|
|
||||||
|
|
||||||
class ControlNetCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe):
|
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||||
"""Class for probing controlnets."""
|
"""Class for probing controlnets."""
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
@ -521,7 +480,7 @@ class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe):
|
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -659,7 +618,7 @@ class ONNXFolderProbe(PipelineFolderProbe):
|
|||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFolderProbe(FolderProbeBase, ControlAdapterProbe):
|
class ControlNetFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
config_file = self.model_path / "config.json"
|
config_file = self.model_path / "config.json"
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
@ -733,7 +692,7 @@ class CLIPVisionFolderProbe(FolderProbeBase):
|
|||||||
return BaseModelType.Any
|
return BaseModelType.Any
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterFolderProbe(FolderProbeBase, ControlAdapterProbe):
|
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
config_file = self.model_path / "config.json"
|
config_file = self.model_path / "config.json"
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
|
@ -4,75 +4,121 @@ Abstract base class and implementation for recursive directory search for models
|
|||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
```
|
```
|
||||||
from invokeai.backend.model_manager import ModelSearch, ModelProbe
|
from invokeai.backend.model_manager import ModelSearch, ModelProbe
|
||||||
|
|
||||||
def find_main_models(model: Path) -> bool:
|
def find_main_models(model: Path) -> bool:
|
||||||
info = ModelProbe.probe(model)
|
info = ModelProbe.probe(model)
|
||||||
if info.model_type == 'main' and info.base_type == 'sd-1':
|
if info.model_type == 'main' and info.base_type == 'sd-1':
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
search = ModelSearch(on_model_found=report_it)
|
search = ModelSearch(on_model_found=report_it)
|
||||||
found = search.search('/tmp/models')
|
found = search.search('/tmp/models')
|
||||||
print(found) # list of matching model paths
|
print(found) # list of matching model paths
|
||||||
print(search.stats) # search stats
|
print(search.stats) # search stats
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from abc import ABC, abstractmethod
|
||||||
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional, Set, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
default_logger: Logger = InvokeAILogger.get_logger()
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SearchStats:
|
|
||||||
"""Statistics about the search.
|
|
||||||
|
|
||||||
Attributes:
|
class SearchStats(BaseModel):
|
||||||
items_scanned: number of items scanned
|
items_scanned: int = 0
|
||||||
models_found: number of models found
|
models_found: int = 0
|
||||||
models_filtered: number of models that passed the filter
|
models_filtered: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSearchBase(ABC, BaseModel):
|
||||||
"""
|
"""
|
||||||
|
Abstract directory traversal model search class
|
||||||
items_scanned = 0
|
|
||||||
models_found = 0
|
|
||||||
models_filtered = 0
|
|
||||||
|
|
||||||
|
|
||||||
class ModelSearch:
|
|
||||||
"""Searches a directory tree for models, using a callback to filter the results.
|
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
search = ModelSearch()
|
search = ModelSearchBase(
|
||||||
search.model_found = lambda path : 'anime' in path.as_posix()
|
on_search_started = search_started_callback,
|
||||||
found = search.list_models(['/tmp/models1','/tmp/models2'])
|
on_search_completed = search_completed_callback,
|
||||||
# returns all models that have 'anime' in the path
|
on_model_found = model_found_callback,
|
||||||
|
)
|
||||||
|
models_found = search.search('/path/to/directory')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
# fmt: off
|
||||||
self,
|
on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221
|
||||||
on_search_started: Optional[Callable[[Path], None]] = None,
|
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
|
||||||
on_model_found: Optional[Callable[[Path], bool]] = None,
|
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
|
||||||
on_search_completed: Optional[Callable[[set[Path]], None]] = None,
|
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
|
||||||
) -> None:
|
logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221
|
||||||
"""Create a new ModelSearch object.
|
# fmt: on
|
||||||
|
|
||||||
Args:
|
class Config:
|
||||||
on_search_started: callback to be invoked when the search starts
|
arbitrary_types_allowed = True
|
||||||
on_model_found: callback to be invoked when a model is found. The callback should return True if the model
|
|
||||||
should be included in the results.
|
@abstractmethod
|
||||||
on_search_completed: callback to be invoked when the search is completed
|
def search_started(self) -> None:
|
||||||
"""
|
"""
|
||||||
self.stats = SearchStats()
|
Called before the scan starts.
|
||||||
self.logger = InvokeAILogger.get_logger()
|
|
||||||
self.on_search_started = on_search_started
|
Passes the root search directory to the Callable `on_search_started`.
|
||||||
self.on_model_found = on_model_found
|
"""
|
||||||
self.on_search_completed = on_search_completed
|
pass
|
||||||
self.models_found: set[Path] = set()
|
|
||||||
|
@abstractmethod
|
||||||
|
def model_found(self, model: Path) -> None:
|
||||||
|
"""
|
||||||
|
Called when a model is found during search.
|
||||||
|
|
||||||
|
:param model: Model to process - could be a directory or checkpoint.
|
||||||
|
|
||||||
|
Passes the model's Path to the Callable `on_model_found`.
|
||||||
|
This Callable receives the path to the model and returns a boolean
|
||||||
|
to indicate whether the model should be returned in the search
|
||||||
|
results.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_completed(self) -> None:
|
||||||
|
"""
|
||||||
|
Called before the scan starts.
|
||||||
|
|
||||||
|
Passes the Set of found model Paths to the Callable `on_search_completed`.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||||
|
"""
|
||||||
|
Recursively search for models in `directory` and return a set of model paths.
|
||||||
|
|
||||||
|
If provided, the `on_search_started`, `on_model_found` and `on_search_completed`
|
||||||
|
Callables will be invoked during the search.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSearch(ModelSearchBase):
|
||||||
|
"""
|
||||||
|
Implementation of ModelSearch with callbacks.
|
||||||
|
Usage:
|
||||||
|
search = ModelSearch()
|
||||||
|
search.model_found = lambda path : 'anime' in path.as_posix()
|
||||||
|
found = search.list_models(['/tmp/models1','/tmp/models2'])
|
||||||
|
# returns all models that have 'anime' in the path
|
||||||
|
"""
|
||||||
|
|
||||||
|
models_found: Set[Path] = Field(default_factory=set)
|
||||||
|
config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
def search_started(self) -> None:
|
def search_started(self) -> None:
|
||||||
self.models_found = set()
|
self.models_found = set()
|
||||||
@ -89,17 +135,17 @@ class ModelSearch:
|
|||||||
if self.on_search_completed is not None:
|
if self.on_search_completed is not None:
|
||||||
self.on_search_completed(self.models_found)
|
self.on_search_completed(self.models_found)
|
||||||
|
|
||||||
def search(self, directory: Path) -> set[Path]:
|
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||||
self._directory = Path(directory)
|
self._directory = Path(directory)
|
||||||
self._directory = self._directory.resolve()
|
if not self._directory.is_absolute():
|
||||||
|
self._directory = self.config.models_path / self._directory
|
||||||
self.stats = SearchStats() # zero out
|
self.stats = SearchStats() # zero out
|
||||||
self.search_started() # This will initialize _models_found to empty
|
self.search_started() # This will initialize _models_found to empty
|
||||||
self._walk_directory(self._directory)
|
self._walk_directory(self._directory)
|
||||||
self.search_completed()
|
self.search_completed()
|
||||||
return self.models_found
|
return self.models_found
|
||||||
|
|
||||||
def _walk_directory(self, path: Path, max_depth: int = 20) -> None:
|
def _walk_directory(self, path: Union[Path, str], max_depth: int = 20) -> None:
|
||||||
"""Recursively walk the directory tree, looking for models."""
|
|
||||||
absolute_path = Path(path)
|
absolute_path = Path(path)
|
||||||
if (
|
if (
|
||||||
len(absolute_path.parts) - len(self._directory.parts) > max_depth
|
len(absolute_path.parts) - len(self._directory.parts) > max_depth
|
||||||
|
@ -455,6 +455,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
|
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
|
|
||||||
|
latents = self.invokeai_diffuser.do_latent_postprocessing(
|
||||||
|
postprocessing_settings=conditioning_data.postprocessing_settings,
|
||||||
|
latents=latents,
|
||||||
|
sigma=batched_t,
|
||||||
|
step_index=i,
|
||||||
|
total_step_count=len(timesteps),
|
||||||
|
)
|
||||||
|
|
||||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
|
@ -44,6 +44,14 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
|||||||
return super().to(device=device, dtype=dtype)
|
return super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PostprocessingSettings:
|
||||||
|
threshold: float
|
||||||
|
warmup: float
|
||||||
|
h_symmetry_time_pct: Optional[float]
|
||||||
|
v_symmetry_time_pct: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IPAdapterConditioningInfo:
|
class IPAdapterConditioningInfo:
|
||||||
cond_image_prompt_embeds: torch.Tensor
|
cond_image_prompt_embeds: torch.Tensor
|
||||||
@ -72,6 +80,10 @@ class ConditioningData:
|
|||||||
"""
|
"""
|
||||||
guidance_rescale_multiplier: float = 0
|
guidance_rescale_multiplier: float = 0
|
||||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||||
|
"""
|
||||||
|
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
||||||
|
"""
|
||||||
|
postprocessing_settings: Optional[PostprocessingSettings] = None
|
||||||
|
|
||||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
|
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
ExtraConditioningInfo,
|
ExtraConditioningInfo,
|
||||||
|
PostprocessingSettings,
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -243,6 +244,19 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
|
def do_latent_postprocessing(
|
||||||
|
self,
|
||||||
|
postprocessing_settings: PostprocessingSettings,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
sigma,
|
||||||
|
step_index,
|
||||||
|
total_step_count,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if postprocessing_settings is not None:
|
||||||
|
percent_through = step_index / total_step_count
|
||||||
|
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
|
||||||
|
return latents
|
||||||
|
|
||||||
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
||||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||||
conditioning_attention_mask = torch.ones(
|
conditioning_attention_mask = torch.ones(
|
||||||
@ -492,3 +506,64 @@ class InvokeAIDiffuserComponent:
|
|||||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
||||||
combined_next_x = unconditioned_next_x + scaled_delta
|
combined_next_x = unconditioned_next_x + scaled_delta
|
||||||
return combined_next_x
|
return combined_next_x
|
||||||
|
|
||||||
|
def apply_symmetry(
|
||||||
|
self,
|
||||||
|
postprocessing_settings: PostprocessingSettings,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
percent_through: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Reset our last percent through if this is our first step.
|
||||||
|
if percent_through == 0.0:
|
||||||
|
self.last_percent_through = 0.0
|
||||||
|
|
||||||
|
if postprocessing_settings is None:
|
||||||
|
return latents
|
||||||
|
|
||||||
|
# Check for out of bounds
|
||||||
|
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
|
||||||
|
if h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0):
|
||||||
|
h_symmetry_time_pct = None
|
||||||
|
|
||||||
|
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
|
||||||
|
if v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0):
|
||||||
|
v_symmetry_time_pct = None
|
||||||
|
|
||||||
|
dev = latents.device.type
|
||||||
|
|
||||||
|
latents.to(device="cpu")
|
||||||
|
|
||||||
|
if (
|
||||||
|
h_symmetry_time_pct is not None
|
||||||
|
and self.last_percent_through < h_symmetry_time_pct
|
||||||
|
and percent_through >= h_symmetry_time_pct
|
||||||
|
):
|
||||||
|
# Horizontal symmetry occurs on the 3rd dimension of the latent
|
||||||
|
width = latents.shape[3]
|
||||||
|
x_flipped = torch.flip(latents, dims=[3])
|
||||||
|
latents = torch.cat(
|
||||||
|
[
|
||||||
|
latents[:, :, :, 0 : int(width / 2)],
|
||||||
|
x_flipped[:, :, :, int(width / 2) : int(width)],
|
||||||
|
],
|
||||||
|
dim=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
v_symmetry_time_pct is not None
|
||||||
|
and self.last_percent_through < v_symmetry_time_pct
|
||||||
|
and percent_through >= v_symmetry_time_pct
|
||||||
|
):
|
||||||
|
# Vertical symmetry occurs on the 2nd dimension of the latent
|
||||||
|
height = latents.shape[2]
|
||||||
|
y_flipped = torch.flip(latents, dims=[2])
|
||||||
|
latents = torch.cat(
|
||||||
|
[
|
||||||
|
latents[:, :, 0 : int(height / 2)],
|
||||||
|
y_flipped[:, :, int(height / 2) : int(height)],
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.last_percent_through = percent_through
|
||||||
|
return latents.to(device=dev)
|
||||||
|
@ -858,9 +858,9 @@ def do_textual_inversion_training(
|
|||||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||||
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||||
orig_embeds_params[index_no_updates]
|
index_no_updates
|
||||||
)
|
] = orig_embeds_params[index_no_updates]
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
|
@ -42,10 +42,9 @@ def install_and_load_model(
|
|||||||
# If the requested model is already installed, return its LoadedModel
|
# If the requested model is already installed, return its LoadedModel
|
||||||
with contextlib.suppress(UnknownModelException):
|
with contextlib.suppress(UnknownModelException):
|
||||||
# TODO: Replace with wrapper call
|
# TODO: Replace with wrapper call
|
||||||
configs = model_manager.store.search_by_attr(
|
loaded_model: LoadedModel = model_manager.load_model_by_attr(
|
||||||
model_name=model_name, base_model=base_model, model_type=model_type
|
model_name=model_name, base_model=base_model, model_type=model_type
|
||||||
)
|
)
|
||||||
loaded_model: LoadedModel = model_manager.load.load_model(configs[0])
|
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
||||||
# Install the requested model.
|
# Install the requested model.
|
||||||
@ -54,7 +53,7 @@ def install_and_load_model(
|
|||||||
assert job.complete
|
assert job.complete
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loaded_model = model_manager.load.load_model(job.config_out)
|
loaded_model = model_manager.load_model_by_config(job.config_out)
|
||||||
return loaded_model
|
return loaded_model
|
||||||
except UnknownModelException as e:
|
except UnknownModelException as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -20,6 +20,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from invokeai.app.services.download import DownloadQueueService
|
from invokeai.app.services.download import DownloadQueueService
|
||||||
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||||
from invokeai.app.services.model_install import ModelInstallService
|
from invokeai.app.services.model_install import ModelInstallService
|
||||||
|
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
@ -412,7 +413,7 @@ def get_config_store() -> ModelRecordServiceSQL:
|
|||||||
assert output_path is not None
|
assert output_path is not None
|
||||||
image_files = DiskImageFileStorage(output_path / "images")
|
image_files = DiskImageFileStorage(output_path / "images")
|
||||||
db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files)
|
db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files)
|
||||||
return ModelRecordServiceSQL(db)
|
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||||
|
|
||||||
|
|
||||||
def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger:
|
def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger:
|
||||||
|
@ -10,7 +10,7 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
useGlobalModifiersInit();
|
useGlobalModifiersInit();
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
dispatch(modelChanged({ key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' }));
|
dispatch(modelChanged({ key: 'test_model', base: 'sd-1' }));
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
return props.children;
|
return props.children;
|
||||||
|
@ -746,7 +746,6 @@
|
|||||||
"delete": "Delete",
|
"delete": "Delete",
|
||||||
"deleteConfig": "Delete Config",
|
"deleteConfig": "Delete Config",
|
||||||
"deleteModel": "Delete Model",
|
"deleteModel": "Delete Model",
|
||||||
"deleteModelImage": "Delete Model Image",
|
|
||||||
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
|
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
|
||||||
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
|
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
|
||||||
"description": "Description",
|
"description": "Description",
|
||||||
@ -766,14 +765,11 @@
|
|||||||
"importModels": "Import Models",
|
"importModels": "Import Models",
|
||||||
"importQueue": "Import Queue",
|
"importQueue": "Import Queue",
|
||||||
"inpainting": "v1 Inpainting",
|
"inpainting": "v1 Inpainting",
|
||||||
"inplaceInstall": "In-place install",
|
|
||||||
"inplaceInstallDesc": "Install models without copying the files. When using the model, it will be loaded from its this location. If disabled, the model file(s) will be copied into the Invoke-managed models directory during installation.",
|
|
||||||
"interpolationType": "Interpolation Type",
|
"interpolationType": "Interpolation Type",
|
||||||
"inverseSigmoid": "Inverse Sigmoid",
|
"inverseSigmoid": "Inverse Sigmoid",
|
||||||
"invokeAIFolder": "Invoke AI Folder",
|
"invokeAIFolder": "Invoke AI Folder",
|
||||||
"invokeRoot": "InvokeAI folder",
|
"invokeRoot": "InvokeAI folder",
|
||||||
"load": "Load",
|
"load": "Load",
|
||||||
"localOnly": "local only",
|
|
||||||
"loraModels": "LoRAs",
|
"loraModels": "LoRAs",
|
||||||
"manual": "Manual",
|
"manual": "Manual",
|
||||||
"merge": "Merge",
|
"merge": "Merge",
|
||||||
@ -790,10 +786,6 @@
|
|||||||
"modelDeleteFailed": "Failed to delete model",
|
"modelDeleteFailed": "Failed to delete model",
|
||||||
"modelEntryDeleted": "Model Entry Deleted",
|
"modelEntryDeleted": "Model Entry Deleted",
|
||||||
"modelExists": "Model Exists",
|
"modelExists": "Model Exists",
|
||||||
"modelImageDeleted": "Model Image Deleted",
|
|
||||||
"modelImageDeleteFailed": "Model Image Delete Failed",
|
|
||||||
"modelImageUpdated": "Model Image Updated",
|
|
||||||
"modelImageUpdateFailed": "Model Image Update Failed",
|
|
||||||
"modelLocation": "Model Location",
|
"modelLocation": "Model Location",
|
||||||
"modelLocationValidationMsg": "Provide the path to a local folder where your Diffusers Model is stored",
|
"modelLocationValidationMsg": "Provide the path to a local folder where your Diffusers Model is stored",
|
||||||
"modelManager": "Model Manager",
|
"modelManager": "Model Manager",
|
||||||
@ -826,7 +818,6 @@
|
|||||||
"oliveModels": "Olives",
|
"oliveModels": "Olives",
|
||||||
"onnxModels": "Onnx",
|
"onnxModels": "Onnx",
|
||||||
"path": "Path",
|
"path": "Path",
|
||||||
"pathToConfig": "Path To Config",
|
|
||||||
"pathToCustomConfig": "Path To Custom Config",
|
"pathToCustomConfig": "Path To Custom Config",
|
||||||
"pickModelType": "Pick Model Type",
|
"pickModelType": "Pick Model Type",
|
||||||
"predictionType": "Prediction Type",
|
"predictionType": "Prediction Type",
|
||||||
@ -859,11 +850,8 @@
|
|||||||
"syncModels": "Sync Models",
|
"syncModels": "Sync Models",
|
||||||
"syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.",
|
"syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.",
|
||||||
"triggerPhrases": "Trigger Phrases",
|
"triggerPhrases": "Trigger Phrases",
|
||||||
"loraTriggerPhrases": "LoRA Trigger Phrases",
|
|
||||||
"mainModelTriggerPhrases": "Main Model Trigger Phrases",
|
|
||||||
"typePhraseHere": "Type phrase here",
|
"typePhraseHere": "Type phrase here",
|
||||||
"upcastAttention": "Upcast Attention",
|
"upcastAttention": "Upcast Attention",
|
||||||
"uploadImage": "Upload Image",
|
|
||||||
"updateModel": "Update Model",
|
"updateModel": "Update Model",
|
||||||
"useCustomConfig": "Use Custom Config",
|
"useCustomConfig": "Use Custom Config",
|
||||||
"useDefaultSettings": "Use Default Settings",
|
"useDefaultSettings": "Use Default Settings",
|
||||||
@ -956,7 +944,6 @@
|
|||||||
"doesNotExist": "does not exist",
|
"doesNotExist": "does not exist",
|
||||||
"downloadWorkflow": "Download Workflow JSON",
|
"downloadWorkflow": "Download Workflow JSON",
|
||||||
"edge": "Edge",
|
"edge": "Edge",
|
||||||
"edit": "Edit",
|
|
||||||
"editMode": "Edit in Workflow Editor",
|
"editMode": "Edit in Workflow Editor",
|
||||||
"enum": "Enum",
|
"enum": "Enum",
|
||||||
"enumDescription": "Enums are values that may be one of a number of options.",
|
"enumDescription": "Enums are values that may be one of a number of options.",
|
||||||
@ -1032,7 +1019,6 @@
|
|||||||
"nodeTemplate": "Node Template",
|
"nodeTemplate": "Node Template",
|
||||||
"nodeType": "Node Type",
|
"nodeType": "Node Type",
|
||||||
"noFieldsLinearview": "No fields added to Linear View",
|
"noFieldsLinearview": "No fields added to Linear View",
|
||||||
"noFieldsViewMode": "This workflow has no selected fields to display. View the full workflow to configure values.",
|
|
||||||
"noFieldType": "No field type",
|
"noFieldType": "No field type",
|
||||||
"noImageFoundState": "No initial image found in state",
|
"noImageFoundState": "No initial image found in state",
|
||||||
"noMatchingNodes": "No matching nodes",
|
"noMatchingNodes": "No matching nodes",
|
||||||
@ -1820,7 +1806,6 @@
|
|||||||
"cursorPosition": "Cursor Position",
|
"cursorPosition": "Cursor Position",
|
||||||
"darkenOutsideSelection": "Darken Outside Selection",
|
"darkenOutsideSelection": "Darken Outside Selection",
|
||||||
"discardAll": "Discard All",
|
"discardAll": "Discard All",
|
||||||
"discardCurrent": "Discard Current",
|
|
||||||
"downloadAsImage": "Download As Image",
|
"downloadAsImage": "Download As Image",
|
||||||
"emptyFolder": "Empty Folder",
|
"emptyFolder": "Empty Folder",
|
||||||
"emptyTempImageFolder": "Empty Temp Image Folder",
|
"emptyTempImageFolder": "Empty Temp Image Folder",
|
||||||
@ -1830,7 +1815,6 @@
|
|||||||
"eraseBoundingBox": "Erase Bounding Box",
|
"eraseBoundingBox": "Erase Bounding Box",
|
||||||
"eraser": "Eraser",
|
"eraser": "Eraser",
|
||||||
"fillBoundingBox": "Fill Bounding Box",
|
"fillBoundingBox": "Fill Bounding Box",
|
||||||
"invertBrushSizeScrollDirection": "Invert Scroll for Brush Size",
|
|
||||||
"layer": "Layer",
|
"layer": "Layer",
|
||||||
"limitStrokesToBox": "Limit Strokes to Box",
|
"limitStrokesToBox": "Limit Strokes to Box",
|
||||||
"mask": "Mask",
|
"mask": "Mask",
|
||||||
|
@ -115,8 +115,7 @@
|
|||||||
"safetensors": "Safetensors",
|
"safetensors": "Safetensors",
|
||||||
"ai": "ia",
|
"ai": "ia",
|
||||||
"file": "File",
|
"file": "File",
|
||||||
"toResolve": "Da risolvere",
|
"toResolve": "Da risolvere"
|
||||||
"add": "Aggiungi"
|
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"generations": "Generazioni",
|
"generations": "Generazioni",
|
||||||
@ -154,12 +153,7 @@
|
|||||||
"starImage": "Immagine preferita",
|
"starImage": "Immagine preferita",
|
||||||
"dropToUpload": "$t(gallery.drop) per aggiornare",
|
"dropToUpload": "$t(gallery.drop) per aggiornare",
|
||||||
"problemDeletingImagesDesc": "Impossibile eliminare una o più immagini",
|
"problemDeletingImagesDesc": "Impossibile eliminare una o più immagini",
|
||||||
"problemDeletingImages": "Problema durante l'eliminazione delle immagini",
|
"problemDeletingImages": "Problema durante l'eliminazione delle immagini"
|
||||||
"bulkDownloadRequested": "Preparazione del download",
|
|
||||||
"bulkDownloadRequestedDesc": "La tua richiesta di download è in preparazione. L'operazione potrebbe richiedere alcuni istanti.",
|
|
||||||
"bulkDownloadRequestFailed": "Problema durante la preparazione del download",
|
|
||||||
"bulkDownloadStarting": "Avvio scaricamento",
|
|
||||||
"bulkDownloadFailed": "Scaricamento fallito"
|
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"keyboardShortcuts": "Tasti di scelta rapida",
|
"keyboardShortcuts": "Tasti di scelta rapida",
|
||||||
@ -511,12 +505,12 @@
|
|||||||
"modelSyncFailed": "Sincronizzazione modello non riuscita",
|
"modelSyncFailed": "Sincronizzazione modello non riuscita",
|
||||||
"settings": "Impostazioni",
|
"settings": "Impostazioni",
|
||||||
"syncModels": "Sincronizza Modelli",
|
"syncModels": "Sincronizza Modelli",
|
||||||
"syncModelsDesc": "Se i tuoi modelli non sono sincronizzati con il back-end, puoi aggiornarli utilizzando questa opzione. Questo è generalmente utile nei casi in cui aggiungi modelli alla cartella principale di InvokeAI dopo l'avvio dell'applicazione.",
|
"syncModelsDesc": "Se i tuoi modelli non sono sincronizzati con il back-end, puoi aggiornarli utilizzando questa opzione. Questo è generalmente utile nei casi in cui aggiorni manualmente il tuo file models.yaml o aggiungi modelli alla cartella principale di InvokeAI dopo l'avvio dell'applicazione.",
|
||||||
"loraModels": "LoRA",
|
"loraModels": "LoRA",
|
||||||
"oliveModels": "Olive",
|
"oliveModels": "Olive",
|
||||||
"onnxModels": "ONNX",
|
"onnxModels": "ONNX",
|
||||||
"noModels": "Nessun modello trovato",
|
"noModels": "Nessun modello trovato",
|
||||||
"predictionType": "Tipo di previsione",
|
"predictionType": "Tipo di previsione (per modelli Stable Diffusion 2.x ed alcuni modelli Stable Diffusion 1.x)",
|
||||||
"quickAdd": "Aggiunta rapida",
|
"quickAdd": "Aggiunta rapida",
|
||||||
"simpleModelDesc": "Fornire un percorso a un modello diffusori locale, un modello checkpoint/safetensor locale, un ID repository HuggingFace o un URL del modello checkpoint/diffusori.",
|
"simpleModelDesc": "Fornire un percorso a un modello diffusori locale, un modello checkpoint/safetensor locale, un ID repository HuggingFace o un URL del modello checkpoint/diffusori.",
|
||||||
"advanced": "Avanzate",
|
"advanced": "Avanzate",
|
||||||
@ -527,34 +521,7 @@
|
|||||||
"vaePrecision": "Precisione VAE",
|
"vaePrecision": "Precisione VAE",
|
||||||
"noModelSelected": "Nessun modello selezionato",
|
"noModelSelected": "Nessun modello selezionato",
|
||||||
"conversionNotSupported": "Conversione non supportata",
|
"conversionNotSupported": "Conversione non supportata",
|
||||||
"configFile": "File di configurazione",
|
"configFile": "File di configurazione"
|
||||||
"modelName": "Nome del modello",
|
|
||||||
"modelSettings": "Impostazioni del modello",
|
|
||||||
"advancedImportInfo": "La scheda opzioni avanzate consente la configurazione manuale delle impostazioni del modello principale. Utilizza questa scheda solo se sei sicuro di conoscere il tipo di modello e la configurazione corretti per il modello selezionato.",
|
|
||||||
"addAll": "Aggiungi tutto",
|
|
||||||
"addModels": "Aggiungi modelli",
|
|
||||||
"cancel": "Annulla",
|
|
||||||
"edit": "Modifica",
|
|
||||||
"imageEncoderModelId": "ID modello codificatore di immagini",
|
|
||||||
"importQueue": "Coda di importazione",
|
|
||||||
"modelMetadata": "Metadati del modello",
|
|
||||||
"path": "Percorso",
|
|
||||||
"prune": "Elimina",
|
|
||||||
"pruneTooltip": "Elimina dalla coda le importazioni completate",
|
|
||||||
"removeFromQueue": "Rimuovi dalla coda",
|
|
||||||
"repoVariant": "Variante del repository",
|
|
||||||
"scan": "Scansiona",
|
|
||||||
"scanFolder": "Scansione cartella",
|
|
||||||
"scanResults": "Risultati della scansione",
|
|
||||||
"source": "Sorgente",
|
|
||||||
"upcastAttention": "Eleva l'attenzione",
|
|
||||||
"ztsnrTraining": "Addestramento ZTSNR",
|
|
||||||
"typePhraseHere": "Digita la frase qui",
|
|
||||||
"defaultSettingsSaved": "Impostazioni predefinite salvate",
|
|
||||||
"defaultSettings": "Impostazioni predefinite",
|
|
||||||
"metadata": "Metadati",
|
|
||||||
"useDefaultSettings": "Usa le impostazioni predefinite",
|
|
||||||
"triggerPhrases": "Frasi trigger"
|
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"images": "Immagini",
|
"images": "Immagini",
|
||||||
@ -636,8 +603,8 @@
|
|||||||
"clipSkip": "CLIP Skip",
|
"clipSkip": "CLIP Skip",
|
||||||
"aspectRatio": "Proporzioni",
|
"aspectRatio": "Proporzioni",
|
||||||
"maskAdjustmentsHeader": "Regolazioni della maschera",
|
"maskAdjustmentsHeader": "Regolazioni della maschera",
|
||||||
"maskBlur": "Sfocatura maschera",
|
"maskBlur": "Sfocatura",
|
||||||
"maskBlurMethod": "Metodo sfocatura maschera",
|
"maskBlurMethod": "Metodo di sfocatura",
|
||||||
"seamLowThreshold": "Basso",
|
"seamLowThreshold": "Basso",
|
||||||
"seamHighThreshold": "Alto",
|
"seamHighThreshold": "Alto",
|
||||||
"coherencePassHeader": "Passaggio di coerenza",
|
"coherencePassHeader": "Passaggio di coerenza",
|
||||||
@ -694,8 +661,7 @@
|
|||||||
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (potrebbe essere troppo grande)",
|
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (potrebbe essere troppo grande)",
|
||||||
"boxBlur": "Box",
|
"boxBlur": "Box",
|
||||||
"gaussianBlur": "Gaussian",
|
"gaussianBlur": "Gaussian",
|
||||||
"remixImage": "Remixa l'immagine",
|
"remixImage": "Remixa l'immagine"
|
||||||
"coherenceEdgeSize": "Dimensione bordo"
|
|
||||||
},
|
},
|
||||||
"settings": {
|
"settings": {
|
||||||
"models": "Modelli",
|
"models": "Modelli",
|
||||||
@ -778,8 +744,8 @@
|
|||||||
"canceled": "Elaborazione annullata",
|
"canceled": "Elaborazione annullata",
|
||||||
"problemCopyingImageLink": "Impossibile copiare il collegamento dell'immagine",
|
"problemCopyingImageLink": "Impossibile copiare il collegamento dell'immagine",
|
||||||
"uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG",
|
"uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG",
|
||||||
"parameterSet": "{{parameter}} impostato",
|
"parameterSet": "Parametro impostato",
|
||||||
"parameterNotSet": "{{parameter}} non impostato",
|
"parameterNotSet": "Parametro non impostato",
|
||||||
"nodesLoadedFailed": "Impossibile caricare i nodi",
|
"nodesLoadedFailed": "Impossibile caricare i nodi",
|
||||||
"nodesSaved": "Nodi salvati",
|
"nodesSaved": "Nodi salvati",
|
||||||
"nodesLoaded": "Nodi caricati",
|
"nodesLoaded": "Nodi caricati",
|
||||||
@ -832,10 +798,7 @@
|
|||||||
"problemRetrievingWorkflow": "Problema nel recupero del flusso di lavoro",
|
"problemRetrievingWorkflow": "Problema nel recupero del flusso di lavoro",
|
||||||
"resetInitialImage": "Reimposta l'immagine iniziale",
|
"resetInitialImage": "Reimposta l'immagine iniziale",
|
||||||
"uploadInitialImage": "Carica l'immagine iniziale",
|
"uploadInitialImage": "Carica l'immagine iniziale",
|
||||||
"problemDownloadingImage": "Impossibile scaricare l'immagine",
|
"problemDownloadingImage": "Impossibile scaricare l'immagine"
|
||||||
"prunedQueue": "Coda ripulita",
|
|
||||||
"modelImportCanceled": "Importazione del modello annullata",
|
|
||||||
"modelImportRemoved": "Importazione del modello rimossa"
|
|
||||||
},
|
},
|
||||||
"tooltip": {
|
"tooltip": {
|
||||||
"feature": {
|
"feature": {
|
||||||
@ -913,10 +876,7 @@
|
|||||||
"antialiasing": "Anti aliasing",
|
"antialiasing": "Anti aliasing",
|
||||||
"showResultsOn": "Mostra i risultati (attivato)",
|
"showResultsOn": "Mostra i risultati (attivato)",
|
||||||
"showResultsOff": "Mostra i risultati (disattivato)",
|
"showResultsOff": "Mostra i risultati (disattivato)",
|
||||||
"saveMask": "Salva $t(unifiedCanvas.mask)",
|
"saveMask": "Salva $t(unifiedCanvas.mask)"
|
||||||
"coherenceModeGaussianBlur": "Sfocatura Gaussiana",
|
|
||||||
"coherenceModeBoxBlur": "Sfocatura Box",
|
|
||||||
"coherenceModeStaged": "Maschera espansa"
|
|
||||||
},
|
},
|
||||||
"accessibility": {
|
"accessibility": {
|
||||||
"modelSelect": "Seleziona modello",
|
"modelSelect": "Seleziona modello",
|
||||||
@ -1385,8 +1345,7 @@
|
|||||||
"allLoRAsAdded": "Tutti i LoRA aggiunti",
|
"allLoRAsAdded": "Tutti i LoRA aggiunti",
|
||||||
"defaultVAE": "VAE predefinito",
|
"defaultVAE": "VAE predefinito",
|
||||||
"incompatibleBaseModel": "Modello base incompatibile",
|
"incompatibleBaseModel": "Modello base incompatibile",
|
||||||
"loraAlreadyAdded": "LoRA già aggiunto",
|
"loraAlreadyAdded": "LoRA già aggiunto"
|
||||||
"concepts": "Concetti"
|
|
||||||
},
|
},
|
||||||
"invocationCache": {
|
"invocationCache": {
|
||||||
"disable": "Disabilita",
|
"disable": "Disabilita",
|
||||||
@ -1739,25 +1698,6 @@
|
|||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Valuta le generazioni in modo che siano più simili alle immagini con un punteggio estetico elevato, in base ai dati di addestramento."
|
"Valuta le generazioni in modo che siano più simili alle immagini con un punteggio estetico elevato, in base ai dati di addestramento."
|
||||||
]
|
]
|
||||||
},
|
|
||||||
"compositingCoherenceMinDenoise": {
|
|
||||||
"heading": "Livello minimo di riduzione del rumore",
|
|
||||||
"paragraphs": [
|
|
||||||
"Intensità minima di riduzione rumore per la modalità di Coerenza",
|
|
||||||
"L'intensità minima di riduzione del rumore per la regione di coerenza durante l'inpainting o l'outpainting"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"compositingMaskBlur": {
|
|
||||||
"paragraphs": [
|
|
||||||
"Il raggio di sfocatura della maschera."
|
|
||||||
],
|
|
||||||
"heading": "Sfocatura maschera"
|
|
||||||
},
|
|
||||||
"compositingCoherenceEdgeSize": {
|
|
||||||
"heading": "Dimensione del bordo",
|
|
||||||
"paragraphs": [
|
|
||||||
"La dimensione del bordo del passaggio di coerenza."
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"sdxl": {
|
"sdxl": {
|
||||||
@ -1806,12 +1746,7 @@
|
|||||||
"scheduler": "Campionatore",
|
"scheduler": "Campionatore",
|
||||||
"recallParameters": "Richiama i parametri",
|
"recallParameters": "Richiama i parametri",
|
||||||
"noRecallParameters": "Nessun parametro da richiamare trovato",
|
"noRecallParameters": "Nessun parametro da richiamare trovato",
|
||||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
|
||||||
"allPrompts": "Tutti i prompt",
|
|
||||||
"imageDimensions": "Dimensioni dell'immagine",
|
|
||||||
"parameterSet": "Parametro {{parameter}} impostato",
|
|
||||||
"parsingFailed": "Analisi non riuscita",
|
|
||||||
"recallParameter": "Richiama {{label}}"
|
|
||||||
},
|
},
|
||||||
"hrf": {
|
"hrf": {
|
||||||
"enableHrf": "Abilita Correzione Alta Risoluzione",
|
"enableHrf": "Abilita Correzione Alta Risoluzione",
|
||||||
@ -1883,11 +1818,5 @@
|
|||||||
"image": {
|
"image": {
|
||||||
"title": "Immagine"
|
"title": "Immagine"
|
||||||
}
|
}
|
||||||
},
|
|
||||||
"prompt": {
|
|
||||||
"compatibleEmbeddings": "Incorporamenti compatibili",
|
|
||||||
"addPromptTrigger": "Aggiungi parola chiave nel prompt",
|
|
||||||
"noPromptTriggers": "Nessuna parola chiave disponibile",
|
|
||||||
"noMatchingTriggers": "Nessuna parola chiave corrispondente"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -52,7 +52,7 @@
|
|||||||
"accept": "Принять",
|
"accept": "Принять",
|
||||||
"postprocessing": "Постобработка",
|
"postprocessing": "Постобработка",
|
||||||
"txt2img": "Текст в изображение (txt2img)",
|
"txt2img": "Текст в изображение (txt2img)",
|
||||||
"linear": "Линейный вид",
|
"linear": "Линейная обработка",
|
||||||
"dontAskMeAgain": "Больше не спрашивать",
|
"dontAskMeAgain": "Больше не спрашивать",
|
||||||
"areYouSure": "Вы уверены?",
|
"areYouSure": "Вы уверены?",
|
||||||
"random": "Случайное",
|
"random": "Случайное",
|
||||||
@ -117,8 +117,7 @@
|
|||||||
"toResolve": "Чтоб решить",
|
"toResolve": "Чтоб решить",
|
||||||
"copy": "Копировать",
|
"copy": "Копировать",
|
||||||
"localSystem": "Локальная система",
|
"localSystem": "Локальная система",
|
||||||
"aboutDesc": "Используя Invoke для работы? Проверьте это:",
|
"aboutDesc": "Используя Invoke для работы? Проверьте это:"
|
||||||
"add": "Добавить"
|
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"generations": "Генерации",
|
"generations": "Генерации",
|
||||||
@ -156,12 +155,7 @@
|
|||||||
"noImageSelected": "Изображение не выбрано",
|
"noImageSelected": "Изображение не выбрано",
|
||||||
"setCurrentImage": "Установить как текущее изображение",
|
"setCurrentImage": "Установить как текущее изображение",
|
||||||
"starImage": "Добавить в избранное",
|
"starImage": "Добавить в избранное",
|
||||||
"dropToUpload": "$t(gallery.drop) чтоб загрузить",
|
"dropToUpload": "$t(gallery.drop) чтоб загрузить"
|
||||||
"bulkDownloadFailed": "Загрузка не удалась",
|
|
||||||
"bulkDownloadStarting": "Начало загрузки",
|
|
||||||
"bulkDownloadRequested": "Подготовка к скачиванию",
|
|
||||||
"bulkDownloadRequestedDesc": "Ваш запрос на скачивание готовится. Это может занять несколько минут.",
|
|
||||||
"bulkDownloadRequestFailed": "Возникла проблема при подготовке скачивания"
|
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"keyboardShortcuts": "Горячие клавиши",
|
"keyboardShortcuts": "Горячие клавиши",
|
||||||
@ -510,7 +504,7 @@
|
|||||||
"settings": "Настройки",
|
"settings": "Настройки",
|
||||||
"selectModel": "Выберите модель",
|
"selectModel": "Выберите модель",
|
||||||
"syncModels": "Синхронизация моделей",
|
"syncModels": "Синхронизация моделей",
|
||||||
"syncModelsDesc": "Если ваши модели не синхронизированы с серверной частью, вы можете обновить их с помощью этой опции. Обычно это удобно в тех случаях, когда вы добавляете модели в корневую папку InvokeAI или каталог автоимпорта после загрузки приложения.",
|
"syncModelsDesc": "Если ваши модели не синхронизированы с серверной частью, вы можете обновить их, используя эту опцию. Обычно это удобно в тех случаях, когда вы вручную обновляете свой файл \"models.yaml\" или добавляете модели в корневую папку InvokeAI после загрузки приложения.",
|
||||||
"modelUpdateFailed": "Не удалось обновить модель",
|
"modelUpdateFailed": "Не удалось обновить модель",
|
||||||
"modelConversionFailed": "Не удалось сконвертировать модель",
|
"modelConversionFailed": "Не удалось сконвертировать модель",
|
||||||
"modelsMergeFailed": "Не удалось выполнить слияние моделей",
|
"modelsMergeFailed": "Не удалось выполнить слияние моделей",
|
||||||
@ -519,7 +513,7 @@
|
|||||||
"oliveModels": "Модели Olives",
|
"oliveModels": "Модели Olives",
|
||||||
"conversionNotSupported": "Преобразование не поддерживается",
|
"conversionNotSupported": "Преобразование не поддерживается",
|
||||||
"noModels": "Нет моделей",
|
"noModels": "Нет моделей",
|
||||||
"predictionType": "Тип прогноза",
|
"predictionType": "Тип прогноза (для моделей Stable Diffusion 2.x и периодических моделей Stable Diffusion 1.x)",
|
||||||
"quickAdd": "Быстрое добавление",
|
"quickAdd": "Быстрое добавление",
|
||||||
"simpleModelDesc": "Укажите путь к локальной модели Diffusers , локальной модели checkpoint / safetensors, идентификатор репозитория HuggingFace или URL-адрес модели контрольной checkpoint / diffusers.",
|
"simpleModelDesc": "Укажите путь к локальной модели Diffusers , локальной модели checkpoint / safetensors, идентификатор репозитория HuggingFace или URL-адрес модели контрольной checkpoint / diffusers.",
|
||||||
"advanced": "Продвинутый",
|
"advanced": "Продвинутый",
|
||||||
@ -530,33 +524,7 @@
|
|||||||
"customConfigFileLocation": "Расположение пользовательского файла конфигурации",
|
"customConfigFileLocation": "Расположение пользовательского файла конфигурации",
|
||||||
"vaePrecision": "Точность VAE",
|
"vaePrecision": "Точность VAE",
|
||||||
"noModelSelected": "Модель не выбрана",
|
"noModelSelected": "Модель не выбрана",
|
||||||
"configFile": "Файл конфигурации",
|
"configFile": "Файл конфигурации"
|
||||||
"addAll": "Добавить всё",
|
|
||||||
"addModels": "Добавить модели",
|
|
||||||
"cancel": "Отмена",
|
|
||||||
"defaultSettings": "Стандартные настройки",
|
|
||||||
"importQueue": "Импортировать очередь",
|
|
||||||
"metadata": "Метаданные",
|
|
||||||
"imageEncoderModelId": "ID модели-энкодера изображений",
|
|
||||||
"typePhraseHere": "Введите фразы здесь",
|
|
||||||
"advancedImportInfo": "Вкладка «Дополнительно» позволяет вручную настроить основные параметры модели. Используйте эту вкладку только в том случае, если вы уверены, что знаете правильный тип модели и конфигурацию выбранной модели.",
|
|
||||||
"defaultSettingsSaved": "Стандартные настройки сохранены",
|
|
||||||
"edit": "Редактировать",
|
|
||||||
"path": "Путь",
|
|
||||||
"prune": "Удалить",
|
|
||||||
"pruneTooltip": "Удалить готовые импорты из очереди",
|
|
||||||
"removeFromQueue": "Удалить из очереди",
|
|
||||||
"repoVariant": "Вариант репозитория",
|
|
||||||
"scan": "Сканировать",
|
|
||||||
"scanFolder": "Сканировать папку",
|
|
||||||
"scanResults": "Результаты сканирования",
|
|
||||||
"source": "Источник",
|
|
||||||
"triggerPhrases": "Триггерные фразы",
|
|
||||||
"useDefaultSettings": "Использовать стандартные настройки",
|
|
||||||
"modelMetadata": "Метаданные модели",
|
|
||||||
"modelName": "Название модели",
|
|
||||||
"modelSettings": "Настройки модели",
|
|
||||||
"upcastAttention": "Внимание"
|
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"images": "Изображения",
|
"images": "Изображения",
|
||||||
@ -623,7 +591,7 @@
|
|||||||
"hSymmetryStep": "Шаг гор. симметрии",
|
"hSymmetryStep": "Шаг гор. симметрии",
|
||||||
"hidePreview": "Скрыть предпросмотр",
|
"hidePreview": "Скрыть предпросмотр",
|
||||||
"imageToImage": "Изображение в изображение",
|
"imageToImage": "Изображение в изображение",
|
||||||
"denoisingStrength": "Сила зашумления",
|
"denoisingStrength": "Сила шумоподавления",
|
||||||
"copyImage": "Скопировать изображение",
|
"copyImage": "Скопировать изображение",
|
||||||
"showPreview": "Показать предпросмотр",
|
"showPreview": "Показать предпросмотр",
|
||||||
"noiseSettings": "Шум",
|
"noiseSettings": "Шум",
|
||||||
@ -638,8 +606,8 @@
|
|||||||
"clipSkip": "CLIP Пропуск",
|
"clipSkip": "CLIP Пропуск",
|
||||||
"aspectRatio": "Соотношение",
|
"aspectRatio": "Соотношение",
|
||||||
"maskAdjustmentsHeader": "Настройка маски",
|
"maskAdjustmentsHeader": "Настройка маски",
|
||||||
"maskBlur": "Размытие маски",
|
"maskBlur": "Размытие",
|
||||||
"maskBlurMethod": "Метод размытия маски",
|
"maskBlurMethod": "Метод размытия",
|
||||||
"seamLowThreshold": "Низкий",
|
"seamLowThreshold": "Низкий",
|
||||||
"seamHighThreshold": "Высокий",
|
"seamHighThreshold": "Высокий",
|
||||||
"coherencePassHeader": "Порог Coherence",
|
"coherencePassHeader": "Порог Coherence",
|
||||||
@ -698,9 +666,7 @@
|
|||||||
"lockAspectRatio": "Заблокировать соотношение",
|
"lockAspectRatio": "Заблокировать соотношение",
|
||||||
"boxBlur": "Размытие прямоугольника",
|
"boxBlur": "Размытие прямоугольника",
|
||||||
"gaussianBlur": "Размытие по Гауссу",
|
"gaussianBlur": "Размытие по Гауссу",
|
||||||
"remixImage": "Ремикс изображения",
|
"remixImage": "Ремикс изображения"
|
||||||
"coherenceMinDenoise": "Мин. шумоподавление",
|
|
||||||
"coherenceEdgeSize": "Размер края"
|
|
||||||
},
|
},
|
||||||
"settings": {
|
"settings": {
|
||||||
"models": "Модели",
|
"models": "Модели",
|
||||||
@ -783,8 +749,8 @@
|
|||||||
"canceled": "Обработка отменена",
|
"canceled": "Обработка отменена",
|
||||||
"problemCopyingImageLink": "Не удалось скопировать ссылку на изображение",
|
"problemCopyingImageLink": "Не удалось скопировать ссылку на изображение",
|
||||||
"uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG",
|
"uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG",
|
||||||
"parameterNotSet": "Параметр {{parameter}} не задан",
|
"parameterNotSet": "Параметр не задан",
|
||||||
"parameterSet": "Параметр {{parameter}} задан",
|
"parameterSet": "Параметр задан",
|
||||||
"nodesLoaded": "Узлы загружены",
|
"nodesLoaded": "Узлы загружены",
|
||||||
"problemCopyingImage": "Не удается скопировать изображение",
|
"problemCopyingImage": "Не удается скопировать изображение",
|
||||||
"nodesLoadedFailed": "Не удалось загрузить Узлы",
|
"nodesLoadedFailed": "Не удалось загрузить Узлы",
|
||||||
@ -837,10 +803,7 @@
|
|||||||
"problemImportingMask": "Проблема с импортом маски",
|
"problemImportingMask": "Проблема с импортом маски",
|
||||||
"problemDownloadingImage": "Не удается скачать изображение",
|
"problemDownloadingImage": "Не удается скачать изображение",
|
||||||
"uploadInitialImage": "Загрузить начальное изображение",
|
"uploadInitialImage": "Загрузить начальное изображение",
|
||||||
"resetInitialImage": "Сбросить начальное изображение",
|
"resetInitialImage": "Сбросить начальное изображение"
|
||||||
"prunedQueue": "Урезанная очередь",
|
|
||||||
"modelImportCanceled": "Импорт модели отменен",
|
|
||||||
"modelImportRemoved": "Импорт модели удален"
|
|
||||||
},
|
},
|
||||||
"tooltip": {
|
"tooltip": {
|
||||||
"feature": {
|
"feature": {
|
||||||
@ -1182,11 +1145,7 @@
|
|||||||
"reorderLinearView": "Изменить порядок линейного просмотра",
|
"reorderLinearView": "Изменить порядок линейного просмотра",
|
||||||
"viewMode": "Использовать в линейном представлении",
|
"viewMode": "Использовать в линейном представлении",
|
||||||
"editMode": "Открыть в редакторе узлов",
|
"editMode": "Открыть в редакторе узлов",
|
||||||
"resetToDefaultValue": "Сбросить к стандартному значкнию",
|
"resetToDefaultValue": "Сбросить к стандартному значкнию"
|
||||||
"latentsField": "Латенты",
|
|
||||||
"latentsCollectionDescription": "Латенты могут передаваться между узлами.",
|
|
||||||
"latentsPolymorphicDescription": "Латенты могут передаваться между узлами.",
|
|
||||||
"latentsFieldDescription": "Латенты могут передаваться между узлами."
|
|
||||||
},
|
},
|
||||||
"controlnet": {
|
"controlnet": {
|
||||||
"amult": "a_mult",
|
"amult": "a_mult",
|
||||||
@ -1335,8 +1294,7 @@
|
|||||||
},
|
},
|
||||||
"paramScheduler": {
|
"paramScheduler": {
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Планировщик, используемый в процессе генерации.",
|
"Планировщик определяет, как итеративно добавлять шум к изображению или как обновлять образец на основе выходных данных модели."
|
||||||
"Каждый планировщик определяет, как итеративно добавлять шум к изображению или как обновлять образец на основе выходных данных модели."
|
|
||||||
],
|
],
|
||||||
"heading": "Планировщик"
|
"heading": "Планировщик"
|
||||||
},
|
},
|
||||||
@ -1389,7 +1347,7 @@
|
|||||||
"compositingCoherenceMode": {
|
"compositingCoherenceMode": {
|
||||||
"heading": "Режим",
|
"heading": "Режим",
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Метод, используемый для создания связного изображения с вновь созданной замаскированной областью."
|
"Режим прохождения когерентности."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"paramSeed": {
|
"paramSeed": {
|
||||||
@ -1407,7 +1365,7 @@
|
|||||||
},
|
},
|
||||||
"controlNetBeginEnd": {
|
"controlNetBeginEnd": {
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Часть процесса шумоподавления, к которой будет применен адаптер контроля.",
|
"На каких этапах процесса шумоподавления будет применена ControlNet.",
|
||||||
"ControlNet, применяемые в начале процесса, направляют композицию, а ControlNet, применяемые в конце, направляют детали."
|
"ControlNet, применяемые в начале процесса, направляют композицию, а ControlNet, применяемые в конце, направляют детали."
|
||||||
],
|
],
|
||||||
"heading": "Процент начала/конца шага"
|
"heading": "Процент начала/конца шага"
|
||||||
@ -1423,8 +1381,8 @@
|
|||||||
},
|
},
|
||||||
"clipSkip": {
|
"clipSkip": {
|
||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Сколько слоев модели CLIP пропустить.",
|
"Выберите, сколько слоев модели CLIP нужно пропустить.",
|
||||||
"Некоторые модели лучше подходят для использования с CLIP Skip."
|
"Некоторые модели работают лучше с определенными настройками пропуска CLIP."
|
||||||
],
|
],
|
||||||
"heading": "CLIP пропуск"
|
"heading": "CLIP пропуск"
|
||||||
},
|
},
|
||||||
@ -1521,25 +1479,6 @@
|
|||||||
"paragraphs": [
|
"paragraphs": [
|
||||||
"Более высокий вес LoRA приведет к большему влиянию на конечное изображение."
|
"Более высокий вес LoRA приведет к большему влиянию на конечное изображение."
|
||||||
]
|
]
|
||||||
},
|
|
||||||
"compositingMaskBlur": {
|
|
||||||
"heading": "Размытие маски",
|
|
||||||
"paragraphs": [
|
|
||||||
"Радиус размытия маски."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"compositingCoherenceMinDenoise": {
|
|
||||||
"heading": "Минимальное шумоподавление",
|
|
||||||
"paragraphs": [
|
|
||||||
"Минимальный уровень шумоподавления для режима Coherence",
|
|
||||||
"Минимальный уровень шумоподавления для области когерентности при перерисовывании или дорисовке"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"compositingCoherenceEdgeSize": {
|
|
||||||
"heading": "Размер края",
|
|
||||||
"paragraphs": [
|
|
||||||
"Размер края прохода когерентности."
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -1570,12 +1509,7 @@
|
|||||||
"steps": "Шаги",
|
"steps": "Шаги",
|
||||||
"scheduler": "Планировщик",
|
"scheduler": "Планировщик",
|
||||||
"noRecallParameters": "Параметры для вызова не найдены",
|
"noRecallParameters": "Параметры для вызова не найдены",
|
||||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
|
||||||
"parameterSet": "Параметр {{parameter}} установлен",
|
|
||||||
"parsingFailed": "Не удалось выполнить синтаксический анализ",
|
|
||||||
"recallParameter": "Отозвать {{label}}",
|
|
||||||
"allPrompts": "Все запросы",
|
|
||||||
"imageDimensions": "Размеры изображения"
|
|
||||||
},
|
},
|
||||||
"queue": {
|
"queue": {
|
||||||
"status": "Статус",
|
"status": "Статус",
|
||||||
@ -1654,11 +1588,10 @@
|
|||||||
"denoisingStrength": "Шумоподавление",
|
"denoisingStrength": "Шумоподавление",
|
||||||
"refinermodel": "Модель перерисовщик",
|
"refinermodel": "Модель перерисовщик",
|
||||||
"posAestheticScore": "Положительная эстетическая оценка",
|
"posAestheticScore": "Положительная эстетическая оценка",
|
||||||
"concatPromptStyle": "Связывание запроса и стиля",
|
"concatPromptStyle": "Объединение запроса и стиля",
|
||||||
"loading": "Загрузка...",
|
"loading": "Загрузка...",
|
||||||
"steps": "Шаги",
|
"steps": "Шаги",
|
||||||
"posStylePrompt": "Запрос стиля",
|
"posStylePrompt": "Запрос стиля"
|
||||||
"freePromptStyle": "Ручной запрос стиля"
|
|
||||||
},
|
},
|
||||||
"invocationCache": {
|
"invocationCache": {
|
||||||
"useCache": "Использовать кэш",
|
"useCache": "Использовать кэш",
|
||||||
@ -1745,8 +1678,7 @@
|
|||||||
"allLoRAsAdded": "Все LoRA добавлены",
|
"allLoRAsAdded": "Все LoRA добавлены",
|
||||||
"defaultVAE": "Стандартное VAE",
|
"defaultVAE": "Стандартное VAE",
|
||||||
"incompatibleBaseModel": "Несовместимая базовая модель",
|
"incompatibleBaseModel": "Несовместимая базовая модель",
|
||||||
"loraAlreadyAdded": "LoRA уже добавлена",
|
"loraAlreadyAdded": "LoRA уже добавлена"
|
||||||
"concepts": "Концепты"
|
|
||||||
},
|
},
|
||||||
"app": {
|
"app": {
|
||||||
"storeNotInitialized": "Магазин не инициализирован"
|
"storeNotInitialized": "Магазин не инициализирован"
|
||||||
@ -1764,7 +1696,7 @@
|
|||||||
},
|
},
|
||||||
"generation": {
|
"generation": {
|
||||||
"title": "Генерация",
|
"title": "Генерация",
|
||||||
"conceptsTab": "LoRA",
|
"conceptsTab": "Концепты",
|
||||||
"modelTab": "Модель"
|
"modelTab": "Модель"
|
||||||
},
|
},
|
||||||
"advanced": {
|
"advanced": {
|
||||||
|
@ -5,55 +5,18 @@ import openapiTS from 'openapi-typescript';
|
|||||||
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
|
const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json';
|
||||||
const OUTPUT_FILE = 'src/services/api/schema.ts';
|
const OUTPUT_FILE = 'src/services/api/schema.ts';
|
||||||
|
|
||||||
async function generateTypes(schema) {
|
async function main() {
|
||||||
process.stdout.write(`Generating types ${OUTPUT_FILE}...`);
|
process.stdout.write(`Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...`);
|
||||||
const types = await openapiTS(schema, {
|
const types = await openapiTS(OPENAPI_URL, {
|
||||||
exportType: true,
|
exportType: true,
|
||||||
transform: (schemaObject) => {
|
transform: (schemaObject) => {
|
||||||
if ('format' in schemaObject && schemaObject.format === 'binary') {
|
if ('format' in schemaObject && schemaObject.format === 'binary') {
|
||||||
return schemaObject.nullable ? 'Blob | null' : 'Blob';
|
return schemaObject.nullable ? 'Blob | null' : 'Blob';
|
||||||
}
|
}
|
||||||
if (schemaObject.title === 'MetadataField') {
|
|
||||||
// This is `Record<string, never>` by default, but it actually accepts any a dict of any valid JSON value.
|
|
||||||
return 'Record<string, unknown>';
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
fs.writeFileSync(OUTPUT_FILE, types);
|
fs.writeFileSync(OUTPUT_FILE, types);
|
||||||
process.stdout.write(`\nOK!\r\n`);
|
process.stdout.write(`\nOK!\r\n`);
|
||||||
}
|
}
|
||||||
|
|
||||||
async function main() {
|
|
||||||
const encoding = 'utf-8';
|
|
||||||
|
|
||||||
if (process.stdin.isTTY) {
|
|
||||||
// Handle generating types with an arg (e.g. URL or path to file)
|
|
||||||
if (process.argv.length > 3) {
|
|
||||||
console.error('Usage: typegen.js <openapi.json>');
|
|
||||||
process.exit(1);
|
|
||||||
}
|
|
||||||
if (process.argv[2]) {
|
|
||||||
const schema = new Buffer.from(process.argv[2], encoding);
|
|
||||||
generateTypes(schema);
|
|
||||||
} else {
|
|
||||||
generateTypes(OPENAPI_URL);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Handle generating types from stdin
|
|
||||||
let schema = '';
|
|
||||||
process.stdin.setEncoding(encoding);
|
|
||||||
|
|
||||||
process.stdin.on('readable', function () {
|
|
||||||
const chunk = process.stdin.read();
|
|
||||||
if (chunk !== null) {
|
|
||||||
schema += chunk;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
process.stdin.on('end', function () {
|
|
||||||
generateTypes(JSON.parse(schema));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
main();
|
main();
|
||||||
|
@ -38,7 +38,7 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi
|
|||||||
type: 'image/png',
|
type: 'image/png',
|
||||||
}),
|
}),
|
||||||
image_category: 'control',
|
image_category: 'control',
|
||||||
is_intermediate: true,
|
is_intermediate: false,
|
||||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||||
crop_visible: false,
|
crop_visible: false,
|
||||||
postUploadAction: {
|
postUploadAction: {
|
||||||
|
@ -48,7 +48,7 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis
|
|||||||
type: 'image/png',
|
type: 'image/png',
|
||||||
}),
|
}),
|
||||||
image_category: 'mask',
|
image_category: 'mask',
|
||||||
is_intermediate: true,
|
is_intermediate: false,
|
||||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||||
crop_visible: false,
|
crop_visible: false,
|
||||||
postUploadAction: {
|
postUploadAction: {
|
||||||
|
@ -101,7 +101,7 @@ export const addEnqueueRequestedCanvasListener = (startAppListening: AppStartLis
|
|||||||
).unwrap();
|
).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
const graph = await buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
|
const graph = buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
|
||||||
|
|
||||||
log.debug({ graph: parseify(graph) }, `Canvas graph built`);
|
log.debug({ graph: parseify(graph) }, `Canvas graph built`);
|
||||||
|
|
||||||
|
@ -20,15 +20,15 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
|||||||
|
|
||||||
if (model && model.base === 'sdxl') {
|
if (model && model.base === 'sdxl') {
|
||||||
if (action.payload.tabName === 'txt2img') {
|
if (action.payload.tabName === 'txt2img') {
|
||||||
graph = await buildLinearSDXLTextToImageGraph(state);
|
graph = buildLinearSDXLTextToImageGraph(state);
|
||||||
} else {
|
} else {
|
||||||
graph = await buildLinearSDXLImageToImageGraph(state);
|
graph = buildLinearSDXLImageToImageGraph(state);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (action.payload.tabName === 'txt2img') {
|
if (action.payload.tabName === 'txt2img') {
|
||||||
graph = await buildLinearTextToImageGraph(state);
|
graph = buildLinearTextToImageGraph(state);
|
||||||
} else {
|
} else {
|
||||||
graph = await buildLinearImageToImageGraph(state);
|
graph = buildLinearImageToImageGraph(state);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,7 +21,6 @@ import { makeToast } from 'features/system/util/makeToast';
|
|||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
import { modelsApi } from 'services/api/endpoints/models';
|
import { modelsApi } from 'services/api/endpoints/models';
|
||||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
|
export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
@ -37,64 +36,61 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
|||||||
|
|
||||||
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
|
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
|
||||||
|
|
||||||
if (!modelConfig) {
|
if (!modelConfig || !modelConfig.default_settings) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isNonRefinerMainModelConfig(modelConfig) && modelConfig.default_settings) {
|
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = modelConfig.default_settings;
|
||||||
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } =
|
|
||||||
modelConfig.default_settings;
|
|
||||||
|
|
||||||
if (vae) {
|
if (vae) {
|
||||||
// we store this as "default" within default settings
|
// we store this as "default" within default settings
|
||||||
// to distinguish it from no default set
|
// to distinguish it from no default set
|
||||||
if (vae === 'default') {
|
if (vae === 'default') {
|
||||||
dispatch(vaeSelected(null));
|
dispatch(vaeSelected(null));
|
||||||
} else {
|
} else {
|
||||||
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
|
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
|
||||||
const vaeArray = map(data?.entities);
|
const vaeArray = map(data?.entities);
|
||||||
const validVae = vaeArray.find((model) => model.key === vae);
|
const validVae = vaeArray.find((model) => model.key === vae);
|
||||||
|
|
||||||
const result = zParameterVAEModel.safeParse(validVae);
|
const result = zParameterVAEModel.safeParse(validVae);
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
return;
|
return;
|
||||||
}
|
|
||||||
dispatch(vaeSelected(result.data));
|
|
||||||
}
|
}
|
||||||
|
dispatch(vaeSelected(result.data));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (vae_precision) {
|
|
||||||
if (isParameterPrecision(vae_precision)) {
|
|
||||||
dispatch(vaePrecisionChanged(vae_precision));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cfg_scale) {
|
|
||||||
if (isParameterCFGScale(cfg_scale)) {
|
|
||||||
dispatch(setCfgScale(cfg_scale));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cfg_rescale_multiplier) {
|
|
||||||
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
|
|
||||||
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (steps) {
|
|
||||||
if (isParameterSteps(steps)) {
|
|
||||||
dispatch(setSteps(steps));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (scheduler) {
|
|
||||||
if (isParameterScheduler(scheduler)) {
|
|
||||||
dispatch(setScheduler(scheduler));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (vae_precision) {
|
||||||
|
if (isParameterPrecision(vae_precision)) {
|
||||||
|
dispatch(vaePrecisionChanged(vae_precision));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cfg_scale) {
|
||||||
|
if (isParameterCFGScale(cfg_scale)) {
|
||||||
|
dispatch(setCfgScale(cfg_scale));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cfg_rescale_multiplier) {
|
||||||
|
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
|
||||||
|
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (steps) {
|
||||||
|
if (isParameterSteps(steps)) {
|
||||||
|
dispatch(setSteps(steps));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (scheduler) {
|
||||||
|
if (isParameterScheduler(scheduler)) {
|
||||||
|
dispatch(setScheduler(scheduler));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -20,7 +20,7 @@ const sx: ChakraProps['sx'] = {
|
|||||||
'.react-colorful__hue-pointer': colorPickerPointerStyles,
|
'.react-colorful__hue-pointer': colorPickerPointerStyles,
|
||||||
'.react-colorful__saturation-pointer': colorPickerPointerStyles,
|
'.react-colorful__saturation-pointer': colorPickerPointerStyles,
|
||||||
'.react-colorful__alpha-pointer': colorPickerPointerStyles,
|
'.react-colorful__alpha-pointer': colorPickerPointerStyles,
|
||||||
gap: 5,
|
gap: 2,
|
||||||
flexDir: 'column',
|
flexDir: 'column',
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -39,8 +39,8 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
|||||||
<Flex sx={sx}>
|
<Flex sx={sx}>
|
||||||
<RgbaColorPicker color={color} onChange={onChange} style={colorPickerStyles} {...rest} />
|
<RgbaColorPicker color={color} onChange={onChange} style={colorPickerStyles} {...rest} />
|
||||||
{withNumberInput && (
|
{withNumberInput && (
|
||||||
<Flex gap={5}>
|
<Flex>
|
||||||
<FormControl gap={0}>
|
<FormControl>
|
||||||
<FormLabel>{t('common.red')}</FormLabel>
|
<FormLabel>{t('common.red')}</FormLabel>
|
||||||
<CompositeNumberInput
|
<CompositeNumberInput
|
||||||
value={color.r}
|
value={color.r}
|
||||||
@ -52,7 +52,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
|||||||
defaultValue={90}
|
defaultValue={90}
|
||||||
/>
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<FormControl gap={0}>
|
<FormControl>
|
||||||
<FormLabel>{t('common.green')}</FormLabel>
|
<FormLabel>{t('common.green')}</FormLabel>
|
||||||
<CompositeNumberInput
|
<CompositeNumberInput
|
||||||
value={color.g}
|
value={color.g}
|
||||||
@ -64,7 +64,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
|||||||
defaultValue={90}
|
defaultValue={90}
|
||||||
/>
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<FormControl gap={0}>
|
<FormControl>
|
||||||
<FormLabel>{t('common.blue')}</FormLabel>
|
<FormLabel>{t('common.blue')}</FormLabel>
|
||||||
<CompositeNumberInput
|
<CompositeNumberInput
|
||||||
value={color.b}
|
value={color.b}
|
||||||
@ -76,7 +76,7 @@ const IAIColorPicker = (props: IAIColorPickerProps) => {
|
|||||||
defaultValue={255}
|
defaultValue={255}
|
||||||
/>
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<FormControl gap={0}>
|
<FormControl>
|
||||||
<FormLabel>{t('common.alpha')}</FormLabel>
|
<FormLabel>{t('common.alpha')}</FormLabel>
|
||||||
<CompositeNumberInput
|
<CompositeNumberInput
|
||||||
value={color.a}
|
value={color.a}
|
||||||
|
@ -2,7 +2,7 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|||||||
import type { EntityState } from '@reduxjs/toolkit';
|
import type { EntityState } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import type { GroupBase } from 'chakra-react-select';
|
import type { GroupBase } from 'chakra-react-select';
|
||||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||||
import { groupBy, map, reduce } from 'lodash-es';
|
import { groupBy, map, reduce } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -10,7 +10,7 @@ import type { AnyModelConfig } from 'services/api/types';
|
|||||||
|
|
||||||
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
|
||||||
modelEntities: EntityState<T, string> | undefined;
|
modelEntities: EntityState<T, string> | undefined;
|
||||||
selectedModel?: ModelIdentifierField | null;
|
selectedModel?: ModelIdentifierWithBase | null;
|
||||||
onChange: (value: T | null) => void;
|
onChange: (value: T | null) => void;
|
||||||
getIsDisabled?: (model: T) => boolean;
|
getIsDisabled?: (model: T) => boolean;
|
||||||
isLoading?: boolean;
|
isLoading?: boolean;
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
import type { EntityState } from '@reduxjs/toolkit';
|
import type { EntityState } from '@reduxjs/toolkit';
|
||||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||||
import { map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -8,7 +8,7 @@ import type { AnyModelConfig } from 'services/api/types';
|
|||||||
|
|
||||||
type UseModelComboboxArg<T extends AnyModelConfig> = {
|
type UseModelComboboxArg<T extends AnyModelConfig> = {
|
||||||
modelEntities: EntityState<T, string> | undefined;
|
modelEntities: EntityState<T, string> | undefined;
|
||||||
selectedModel?: ModelIdentifierField | null;
|
selectedModel?: ModelIdentifierWithBase | null;
|
||||||
onChange: (value: T | null) => void;
|
onChange: (value: T | null) => void;
|
||||||
getIsDisabled?: (model: T) => boolean;
|
getIsDisabled?: (model: T) => boolean;
|
||||||
optionsFilter?: (model: T) => boolean;
|
optionsFilter?: (model: T) => boolean;
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import type { Item } from '@invoke-ai/ui-library';
|
import type { Item } from '@invoke-ai/ui-library';
|
||||||
import type { EntityState } from '@reduxjs/toolkit';
|
import type { EntityState } from '@reduxjs/toolkit';
|
||||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||||
import { filter } from 'lodash-es';
|
import { filter } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
@ -11,7 +11,7 @@ import type { AnyModelConfig } from 'services/api/types';
|
|||||||
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
|
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
|
||||||
data: EntityState<T, string> | undefined;
|
data: EntityState<T, string> | undefined;
|
||||||
isLoading: boolean;
|
isLoading: boolean;
|
||||||
selectedModel?: ModelIdentifierField | null;
|
selectedModel?: ModelIdentifierWithBase | null;
|
||||||
onChange: (value: T | null) => void;
|
onChange: (value: T | null) => void;
|
||||||
modelFilter?: (model: T) => boolean;
|
modelFilter?: (model: T) => boolean;
|
||||||
isModelDisabled?: (model: T) => boolean;
|
isModelDisabled?: (model: T) => boolean;
|
||||||
|
@ -29,7 +29,7 @@ import { Layer, Stage } from 'react-konva';
|
|||||||
import IAICanvasBoundingBoxOverlay from './IAICanvasBoundingBoxOverlay';
|
import IAICanvasBoundingBoxOverlay from './IAICanvasBoundingBoxOverlay';
|
||||||
import IAICanvasGrid from './IAICanvasGrid';
|
import IAICanvasGrid from './IAICanvasGrid';
|
||||||
import IAICanvasIntermediateImage from './IAICanvasIntermediateImage';
|
import IAICanvasIntermediateImage from './IAICanvasIntermediateImage';
|
||||||
import IAICanvasMaskCompositor from './IAICanvasMaskCompositor';
|
import IAICanvasMaskCompositer from './IAICanvasMaskCompositer';
|
||||||
import IAICanvasMaskLines from './IAICanvasMaskLines';
|
import IAICanvasMaskLines from './IAICanvasMaskLines';
|
||||||
import IAICanvasObjectRenderer from './IAICanvasObjectRenderer';
|
import IAICanvasObjectRenderer from './IAICanvasObjectRenderer';
|
||||||
import IAICanvasStagingArea from './IAICanvasStagingArea';
|
import IAICanvasStagingArea from './IAICanvasStagingArea';
|
||||||
@ -176,7 +176,7 @@ const IAICanvas = () => {
|
|||||||
</Layer>
|
</Layer>
|
||||||
<Layer id="mask" visible={isMaskEnabled && !isStaging} listening={false}>
|
<Layer id="mask" visible={isMaskEnabled && !isStaging} listening={false}>
|
||||||
<IAICanvasMaskLines visible={true} listening={false} />
|
<IAICanvasMaskLines visible={true} listening={false} />
|
||||||
<IAICanvasMaskCompositor listening={false} />
|
<IAICanvasMaskCompositer listening={false} />
|
||||||
</Layer>
|
</Layer>
|
||||||
<Layer listening={false}>
|
<Layer listening={false}>
|
||||||
<IAICanvasBoundingBoxOverlay />
|
<IAICanvasBoundingBoxOverlay />
|
||||||
|
@ -16,9 +16,9 @@ const canvasMaskCompositerSelector = createMemoizedSelector(selectCanvasSlice, (
|
|||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
type IAICanvasMaskCompositorProps = RectConfig;
|
type IAICanvasMaskCompositerProps = RectConfig;
|
||||||
|
|
||||||
const IAICanvasMaskCompositor = (props: IAICanvasMaskCompositorProps) => {
|
const IAICanvasMaskCompositer = (props: IAICanvasMaskCompositerProps) => {
|
||||||
const { ...rest } = props;
|
const { ...rest } = props;
|
||||||
|
|
||||||
const { stageCoordinates, stageDimensions } = useAppSelector(canvasMaskCompositerSelector);
|
const { stageCoordinates, stageDimensions } = useAppSelector(canvasMaskCompositerSelector);
|
||||||
@ -89,4 +89,4 @@ const IAICanvasMaskCompositor = (props: IAICanvasMaskCompositorProps) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(IAICanvasMaskCompositor);
|
export default memo(IAICanvasMaskCompositer);
|
@ -5,7 +5,6 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { stagingAreaImageSaved } from 'features/canvas/store/actions';
|
import { stagingAreaImageSaved } from 'features/canvas/store/actions';
|
||||||
import {
|
import {
|
||||||
commitStagingAreaImage,
|
commitStagingAreaImage,
|
||||||
discardStagedImage,
|
|
||||||
discardStagedImages,
|
discardStagedImages,
|
||||||
nextStagingAreaImage,
|
nextStagingAreaImage,
|
||||||
prevStagingAreaImage,
|
prevStagingAreaImage,
|
||||||
@ -23,7 +22,6 @@ import {
|
|||||||
PiEyeBold,
|
PiEyeBold,
|
||||||
PiEyeSlashBold,
|
PiEyeSlashBold,
|
||||||
PiFloppyDiskBold,
|
PiFloppyDiskBold,
|
||||||
PiTrashSimpleBold,
|
|
||||||
PiXBold,
|
PiXBold,
|
||||||
} from 'react-icons/pi';
|
} from 'react-icons/pi';
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
@ -46,40 +44,6 @@ const selector = createMemoizedSelector(selectCanvasSlice, (canvas) => {
|
|||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
const ClearStagingIntermediatesIconButton = () => {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const handleDiscardStagingArea = useCallback(() => {
|
|
||||||
dispatch(discardStagedImages());
|
|
||||||
}, [dispatch]);
|
|
||||||
|
|
||||||
const handleDiscardStagingImage = useCallback(() => {
|
|
||||||
dispatch(discardStagedImage());
|
|
||||||
}, [dispatch]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<>
|
|
||||||
<IconButton
|
|
||||||
tooltip={`${t('unifiedCanvas.discardCurrent')}`}
|
|
||||||
aria-label={t('unifiedCanvas.discardCurrent')}
|
|
||||||
icon={<PiXBold />}
|
|
||||||
onClick={handleDiscardStagingImage}
|
|
||||||
colorScheme="invokeBlue"
|
|
||||||
fontSize={16}
|
|
||||||
/>
|
|
||||||
<IconButton
|
|
||||||
tooltip={`${t('unifiedCanvas.discardAll')} (Esc)`}
|
|
||||||
aria-label={t('unifiedCanvas.discardAll')}
|
|
||||||
icon={<PiTrashSimpleBold />}
|
|
||||||
onClick={handleDiscardStagingArea}
|
|
||||||
colorScheme="error"
|
|
||||||
fontSize={16}
|
|
||||||
/>
|
|
||||||
</>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const IAICanvasStagingAreaToolbar = () => {
|
const IAICanvasStagingAreaToolbar = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { currentStagingAreaImage, shouldShowStagingImage, currentIndex, total } = useAppSelector(selector);
|
const { currentStagingAreaImage, shouldShowStagingImage, currentIndex, total } = useAppSelector(selector);
|
||||||
@ -221,7 +185,14 @@ const IAICanvasStagingAreaToolbar = () => {
|
|||||||
onClick={handleSaveToGallery}
|
onClick={handleSaveToGallery}
|
||||||
colorScheme="invokeBlue"
|
colorScheme="invokeBlue"
|
||||||
/>
|
/>
|
||||||
<ClearStagingIntermediatesIconButton />
|
<IconButton
|
||||||
|
tooltip={`${t('unifiedCanvas.discardAll')} (Esc)`}
|
||||||
|
aria-label={t('unifiedCanvas.discardAll')}
|
||||||
|
icon={<PiXBold />}
|
||||||
|
onClick={handleDiscardStagingArea}
|
||||||
|
colorScheme="error"
|
||||||
|
fontSize={20}
|
||||||
|
/>
|
||||||
</ButtonGroup>
|
</ButtonGroup>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
|
@ -18,7 +18,6 @@ import {
|
|||||||
setShouldAutoSave,
|
setShouldAutoSave,
|
||||||
setShouldCropToBoundingBoxOnSave,
|
setShouldCropToBoundingBoxOnSave,
|
||||||
setShouldDarkenOutsideBoundingBox,
|
setShouldDarkenOutsideBoundingBox,
|
||||||
setShouldInvertBrushSizeScrollDirection,
|
|
||||||
setShouldRestrictStrokesToBox,
|
setShouldRestrictStrokesToBox,
|
||||||
setShouldShowCanvasDebugInfo,
|
setShouldShowCanvasDebugInfo,
|
||||||
setShouldShowGrid,
|
setShouldShowGrid,
|
||||||
@ -41,7 +40,6 @@ const IAICanvasSettingsButtonPopover = () => {
|
|||||||
const shouldAutoSave = useAppSelector((s) => s.canvas.shouldAutoSave);
|
const shouldAutoSave = useAppSelector((s) => s.canvas.shouldAutoSave);
|
||||||
const shouldCropToBoundingBoxOnSave = useAppSelector((s) => s.canvas.shouldCropToBoundingBoxOnSave);
|
const shouldCropToBoundingBoxOnSave = useAppSelector((s) => s.canvas.shouldCropToBoundingBoxOnSave);
|
||||||
const shouldDarkenOutsideBoundingBox = useAppSelector((s) => s.canvas.shouldDarkenOutsideBoundingBox);
|
const shouldDarkenOutsideBoundingBox = useAppSelector((s) => s.canvas.shouldDarkenOutsideBoundingBox);
|
||||||
const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
|
|
||||||
const shouldShowCanvasDebugInfo = useAppSelector((s) => s.canvas.shouldShowCanvasDebugInfo);
|
const shouldShowCanvasDebugInfo = useAppSelector((s) => s.canvas.shouldShowCanvasDebugInfo);
|
||||||
const shouldShowGrid = useAppSelector((s) => s.canvas.shouldShowGrid);
|
const shouldShowGrid = useAppSelector((s) => s.canvas.shouldShowGrid);
|
||||||
const shouldShowIntermediates = useAppSelector((s) => s.canvas.shouldShowIntermediates);
|
const shouldShowIntermediates = useAppSelector((s) => s.canvas.shouldShowIntermediates);
|
||||||
@ -78,10 +76,6 @@ const IAICanvasSettingsButtonPopover = () => {
|
|||||||
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldDarkenOutsideBoundingBox(e.target.checked)),
|
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldDarkenOutsideBoundingBox(e.target.checked)),
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
const handleChangeShouldInvertBrushSizeScrollDirection = useCallback(
|
|
||||||
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldInvertBrushSizeScrollDirection(e.target.checked)),
|
|
||||||
[dispatch]
|
|
||||||
);
|
|
||||||
const handleChangeShouldAutoSave = useCallback(
|
const handleChangeShouldAutoSave = useCallback(
|
||||||
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldAutoSave(e.target.checked)),
|
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldAutoSave(e.target.checked)),
|
||||||
[dispatch]
|
[dispatch]
|
||||||
@ -150,13 +144,6 @@ const IAICanvasSettingsButtonPopover = () => {
|
|||||||
<FormLabel>{t('unifiedCanvas.limitStrokesToBox')}</FormLabel>
|
<FormLabel>{t('unifiedCanvas.limitStrokesToBox')}</FormLabel>
|
||||||
<Checkbox isChecked={shouldRestrictStrokesToBox} onChange={handleChangeShouldRestrictStrokesToBox} />
|
<Checkbox isChecked={shouldRestrictStrokesToBox} onChange={handleChangeShouldRestrictStrokesToBox} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('unifiedCanvas.invertBrushSizeScrollDirection')}</FormLabel>
|
|
||||||
<Checkbox
|
|
||||||
isChecked={shouldInvertBrushSizeScrollDirection}
|
|
||||||
onChange={handleChangeShouldInvertBrushSizeScrollDirection}
|
|
||||||
/>
|
|
||||||
</FormControl>
|
|
||||||
<FormControl>
|
<FormControl>
|
||||||
<FormLabel>{t('unifiedCanvas.showCanvasDebugInfo')}</FormLabel>
|
<FormLabel>{t('unifiedCanvas.showCanvasDebugInfo')}</FormLabel>
|
||||||
<Checkbox isChecked={shouldShowCanvasDebugInfo} onChange={handleChangeShouldShowCanvasDebugInfo} />
|
<Checkbox isChecked={shouldShowCanvasDebugInfo} onChange={handleChangeShouldShowCanvasDebugInfo} />
|
||||||
|
@ -15,7 +15,6 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
|||||||
const stageScale = useAppSelector((s) => s.canvas.stageScale);
|
const stageScale = useAppSelector((s) => s.canvas.stageScale);
|
||||||
const isMoveStageKeyHeld = useStore($isMoveStageKeyHeld);
|
const isMoveStageKeyHeld = useStore($isMoveStageKeyHeld);
|
||||||
const brushSize = useAppSelector((s) => s.canvas.brushSize);
|
const brushSize = useAppSelector((s) => s.canvas.brushSize);
|
||||||
const shouldInvertBrushSizeScrollDirection = useAppSelector((s) => s.canvas.shouldInvertBrushSizeScrollDirection);
|
|
||||||
|
|
||||||
return useCallback(
|
return useCallback(
|
||||||
(e: KonvaEventObject<WheelEvent>) => {
|
(e: KonvaEventObject<WheelEvent>) => {
|
||||||
@ -29,16 +28,10 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
|||||||
// checking for ctrl key is pressed or not,
|
// checking for ctrl key is pressed or not,
|
||||||
// so that brush size can be controlled using ctrl + scroll up/down
|
// so that brush size can be controlled using ctrl + scroll up/down
|
||||||
|
|
||||||
// Invert the delta if the property is set to true
|
|
||||||
let delta = e.evt.deltaY;
|
|
||||||
if (shouldInvertBrushSizeScrollDirection) {
|
|
||||||
delta = -delta;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ($ctrl.get() || $meta.get()) {
|
if ($ctrl.get() || $meta.get()) {
|
||||||
// This equation was derived by fitting a curve to the desired brush sizes and deltas
|
// This equation was derived by fitting a curve to the desired brush sizes and deltas
|
||||||
// see https://github.com/invoke-ai/InvokeAI/pull/5542#issuecomment-1915847565
|
// see https://github.com/invoke-ai/InvokeAI/pull/5542#issuecomment-1915847565
|
||||||
const targetDelta = Math.sign(delta) * 0.7363 * Math.pow(1.0394, brushSize);
|
const targetDelta = Math.sign(e.evt.deltaY) * 0.7363 * Math.pow(1.0394, brushSize);
|
||||||
// This needs to be clamped to prevent the delta from getting too large
|
// This needs to be clamped to prevent the delta from getting too large
|
||||||
const finalDelta = clamp(targetDelta, -20, 20);
|
const finalDelta = clamp(targetDelta, -20, 20);
|
||||||
// The new brush size is also clamped to prevent it from getting too large or small
|
// The new brush size is also clamped to prevent it from getting too large or small
|
||||||
@ -74,7 +67,7 @@ const useCanvasWheel = (stageRef: MutableRefObject<Konva.Stage | null>) => {
|
|||||||
dispatch(setStageCoordinates(newCoordinates));
|
dispatch(setStageCoordinates(newCoordinates));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[stageRef, isMoveStageKeyHeld, brushSize, dispatch, stageScale, shouldInvertBrushSizeScrollDirection]
|
[stageRef, isMoveStageKeyHeld, stageScale, dispatch, brushSize]
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -65,7 +65,6 @@ const initialCanvasState: CanvasState = {
|
|||||||
shouldAutoSave: false,
|
shouldAutoSave: false,
|
||||||
shouldCropToBoundingBoxOnSave: false,
|
shouldCropToBoundingBoxOnSave: false,
|
||||||
shouldDarkenOutsideBoundingBox: false,
|
shouldDarkenOutsideBoundingBox: false,
|
||||||
shouldInvertBrushSizeScrollDirection: false,
|
|
||||||
shouldLockBoundingBox: false,
|
shouldLockBoundingBox: false,
|
||||||
shouldPreserveMaskedArea: false,
|
shouldPreserveMaskedArea: false,
|
||||||
shouldRestrictStrokesToBox: true,
|
shouldRestrictStrokesToBox: true,
|
||||||
@ -221,9 +220,6 @@ export const canvasSlice = createSlice({
|
|||||||
setShouldDarkenOutsideBoundingBox: (state, action: PayloadAction<boolean>) => {
|
setShouldDarkenOutsideBoundingBox: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldDarkenOutsideBoundingBox = action.payload;
|
state.shouldDarkenOutsideBoundingBox = action.payload;
|
||||||
},
|
},
|
||||||
setShouldInvertBrushSizeScrollDirection: (state, action: PayloadAction<boolean>) => {
|
|
||||||
state.shouldInvertBrushSizeScrollDirection = action.payload;
|
|
||||||
},
|
|
||||||
clearCanvasHistory: (state) => {
|
clearCanvasHistory: (state) => {
|
||||||
state.pastLayerStates = [];
|
state.pastLayerStates = [];
|
||||||
state.futureLayerStates = [];
|
state.futureLayerStates = [];
|
||||||
@ -292,31 +288,6 @@ export const canvasSlice = createSlice({
|
|||||||
state.shouldShowStagingImage = true;
|
state.shouldShowStagingImage = true;
|
||||||
state.batchIds = [];
|
state.batchIds = [];
|
||||||
},
|
},
|
||||||
discardStagedImage: (state) => {
|
|
||||||
const { images, selectedImageIndex } = state.layerState.stagingArea;
|
|
||||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
|
||||||
|
|
||||||
if (state.pastLayerStates.length > MAX_HISTORY) {
|
|
||||||
state.pastLayerStates.shift();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!images.length) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
images.splice(selectedImageIndex, 1);
|
|
||||||
|
|
||||||
if (selectedImageIndex >= images.length) {
|
|
||||||
state.layerState.stagingArea.selectedImageIndex = images.length - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!images.length) {
|
|
||||||
state.shouldShowStagingImage = false;
|
|
||||||
state.shouldShowStagingOutline = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
state.futureLayerStates = [];
|
|
||||||
},
|
|
||||||
addFillRect: (state) => {
|
addFillRect: (state) => {
|
||||||
const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } = state;
|
const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } = state;
|
||||||
|
|
||||||
@ -684,7 +655,6 @@ export const {
|
|||||||
commitColorPickerColor,
|
commitColorPickerColor,
|
||||||
commitStagingAreaImage,
|
commitStagingAreaImage,
|
||||||
discardStagedImages,
|
discardStagedImages,
|
||||||
discardStagedImage,
|
|
||||||
nextStagingAreaImage,
|
nextStagingAreaImage,
|
||||||
prevStagingAreaImage,
|
prevStagingAreaImage,
|
||||||
redo,
|
redo,
|
||||||
@ -704,7 +674,6 @@ export const {
|
|||||||
setShouldAutoSave,
|
setShouldAutoSave,
|
||||||
setShouldCropToBoundingBoxOnSave,
|
setShouldCropToBoundingBoxOnSave,
|
||||||
setShouldDarkenOutsideBoundingBox,
|
setShouldDarkenOutsideBoundingBox,
|
||||||
setShouldInvertBrushSizeScrollDirection,
|
|
||||||
setShouldPreserveMaskedArea,
|
setShouldPreserveMaskedArea,
|
||||||
setShouldShowBoundingBox,
|
setShouldShowBoundingBox,
|
||||||
setShouldShowCanvasDebugInfo,
|
setShouldShowCanvasDebugInfo,
|
||||||
|
@ -120,7 +120,6 @@ export interface CanvasState {
|
|||||||
shouldAutoSave: boolean;
|
shouldAutoSave: boolean;
|
||||||
shouldCropToBoundingBoxOnSave: boolean;
|
shouldCropToBoundingBoxOnSave: boolean;
|
||||||
shouldDarkenOutsideBoundingBox: boolean;
|
shouldDarkenOutsideBoundingBox: boolean;
|
||||||
shouldInvertBrushSizeScrollDirection: boolean;
|
|
||||||
shouldLockBoundingBox: boolean;
|
shouldLockBoundingBox: boolean;
|
||||||
shouldPreserveMaskedArea: boolean;
|
shouldPreserveMaskedArea: boolean;
|
||||||
shouldRestrictStrokesToBox: boolean;
|
shouldRestrictStrokesToBox: boolean;
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
|
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||||
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
|
||||||
import { useControlAdapterShouldAutoConfig } from 'features/controlAdapters/hooks/useControlAdapterShouldAutoConfig';
|
import { useControlAdapterShouldAutoConfig } from 'features/controlAdapters/hooks/useControlAdapterShouldAutoConfig';
|
||||||
import { controlAdapterAutoConfigToggled } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { controlAdapterAutoConfigToggled } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { isNil } from 'lodash-es';
|
import { isNil } from 'lodash-es';
|
||||||
@ -15,13 +14,12 @@ type Props = {
|
|||||||
const ControlAdapterShouldAutoConfig = ({ id }: Props) => {
|
const ControlAdapterShouldAutoConfig = ({ id }: Props) => {
|
||||||
const isEnabled = useControlAdapterIsEnabled(id);
|
const isEnabled = useControlAdapterIsEnabled(id);
|
||||||
const shouldAutoConfig = useControlAdapterShouldAutoConfig(id);
|
const shouldAutoConfig = useControlAdapterShouldAutoConfig(id);
|
||||||
const { modelConfig } = useControlAdapterModel(id);
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const handleShouldAutoConfigChanged = useCallback(() => {
|
const handleShouldAutoConfigChanged = useCallback(() => {
|
||||||
dispatch(controlAdapterAutoConfigToggled({ id, modelConfig }));
|
dispatch(controlAdapterAutoConfigToggled({ id }));
|
||||||
}, [id, dispatch, modelConfig]);
|
}, [id, dispatch]);
|
||||||
|
|
||||||
if (isNil(shouldAutoConfig)) {
|
if (isNil(shouldAutoConfig)) {
|
||||||
return null;
|
return null;
|
||||||
|
@ -6,6 +6,7 @@ import { useControlAdapterModel } from 'features/controlAdapters/hooks/useContro
|
|||||||
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
|
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
|
||||||
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
||||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
|
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
@ -16,21 +17,21 @@ type ParamControlAdapterModelProps = {
|
|||||||
const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||||
const isEnabled = useControlAdapterIsEnabled(id);
|
const isEnabled = useControlAdapterIsEnabled(id);
|
||||||
const controlAdapterType = useControlAdapterType(id);
|
const controlAdapterType = useControlAdapterType(id);
|
||||||
const { modelConfig } = useControlAdapterModel(id);
|
const model = useControlAdapterModel(id);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||||
|
|
||||||
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
|
const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);
|
||||||
|
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
|
(model: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
|
||||||
if (!modelConfig) {
|
if (!model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(
|
dispatch(
|
||||||
controlAdapterModelChanged({
|
controlAdapterModelChanged({
|
||||||
id,
|
id,
|
||||||
modelConfig,
|
model: getModelKeyAndBase(model),
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
@ -38,8 +39,8 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const selectedModel = useMemo(
|
const selectedModel = useMemo(
|
||||||
() => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
|
() => (model && controlAdapterType ? { ...model, model_type: controlAdapterType } : null),
|
||||||
[controlAdapterType, modelConfig]
|
[controlAdapterType, model]
|
||||||
);
|
);
|
||||||
|
|
||||||
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
|
||||||
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { type ControlAdapterType, isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
import type { ControlAdapterType } from 'features/controlAdapters/store/types';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
import { useControlAdapterModels } from './useControlAdapterModels';
|
import { useControlAdapterModels } from './useControlAdapterModels';
|
||||||
|
|
||||||
@ -13,7 +11,7 @@ export const useAddControlAdapter = (type: ControlAdapterType) => {
|
|||||||
|
|
||||||
const models = useControlAdapterModels(type);
|
const models = useControlAdapterModels(type);
|
||||||
|
|
||||||
const firstModel: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig | undefined = useMemo(() => {
|
const firstModel = useMemo(() => {
|
||||||
// prefer to use a model that matches the base model
|
// prefer to use a model that matches the base model
|
||||||
const firstCompatibleModel = models.filter((m) => (baseModel ? m.base === baseModel : true))[0];
|
const firstCompatibleModel = models.filter((m) => (baseModel ? m.base === baseModel : true))[0];
|
||||||
|
|
||||||
@ -30,26 +28,6 @@ export const useAddControlAdapter = (type: ControlAdapterType) => {
|
|||||||
if (isDisabled) {
|
if (isDisabled) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
|
||||||
(type === 'controlnet' || type === 't2i_adapter') &&
|
|
||||||
(firstModel?.type === 'controlnet' || firstModel?.type === 't2i_adapter')
|
|
||||||
) {
|
|
||||||
const defaultPreprocessor = firstModel.default_settings?.preprocessor;
|
|
||||||
const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
|
|
||||||
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
|
|
||||||
dispatch(
|
|
||||||
controlAdapterAdded({
|
|
||||||
type,
|
|
||||||
overrides: {
|
|
||||||
model: firstModel,
|
|
||||||
processorType,
|
|
||||||
processorNode,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
dispatch(
|
dispatch(
|
||||||
controlAdapterAdded({
|
controlAdapterAdded({
|
||||||
type,
|
type,
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import {
|
import {
|
||||||
@ -6,22 +5,18 @@ import {
|
|||||||
selectControlAdaptersSlice,
|
selectControlAdaptersSlice,
|
||||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
|
|
||||||
import { isControlAdapterModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
export const useControlAdapterModel = (id: string) => {
|
export const useControlAdapterModel = (id: string) => {
|
||||||
const selector = useMemo(
|
const selector = useMemo(
|
||||||
() =>
|
() =>
|
||||||
createMemoizedSelector(
|
createMemoizedSelector(
|
||||||
selectControlAdaptersSlice,
|
selectControlAdaptersSlice,
|
||||||
(controlAdapters) => selectControlAdapterById(controlAdapters, id)?.model?.key
|
(controlAdapters) => selectControlAdapterById(controlAdapters, id)?.model
|
||||||
),
|
),
|
||||||
[id]
|
[id]
|
||||||
);
|
);
|
||||||
|
|
||||||
const key = useAppSelector(selector);
|
const model = useAppSelector(selector);
|
||||||
|
|
||||||
const result = useGetModelConfigWithTypeGuard(key ?? skipToken, isControlAdapterModelConfig);
|
return model;
|
||||||
|
|
||||||
return result;
|
|
||||||
};
|
};
|
||||||
|
@ -253,3 +253,23 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
|
||||||
|
[key: string]: ControlAdapterProcessorType;
|
||||||
|
} = {
|
||||||
|
canny: 'canny_image_processor',
|
||||||
|
mlsd: 'mlsd_image_processor',
|
||||||
|
depth: 'depth_anything_image_processor',
|
||||||
|
bae: 'normalbae_image_processor',
|
||||||
|
sketch: 'pidi_image_processor',
|
||||||
|
scribble: 'lineart_image_processor',
|
||||||
|
lineart: 'lineart_image_processor',
|
||||||
|
lineart_anime: 'lineart_anime_image_processor',
|
||||||
|
softedge: 'hed_image_processor',
|
||||||
|
shuffle: 'content_shuffle_image_processor',
|
||||||
|
openpose: 'dw_openpose_image_processor',
|
||||||
|
mediapipe: 'mediapipe_face_processor',
|
||||||
|
pidi: 'pidi_image_processor',
|
||||||
|
zoe: 'zoe_depth_image_processor',
|
||||||
|
color: 'color_map_image_processor',
|
||||||
|
};
|
||||||
|
@ -3,15 +3,20 @@ import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit';
|
|||||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||||
import type { PersistConfig, RootState } from 'app/store/store';
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
|
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
|
||||||
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
|
import type {
|
||||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
ParameterControlNetModel,
|
||||||
|
ParameterIPAdapterModel,
|
||||||
|
ParameterT2IAdapterModel,
|
||||||
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
import { cloneDeep, merge, uniq } from 'lodash-es';
|
import { cloneDeep, merge, uniq } from 'lodash-es';
|
||||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
|
||||||
import { socketInvocationError } from 'services/events/actions';
|
import { socketInvocationError } from 'services/events/actions';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
import { controlAdapterImageProcessed } from './actions';
|
import { controlAdapterImageProcessed } from './actions';
|
||||||
import { CONTROLNET_PROCESSORS } from './constants';
|
import {
|
||||||
|
CONTROLNET_MODEL_DEFAULT_PROCESSORS as CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS,
|
||||||
|
CONTROLNET_PROCESSORS,
|
||||||
|
} from './constants';
|
||||||
import type {
|
import type {
|
||||||
ControlAdapterConfig,
|
ControlAdapterConfig,
|
||||||
ControlAdapterProcessorType,
|
ControlAdapterProcessorType,
|
||||||
@ -189,17 +194,15 @@ export const controlAdaptersSlice = createSlice({
|
|||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
id: string;
|
id: string;
|
||||||
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig;
|
model: ParameterControlNetModel | ParameterT2IAdapterModel | ParameterIPAdapterModel;
|
||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { id, modelConfig } = action.payload;
|
const { id, model } = action.payload;
|
||||||
const cn = selectControlAdapterById(state, id);
|
const cn = selectControlAdapterById(state, id);
|
||||||
if (!cn) {
|
if (!cn) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const model = zModelIdentifierField.parse(modelConfig);
|
|
||||||
|
|
||||||
if (!isControlNetOrT2IAdapter(cn)) {
|
if (!isControlNetOrT2IAdapter(cn)) {
|
||||||
caAdapter.updateOne(state, { id, changes: { model } });
|
caAdapter.updateOne(state, { id, changes: { model } });
|
||||||
return;
|
return;
|
||||||
@ -212,14 +215,24 @@ export const controlAdaptersSlice = createSlice({
|
|||||||
|
|
||||||
update.changes.processedControlImage = null;
|
update.changes.processedControlImage = null;
|
||||||
|
|
||||||
if (modelConfig.type === 'ip_adapter') {
|
let processorType: ControlAdapterProcessorType | undefined = undefined;
|
||||||
// should never happen...
|
|
||||||
return;
|
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
|
||||||
|
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
|
||||||
|
if (model.key.includes(modelSubstring)) {
|
||||||
|
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const processor = buildControlAdapterProcessor(modelConfig);
|
if (processorType) {
|
||||||
update.changes.processorType = processor.processorType;
|
update.changes.processorType = processorType;
|
||||||
update.changes.processorNode = processor.processorNode;
|
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
|
||||||
|
.default as RequiredControlAdapterProcessorNode;
|
||||||
|
} else {
|
||||||
|
update.changes.processorType = 'none';
|
||||||
|
update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode;
|
||||||
|
}
|
||||||
|
|
||||||
caAdapter.updateOne(state, update);
|
caAdapter.updateOne(state, update);
|
||||||
},
|
},
|
||||||
@ -311,23 +324,39 @@ export const controlAdaptersSlice = createSlice({
|
|||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
id: string;
|
id: string;
|
||||||
modelConfig?: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig;
|
|
||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { id, modelConfig } = action.payload;
|
const { id } = action.payload;
|
||||||
const cn = selectControlAdapterById(state, id);
|
const cn = selectControlAdapterById(state, id);
|
||||||
if (!cn || !isControlNetOrT2IAdapter(cn) || modelConfig?.type === 'ip_adapter') {
|
if (!cn || !isControlNetOrT2IAdapter(cn)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const update: Update<ControlNetConfig | T2IAdapterConfig, string> = {
|
const update: Update<ControlNetConfig | T2IAdapterConfig, string> = {
|
||||||
id,
|
id,
|
||||||
changes: { shouldAutoConfig: !cn.shouldAutoConfig },
|
changes: { shouldAutoConfig: !cn.shouldAutoConfig },
|
||||||
};
|
};
|
||||||
|
|
||||||
if (update.changes.shouldAutoConfig && modelConfig) {
|
if (update.changes.shouldAutoConfig) {
|
||||||
const processor = buildControlAdapterProcessor(modelConfig);
|
// manage the processor for the user
|
||||||
update.changes.processorType = processor.processorType;
|
let processorType: ControlAdapterProcessorType | undefined = undefined;
|
||||||
update.changes.processorNode = processor.processorNode;
|
|
||||||
|
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
|
||||||
|
// TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType
|
||||||
|
if (cn.model?.key.includes(modelSubstring)) {
|
||||||
|
processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (processorType) {
|
||||||
|
update.changes.processorType = processorType;
|
||||||
|
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
|
||||||
|
.default as RequiredControlAdapterProcessorNode;
|
||||||
|
} else {
|
||||||
|
update.changes.processorType = 'none';
|
||||||
|
update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
caAdapter.updateOne(state, update);
|
caAdapter.updateOne(state, update);
|
||||||
|
@ -1,10 +0,0 @@
|
|||||||
import type { ControlAdapterProcessorType, zControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
|
||||||
import type { Equals } from 'tsafe';
|
|
||||||
import { assert } from 'tsafe';
|
|
||||||
import { describe, test } from 'vitest';
|
|
||||||
import type { z } from 'zod';
|
|
||||||
|
|
||||||
describe('Control Adapter Types', () => {
|
|
||||||
test('ControlAdapterProcessorType', () =>
|
|
||||||
assert<Equals<ControlAdapterProcessorType, z.infer<typeof zControlAdapterProcessorType>>>());
|
|
||||||
});
|
|
@ -47,25 +47,6 @@ export type ControlAdapterProcessorNode =
|
|||||||
* Any ControlNet processor type
|
* Any ControlNet processor type
|
||||||
*/
|
*/
|
||||||
export type ControlAdapterProcessorType = NonNullable<ControlAdapterProcessorNode['type'] | 'none'>;
|
export type ControlAdapterProcessorType = NonNullable<ControlAdapterProcessorNode['type'] | 'none'>;
|
||||||
export const zControlAdapterProcessorType = z.enum([
|
|
||||||
'canny_image_processor',
|
|
||||||
'color_map_image_processor',
|
|
||||||
'content_shuffle_image_processor',
|
|
||||||
'depth_anything_image_processor',
|
|
||||||
'hed_image_processor',
|
|
||||||
'lineart_anime_image_processor',
|
|
||||||
'lineart_image_processor',
|
|
||||||
'mediapipe_face_processor',
|
|
||||||
'midas_depth_image_processor',
|
|
||||||
'mlsd_image_processor',
|
|
||||||
'normalbae_image_processor',
|
|
||||||
'dw_openpose_image_processor',
|
|
||||||
'pidi_image_processor',
|
|
||||||
'zoe_depth_image_processor',
|
|
||||||
'none',
|
|
||||||
]);
|
|
||||||
export const isControlAdapterProcessorType = (v: unknown): v is ControlAdapterProcessorType =>
|
|
||||||
zControlAdapterProcessorType.safeParse(v).success;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The Canny processor node, with parameters flagged as required
|
* The Canny processor node, with parameters flagged as required
|
||||||
|
@ -1,11 +0,0 @@
|
|||||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
|
||||||
import { isControlAdapterProcessorType } from 'features/controlAdapters/store/types';
|
|
||||||
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
export const buildControlAdapterProcessor = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
|
|
||||||
const defaultPreprocessor = modelConfig.default_settings?.preprocessor;
|
|
||||||
const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none';
|
|
||||||
const processorNode = CONTROLNET_PROCESSORS[processorType].default;
|
|
||||||
|
|
||||||
return { processorType, processorNode };
|
|
||||||
};
|
|
@ -6,7 +6,7 @@ const AutoAddIcon = () => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
return (
|
return (
|
||||||
<Flex position="absolute" insetInlineEnd={0} top={0} p={1}>
|
<Flex position="absolute" insetInlineEnd={0} top={0} p={1}>
|
||||||
<Badge variant="solid" bg="invokeBlue.400">
|
<Badge variant="solid" bg="invokeBlue.500">
|
||||||
{t('common.auto')}
|
{t('common.auto')}
|
||||||
</Badge>
|
</Badge>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -173,8 +173,8 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
|||||||
w="full"
|
w="full"
|
||||||
maxW="full"
|
maxW="full"
|
||||||
borderBottomRadius="base"
|
borderBottomRadius="base"
|
||||||
bg={isSelected ? 'invokeBlue.400' : 'base.600'}
|
bg={isSelected ? 'invokeBlue.500' : 'base.600'}
|
||||||
color={isSelected ? 'base.800' : 'base.100'}
|
color={isSelected ? 'base.50' : 'base.100'}
|
||||||
lineHeight="short"
|
lineHeight="short"
|
||||||
fontSize="xs"
|
fontSize="xs"
|
||||||
>
|
>
|
||||||
@ -193,7 +193,6 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
|||||||
overflow="hidden"
|
overflow="hidden"
|
||||||
textOverflow="ellipsis"
|
textOverflow="ellipsis"
|
||||||
noOfLines={1}
|
noOfLines={1}
|
||||||
color="inherit"
|
|
||||||
/>
|
/>
|
||||||
<EditableInput sx={editableInputStyles} />
|
<EditableInput sx={editableInputStyles} />
|
||||||
</Editable>
|
</Editable>
|
||||||
|
@ -109,8 +109,8 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
|||||||
w="full"
|
w="full"
|
||||||
maxW="full"
|
maxW="full"
|
||||||
borderBottomRadius="base"
|
borderBottomRadius="base"
|
||||||
bg={isSelected ? 'invokeBlue.400' : 'base.600'}
|
bg={isSelected ? 'invokeBlue.500' : 'base.600'}
|
||||||
color={isSelected ? 'base.800' : 'base.100'}
|
color={isSelected ? 'base.50' : 'base.100'}
|
||||||
lineHeight="short"
|
lineHeight="short"
|
||||||
fontSize="xs"
|
fontSize="xs"
|
||||||
fontWeight={isSelected ? 'bold' : 'normal'}
|
fontWeight={isSelected ? 'bold' : 'normal'}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import type { PersistConfig, RootState } from 'app/store/store';
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||||
import type { LoRAModelConfig } from 'services/api/types';
|
import type { LoRAModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ export const loraSlice = createSlice({
|
|||||||
initialState: initialLoraState,
|
initialState: initialLoraState,
|
||||||
reducers: {
|
reducers: {
|
||||||
loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => {
|
loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => {
|
||||||
const model = zModelIdentifierField.parse(action.payload);
|
const model = getModelKeyAndBase(action.payload);
|
||||||
state.loras[model.key] = { ...defaultLoRAConfig, model };
|
state.loras[model.key] = { ...defaultLoRAConfig, model };
|
||||||
},
|
},
|
||||||
loraRecalled: (state, action: PayloadAction<LoRA>) => {
|
loraRecalled: (state, action: PayloadAction<LoRA>) => {
|
||||||
|
@ -15,7 +15,7 @@ export const MetadataItemView = memo(
|
|||||||
return (
|
return (
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
{onRecall && <RecallButton label={label} onClick={onRecall} isDisabled={isDisabled} />}
|
{onRecall && <RecallButton label={label} onClick={onRecall} isDisabled={isDisabled} />}
|
||||||
<Flex direction={direction} fontSize="sm">
|
<Flex direction={direction}>
|
||||||
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
|
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
|
||||||
{label}:
|
{label}:
|
||||||
</Text>
|
</Text>
|
||||||
|
@ -13,13 +13,13 @@ import type {
|
|||||||
} from 'features/metadata/types';
|
} from 'features/metadata/types';
|
||||||
import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { validators } from 'features/metadata/util/validators';
|
import { validators } from 'features/metadata/util/validators';
|
||||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
|
|
||||||
import { parsers } from './parsers';
|
import { parsers } from './parsers';
|
||||||
import { recallers } from './recallers';
|
import { recallers } from './recallers';
|
||||||
|
|
||||||
const renderModelConfigValue: MetadataRenderValueFunc<ModelIdentifierField> = async (value) => {
|
const renderModelConfigValue: MetadataRenderValueFunc<ModelIdentifierWithBase> = async (value) => {
|
||||||
try {
|
try {
|
||||||
const modelConfig = await fetchModelConfig(value.key);
|
const modelConfig = await fetchModelConfig(value.key);
|
||||||
return `${modelConfig.name} (${modelConfig.base.toUpperCase()})`;
|
return `${modelConfig.name} (${modelConfig.base.toUpperCase()})`;
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { getStore } from 'app/store/nanostores/store';
|
import { getStore } from 'app/store/nanostores/store';
|
||||||
|
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||||
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
|
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
|
||||||
import { modelsApi } from 'services/api/endpoints/models';
|
import { modelsApi } from 'services/api/endpoints/models';
|
||||||
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
|
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
|
||||||
@ -104,3 +105,8 @@ export const getModelKey = async (modelIdentifier: unknown, type: ModelType, mes
|
|||||||
}
|
}
|
||||||
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
|
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const getModelKeyAndBase = (modelConfig: AnyModelConfig): ModelIdentifierWithBase => ({
|
||||||
|
key: modelConfig.key,
|
||||||
|
base: modelConfig.base,
|
||||||
|
});
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
|
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||||
import {
|
import {
|
||||||
initialControlNet,
|
initialControlNet,
|
||||||
initialIPAdapter,
|
initialIPAdapter,
|
||||||
initialT2IAdapter,
|
initialT2IAdapter,
|
||||||
} from 'features/controlAdapters/util/buildControlAdapter';
|
} from 'features/controlAdapters/util/buildControlAdapter';
|
||||||
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
|
|
||||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||||
import { defaultLoRAConfig } from 'features/lora/store/loraSlice';
|
import { defaultLoRAConfig } from 'features/lora/store/loraSlice';
|
||||||
import type {
|
import type {
|
||||||
@ -13,7 +13,12 @@ import type {
|
|||||||
T2IAdapterConfigMetadata,
|
T2IAdapterConfigMetadata,
|
||||||
} from 'features/metadata/types';
|
} from 'features/metadata/types';
|
||||||
import { fetchModelConfigWithTypeGuard, getModelKey } from 'features/metadata/util/modelFetchingHelpers';
|
import { fetchModelConfigWithTypeGuard, getModelKey } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { zControlField, zIPAdapterField, zModelIdentifierField, zT2IAdapterField } from 'features/nodes/types/common';
|
import {
|
||||||
|
zControlField,
|
||||||
|
zIPAdapterField,
|
||||||
|
zModelIdentifierWithBase,
|
||||||
|
zT2IAdapterField,
|
||||||
|
} from 'features/nodes/types/common';
|
||||||
import type {
|
import type {
|
||||||
ParameterCFGRescaleMultiplier,
|
ParameterCFGRescaleMultiplier,
|
||||||
ParameterCFGScale,
|
ParameterCFGScale,
|
||||||
@ -176,7 +181,7 @@ const parseMainModel: MetadataParseFunc<ParameterModel> = async (metadata) => {
|
|||||||
const model = await getProperty(metadata, 'model', undefined);
|
const model = await getProperty(metadata, 'model', undefined);
|
||||||
const key = await getModelKey(model, 'main');
|
const key = await getModelKey(model, 'main');
|
||||||
const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
|
const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
|
||||||
const modelIdentifier = zModelIdentifierField.parse(mainModelConfig);
|
const modelIdentifier = zModelIdentifierWithBase.parse(mainModelConfig);
|
||||||
return modelIdentifier;
|
return modelIdentifier;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -184,7 +189,7 @@ const parseRefinerModel: MetadataParseFunc<ParameterSDXLRefinerModel> = async (m
|
|||||||
const refiner_model = await getProperty(metadata, 'refiner_model', undefined);
|
const refiner_model = await getProperty(metadata, 'refiner_model', undefined);
|
||||||
const key = await getModelKey(refiner_model, 'main');
|
const key = await getModelKey(refiner_model, 'main');
|
||||||
const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
|
const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
|
||||||
const modelIdentifier = zModelIdentifierField.parse(refinerModelConfig);
|
const modelIdentifier = zModelIdentifierWithBase.parse(refinerModelConfig);
|
||||||
return modelIdentifier;
|
return modelIdentifier;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -192,7 +197,7 @@ const parseVAEModel: MetadataParseFunc<ParameterVAEModel> = async (metadata) =>
|
|||||||
const vae = await getProperty(metadata, 'vae', undefined);
|
const vae = await getProperty(metadata, 'vae', undefined);
|
||||||
const key = await getModelKey(vae, 'vae');
|
const key = await getModelKey(vae, 'vae');
|
||||||
const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
|
const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
|
||||||
const modelIdentifier = zModelIdentifierField.parse(vaeModelConfig);
|
const modelIdentifier = zModelIdentifierWithBase.parse(vaeModelConfig);
|
||||||
return modelIdentifier;
|
return modelIdentifier;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -206,7 +211,7 @@ const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
|
|||||||
const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
model: zModelIdentifierField.parse(loraModelConfig),
|
model: zModelIdentifierWithBase.parse(loraModelConfig),
|
||||||
weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight,
|
weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight,
|
||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
};
|
};
|
||||||
@ -248,12 +253,13 @@ const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (meta
|
|||||||
.catch(null)
|
.catch(null)
|
||||||
.parse(getProperty(metadataItem, 'resize_mode'));
|
.parse(getProperty(metadataItem, 'resize_mode'));
|
||||||
|
|
||||||
const { processorType, processorNode } = buildControlAdapterProcessor(controlNetModel);
|
const processorType = 'none';
|
||||||
|
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||||
|
|
||||||
const controlNet: ControlNetConfigMetadata = {
|
const controlNet: ControlNetConfigMetadata = {
|
||||||
type: 'controlnet',
|
type: 'controlnet',
|
||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
model: zModelIdentifierField.parse(controlNetModel),
|
model: zModelIdentifierWithBase.parse(controlNetModel),
|
||||||
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
|
weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight,
|
||||||
beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct,
|
beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct,
|
||||||
endStepPct: end_step_percent ?? initialControlNet.endStepPct,
|
endStepPct: end_step_percent ?? initialControlNet.endStepPct,
|
||||||
@ -299,12 +305,13 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (meta
|
|||||||
.catch(null)
|
.catch(null)
|
||||||
.parse(getProperty(metadataItem, 'resize_mode'));
|
.parse(getProperty(metadataItem, 'resize_mode'));
|
||||||
|
|
||||||
const { processorType, processorNode } = buildControlAdapterProcessor(t2iAdapterModel);
|
const processorType = 'none';
|
||||||
|
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||||
|
|
||||||
const t2iAdapter: T2IAdapterConfigMetadata = {
|
const t2iAdapter: T2IAdapterConfigMetadata = {
|
||||||
type: 't2i_adapter',
|
type: 't2i_adapter',
|
||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
model: zModelIdentifierField.parse(t2iAdapterModel),
|
model: zModelIdentifierWithBase.parse(t2iAdapterModel),
|
||||||
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
|
weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight,
|
||||||
beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct,
|
beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct,
|
||||||
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
|
endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct,
|
||||||
@ -349,7 +356,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
|||||||
id: uuidv4(),
|
id: uuidv4(),
|
||||||
type: 'ip_adapter',
|
type: 'ip_adapter',
|
||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
model: zModelIdentifierField.parse(ipAdapterModel),
|
model: zModelIdentifierWithBase.parse(ipAdapterModel),
|
||||||
controlImage: image?.image_name ?? null,
|
controlImage: image?.image_name ?? null,
|
||||||
weight: weight ?? initialIPAdapter.weight,
|
weight: weight ?? initialIPAdapter.weight,
|
||||||
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
||||||
|
@ -1,33 +0,0 @@
|
|||||||
import type { ButtonProps } from '@invoke-ai/ui-library';
|
|
||||||
import { Button } from '@invoke-ai/ui-library';
|
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
|
||||||
import { memo } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { PiArrowsClockwiseBold } from 'react-icons/pi';
|
|
||||||
|
|
||||||
import { useSyncModels } from './useSyncModels';
|
|
||||||
|
|
||||||
export const SyncModelsButton = memo((props: Omit<ButtonProps, 'aria-label'>) => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const { syncModels, isLoading } = useSyncModels();
|
|
||||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
|
||||||
|
|
||||||
if (!isSyncModelEnabled) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Button
|
|
||||||
leftIcon={<PiArrowsClockwiseBold />}
|
|
||||||
isLoading={isLoading}
|
|
||||||
onClick={syncModels}
|
|
||||||
size="sm"
|
|
||||||
variant="ghost"
|
|
||||||
{...props}
|
|
||||||
>
|
|
||||||
{t('modelManager.syncModels')}
|
|
||||||
</Button>
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
SyncModelsButton.displayName = 'SyncModelsButton';
|
|
@ -1,23 +0,0 @@
|
|||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
|
||||||
import { isNil } from 'lodash-es';
|
|
||||||
import { useMemo } from 'react';
|
|
||||||
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
|
|
||||||
import { isControlNetOrT2IAdapterModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
export const useControlNetOrT2IAdapterDefaultSettings = (modelKey?: string | null) => {
|
|
||||||
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(
|
|
||||||
modelKey ?? skipToken,
|
|
||||||
isControlNetOrT2IAdapterModelConfig
|
|
||||||
);
|
|
||||||
|
|
||||||
const defaultSettingsDefaults = useMemo(() => {
|
|
||||||
return {
|
|
||||||
preprocessor: {
|
|
||||||
isEnabled: !isNil(modelConfig?.default_settings?.preprocessor),
|
|
||||||
value: modelConfig?.default_settings?.preprocessor || 'none',
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}, [modelConfig?.default_settings]);
|
|
||||||
|
|
||||||
return { defaultSettingsDefaults, isLoading };
|
|
||||||
};
|
|
@ -1,65 +0,0 @@
|
|||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
|
||||||
import { isNil } from 'lodash-es';
|
|
||||||
import { useMemo } from 'react';
|
|
||||||
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
|
|
||||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
|
|
||||||
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision } = config.sd;
|
|
||||||
|
|
||||||
return {
|
|
||||||
initialSteps: steps.initial,
|
|
||||||
initialCfg: guidance.initial,
|
|
||||||
initialScheduler: scheduler,
|
|
||||||
initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial,
|
|
||||||
initialVaePrecision: vaePrecision,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
export const useMainModelDefaultSettings = (modelKey?: string | null) => {
|
|
||||||
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(modelKey ?? skipToken, isNonRefinerMainModelConfig);
|
|
||||||
|
|
||||||
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
|
|
||||||
useAppSelector(initialStatesSelector);
|
|
||||||
|
|
||||||
const defaultSettingsDefaults = useMemo(() => {
|
|
||||||
return {
|
|
||||||
vae: {
|
|
||||||
isEnabled: !isNil(modelConfig?.default_settings?.vae),
|
|
||||||
value: modelConfig?.default_settings?.vae || 'default',
|
|
||||||
},
|
|
||||||
vaePrecision: {
|
|
||||||
isEnabled: !isNil(modelConfig?.default_settings?.vae_precision),
|
|
||||||
value: modelConfig?.default_settings?.vae_precision || initialVaePrecision || 'fp32',
|
|
||||||
},
|
|
||||||
scheduler: {
|
|
||||||
isEnabled: !isNil(modelConfig?.default_settings?.scheduler),
|
|
||||||
value: modelConfig?.default_settings?.scheduler || initialScheduler || 'euler',
|
|
||||||
},
|
|
||||||
steps: {
|
|
||||||
isEnabled: !isNil(modelConfig?.default_settings?.steps),
|
|
||||||
value: modelConfig?.default_settings?.steps || initialSteps,
|
|
||||||
},
|
|
||||||
cfgScale: {
|
|
||||||
isEnabled: !isNil(modelConfig?.default_settings?.cfg_scale),
|
|
||||||
value: modelConfig?.default_settings?.cfg_scale || initialCfg,
|
|
||||||
},
|
|
||||||
cfgRescaleMultiplier: {
|
|
||||||
isEnabled: !isNil(modelConfig?.default_settings?.cfg_rescale_multiplier),
|
|
||||||
value: modelConfig?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}, [
|
|
||||||
modelConfig?.default_settings,
|
|
||||||
initialSteps,
|
|
||||||
initialCfg,
|
|
||||||
initialScheduler,
|
|
||||||
initialCfgRescaleMultiplier,
|
|
||||||
initialVaePrecision,
|
|
||||||
]);
|
|
||||||
|
|
||||||
return { defaultSettingsDefaults, isLoading };
|
|
||||||
};
|
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user