mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into update-textual-inversion-training
This commit is contained in:
commit
f66ead0819
4
.github/CODEOWNERS
vendored
4
.github/CODEOWNERS
vendored
@ -6,7 +6,7 @@
|
||||
/mkdocs.yml @lstein @blessedcoolant
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant @psychedelicious @brandonrising
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @lstein @blessedcoolant
|
||||
@ -22,7 +22,7 @@
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 @StAlKeR7779
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein
|
||||
|
287
docs/features/CONFIGURATION.md
Normal file
287
docs/features/CONFIGURATION.md
Normal file
@ -0,0 +1,287 @@
|
||||
---
|
||||
title: Configuration
|
||||
---
|
||||
|
||||
# :material-tune-variant: InvokeAI Configuration
|
||||
|
||||
## Intro
|
||||
|
||||
InvokeAI has numerous runtime settings which can be used to adjust
|
||||
many aspects of its operations, including the location of files and
|
||||
directories, memory usage, and performance. These settings can be
|
||||
viewed and customized in several ways:
|
||||
|
||||
1. By editing settings in the `invokeai.yaml` file.
|
||||
2. By setting environment variables.
|
||||
3. On the command-line, when InvokeAI is launched.
|
||||
|
||||
In addition, the most commonly changed settings are accessible
|
||||
graphically via the `invokeai-configure` script.
|
||||
|
||||
### How the Configuration System Works
|
||||
|
||||
When InvokeAI is launched, the very first thing it needs to do is to
|
||||
find its "root" directory, which contains its configuration files,
|
||||
installed models, its database of images, and the folder(s) of
|
||||
generated images themselves. In this document, the root directory will
|
||||
be referred to as ROOT.
|
||||
|
||||
#### Finding the Root Directory
|
||||
|
||||
To find its root directory, InvokeAI uses the following recipe:
|
||||
|
||||
1. It first looks for the argument `--root <path>` on the command line
|
||||
it was launched from, and uses the indicated path if present.
|
||||
|
||||
2. Next it looks for the environment variable INVOKEAI_ROOT, and uses
|
||||
the directory path found there if present.
|
||||
|
||||
3. If neither of these are present, then InvokeAI looks for the
|
||||
folder containing the `.venv` Python virtual environment directory for
|
||||
the currently active environment. This directory is checked for files
|
||||
expected inside the InvokeAI root before it is used.
|
||||
|
||||
4. Finally, InvokeAI looks for a directory in the current user's home
|
||||
directory named `invokeai`.
|
||||
|
||||
#### Reading the InvokeAI Configuration File
|
||||
|
||||
Once the root directory has been located, InvokeAI looks for a file
|
||||
named `ROOT/invokeai.yaml`, and if present reads configuration values
|
||||
from it. The top of this file looks like this:
|
||||
|
||||
```
|
||||
InvokeAI:
|
||||
Web Server:
|
||||
host: localhost
|
||||
port: 9090
|
||||
allow_origins: []
|
||||
allow_credentials: true
|
||||
allow_methods:
|
||||
- '*'
|
||||
allow_headers:
|
||||
- '*'
|
||||
Features:
|
||||
esrgan: true
|
||||
internet_available: true
|
||||
log_tokenization: false
|
||||
nsfw_checker: false
|
||||
patchmatch: true
|
||||
restore: true
|
||||
...
|
||||
```
|
||||
|
||||
This lines in this file are used to establish default values for
|
||||
Invoke's settings. In the above fragment, the Web Server's listening
|
||||
port is set to 9090 by the `port` setting.
|
||||
|
||||
You can edit this file with a text editor such as "Notepad" (do not
|
||||
use Word or any other word processor). When editing, be careful to
|
||||
maintain the indentation, and do not add extraneous text, as syntax
|
||||
errors will prevent InvokeAI from launching. A basic guide to the
|
||||
format of YAML files can be found
|
||||
[here](https://circleci.com/blog/what-is-yaml-a-beginner-s-guide/).
|
||||
|
||||
You can fix a broken `invokeai.yaml` by deleting it and running the
|
||||
configuration script again -- option [7] in the launcher, "Re-run the
|
||||
configure script".
|
||||
|
||||
#### Reading Environment Variables
|
||||
|
||||
Next InvokeAI looks for defined environment variables in the format
|
||||
`INVOKEAI_<setting_name>`, for example `INVOKEAI_port`. Environment
|
||||
variable values take precedence over configuration file variables. On
|
||||
a Macintosh system, for example, you could change the port that the
|
||||
web server listens on by setting the environment variable this way:
|
||||
|
||||
```
|
||||
export INVOKEAI_port=8000
|
||||
invokeai-web
|
||||
```
|
||||
|
||||
Please check out these
|
||||
[Macintosh](https://phoenixnap.com/kb/set-environment-variable-mac)
|
||||
and
|
||||
[Windows](https://phoenixnap.com/kb/windows-set-environment-variable)
|
||||
guides for setting temporary and permanent environment variables.
|
||||
|
||||
#### Reading the Command Line
|
||||
|
||||
Lastly, InvokeAI takes settings from the command line, which override
|
||||
everything else. The command-line settings have the same name as the
|
||||
corresponding configuration file settings, preceded by a `--`, for
|
||||
example `--port 8000`.
|
||||
|
||||
If you are using the launcher (`invoke.sh` or `invoke.bat`) to launch
|
||||
InvokeAI, then just pass the command-line arguments to the launcher:
|
||||
|
||||
```
|
||||
invoke.bat --port 8000 --host 0.0.0.0
|
||||
```
|
||||
|
||||
The arguments will be applied when you select the web server option
|
||||
(and the other options as well).
|
||||
|
||||
If, on the other hand, you prefer to launch InvokeAI directly from the
|
||||
command line, you would first activate the virtual environment (known
|
||||
as the "developer's console" in the launcher), and run `invokeai-web`:
|
||||
|
||||
```
|
||||
> C:\Users\Fred\invokeai\.venv\scripts\activate
|
||||
(.venv) > invokeai-web --port 8000 --host 0.0.0.0
|
||||
```
|
||||
|
||||
You can get a listing and brief instructions for each of the
|
||||
command-line options by giving the `--help` argument:
|
||||
|
||||
```
|
||||
(.venv) > invokeai-web --help
|
||||
usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS ...]] [--allow_credentials | --no-allow_credentials]
|
||||
[--allow_methods [ALLOW_METHODS ...]] [--allow_headers [ALLOW_HEADERS ...]] [--esrgan | --no-esrgan]
|
||||
[--internet_available | --no-internet_available] [--log_tokenization | --no-log_tokenization]
|
||||
[--nsfw_checker | --no-nsfw_checker] [--patchmatch | --no-patchmatch] [--restore | --no-restore]
|
||||
[--always_use_cpu | --no-always_use_cpu] [--free_gpu_mem | --no-free_gpu_mem] [--max_cache_size MAX_CACHE_SIZE]
|
||||
[--max_vram_cache_size MAX_VRAM_CACHE_SIZE] [--precision {auto,float16,float32,autocast}]
|
||||
[--sequential_guidance | --no-sequential_guidance] [--xformers_enabled | --no-xformers_enabled]
|
||||
[--tiled_decode | --no-tiled_decode] [--root ROOT] [--autoimport_dir AUTOIMPORT_DIR] [--lora_dir LORA_DIR]
|
||||
[--embedding_dir EMBEDDING_DIR] [--controlnet_dir CONTROLNET_DIR] [--conf_path CONF_PATH] [--models_dir MODELS_DIR]
|
||||
[--legacy_conf_dir LEGACY_CONF_DIR] [--db_dir DB_DIR] [--outdir OUTDIR] [--from_file FROM_FILE]
|
||||
[--use_memory_db | --no-use_memory_db] [--model MODEL] [--log_handlers [LOG_HANDLERS ...]]
|
||||
[--log_format {plain,color,syslog,legacy}] [--log_level {debug,info,warning,error,critical}]
|
||||
...
|
||||
```
|
||||
|
||||
## The Configuration Settings
|
||||
|
||||
The configuration settings are divided into several distinct
|
||||
groups in `invokeia.yaml`:
|
||||
|
||||
### Web Server
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|----------|----------------|--------------|
|
||||
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
|
||||
| `port` | `9090` | Network port number that the web server will listen on |
|
||||
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
|
||||
| `allow_credentials | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
|
||||
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
|
||||
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
|
||||
|
||||
The documentation for InvokeAI's API can be accessed by browsing to the following URL: [http://localhost:9090/docs].
|
||||
|
||||
### Features
|
||||
|
||||
These configuration settings allow you to enable and disable various InvokeAI features:
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|----------|----------------|--------------|
|
||||
| `esrgan` | `true` | Activate the ESRGAN upscaling options|
|
||||
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
|
||||
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
|
||||
| `nsfw_checker` | `true` | Activate the NSFW checker to blur out risque images |
|
||||
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
|
||||
| `restore` | `true` | Activate the facial restoration features (DEPRECATED; restoration features will be removed in 3.0.0) |
|
||||
|
||||
### Memory/Performance
|
||||
|
||||
These options tune InvokeAI's memory and performance characteristics.
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|----------|----------------|--------------|
|
||||
| `always_use_cpu` | `false` | Use the CPU to generate images, even if a GPU is available |
|
||||
| `free_gpu_mem` | `false` | Aggressively free up GPU memory after each operation; this will allow you to run in low-VRAM environments with some performance penalties |
|
||||
| `max_cache_size` | `6` | Amount of CPU RAM (in GB) to reserve for caching models in memory; more cache allows you to keep models in memory and switch among them quickly |
|
||||
| `max_vram_cache_size` | `2.75` | Amount of GPU VRAM (in GB) to reserve for caching models in VRAM; more cache speeds up generation but reduces the size of the images that can be generated. This can be set to zero to maximize the amount of memory available for generation. |
|
||||
| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system |
|
||||
| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss |
|
||||
| `xformers_enabled` | `true` | If the x-formers memory-efficient attention module is installed, activate it for better memory usage and generation speed|
|
||||
| `tiled_decode` | `false` | If true, then during the VAE decoding phase the image will be decoded a section at a time, reducing memory consumption at the cost of a performance hit |
|
||||
|
||||
### Paths
|
||||
|
||||
These options set the paths of various directories and files used by
|
||||
InvokeAI. Relative paths are interpreted relative to INVOKEAI_ROOT, so
|
||||
if INVOKEAI_ROOT is `/home/fred/invokeai` and the path is
|
||||
`autoimport/main`, then the corresponding directory will be located at
|
||||
`/home/fred/invokeai/autoimport/main`.
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|----------|----------------|--------------|
|
||||
| `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory |
|
||||
| `lora_dir` | `autoimport/lora` | At startup time, read and import any LoRA/LyCORIS models found in this directory |
|
||||
| `embedding_dir` | `autoimport/embedding` | At startup time, read and import any textual inversion (embedding) models found in this directory |
|
||||
| `controlnet_dir` | `autoimport/controlnet` | At startup time, read and import any ControlNet models found in this directory |
|
||||
| `conf_path` | `configs/models.yaml` | Location of the `models.yaml` model configuration file |
|
||||
| `models_dir` | `models` | Location of the directory containing models installed by InvokeAI's model manager |
|
||||
| `legacy_conf_dir` | `configs/stable-diffusion` | Location of the directory containing the .yaml configuration files for legacy checkpoint models |
|
||||
| `db_dir` | `databases` | Location of the directory containing InvokeAI's image, schema and session database |
|
||||
| `outdir` | `outputs` | Location of the directory in which the gallery of generated and uploaded images will be stored |
|
||||
| `use_memory_db` | `false` | Keep database information in memory rather than on disk; this will not preserve image gallery information across restarts |
|
||||
|
||||
Note that the autoimport directories will be searched recursively,
|
||||
allowing you to organize the models into folders and subfolders in any
|
||||
way you wish. In addition, while we have split up autoimport
|
||||
directories by the type of model they contain, this isn't
|
||||
necessary. You can combine different model types in the same folder
|
||||
and InvokeAI will figure out what they are. So you can easily use just
|
||||
one autoimport directory by commenting out the unneeded paths:
|
||||
|
||||
```
|
||||
Paths:
|
||||
autoimport_dir: autoimport
|
||||
# lora_dir: null
|
||||
# embedding_dir: null
|
||||
# controlnet_dir: null
|
||||
```
|
||||
|
||||
### Logging
|
||||
|
||||
These settings control the information, warning, and debugging
|
||||
messages printed to the console log while InvokeAI is running:
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|----------|----------------|--------------|
|
||||
| `log_handlers` | `console` | This controls where log messages are sent, and can be a list of one or more destinations. Values include `console`, `file`, `syslog` and `http`. These are described in more detail below |
|
||||
| `log_format` | `color` | This controls the formatting of the log messages. Values are `plain`, `color`, `legacy` and `syslog` |
|
||||
| `log_level` | `debug` | This filters messages according to the level of severity and can be one of `debug`, `info`, `warning`, `error` and `critical`. For example, setting to `warning` will display all messages at the warning level or higher, but won't display "debug" or "info" messages |
|
||||
|
||||
Several different log handler destinations are available, and multiple destinations are supported by providing a list:
|
||||
|
||||
```
|
||||
log_handlers:
|
||||
- console
|
||||
- syslog=localhost
|
||||
- file=/var/log/invokeai.log
|
||||
```
|
||||
|
||||
* `console` is the default. It prints log messages to the command-line window from which InvokeAI was launched.
|
||||
|
||||
* `syslog` is only available on Linux and Macintosh systems. It uses
|
||||
the operating system's "syslog" facility to write log file entries
|
||||
locally or to a remote logging machine. `syslog` offers a variety
|
||||
of configuration options:
|
||||
|
||||
```
|
||||
syslog=/dev/log` - log to the /dev/log device
|
||||
syslog=localhost` - log to the network logger running on the local machine
|
||||
syslog=localhost:512` - same as above, but using a non-standard port
|
||||
syslog=fredserver,facility=LOG_USER,socktype=SOCK_DRAM`
|
||||
- Log to LAN-connected server "fredserver" using the facility LOG_USER and datagram packets.
|
||||
```
|
||||
|
||||
* `http` can be used to log to a remote web server. The server must be
|
||||
properly configured to receive and act on log messages. The option
|
||||
accepts the URL to the web server, and a `method` argument
|
||||
indicating whether the message should be submitted using the GET or
|
||||
POST method.
|
||||
|
||||
```
|
||||
http=http://my.server/path/to/logger,method=POST
|
||||
```
|
||||
|
||||
The `log_format` option provides several alternative formats:
|
||||
|
||||
* `color` - default format providing time, date and a message, using text colors to distinguish different log severities
|
||||
* `plain` - same as above, but monochrome text only
|
||||
* `syslog` - the log level and error message only, allowing the syslog system to attach the time and date
|
||||
* `legacy` - a format similar to the one used by the legacy 2.3 InvokeAI releases.
|
@ -153,6 +153,9 @@ This method is recommended for those familiar with running Docker containers
|
||||
- [Prompt Syntax](features/PROMPTS.md)
|
||||
- [Generating Variations](features/VARIATIONS.md)
|
||||
|
||||
### InvokeAI Configuration
|
||||
- [Guide to InvokeAI Runtime Settings](features/CONFIGURATION.md)
|
||||
|
||||
## :octicons-log-16: Important Changes Since Version 2.3
|
||||
|
||||
### Nodes
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
|
||||
|
||||
|
||||
import pathlib
|
||||
from typing import Literal, List, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
@ -22,6 +23,7 @@ UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
@ -78,7 +80,7 @@ async def update_model(
|
||||
return model_response
|
||||
|
||||
@models_router.post(
|
||||
"/",
|
||||
"/import",
|
||||
operation_id="import_model",
|
||||
responses= {
|
||||
201: {"description" : "The model imported successfully"},
|
||||
@ -94,7 +96,7 @@ async def import_model(
|
||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using its local path, repo_id, or remote URL """
|
||||
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
|
||||
|
||||
items_to_import = {location}
|
||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||
@ -126,18 +128,100 @@ async def import_model(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
@models_router.post(
|
||||
"/add",
|
||||
operation_id="add_model",
|
||||
responses= {
|
||||
201: {"description" : "The model added successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
|
||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
)
|
||||
async def add_model(
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
info.model_name,
|
||||
info.base_model,
|
||||
info.model_type,
|
||||
model_attributes = info.dict()
|
||||
)
|
||||
logger.info(f'Successfully added {info.model_name}')
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.model_name,
|
||||
base_model=info.base_model,
|
||||
model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
except KeyError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
@models_router.post(
|
||||
"/rename/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="rename_model",
|
||||
responses= {
|
||||
201: {"description" : "The model was renamed successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
409: {"description" : "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
)
|
||||
async def rename_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="current model name"),
|
||||
new_name: Optional[str] = Query(description="new model name", default=None),
|
||||
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
|
||||
) -> ImportModelResponse:
|
||||
""" Rename a model"""
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.model_manager.rename_model(
|
||||
base_model = base_model,
|
||||
model_type = model_type,
|
||||
model_name = model_name,
|
||||
new_name = new_name,
|
||||
new_base = new_base,
|
||||
)
|
||||
logger.debug(result)
|
||||
logger.info(f'Successfully renamed {model_name}=>{new_name}')
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=new_name or model_name,
|
||||
base_model=new_base or base_model,
|
||||
model_type=model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
except KeyError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
@models_router.delete(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="del_model",
|
||||
responses={
|
||||
204: {
|
||||
"description": "Model deleted successfully"
|
||||
},
|
||||
404: {
|
||||
"description": "Model not found"
|
||||
}
|
||||
204: { "description": "Model deleted successfully" },
|
||||
404: { "description": "Model not found" }
|
||||
},
|
||||
status_code = 204,
|
||||
response_model = None,
|
||||
)
|
||||
async def delete_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
@ -173,14 +257,17 @@ async def convert_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
|
||||
) -> ConvertModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Converting model: {model_name}")
|
||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type
|
||||
model_type = model_type,
|
||||
convert_dest_directory = dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
|
||||
base_model = base_model,
|
||||
@ -192,6 +279,53 @@ async def convert_model(
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
||||
@models_router.get(
|
||||
"/search",
|
||||
operation_id="search_for_models",
|
||||
responses={
|
||||
200: { "description": "Directory searched successfully" },
|
||||
404: { "description": "Invalid directory path" },
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = List[pathlib.Path]
|
||||
)
|
||||
async def search_for_models(
|
||||
search_path: pathlib.Path = Query(description="Directory path to search for models")
|
||||
)->List[pathlib.Path]:
|
||||
if not search_path.is_dir():
|
||||
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
|
||||
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
|
||||
|
||||
@models_router.get(
|
||||
"/ckpt_confs",
|
||||
operation_id="list_ckpt_configs",
|
||||
responses={
|
||||
200: { "description" : "paths retrieved successfully" },
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = List[pathlib.Path]
|
||||
)
|
||||
async def list_ckpt_configs(
|
||||
)->List[pathlib.Path]:
|
||||
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
||||
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/sync",
|
||||
operation_id="sync_to_config",
|
||||
responses={
|
||||
201: { "description": "synchronization successful" },
|
||||
},
|
||||
status_code = 201,
|
||||
response_model = None
|
||||
)
|
||||
async def sync_to_config(
|
||||
)->None:
|
||||
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||
in-memory data structures with disk data structures."""
|
||||
return ApiDependencies.invoker.services.model_manager.sync_to_config()
|
||||
|
||||
@models_router.put(
|
||||
"/merge/{base_model}",
|
||||
operation_id="merge_models",
|
||||
@ -210,17 +344,21 @@ async def merge_models(
|
||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
|
||||
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
|
||||
) -> MergeModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {model_names}")
|
||||
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
||||
base_model,
|
||||
merged_model_name or "+".join(model_names),
|
||||
alpha,
|
||||
interp,
|
||||
force)
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory = dest
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
||||
base_model = base_model,
|
||||
model_type = ModelType.Main,
|
||||
|
@ -9,6 +9,7 @@ from typing import Literal, Optional, Union, List, Dict
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType
|
||||
from ..models.image import ImageField, ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -105,9 +106,15 @@ CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control
|
||||
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
|
||||
|
||||
|
||||
class ControlNetModelField(BaseModel):
|
||||
"""ControlNet model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the ControlNet model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
|
||||
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
|
||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
@ -118,15 +125,15 @@ class ControlField(BaseModel):
|
||||
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@validator("control_weight")
|
||||
def abs_le_one(cls, v):
|
||||
"""validate that all abs(values) are <=1"""
|
||||
def validate_control_weight(cls, v):
|
||||
"""Validate that all control weights in the valid range"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if abs(i) > 1:
|
||||
raise ValueError('all abs(control_weight) must be <= 1')
|
||||
if i < -1 or i > 2:
|
||||
raise ValueError('Control weights must be within -1 to 2 range')
|
||||
else:
|
||||
if abs(v) > 1:
|
||||
raise ValueError('abs(control_weight) must be <= 1')
|
||||
if v < -1 or v > 2:
|
||||
raise ValueError('Control weights must be within -1 to 2 range')
|
||||
return v
|
||||
class Config:
|
||||
schema_extra = {
|
||||
@ -134,6 +141,7 @@ class ControlField(BaseModel):
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
"control_weight": "float",
|
||||
"control_model": "controlnet_model",
|
||||
# "control_weight": "number",
|
||||
}
|
||||
}
|
||||
@ -154,10 +162,10 @@ class ControlNetInvocation(BaseInvocation):
|
||||
type: Literal["controlnet"] = "controlnet"
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
||||
control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
|
||||
description="control model used")
|
||||
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
begin_step_percent: float = Field(default=0, ge=-1, le=2,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
@ -11,6 +12,7 @@ from pydantic import BaseModel, Field, validator
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.models.base import ModelType
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
@ -71,16 +73,21 @@ def get_scheduler(
|
||||
scheduler_name: str,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
|
||||
scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
scheduler_name, SCHEDULER_MAP['ddim']
|
||||
)
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
**scheduler_info.dict())
|
||||
**scheduler_info.dict()
|
||||
)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **
|
||||
scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler_config = {
|
||||
**scheduler_config,
|
||||
**scheduler_extra_config,
|
||||
"_backup": scheduler_config,
|
||||
}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
@ -137,8 +144,11 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState) -> None:
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
@ -147,11 +157,16 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def get_conditioning_data(
|
||||
self, context: InvocationContext, scheduler) -> ConditioningData:
|
||||
self,
|
||||
context: InvocationContext,
|
||||
scheduler,
|
||||
) -> ConditioningData:
|
||||
c, extra_conditioning_info = context.services.latents.get(
|
||||
self.positive_conditioning.conditioning_name)
|
||||
self.positive_conditioning.conditioning_name
|
||||
)
|
||||
uc, _ = context.services.latents.get(
|
||||
self.negative_conditioning.conditioning_name)
|
||||
self.negative_conditioning.conditioning_name
|
||||
)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
unconditioned_embeddings=uc,
|
||||
@ -178,7 +193,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
return conditioning_data
|
||||
|
||||
def create_pipeline(
|
||||
self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
||||
self,
|
||||
unet,
|
||||
scheduler,
|
||||
) -> StableDiffusionGeneratorPipeline:
|
||||
# TODO:
|
||||
# configure_model_padding(
|
||||
# unet,
|
||||
@ -213,6 +231,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
model: StableDiffusionGeneratorPipeline,
|
||||
control_input: List[ControlField],
|
||||
latents_shape: List[int],
|
||||
exit_stack: ExitStack,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[ControlNetData]:
|
||||
|
||||
@ -238,25 +257,19 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
control_data = []
|
||||
control_models = []
|
||||
for control_info in control_list:
|
||||
# handle control models
|
||||
if ("," in control_info.control_model):
|
||||
control_model_split = control_info.control_model.split(",")
|
||||
control_name = control_model_split[0]
|
||||
control_subfolder = control_model_split[1]
|
||||
print("Using HF model subfolders")
|
||||
print(" control_name: ", control_name)
|
||||
print(" control_subfolder: ", control_subfolder)
|
||||
control_model = ControlNetModel.from_pretrained(
|
||||
control_name, subfolder=control_subfolder,
|
||||
torch_dtype=model.unet.dtype).to(
|
||||
model.device)
|
||||
else:
|
||||
control_model = ControlNetModel.from_pretrained(
|
||||
control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
|
||||
control_model = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=control_info.control_model.model_name,
|
||||
model_type=ModelType.ControlNet,
|
||||
base_model=control_info.control_model.base_model,
|
||||
)
|
||||
)
|
||||
|
||||
control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(
|
||||
control_image_field.image_name)
|
||||
control_image_field.image_name
|
||||
)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
@ -278,7 +291,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent,
|
||||
control_mode=control_info.control_mode,)
|
||||
control_mode=control_info.control_mode,
|
||||
)
|
||||
control_data.append(control_item)
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
return control_data
|
||||
@ -289,7 +303,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id)
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
@ -298,14 +313,17 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
**lora.dict(exclude={"weight"})
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict())
|
||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
**self.unet.unet.dict()
|
||||
)
|
||||
with ExitStack() as exit_stack,\
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
|
||||
scheduler = get_scheduler(
|
||||
@ -322,6 +340,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
@ -374,7 +393,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id)
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
@ -383,14 +403,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
**lora.dict(exclude={"weight"})
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict())
|
||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
**self.unet.unet.dict()
|
||||
)
|
||||
with ExitStack() as exit_stack,\
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
|
||||
scheduler = get_scheduler(
|
||||
@ -407,11 +430,13 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
latent, device=unet.device, dtype=latent.dtype)
|
||||
latent, device=unet.device, dtype=latent.dtype
|
||||
)
|
||||
|
||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||
self.steps,
|
||||
@ -535,7 +560,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents, size=(self.height // 8, self.width // 8),
|
||||
mode=self.mode, antialias=self.antialias
|
||||
if self.mode in ["bilinear", "bicubic"] else False,)
|
||||
if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
@ -569,7 +595,8 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents, scale_factor=self.scale_factor, mode=self.mode,
|
||||
antialias=self.antialias
|
||||
if self.mode in ["bilinear", "bicubic"] else False,)
|
||||
if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -271,8 +271,8 @@ class InvokeAISettings(BaseSettings):
|
||||
|
||||
@classmethod
|
||||
def _excluded(self)->List[str]:
|
||||
# combination of deprecated parameters and internal ones
|
||||
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version']
|
||||
# combination of deprecated parameters and internal ones that shouldn't be exposed
|
||||
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'root']
|
||||
|
||||
class Config:
|
||||
env_file_encoding = 'utf-8'
|
||||
|
@ -19,7 +19,7 @@ from invokeai.backend.model_management import (
|
||||
ModelMerger,
|
||||
MergeInterpolationMethod,
|
||||
)
|
||||
|
||||
from invokeai.backend.model_management.model_search import FindModels
|
||||
|
||||
import torch
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
@ -167,6 +167,27 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rename_model(self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_checkpoint_configs(
|
||||
self
|
||||
)->List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
@ -220,6 +241,7 @@ class ModelManagerServiceBase(ABC):
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
@ -228,6 +250,23 @@ class ModelManagerServiceBase(ABC):
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_for_models(self, directory: Path)->List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -431,16 +470,18 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
as well.
|
||||
"""
|
||||
self.logger.debug(f'delete model {model_name}')
|
||||
self.mgr.del_model(model_name, base_model, model_type)
|
||||
self.mgr.commit()
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
@ -449,13 +490,14 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
self.logger.debug(f'convert model {model_name}')
|
||||
return self.mgr.convert_model(model_name, base_model, model_type)
|
||||
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
||||
|
||||
def commit(self, conf_file: Optional[Path]=None):
|
||||
"""
|
||||
@ -536,6 +578,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
@ -544,6 +587,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
merger = ModelMerger(self.mgr)
|
||||
try:
|
||||
@ -554,7 +598,55 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
alpha = alpha,
|
||||
interp = interp,
|
||||
force = force,
|
||||
merge_dest_directory=merge_dest_directory,
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise ValueError(e)
|
||||
return result
|
||||
|
||||
def search_for_models(self, directory: Path)->List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
search = FindModels(directory,self.logger)
|
||||
return search.list_models()
|
||||
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
return self.mgr.sync_to_config()
|
||||
|
||||
def list_checkpoint_configs(self)->List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
config = self.mgr.app_config
|
||||
conf_path = config.legacy_conf_path
|
||||
root_path = config.root_path
|
||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')]
|
||||
|
||||
def rename_model(self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str = None,
|
||||
new_base: BaseModelType = None,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model. Can provide a new name and/or a new base.
|
||||
:param model_name: Current name of the model
|
||||
:param base_model: Current base of the model
|
||||
:param model_type: Model type (can't be changed)
|
||||
:param new_name: New name for the model
|
||||
:param new_base: New base for the model
|
||||
"""
|
||||
self.mgr.rename_model(base_model = base_model,
|
||||
model_type = model_type,
|
||||
model_name = model_name,
|
||||
new_name = new_name,
|
||||
new_base = new_base,
|
||||
)
|
||||
|
||||
|
@ -593,9 +593,12 @@ script, which will perform a full upgrade in place."""
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args(['--root',str(dest_root)])
|
||||
|
||||
# TODO: revisit
|
||||
# assert (dest_root / 'models').is_dir(), f"{dest_root} does not contain a 'models' subdirectory"
|
||||
# assert (dest_root / 'invokeai.yaml').exists(), f"{dest_root} does not contain an InvokeAI init file."
|
||||
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
|
||||
dest_is_setup = (dest_root / 'models/core').exists() and (dest_root / 'databases').exists()
|
||||
if not dest_is_setup:
|
||||
import invokeai.frontend.install.invokeai_configure
|
||||
from invokeai.backend.install.invokeai_configure import initialize_rootdir
|
||||
initialize_rootdir(dest_root, True)
|
||||
|
||||
do_migrate(src_root,dest_root)
|
||||
|
||||
|
@ -71,8 +71,6 @@ class ModelInstallList:
|
||||
class InstallSelections():
|
||||
install_models: List[str]= field(default_factory=list)
|
||||
remove_models: List[str]=field(default_factory=list)
|
||||
# scan_directory: Path = None
|
||||
# autoscan_on_startup: bool=False
|
||||
|
||||
@dataclass
|
||||
class ModelLoadInfo():
|
||||
|
@ -247,6 +247,7 @@ import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util import CUDA_DEVICE, Chdir
|
||||
from .model_cache import ModelCache, ModelLocker
|
||||
from .model_search import ModelSearch
|
||||
from .models import (
|
||||
BaseModelType, ModelType, SubModelType,
|
||||
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
||||
@ -323,15 +324,6 @@ class ModelManager(object):
|
||||
# TODO: metadata not found
|
||||
# TODO: version check
|
||||
|
||||
self.models = dict()
|
||||
for model_key, model_config in config.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
# alias for config file
|
||||
model_config["model_format"] = model_config.pop("format")
|
||||
self.models[model_key] = model_class.create_config(**model_config)
|
||||
|
||||
# check config version number and update on disk/RAM if necessary
|
||||
self.app_config = InvokeAIAppConfig.get_config()
|
||||
self.logger = logger
|
||||
self.cache = ModelCache(
|
||||
@ -342,11 +334,41 @@ class ModelManager(object):
|
||||
sequential_offload = sequential_offload,
|
||||
logger = logger,
|
||||
)
|
||||
|
||||
self._read_models(config)
|
||||
|
||||
def _read_models(self, config: Optional[DictConfig] = None):
|
||||
if not config:
|
||||
if self.config_path:
|
||||
config = OmegaConf.load(self.config_path)
|
||||
else:
|
||||
return
|
||||
|
||||
self.models = dict()
|
||||
for model_key, model_config in config.items():
|
||||
if model_key.startswith('_'):
|
||||
continue
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
# alias for config file
|
||||
model_config["model_format"] = model_config.pop("format")
|
||||
self.models[model_key] = model_class.create_config(**model_config)
|
||||
|
||||
# check config version number and update on disk/RAM if necessary
|
||||
self.cache_keys = dict()
|
||||
|
||||
# add controlnet, lora and textual_inversion models from disk
|
||||
self.scan_models_directory()
|
||||
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Call this when `models.yaml` has been changed externally.
|
||||
This will reinitialize internal data structures
|
||||
"""
|
||||
# Reread models directory; note that this will reinitialize the cache,
|
||||
# causing otherwise unreferenced models to be removed from memory
|
||||
self._read_models()
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -527,7 +549,10 @@ class ModelManager(object):
|
||||
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
|
||||
models = []
|
||||
for model_key in model_keys:
|
||||
model_config = self.models[model_key]
|
||||
model_config = self.models.get(model_key)
|
||||
if not model_config:
|
||||
self.logger.error(f'Unknown model {model_name}')
|
||||
raise KeyError(f'Unknown model {model_name}')
|
||||
|
||||
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||
if base_model is not None and cur_base_model != base_model:
|
||||
@ -646,11 +671,61 @@ class ModelManager(object):
|
||||
config = model_config,
|
||||
)
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str = None,
|
||||
new_base: BaseModelType = None,
|
||||
):
|
||||
'''
|
||||
Rename or rebase a model.
|
||||
'''
|
||||
if new_name is None and new_base is None:
|
||||
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
|
||||
return
|
||||
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
model_cfg = self.models.get(model_key, None)
|
||||
if not model_cfg:
|
||||
raise KeyError(f"Unknown model: {model_key}")
|
||||
|
||||
old_path = self.app_config.root_path / model_cfg.path
|
||||
new_name = new_name or model_name
|
||||
new_base = new_base or base_model
|
||||
new_key = self.create_key(new_name, new_base, model_type)
|
||||
if new_key in self.models:
|
||||
raise ValueError(f'Attempt to overwrite existing model definition "{new_key}"')
|
||||
|
||||
# if this is a model file/directory that we manage ourselves, we need to move it
|
||||
if old_path.is_relative_to(self.app_config.models_path):
|
||||
new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name
|
||||
move(old_path, new_path)
|
||||
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
|
||||
|
||||
# clean up caches
|
||||
old_model_cache = self._get_model_cache_path(old_path)
|
||||
if old_model_cache.exists():
|
||||
if old_model_cache.is_dir():
|
||||
rmtree(str(old_model_cache))
|
||||
else:
|
||||
old_model_cache.unlink()
|
||||
|
||||
cache_ids = self.cache_keys.pop(model_key, [])
|
||||
for cache_id in cache_ids:
|
||||
self.cache.uncache_model(cache_id)
|
||||
|
||||
self.models.pop(model_key, None) # delete
|
||||
self.models[new_key] = model_cfg
|
||||
self.commit()
|
||||
|
||||
def convert_model (
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
dest_directory: Optional[Path]=None,
|
||||
) -> AddModelResult:
|
||||
'''
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
@ -677,14 +752,14 @@ class ModelManager(object):
|
||||
)
|
||||
checkpoint_path = self.app_config.root_path / info["path"]
|
||||
old_diffusers_path = self.app_config.models_path / model.location
|
||||
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
|
||||
new_diffusers_path = (dest_directory or self.app_config.models_path / base_model.value / model_type.value) / model_name
|
||||
if new_diffusers_path.exists():
|
||||
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
||||
|
||||
try:
|
||||
move(old_diffusers_path,new_diffusers_path)
|
||||
info["model_format"] = "diffusers"
|
||||
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path))
|
||||
info["path"] = str(new_diffusers_path) if dest_directory else str(new_diffusers_path.relative_to(self.app_config.root_path))
|
||||
info.pop('config')
|
||||
|
||||
result = self.add_model(model_name, base_model, model_type,
|
||||
@ -824,6 +899,7 @@ class ModelManager(object):
|
||||
if (new_models_found or imported_models) and self.config_path:
|
||||
self.commit()
|
||||
|
||||
|
||||
def autoimport(self)->Dict[str, AddModelResult]:
|
||||
'''
|
||||
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
||||
@ -832,62 +908,41 @@ class ModelManager(object):
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
||||
|
||||
|
||||
class ScanAndImport(ModelSearch):
|
||||
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
|
||||
super().__init__(directories, logger)
|
||||
self.installer = installer
|
||||
self.ignore = ignore
|
||||
|
||||
def on_search_started(self):
|
||||
self.new_models_found = dict()
|
||||
|
||||
def on_model_found(self, model: Path):
|
||||
if model not in self.ignore:
|
||||
self.new_models_found.update(self.installer.heuristic_import(model))
|
||||
|
||||
def on_search_completed(self):
|
||||
self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models')
|
||||
|
||||
def models_found(self):
|
||||
return self.new_models_found
|
||||
|
||||
|
||||
installer = ModelInstall(config = self.app_config,
|
||||
model_manager = self,
|
||||
prediction_type_helper = ask_user_for_prediction_type,
|
||||
)
|
||||
|
||||
scanned_dirs = set()
|
||||
|
||||
config = self.app_config
|
||||
known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()}
|
||||
|
||||
for autodir in [config.autoimport_dir,
|
||||
config.lora_dir,
|
||||
config.embedding_dir,
|
||||
config.controlnet_dir]:
|
||||
if autodir is None:
|
||||
continue
|
||||
|
||||
self.logger.info(f'Scanning {autodir} for models to import')
|
||||
installed = dict()
|
||||
|
||||
autodir = self.app_config.root_path / autodir
|
||||
if not autodir.exists():
|
||||
continue
|
||||
|
||||
items_scanned = 0
|
||||
new_models_found = dict()
|
||||
|
||||
for root, dirs, files in os.walk(autodir):
|
||||
items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path in known_paths or path.parent in scanned_dirs:
|
||||
scanned_dirs.add(path)
|
||||
continue
|
||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
||||
try:
|
||||
new_models_found.update(installer.heuristic_import(path))
|
||||
scanned_dirs.add(path)
|
||||
except ValueError as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path in known_paths or path.parent in scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||
try:
|
||||
import_result = installer.heuristic_import(path)
|
||||
new_models_found.update(import_result)
|
||||
except ValueError as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
||||
installed.update(new_models_found)
|
||||
|
||||
return installed
|
||||
known_paths = {config.root_path / x['path'] for x in self.list_models()}
|
||||
directories = {config.root_path / x for x in [config.autoimport_dir,
|
||||
config.lora_dir,
|
||||
config.embedding_dir,
|
||||
config.controlnet_dir]
|
||||
}
|
||||
scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
|
||||
scanner.search()
|
||||
return scanner.models_found()
|
||||
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
@ -925,3 +980,4 @@ class ModelManager(object):
|
||||
successfully_installed.update(installed)
|
||||
self.commit()
|
||||
return successfully_installed
|
||||
|
||||
|
@ -11,7 +11,7 @@ from enum import Enum
|
||||
from pathlib import Path
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
@ -74,6 +74,7 @@ class ModelMerger(object):
|
||||
alpha: float = 0.5,
|
||||
interp: MergeInterpolationMethod = None,
|
||||
force: bool = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
**kwargs,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
@ -85,7 +86,7 @@ class ModelMerger(object):
|
||||
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
@ -111,7 +112,7 @@ class ModelMerger(object):
|
||||
merged_pipe = self.merge_diffusion_models(
|
||||
model_paths, alpha, merge_method, force, **kwargs
|
||||
)
|
||||
dump_path = config.models_path / base_model.value / ModelType.Main.value
|
||||
dump_path = Path(merge_dest_directory) if merge_dest_directory else config.models_path / base_model.value / ModelType.Main.value
|
||||
dump_path.mkdir(parents=True, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
|
||||
|
103
invokeai/backend/model_management/model_search.py
Normal file
103
invokeai/backend/model_management/model_search.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
||||
"""
|
||||
Abstract base class for recursive directory search for models.
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, types
|
||||
from pathlib import Path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class ModelSearch(ABC):
|
||||
def __init__(self, directories: List[Path], logger: types.ModuleType=logger):
|
||||
"""
|
||||
Initialize a recursive model directory search.
|
||||
:param directories: List of directory Paths to recurse through
|
||||
:param logger: Logger to use
|
||||
"""
|
||||
self.directories = directories
|
||||
self.logger = logger
|
||||
self._items_scanned = 0
|
||||
self._models_found = 0
|
||||
self._scanned_dirs = set()
|
||||
self._scanned_paths = set()
|
||||
self._pruned_paths = set()
|
||||
|
||||
@abstractmethod
|
||||
def on_search_started(self):
|
||||
"""
|
||||
Called before the scan starts.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_model_found(self, model: Path):
|
||||
"""
|
||||
Process a found model. Raise an exception if something goes wrong.
|
||||
:param model: Model to process - could be a directory or checkpoint.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_search_completed(self):
|
||||
"""
|
||||
Perform some activity when the scan is completed. May use instance
|
||||
variables, items_scanned and models_found
|
||||
"""
|
||||
pass
|
||||
|
||||
def search(self):
|
||||
self.on_search_started()
|
||||
for dir in self.directories:
|
||||
self.walk_directory(dir)
|
||||
self.on_search_completed()
|
||||
|
||||
def walk_directory(self, path: Path):
|
||||
for root, dirs, files in os.walk(path):
|
||||
if str(Path(root).name).startswith('.'):
|
||||
self._pruned_paths.add(root)
|
||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||
continue
|
||||
|
||||
self._items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path in self._scanned_paths or path.parent in self._scanned_dirs:
|
||||
self._scanned_dirs.add(path)
|
||||
continue
|
||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
self._scanned_dirs.add(path)
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path.parent in self._scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
class FindModels(ModelSearch):
|
||||
def on_search_started(self):
|
||||
self.models_found: Set[Path] = set()
|
||||
|
||||
def on_model_found(self,model: Path):
|
||||
self.models_found.add(model)
|
||||
|
||||
def on_search_completed(self):
|
||||
pass
|
||||
|
||||
def list_models(self) -> List[Path]:
|
||||
self.search()
|
||||
return self.models_found
|
||||
|
||||
|
@ -48,7 +48,9 @@ for base_model, models in MODEL_CLASSES.items():
|
||||
model_configs.discard(None)
|
||||
MODEL_CONFIGS.extend(model_configs)
|
||||
|
||||
for cfg in model_configs:
|
||||
# LS: sort to get the checkpoint configs first, which makes
|
||||
# for a better template in the Swagger docs
|
||||
for cfg in sorted(model_configs, key=lambda x: str(x)):
|
||||
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
||||
openapi_cfg_name = model_name + cfg_name
|
||||
if openapi_cfg_name in vars():
|
||||
|
@ -59,7 +59,6 @@ class ModelConfigBase(BaseModel):
|
||||
path: str # or Path
|
||||
description: Optional[str] = Field(None)
|
||||
model_format: Optional[str] = Field(None)
|
||||
# do not save to config
|
||||
error: Optional[ModelError] = Field(None)
|
||||
|
||||
class Config:
|
||||
|
@ -1,8 +1,7 @@
|
||||
import os
|
||||
import torch
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Literal
|
||||
from typing import Optional
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
@ -14,6 +13,7 @@ from .base import (
|
||||
calc_model_size_by_data,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
ModelNotFoundException,
|
||||
)
|
||||
|
||||
class ControlNetModelFormat(str, Enum):
|
||||
@ -60,10 +60,20 @@ class ControlNetModel(ModelBase):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in controlnet model")
|
||||
|
||||
model = self.model_class.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
model = None
|
||||
for variant in ['fp16',None]:
|
||||
try:
|
||||
model = self.model_class.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
variant=variant,
|
||||
)
|
||||
break
|
||||
except:
|
||||
pass
|
||||
if not model:
|
||||
raise ModelNotFoundException()
|
||||
|
||||
# calc more accurate size
|
||||
self.model_size = calc_model_size_by_data(model)
|
||||
return model
|
||||
|
@ -38,7 +38,6 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
config: str
|
||||
variant: ModelVariantType
|
||||
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion1
|
||||
assert model_type == ModelType.Main
|
||||
|
169
invokeai/frontend/web/dist/assets/App-3986879c.js
vendored
Normal file
169
invokeai/frontend/web/dist/assets/App-3986879c.js
vendored
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
invokeai/frontend/web/dist/assets/App-6125620a.css
vendored
Normal file
1
invokeai/frontend/web/dist/assets/App-6125620a.css
vendored
Normal file
File diff suppressed because one or more lines are too long
199
invokeai/frontend/web/dist/assets/App-a44d46fe.js
vendored
199
invokeai/frontend/web/dist/assets/App-a44d46fe.js
vendored
File diff suppressed because one or more lines are too long
169
invokeai/frontend/web/dist/assets/App-c8b96e06.js
vendored
Normal file
169
invokeai/frontend/web/dist/assets/App-c8b96e06.js
vendored
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
invokeai/frontend/web/dist/assets/MantineProvider-e5b33be1.js
vendored
Normal file
1
invokeai/frontend/web/dist/assets/MantineProvider-e5b33be1.js
vendored
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
302
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-fa40c0d9.js
vendored
Normal file
302
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-fa40c0d9.js
vendored
Normal file
File diff suppressed because one or more lines are too long
125
invokeai/frontend/web/dist/assets/index-078526aa.js
vendored
125
invokeai/frontend/web/dist/assets/index-078526aa.js
vendored
File diff suppressed because one or more lines are too long
125
invokeai/frontend/web/dist/assets/index-8888b06f.js
vendored
Normal file
125
invokeai/frontend/web/dist/assets/index-8888b06f.js
vendored
Normal file
File diff suppressed because one or more lines are too long
125
invokeai/frontend/web/dist/assets/index-f1a5f9cf.js
vendored
Normal file
125
invokeai/frontend/web/dist/assets/index-f1a5f9cf.js
vendored
Normal file
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-078526aa.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-8888b06f.js"></script>
|
||||
</head>
|
||||
|
||||
<body dir="ltr">
|
||||
|
31
invokeai/frontend/web/dist/locales/en.json
vendored
31
invokeai/frontend/web/dist/locales/en.json
vendored
@ -102,7 +102,8 @@
|
||||
"openInNewTab": "Open in New Tab",
|
||||
"dontAskMeAgain": "Don't ask me again",
|
||||
"areYouSure": "Are you sure?",
|
||||
"imagePrompt": "Image Prompt"
|
||||
"imagePrompt": "Image Prompt",
|
||||
"clearNodes": "Are you sure you want to clear all nodes?"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generations",
|
||||
@ -118,7 +119,7 @@
|
||||
"pinGallery": "Pin Gallery",
|
||||
"allImagesLoaded": "All Images Loaded",
|
||||
"loadMore": "Load More",
|
||||
"noImagesInGallery": "No Images In Gallery",
|
||||
"noImagesInGallery": "No Images to Display",
|
||||
"deleteImage": "Delete Image",
|
||||
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
||||
"deleteImagePermanent": "Deleted images cannot be restored.",
|
||||
@ -342,6 +343,7 @@
|
||||
"safetensorModels": "SafeTensors",
|
||||
"modelAdded": "Model Added",
|
||||
"modelUpdated": "Model Updated",
|
||||
"modelUpdateFailed": "Model Update Failed",
|
||||
"modelEntryDeleted": "Model Entry Deleted",
|
||||
"cannotUseSpaces": "Cannot Use Spaces",
|
||||
"addNew": "Add New",
|
||||
@ -396,8 +398,8 @@
|
||||
"delete": "Delete",
|
||||
"deleteModel": "Delete Model",
|
||||
"deleteConfig": "Delete Config",
|
||||
"deleteMsg1": "Are you sure you want to delete this model entry from InvokeAI?",
|
||||
"deleteMsg2": "This will not delete the model checkpoint file from your disk. You can readd them if you wish to.",
|
||||
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
|
||||
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
|
||||
"formMessageDiffusersModelLocation": "Diffusers Model Location",
|
||||
"formMessageDiffusersModelLocationDesc": "Please enter at least one.",
|
||||
"formMessageDiffusersVAELocation": "VAE Location",
|
||||
@ -408,7 +410,7 @@
|
||||
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
|
||||
"convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.",
|
||||
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
|
||||
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 4GB-7GB in size.",
|
||||
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
|
||||
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
|
||||
"convertToDiffusersSaveLocation": "Save Location",
|
||||
"v1": "v1",
|
||||
@ -419,12 +421,14 @@
|
||||
"pathToCustomConfig": "Path To Custom Config",
|
||||
"statusConverting": "Converting",
|
||||
"modelConverted": "Model Converted",
|
||||
"modelConversionFailed": "Model Conversion Failed",
|
||||
"sameFolder": "Same folder",
|
||||
"invokeRoot": "InvokeAI folder",
|
||||
"custom": "Custom",
|
||||
"customSaveLocation": "Custom Save Location",
|
||||
"merge": "Merge",
|
||||
"modelsMerged": "Models Merged",
|
||||
"modelsMergeFailed": "Model Merge Failed",
|
||||
"mergeModels": "Merge Models",
|
||||
"modelOne": "Model 1",
|
||||
"modelTwo": "Model 2",
|
||||
@ -445,7 +449,8 @@
|
||||
"weightedSum": "Weighted Sum",
|
||||
"none": "none",
|
||||
"addDifference": "Add Difference",
|
||||
"pickModelType": "Pick Model Type"
|
||||
"pickModelType": "Pick Model Type",
|
||||
"selectModel": "Select Model"
|
||||
},
|
||||
"parameters": {
|
||||
"general": "General",
|
||||
@ -528,7 +533,7 @@
|
||||
"hidePreview": "Hide Preview",
|
||||
"showPreview": "Show Preview",
|
||||
"controlNetControlMode": "Control Mode",
|
||||
"clipSkip": "Clip Skip",
|
||||
"clipSkip": "CLIP Skip",
|
||||
"aspectRatio": "Ratio"
|
||||
},
|
||||
"settings": {
|
||||
@ -593,7 +598,11 @@
|
||||
"metadataLoadFailed": "Failed to load metadata",
|
||||
"initialImageSet": "Initial Image Set",
|
||||
"initialImageNotSet": "Initial Image Not Set",
|
||||
"initialImageNotSetDesc": "Could not load initial image"
|
||||
"initialImageNotSetDesc": "Could not load initial image",
|
||||
"nodesSaved": "Nodes Saved",
|
||||
"nodesLoaded": "Nodes Loaded",
|
||||
"nodesLoadedFailed": "Failed To Load Nodes",
|
||||
"nodesCleared": "Nodes Cleared"
|
||||
},
|
||||
"tooltip": {
|
||||
"feature": {
|
||||
@ -674,5 +683,11 @@
|
||||
"showProgressImages": "Show Progress Images",
|
||||
"hideProgressImages": "Hide Progress Images",
|
||||
"swapSizes": "Swap Sizes"
|
||||
},
|
||||
"nodes": {
|
||||
"reloadSchema": "Reload Schema",
|
||||
"saveNodes": "Save Nodes",
|
||||
"loadNodes": "Load Nodes",
|
||||
"clearNodes": "Clear Nodes"
|
||||
}
|
||||
}
|
||||
|
@ -13,7 +13,11 @@ import { RootState } from 'app/store/store';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'controlNet' });
|
||||
|
||||
const predicate: AnyListenerPredicate<RootState> = (action, state) => {
|
||||
const predicate: AnyListenerPredicate<RootState> = (
|
||||
action,
|
||||
state,
|
||||
prevState
|
||||
) => {
|
||||
const isActionMatched =
|
||||
controlNetProcessorParamsChanged.match(action) ||
|
||||
controlNetModelChanged.match(action) ||
|
||||
@ -25,6 +29,16 @@ const predicate: AnyListenerPredicate<RootState> = (action, state) => {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (controlNetAutoConfigToggled.match(action)) {
|
||||
// do not process if the user just disabled auto-config
|
||||
if (
|
||||
prevState.controlNet.controlNets[action.payload.controlNetId]
|
||||
.shouldAutoConfig === true
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const { controlImage, processorType, shouldAutoConfig } =
|
||||
state.controlNet.controlNets[action.payload.controlNetId];
|
||||
|
||||
|
@ -10,6 +10,7 @@ import { zMainModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { startAppListening } from '..';
|
||||
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
|
||||
|
||||
const moduleLog = log.child({ module: 'models' });
|
||||
|
||||
@ -51,7 +52,14 @@ export const addModelSelectedListener = () => {
|
||||
modelsCleared += 1;
|
||||
}
|
||||
|
||||
// TODO: handle incompatible controlnet; pending model manager support
|
||||
const { controlNets } = state.controlNet;
|
||||
forEach(controlNets, (controlNet, controlNetId) => {
|
||||
if (controlNet.model?.base_model !== base_model) {
|
||||
dispatch(controlNetRemoved({ controlNetId }));
|
||||
modelsCleared += 1;
|
||||
}
|
||||
});
|
||||
|
||||
if (modelsCleared > 0) {
|
||||
dispatch(
|
||||
addToast(
|
||||
|
@ -11,6 +11,7 @@ import {
|
||||
import { forEach, some } from 'lodash-es';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import { startAppListening } from '..';
|
||||
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
|
||||
|
||||
const moduleLog = log.child({ module: 'models' });
|
||||
|
||||
@ -127,7 +128,22 @@ export const addModelsLoadedListener = () => {
|
||||
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// ControlNet models loaded - need to remove missing ControlNets from state
|
||||
// TODO: pending model manager controlnet support
|
||||
const controlNets = getState().controlNet.controlNets;
|
||||
|
||||
forEach(controlNets, (controlNet, controlNetId) => {
|
||||
const isControlNetAvailable = some(
|
||||
action.payload.entities,
|
||||
(m) =>
|
||||
m?.model_name === controlNet?.model?.model_name &&
|
||||
m?.base_model === controlNet?.model?.base_model
|
||||
);
|
||||
|
||||
if (isControlNetAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(controlNetRemoved({ controlNetId }));
|
||||
});
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,5 +1,5 @@
|
||||
import {
|
||||
CONTROLNET_MODELS,
|
||||
// CONTROLNET_MODELS,
|
||||
CONTROLNET_PROCESSORS,
|
||||
} from 'features/controlNet/store/constants';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
@ -128,7 +128,7 @@ export type AppConfig = {
|
||||
canRestoreDeletedImagesFromBin: boolean;
|
||||
sd: {
|
||||
defaultModel?: string;
|
||||
disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[];
|
||||
disabledControlNetModels: string[];
|
||||
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
|
||||
iterations: {
|
||||
initial: number;
|
||||
|
@ -170,12 +170,14 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
</>
|
||||
)}
|
||||
{!imageDTO && isUploadDisabled && noContentFallback}
|
||||
<IAIDroppable
|
||||
data={droppableData}
|
||||
disabled={isDropDisabled}
|
||||
dropLabel={dropLabel}
|
||||
/>
|
||||
{imageDTO && (
|
||||
{!isDropDisabled && (
|
||||
<IAIDroppable
|
||||
data={droppableData}
|
||||
disabled={isDropDisabled}
|
||||
dropLabel={dropLabel}
|
||||
/>
|
||||
)}
|
||||
{imageDTO && !isDragDisabled && (
|
||||
<IAIDraggable
|
||||
data={draggableData}
|
||||
disabled={isDragDisabled || !imageDTO}
|
||||
|
@ -1,17 +1,25 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
|
||||
import { MultiSelect, MultiSelectProps } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||
import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles';
|
||||
import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
|
||||
|
||||
type IAIMultiSelectProps = MultiSelectProps & {
|
||||
type IAIMultiSelectProps = Omit<MultiSelectProps, 'label'> & {
|
||||
tooltip?: string;
|
||||
inputRef?: RefObject<HTMLInputElement>;
|
||||
label?: string;
|
||||
};
|
||||
|
||||
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
const { searchable = true, tooltip, inputRef, ...rest } = props;
|
||||
const {
|
||||
searchable = true,
|
||||
tooltip,
|
||||
inputRef,
|
||||
label,
|
||||
disabled,
|
||||
...rest
|
||||
} = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
@ -37,7 +45,15 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
|
||||
<MultiSelect
|
||||
label={
|
||||
label ? (
|
||||
<FormControl isDisabled={disabled}>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
</FormControl>
|
||||
) : undefined
|
||||
}
|
||||
ref={inputRef}
|
||||
disabled={disabled}
|
||||
onKeyDown={handleKeyDown}
|
||||
onKeyUp={handleKeyUp}
|
||||
searchable={searchable}
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
|
||||
import { Select, SelectProps } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||
@ -11,13 +11,22 @@ export type IAISelectDataType = {
|
||||
tooltip?: string;
|
||||
};
|
||||
|
||||
type IAISelectProps = SelectProps & {
|
||||
type IAISelectProps = Omit<SelectProps, 'label'> & {
|
||||
tooltip?: string;
|
||||
label?: string;
|
||||
inputRef?: RefObject<HTMLInputElement>;
|
||||
};
|
||||
|
||||
const IAIMantineSearchableSelect = (props: IAISelectProps) => {
|
||||
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props;
|
||||
const {
|
||||
searchable = true,
|
||||
tooltip,
|
||||
inputRef,
|
||||
onChange,
|
||||
label,
|
||||
disabled,
|
||||
...rest
|
||||
} = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [searchValue, setSearchValue] = useState('');
|
||||
@ -61,6 +70,14 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => {
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<Select
|
||||
ref={inputRef}
|
||||
label={
|
||||
label ? (
|
||||
<FormControl isDisabled={disabled}>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
</FormControl>
|
||||
) : undefined
|
||||
}
|
||||
disabled={disabled}
|
||||
searchValue={searchValue}
|
||||
onSearchChange={setSearchValue}
|
||||
onChange={handleChange}
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
|
||||
import { Select, SelectProps } from '@mantine/core';
|
||||
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
|
||||
import { RefObject, memo } from 'react';
|
||||
@ -9,19 +9,32 @@ export type IAISelectDataType = {
|
||||
tooltip?: string;
|
||||
};
|
||||
|
||||
type IAISelectProps = SelectProps & {
|
||||
type IAISelectProps = Omit<SelectProps, 'label'> & {
|
||||
tooltip?: string;
|
||||
inputRef?: RefObject<HTMLInputElement>;
|
||||
label?: string;
|
||||
};
|
||||
|
||||
const IAIMantineSelect = (props: IAISelectProps) => {
|
||||
const { tooltip, inputRef, ...rest } = props;
|
||||
const { tooltip, inputRef, label, disabled, ...rest } = props;
|
||||
|
||||
const styles = useMantineSelectStyles();
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<Select ref={inputRef} styles={styles} {...rest} />
|
||||
<Select
|
||||
label={
|
||||
label ? (
|
||||
<FormControl isDisabled={disabled}>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
</FormControl>
|
||||
) : undefined
|
||||
}
|
||||
disabled={disabled}
|
||||
ref={inputRef}
|
||||
styles={styles}
|
||||
{...rest}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
@ -43,11 +43,6 @@ import { useTranslation } from 'react-i18next';
|
||||
import { BiReset } from 'react-icons/bi';
|
||||
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
|
||||
|
||||
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
|
||||
mt: 1.5,
|
||||
fontSize: '2xs',
|
||||
};
|
||||
|
||||
export type IAIFullSliderProps = {
|
||||
label?: string;
|
||||
value: number;
|
||||
@ -207,7 +202,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
{...sliderFormControlProps}
|
||||
>
|
||||
{label && (
|
||||
<FormLabel {...sliderFormLabelProps} mb={-1}>
|
||||
<FormLabel sx={withInput ? { mb: -1.5 } : {}} {...sliderFormLabelProps}>
|
||||
{label}
|
||||
</FormLabel>
|
||||
)}
|
||||
@ -233,7 +228,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
sx={{
|
||||
insetInlineStart: '0 !important',
|
||||
insetInlineEnd: 'unset !important',
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
{...sliderMarkProps}
|
||||
>
|
||||
@ -244,7 +238,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
sx={{
|
||||
insetInlineStart: 'unset !important',
|
||||
insetInlineEnd: '0 !important',
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
{...sliderMarkProps}
|
||||
>
|
||||
@ -263,7 +256,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
sx={{
|
||||
insetInlineStart: '0 !important',
|
||||
insetInlineEnd: 'unset !important',
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
{...sliderMarkProps}
|
||||
>
|
||||
@ -278,7 +270,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
sx={{
|
||||
insetInlineStart: 'unset !important',
|
||||
insetInlineEnd: '0 !important',
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
{...sliderMarkProps}
|
||||
>
|
||||
@ -291,7 +282,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
key={m}
|
||||
value={m}
|
||||
sx={{
|
||||
...SLIDER_MARK_STYLES,
|
||||
transform: 'translateX(-50%)',
|
||||
}}
|
||||
{...sliderMarkProps}
|
||||
>
|
||||
|
@ -5,6 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { modelsApi } from '../../services/api/endpoints/models';
|
||||
import { forEach } from 'lodash-es';
|
||||
|
||||
const readinessSelector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
@ -52,6 +53,13 @@ const readinessSelector = createSelector(
|
||||
reasonsWhyNotReady.push('Seed-Weights badly formatted.');
|
||||
}
|
||||
|
||||
forEach(state.controlNet.controlNets, (controlNet, id) => {
|
||||
if (!controlNet.model) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push('ControlNet ${id} has no model selected.');
|
||||
}
|
||||
});
|
||||
|
||||
// All good
|
||||
return { isReady, reasonsWhyNotReady };
|
||||
},
|
||||
|
@ -1,10 +1,9 @@
|
||||
import { Box, ChakraProps, Flex, useColorMode } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { FaCopy, FaTrash } from 'react-icons/fa';
|
||||
import {
|
||||
ControlNetConfig,
|
||||
controlNetAdded,
|
||||
controlNetDuplicated,
|
||||
controlNetRemoved,
|
||||
controlNetToggled,
|
||||
} from '../store/controlNetSlice';
|
||||
@ -12,6 +11,9 @@ import ParamControlNetModel from './parameters/ParamControlNetModel';
|
||||
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
|
||||
|
||||
import { ChevronUpIcon } from '@chakra-ui/icons';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { useToggle } from 'react-use';
|
||||
@ -22,41 +24,41 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
|
||||
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
|
||||
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
|
||||
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
|
||||
import { mode } from 'theme/util/mode';
|
||||
|
||||
const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 };
|
||||
|
||||
type ControlNetProps = {
|
||||
controlNet: ControlNetConfig;
|
||||
controlNetId: string;
|
||||
};
|
||||
|
||||
const ControlNet = (props: ControlNetProps) => {
|
||||
const {
|
||||
controlNetId,
|
||||
isEnabled,
|
||||
model,
|
||||
weight,
|
||||
beginStepPct,
|
||||
endStepPct,
|
||||
controlMode,
|
||||
controlImage,
|
||||
processedControlImage,
|
||||
processorNode,
|
||||
processorType,
|
||||
shouldAutoConfig,
|
||||
} = props.controlNet;
|
||||
const { controlNetId } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { isEnabled, shouldAutoConfig } =
|
||||
controlNet.controlNets[controlNetId];
|
||||
|
||||
return { isEnabled, shouldAutoConfig };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
|
||||
const [isExpanded, toggleIsExpanded] = useToggle(false);
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const handleDelete = useCallback(() => {
|
||||
dispatch(controlNetRemoved({ controlNetId }));
|
||||
}, [controlNetId, dispatch]);
|
||||
|
||||
const handleDuplicate = useCallback(() => {
|
||||
dispatch(
|
||||
controlNetAdded({ controlNetId: uuidv4(), controlNet: props.controlNet })
|
||||
controlNetDuplicated({
|
||||
sourceControlNetId: controlNetId,
|
||||
newControlNetId: uuidv4(),
|
||||
})
|
||||
);
|
||||
}, [dispatch, props.controlNet]);
|
||||
}, [controlNetId, dispatch]);
|
||||
|
||||
const handleToggleIsEnabled = useCallback(() => {
|
||||
dispatch(controlNetToggled({ controlNetId }));
|
||||
@ -68,15 +70,18 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
flexDir: 'column',
|
||||
gap: 2,
|
||||
p: 3,
|
||||
bg: mode('base.200', 'base.850')(colorMode),
|
||||
borderRadius: 'base',
|
||||
position: 'relative',
|
||||
bg: 'base.200',
|
||||
_dark: {
|
||||
bg: 'base.850',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Flex sx={{ gap: 2 }}>
|
||||
<Flex sx={{ gap: 2, alignItems: 'center' }}>
|
||||
<IAISwitch
|
||||
tooltip="Toggle"
|
||||
aria-label="Toggle"
|
||||
tooltip={'Toggle this ControlNet'}
|
||||
aria-label={'Toggle this ControlNet'}
|
||||
isChecked={isEnabled}
|
||||
onChange={handleToggleIsEnabled}
|
||||
/>
|
||||
@ -90,7 +95,7 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
transitionDuration: '0.1s',
|
||||
}}
|
||||
>
|
||||
<ParamControlNetModel controlNetId={controlNetId} model={model} />
|
||||
<ParamControlNetModel controlNetId={controlNetId} />
|
||||
</Box>
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
@ -109,21 +114,26 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
/>
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
aria-label="Show All Options"
|
||||
tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
|
||||
aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
|
||||
onClick={toggleIsExpanded}
|
||||
variant="link"
|
||||
icon={
|
||||
<ChevronUpIcon
|
||||
sx={{
|
||||
boxSize: 4,
|
||||
color: mode('base.700', 'base.300')(colorMode),
|
||||
color: 'base.700',
|
||||
transform: isExpanded ? 'rotate(0deg)' : 'rotate(180deg)',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
_dark: {
|
||||
color: 'base.300',
|
||||
},
|
||||
}}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
|
||||
{!shouldAutoConfig && (
|
||||
<Box
|
||||
sx={{
|
||||
@ -131,85 +141,59 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
w: 1.5,
|
||||
h: 1.5,
|
||||
borderRadius: 'full',
|
||||
bg: mode('error.700', 'error.200')(colorMode),
|
||||
top: 4,
|
||||
insetInlineEnd: 4,
|
||||
bg: 'accent.700',
|
||||
_dark: {
|
||||
bg: 'accent.400',
|
||||
},
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
{isEnabled && (
|
||||
<>
|
||||
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
|
||||
<Flex sx={{ gap: 4, w: 'full' }}>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
gap: 3,
|
||||
w: 'full',
|
||||
paddingInlineStart: 1,
|
||||
paddingInlineEnd: isExpanded ? 1 : 0,
|
||||
pb: 2,
|
||||
justifyContent: 'space-between',
|
||||
}}
|
||||
>
|
||||
<ParamControlNetWeight
|
||||
controlNetId={controlNetId}
|
||||
weight={weight}
|
||||
mini={!isExpanded}
|
||||
/>
|
||||
<ParamControlNetBeginEnd
|
||||
controlNetId={controlNetId}
|
||||
beginStepPct={beginStepPct}
|
||||
endStepPct={endStepPct}
|
||||
mini={!isExpanded}
|
||||
/>
|
||||
</Flex>
|
||||
{!isExpanded && (
|
||||
<Flex
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
h: 24,
|
||||
w: 24,
|
||||
aspectRatio: '1/1',
|
||||
}}
|
||||
>
|
||||
<ControlNetImagePreview
|
||||
controlNet={props.controlNet}
|
||||
height={24}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
<ParamControlNetControlMode
|
||||
controlNetId={controlNetId}
|
||||
controlMode={controlMode}
|
||||
/>
|
||||
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
|
||||
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
gap: 3,
|
||||
h: 28,
|
||||
w: 'full',
|
||||
paddingInlineStart: 1,
|
||||
paddingInlineEnd: isExpanded ? 1 : 0,
|
||||
pb: 2,
|
||||
justifyContent: 'space-between',
|
||||
}}
|
||||
>
|
||||
<ParamControlNetWeight controlNetId={controlNetId} />
|
||||
<ParamControlNetBeginEnd controlNetId={controlNetId} />
|
||||
</Flex>
|
||||
|
||||
{isExpanded && (
|
||||
<>
|
||||
<Box mt={2}>
|
||||
<ControlNetImagePreview
|
||||
controlNet={props.controlNet}
|
||||
height={96}
|
||||
/>
|
||||
</Box>
|
||||
<ParamControlNetProcessorSelect
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
/>
|
||||
<ControlNetProcessorComponent
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
/>
|
||||
<ParamControlNetShouldAutoConfig
|
||||
controlNetId={controlNetId}
|
||||
shouldAutoConfig={shouldAutoConfig}
|
||||
/>
|
||||
</>
|
||||
{!isExpanded && (
|
||||
<Flex
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
h: 28,
|
||||
w: 28,
|
||||
aspectRatio: '1/1',
|
||||
mt: 3,
|
||||
}}
|
||||
>
|
||||
<ControlNetImagePreview controlNetId={controlNetId} height={28} />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
<Box mt={2}>
|
||||
<ParamControlNetControlMode controlNetId={controlNetId} />
|
||||
</Box>
|
||||
<ParamControlNetProcessorSelect controlNetId={controlNetId} />
|
||||
</Flex>
|
||||
|
||||
{isExpanded && (
|
||||
<>
|
||||
<ControlNetImagePreview controlNetId={controlNetId} height="392px" />
|
||||
<ParamControlNetShouldAutoConfig controlNetId={controlNetId} />
|
||||
<ControlNetProcessorComponent controlNetId={controlNetId} />
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
|
@ -5,42 +5,57 @@ import {
|
||||
TypesafeDraggableData,
|
||||
TypesafeDroppableData,
|
||||
} from 'app/components/ImageDnd/typesafeDnd';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { PostUploadAction } from 'services/api/thunks/image';
|
||||
import {
|
||||
ControlNetConfig,
|
||||
controlNetImageChanged,
|
||||
controlNetSelector,
|
||||
} from '../store/controlNetSlice';
|
||||
|
||||
const selector = createSelector(
|
||||
controlNetSelector,
|
||||
(controlNet) => {
|
||||
const { pendingControlImages } = controlNet;
|
||||
return { pendingControlImages };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
import { controlNetImageChanged } from '../store/controlNetSlice';
|
||||
|
||||
type Props = {
|
||||
controlNet: ControlNetConfig;
|
||||
controlNetId: string;
|
||||
height: SystemStyleObject['h'];
|
||||
};
|
||||
|
||||
const ControlNetImagePreview = (props: Props) => {
|
||||
const { height } = props;
|
||||
const {
|
||||
controlNetId,
|
||||
controlImage: controlImageName,
|
||||
processedControlImage: processedControlImageName,
|
||||
processorType,
|
||||
} = props.controlNet;
|
||||
const { height, controlNetId } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { pendingControlImages } = useAppSelector(selector);
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { pendingControlImages } = controlNet;
|
||||
const {
|
||||
controlImage,
|
||||
processedControlImage,
|
||||
processorType,
|
||||
isEnabled,
|
||||
} = controlNet.controlNets[controlNetId];
|
||||
|
||||
return {
|
||||
controlImageName: controlImage,
|
||||
processedControlImageName: processedControlImage,
|
||||
processorType,
|
||||
isEnabled,
|
||||
pendingControlImages,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
const {
|
||||
controlImageName,
|
||||
processedControlImageName,
|
||||
processorType,
|
||||
pendingControlImages,
|
||||
isEnabled,
|
||||
} = useAppSelector(selector);
|
||||
|
||||
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
||||
|
||||
@ -110,13 +125,15 @@ const ControlNetImagePreview = (props: Props) => {
|
||||
h: height,
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
pointerEvents: isEnabled ? 'auto' : 'none',
|
||||
opacity: isEnabled ? 1 : 0.5,
|
||||
}}
|
||||
>
|
||||
<IAIDndImage
|
||||
draggableData={draggableData}
|
||||
droppableData={droppableData}
|
||||
imageDTO={controlImage}
|
||||
isDropDisabled={shouldShowProcessedImage}
|
||||
isDropDisabled={shouldShowProcessedImage || !isEnabled}
|
||||
onClickReset={handleResetControlImage}
|
||||
postUploadAction={postUploadAction}
|
||||
resetTooltip="Reset Control Image"
|
||||
@ -140,6 +157,7 @@ const ControlNetImagePreview = (props: Props) => {
|
||||
droppableData={droppableData}
|
||||
imageDTO={processedControlImage}
|
||||
isUploadDisabled={true}
|
||||
isDropDisabled={!isEnabled}
|
||||
onClickReset={handleResetControlImage}
|
||||
resetTooltip="Reset Control Image"
|
||||
withResetIcon={Boolean(controlImage)}
|
||||
|
@ -1,10 +1,13 @@
|
||||
import { memo } from 'react';
|
||||
import { RequiredControlNetProcessorNode } from '../store/types';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { memo, useMemo } from 'react';
|
||||
import CannyProcessor from './processors/CannyProcessor';
|
||||
import HedProcessor from './processors/HedProcessor';
|
||||
import LineartProcessor from './processors/LineartProcessor';
|
||||
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
|
||||
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
|
||||
import HedProcessor from './processors/HedProcessor';
|
||||
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
|
||||
import LineartProcessor from './processors/LineartProcessor';
|
||||
import MediapipeFaceProcessor from './processors/MediapipeFaceProcessor';
|
||||
import MidasDepthProcessor from './processors/MidasDepthProcessor';
|
||||
import MlsdImageProcessor from './processors/MlsdImageProcessor';
|
||||
@ -15,23 +18,45 @@ import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
|
||||
|
||||
export type ControlNetProcessorProps = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredControlNetProcessorNode;
|
||||
};
|
||||
|
||||
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId } = props;
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { isEnabled, processorNode } =
|
||||
controlNet.controlNets[controlNetId];
|
||||
|
||||
return { isEnabled, processorNode };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
const { isEnabled, processorNode } = useAppSelector(selector);
|
||||
|
||||
if (processorNode.type === 'canny_image_processor') {
|
||||
return (
|
||||
<CannyProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (processorNode.type === 'hed_image_processor') {
|
||||
return (
|
||||
<HedProcessor controlNetId={controlNetId} processorNode={processorNode} />
|
||||
<HedProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
@ -40,6 +65,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<LineartProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -49,6 +75,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<ContentShuffleProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -58,6 +85,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<LineartAnimeProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -67,6 +95,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<MediapipeFaceProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -76,6 +105,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<MidasDepthProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -85,6 +115,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<MlsdImageProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -94,6 +125,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<NormalBaeProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -103,6 +135,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<OpenposeProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -112,6 +145,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<PidiProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -121,6 +155,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<ZoeDepthProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
@ -1,18 +1,36 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
shouldAutoConfig: boolean;
|
||||
};
|
||||
|
||||
const ParamControlNetShouldAutoConfig = (props: Props) => {
|
||||
const { controlNetId, shouldAutoConfig } = props;
|
||||
const { controlNetId } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { isEnabled, shouldAutoConfig } =
|
||||
controlNet.controlNets[controlNetId];
|
||||
return { isEnabled, shouldAutoConfig };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleShouldAutoConfigChanged = useCallback(() => {
|
||||
dispatch(controlNetAutoConfigToggled({ controlNetId }));
|
||||
}, [controlNetId, dispatch]);
|
||||
@ -23,7 +41,7 @@ const ParamControlNetShouldAutoConfig = (props: Props) => {
|
||||
aria-label="Auto configure processor"
|
||||
isChecked={shouldAutoConfig}
|
||||
onChange={handleShouldAutoConfigChanged}
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,5 +1,4 @@
|
||||
import {
|
||||
ChakraProps,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
HStack,
|
||||
@ -10,34 +9,41 @@ import {
|
||||
RangeSliderTrack,
|
||||
Tooltip,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import {
|
||||
controlNetBeginStepPctChanged,
|
||||
controlNetEndStepPctChanged,
|
||||
} from 'features/controlNet/store/controlNetSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
|
||||
mt: 1.5,
|
||||
fontSize: '2xs',
|
||||
fontWeight: '500',
|
||||
color: 'base.400',
|
||||
};
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
mini?: boolean;
|
||||
};
|
||||
|
||||
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
|
||||
|
||||
const ParamControlNetBeginEnd = (props: Props) => {
|
||||
const { controlNetId, beginStepPct, mini = false, endStepPct } = props;
|
||||
const { controlNetId } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { beginStepPct, endStepPct, isEnabled } =
|
||||
controlNet.controlNets[controlNetId];
|
||||
return { beginStepPct, endStepPct, isEnabled };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
const { beginStepPct, endStepPct, isEnabled } = useAppSelector(selector);
|
||||
|
||||
const handleStepPctChanged = useCallback(
|
||||
(v: number[]) => {
|
||||
@ -55,7 +61,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
|
||||
}, [controlNetId, dispatch]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormControl isDisabled={!isEnabled}>
|
||||
<FormLabel>Begin / End Step Percentage</FormLabel>
|
||||
<HStack w="100%" gap={2} alignItems="center">
|
||||
<RangeSlider
|
||||
@ -66,6 +72,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
|
||||
max={1}
|
||||
step={0.01}
|
||||
minStepsBetweenThumbs={5}
|
||||
isDisabled={!isEnabled}
|
||||
>
|
||||
<RangeSliderTrack>
|
||||
<RangeSliderFilledTrack />
|
||||
@ -76,38 +83,33 @@ const ParamControlNetBeginEnd = (props: Props) => {
|
||||
<Tooltip label={formatPct(endStepPct)} placement="top" hasArrow>
|
||||
<RangeSliderThumb index={1} />
|
||||
</Tooltip>
|
||||
{!mini && (
|
||||
<>
|
||||
<RangeSliderMark
|
||||
value={0}
|
||||
sx={{
|
||||
insetInlineStart: '0 !important',
|
||||
insetInlineEnd: 'unset !important',
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
>
|
||||
0%
|
||||
</RangeSliderMark>
|
||||
<RangeSliderMark
|
||||
value={0.5}
|
||||
sx={{
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
>
|
||||
50%
|
||||
</RangeSliderMark>
|
||||
<RangeSliderMark
|
||||
value={1}
|
||||
sx={{
|
||||
insetInlineStart: 'unset !important',
|
||||
insetInlineEnd: '0 !important',
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
>
|
||||
100%
|
||||
</RangeSliderMark>
|
||||
</>
|
||||
)}
|
||||
<RangeSliderMark
|
||||
value={0}
|
||||
sx={{
|
||||
insetInlineStart: '0 !important',
|
||||
insetInlineEnd: 'unset !important',
|
||||
}}
|
||||
>
|
||||
0%
|
||||
</RangeSliderMark>
|
||||
<RangeSliderMark
|
||||
value={0.5}
|
||||
sx={{
|
||||
insetInlineStart: '50% !important',
|
||||
transform: 'translateX(-50%)',
|
||||
}}
|
||||
>
|
||||
50%
|
||||
</RangeSliderMark>
|
||||
<RangeSliderMark
|
||||
value={1}
|
||||
sx={{
|
||||
insetInlineStart: 'unset !important',
|
||||
insetInlineEnd: '0 !important',
|
||||
}}
|
||||
>
|
||||
100%
|
||||
</RangeSliderMark>
|
||||
</RangeSlider>
|
||||
</HStack>
|
||||
</FormControl>
|
||||
|
@ -1,15 +1,17 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import {
|
||||
ControlModes,
|
||||
controlNetControlModeChanged,
|
||||
} from 'features/controlNet/store/controlNetSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type ParamControlNetControlModeProps = {
|
||||
controlNetId: string;
|
||||
controlMode: string;
|
||||
};
|
||||
|
||||
const CONTROL_MODE_DATA = [
|
||||
@ -22,8 +24,23 @@ const CONTROL_MODE_DATA = [
|
||||
export default function ParamControlNetControlMode(
|
||||
props: ParamControlNetControlModeProps
|
||||
) {
|
||||
const { controlNetId, controlMode = false } = props;
|
||||
const { controlNetId } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { controlMode, isEnabled } =
|
||||
controlNet.controlNets[controlNetId];
|
||||
return { controlMode, isEnabled };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
const { controlMode, isEnabled } = useAppSelector(selector);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@ -36,7 +53,8 @@ export default function ParamControlNetControlMode(
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
label={t('parameters.controlNetControlMode')}
|
||||
disabled={!isEnabled}
|
||||
label="Control Mode"
|
||||
data={CONTROL_MODE_DATA}
|
||||
value={String(controlMode)}
|
||||
onChange={handleControlModeChange}
|
||||
|
@ -29,6 +29,9 @@ const ParamControlNetFeatureToggle = () => {
|
||||
label="Enable ControlNet"
|
||||
isChecked={isEnabled}
|
||||
onChange={handleChange}
|
||||
formControlProps={{
|
||||
width: '100%',
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,28 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { controlNetToggled } from 'features/controlNet/store/controlNetSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
type ParamControlNetIsEnabledProps = {
|
||||
controlNetId: string;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
|
||||
const { controlNetId, isEnabled } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleIsEnabledChanged = useCallback(() => {
|
||||
dispatch(controlNetToggled({ controlNetId }));
|
||||
}, [dispatch, controlNetId]);
|
||||
|
||||
return (
|
||||
<IAISwitch
|
||||
label="Enabled"
|
||||
isChecked={isEnabled}
|
||||
onChange={handleIsEnabledChanged}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamControlNetIsEnabled);
|
@ -1,36 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIFullCheckbox from 'common/components/IAIFullCheckbox';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import {
|
||||
controlNetToggled,
|
||||
isControlNetImagePreprocessedToggled,
|
||||
} from 'features/controlNet/store/controlNetSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
type ParamControlNetIsEnabledProps = {
|
||||
controlNetId: string;
|
||||
isControlImageProcessed: boolean;
|
||||
};
|
||||
|
||||
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
|
||||
const { controlNetId, isControlImageProcessed } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleIsControlImageProcessedToggled = useCallback(() => {
|
||||
dispatch(
|
||||
isControlNetImagePreprocessedToggled({
|
||||
controlNetId,
|
||||
})
|
||||
);
|
||||
}, [controlNetId, dispatch]);
|
||||
|
||||
return (
|
||||
<IAISwitch
|
||||
label="Preprocess"
|
||||
isChecked={isControlImageProcessed}
|
||||
onChange={handleIsControlImageProcessedToggled}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamControlNetIsEnabled);
|
@ -1,59 +1,118 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect, {
|
||||
IAISelectDataType,
|
||||
} from 'common/components/IAIMantineSearchableSelect';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import {
|
||||
CONTROLNET_MODELS,
|
||||
ControlNetModelName,
|
||||
} from 'features/controlNet/store/constants';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { configSelector } from 'features/system/store/configSelectors';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
type ParamControlNetModelProps = {
|
||||
controlNetId: string;
|
||||
model: ControlNetModelName;
|
||||
};
|
||||
|
||||
const selector = createSelector(configSelector, (config) => {
|
||||
const controlNetModels: IAISelectDataType[] = map(CONTROLNET_MODELS, (m) => ({
|
||||
label: m.label,
|
||||
value: m.type,
|
||||
})).filter(
|
||||
(d) =>
|
||||
!config.sd.disabledControlNetModels.includes(
|
||||
d.value as ControlNetModelName
|
||||
)
|
||||
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||
const { controlNetId } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ generation, controlNet }) => {
|
||||
const { model } = generation;
|
||||
const controlNetModel = controlNet.controlNets[controlNetId]?.model;
|
||||
const isEnabled = controlNet.controlNets[controlNetId]?.isEnabled;
|
||||
return { mainModel: model, controlNetModel, isEnabled };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
return controlNetModels;
|
||||
});
|
||||
const { mainModel, controlNetModel, isEnabled } = useAppSelector(selector);
|
||||
|
||||
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||
const { controlNetId, model } = props;
|
||||
const controlNetModels = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!controlNetModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(controlNetModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
const disabled = model?.base_model !== mainModel?.base_model;
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
disabled,
|
||||
tooltip: disabled
|
||||
? `Incompatible base model: ${model.base_model}`
|
||||
: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [controlNetModels, mainModel?.base_model]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
controlNetModels?.entities[
|
||||
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
|
||||
] ?? null,
|
||||
[
|
||||
controlNetModel?.base_model,
|
||||
controlNetModel?.model_name,
|
||||
controlNetModels?.entities,
|
||||
]
|
||||
);
|
||||
|
||||
const handleModelChanged = useCallback(
|
||||
(val: string | null) => {
|
||||
// TODO: do not cast
|
||||
const model = val as ControlNetModelName;
|
||||
dispatch(controlNetModelChanged({ controlNetId, model }));
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newControlNetModel = modelIdToControlNetModelParam(v);
|
||||
|
||||
if (!newControlNetModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
controlNetModelChanged({ controlNetId, model: newControlNetModel })
|
||||
);
|
||||
},
|
||||
[controlNetId, dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSearchableSelect
|
||||
data={controlNetModels}
|
||||
value={model}
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
data={data}
|
||||
error={
|
||||
!selectedModel || mainModel?.base_model !== selectedModel.base_model
|
||||
}
|
||||
placeholder="Select a model"
|
||||
value={selectedModel?.id ?? null}
|
||||
onChange={handleModelChanged}
|
||||
disabled={!isReady}
|
||||
tooltip={model}
|
||||
disabled={isBusy || !isEnabled}
|
||||
tooltip={selectedModel?.description}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,24 +1,22 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSearchableSelect, {
|
||||
IAISelectDataType,
|
||||
} from 'common/components/IAIMantineSearchableSelect';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import { configSelector } from 'features/system/store/configSelectors';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { CONTROLNET_PROCESSORS } from '../../store/constants';
|
||||
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
|
||||
import {
|
||||
ControlNetProcessorNode,
|
||||
ControlNetProcessorType,
|
||||
} from '../../store/types';
|
||||
import { ControlNetProcessorType } from '../../store/types';
|
||||
import { FormControl, FormLabel } from '@chakra-ui/react';
|
||||
|
||||
type ParamControlNetProcessorSelectProps = {
|
||||
controlNetId: string;
|
||||
processorNode: ControlNetProcessorNode;
|
||||
};
|
||||
|
||||
const selector = createSelector(
|
||||
@ -54,10 +52,24 @@ const selector = createSelector(
|
||||
const ParamControlNetProcessorSelect = (
|
||||
props: ParamControlNetProcessorSelectProps
|
||||
) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const { controlNetId } = props;
|
||||
const processorNodeSelector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { isEnabled, processorNode } =
|
||||
controlNet.controlNets[controlNetId];
|
||||
return { isEnabled, processorNode };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
const controlNetProcessors = useAppSelector(selector);
|
||||
const { isEnabled, processorNode } = useAppSelector(processorNodeSelector);
|
||||
|
||||
const handleProcessorTypeChanged = useCallback(
|
||||
(v: string | null) => {
|
||||
@ -77,7 +89,7 @@ const ParamControlNetProcessorSelect = (
|
||||
value={processorNode.type ?? 'canny_image_processor'}
|
||||
data={controlNetProcessors}
|
||||
onChange={handleProcessorTypeChanged}
|
||||
disabled={!isReady}
|
||||
disabled={isBusy || !isEnabled}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,18 +1,32 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
|
||||
type ParamControlNetWeightProps = {
|
||||
controlNetId: string;
|
||||
weight: number;
|
||||
mini?: boolean;
|
||||
};
|
||||
|
||||
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
|
||||
const { controlNetId, weight, mini = false } = props;
|
||||
const { controlNetId } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { weight, isEnabled } = controlNet.controlNets[controlNetId];
|
||||
return { weight, isEnabled };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
const { weight, isEnabled } = useAppSelector(selector);
|
||||
const handleWeightChanged = useCallback(
|
||||
(weight: number) => {
|
||||
dispatch(controlNetWeightChanged({ controlNetId, weight }));
|
||||
@ -22,15 +36,15 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
isDisabled={!isEnabled}
|
||||
label={'Weight'}
|
||||
sliderFormLabelProps={{ pb: 2 }}
|
||||
value={weight}
|
||||
onChange={handleWeightChanged}
|
||||
min={-1}
|
||||
max={1}
|
||||
min={0}
|
||||
max={2}
|
||||
step={0.01}
|
||||
withSliderMarks={!mini}
|
||||
sliderMarks={[-1, 0, 1]}
|
||||
withSliderMarks
|
||||
sliderMarks={[0, 1, 2]}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,22 +1,25 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor
|
||||
.default as RequiredCannyImageProcessorInvocation;
|
||||
|
||||
type CannyProcessorProps = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredCannyImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const CannyProcessor = (props: CannyProcessorProps) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { low_threshold, high_threshold } = processorNode;
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
|
||||
const handleLowThresholdChanged = useCallback(
|
||||
@ -48,7 +51,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
|
||||
return (
|
||||
<ProcessorWrapper>
|
||||
<IAISlider
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
label="Low Threshold"
|
||||
value={low_threshold}
|
||||
onChange={handleLowThresholdChanged}
|
||||
@ -60,7 +63,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
|
||||
withSliderMarks
|
||||
/>
|
||||
<IAISlider
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
label="High Threshold"
|
||||
value={high_threshold}
|
||||
onChange={handleHighThresholdChanged}
|
||||
|
@ -4,20 +4,23 @@ import { RequiredContentShuffleImageProcessorInvocation } from 'features/control
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor
|
||||
.default as RequiredContentShuffleImageProcessorInvocation;
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredContentShuffleImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const ContentShuffleProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { image_resolution, detect_resolution, w, h, f } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -96,7 +99,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -108,7 +111,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="W"
|
||||
@ -120,7 +123,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="H"
|
||||
@ -132,7 +135,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="F"
|
||||
@ -144,7 +147,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,25 +1,29 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor
|
||||
.default as RequiredHedImageProcessorInvocation;
|
||||
|
||||
type HedProcessorProps = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredHedImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const HedPreprocessor = (props: HedProcessorProps) => {
|
||||
const {
|
||||
controlNetId,
|
||||
processorNode: { detect_resolution, image_resolution, scribble },
|
||||
isEnabled,
|
||||
} = props;
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
@ -67,7 +71,7 @@ const HedPreprocessor = (props: HedProcessorProps) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -79,13 +83,13 @@ const HedPreprocessor = (props: HedProcessorProps) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISwitch
|
||||
label="Scribble"
|
||||
isChecked={scribble}
|
||||
onChange={handleScribbleChanged}
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,23 +1,26 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor
|
||||
.default as RequiredLineartAnimeImageProcessorInvocation;
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredLineartAnimeImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const LineartAnimeProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { image_resolution, detect_resolution } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -57,7 +60,7 @@ const LineartAnimeProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -69,7 +72,7 @@ const LineartAnimeProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,24 +1,27 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor
|
||||
.default as RequiredLineartImageProcessorInvocation;
|
||||
|
||||
type LineartProcessorProps = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredLineartImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const LineartProcessor = (props: LineartProcessorProps) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { image_resolution, detect_resolution, coarse } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -65,7 +68,7 @@ const LineartProcessor = (props: LineartProcessorProps) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -77,13 +80,13 @@ const LineartProcessor = (props: LineartProcessorProps) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISwitch
|
||||
label="Coarse"
|
||||
isChecked={coarse}
|
||||
onChange={handleCoarseChanged}
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,23 +1,26 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor
|
||||
.default as RequiredMediapipeFaceProcessorInvocation;
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredMediapipeFaceProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const MediapipeFaceProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { max_faces, min_confidence } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleMaxFacesChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -53,7 +56,7 @@ const MediapipeFaceProcessor = (props: Props) => {
|
||||
max={20}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Min Confidence"
|
||||
@ -66,7 +69,7 @@ const MediapipeFaceProcessor = (props: Props) => {
|
||||
step={0.01}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,23 +1,26 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor
|
||||
.default as RequiredMidasDepthImageProcessorInvocation;
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredMidasDepthImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const MidasDepthProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { a_mult, bg_th } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleAMultChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -54,7 +57,7 @@ const MidasDepthProcessor = (props: Props) => {
|
||||
step={0.01}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="bg_th"
|
||||
@ -67,7 +70,7 @@ const MidasDepthProcessor = (props: Props) => {
|
||||
step={0.01}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,23 +1,26 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor
|
||||
.default as RequiredMlsdImageProcessorInvocation;
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredMlsdImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const MlsdImageProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -79,7 +82,7 @@ const MlsdImageProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -91,7 +94,7 @@ const MlsdImageProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="W"
|
||||
@ -104,7 +107,7 @@ const MlsdImageProcessor = (props: Props) => {
|
||||
step={0.01}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="H"
|
||||
@ -117,7 +120,7 @@ const MlsdImageProcessor = (props: Props) => {
|
||||
step={0.01}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,23 +1,26 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor
|
||||
.default as RequiredNormalbaeImageProcessorInvocation;
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredNormalbaeImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const NormalBaeProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { image_resolution, detect_resolution } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -57,7 +60,7 @@ const NormalBaeProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -69,7 +72,7 @@ const NormalBaeProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,24 +1,27 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor
|
||||
.default as RequiredOpenposeImageProcessorInvocation;
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredOpenposeImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const OpenposeProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { image_resolution, detect_resolution, hand_and_face } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -65,7 +68,7 @@ const OpenposeProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -77,13 +80,13 @@ const OpenposeProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISwitch
|
||||
label="Hand and Face"
|
||||
isChecked={hand_and_face}
|
||||
onChange={handleHandAndFaceChanged}
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,24 +1,27 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor
|
||||
.default as RequiredPidiImageProcessorInvocation;
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredPidiImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const PidiProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { image_resolution, detect_resolution, scribble, safe } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -72,7 +75,7 @@ const PidiProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -84,7 +87,7 @@ const PidiProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISwitch
|
||||
label="Scribble"
|
||||
@ -95,7 +98,7 @@ const PidiProcessor = (props: Props) => {
|
||||
label="Safe"
|
||||
isChecked={safe}
|
||||
onChange={handleSafeChanged}
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -4,6 +4,7 @@ import { memo } from 'react';
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredZoeDepthImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const ZoeDepthProcessor = (props: Props) => {
|
||||
|
@ -173,91 +173,17 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
|
||||
},
|
||||
};
|
||||
|
||||
type ControlNetModelsDict = Record<string, ControlNetModel>;
|
||||
|
||||
type ControlNetModel = {
|
||||
type: string;
|
||||
label: string;
|
||||
description?: string;
|
||||
defaultProcessor?: ControlNetProcessorType;
|
||||
export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
|
||||
[key: string]: ControlNetProcessorType;
|
||||
} = {
|
||||
canny: 'canny_image_processor',
|
||||
mlsd: 'mlsd_image_processor',
|
||||
depth: 'midas_depth_image_processor',
|
||||
bae: 'normalbae_image_processor',
|
||||
lineart: 'lineart_image_processor',
|
||||
lineart_anime: 'lineart_anime_image_processor',
|
||||
softedge: 'hed_image_processor',
|
||||
shuffle: 'content_shuffle_image_processor',
|
||||
openpose: 'openpose_image_processor',
|
||||
mediapipe: 'mediapipe_face_processor',
|
||||
};
|
||||
|
||||
export const CONTROLNET_MODELS: ControlNetModelsDict = {
|
||||
'lllyasviel/control_v11p_sd15_canny': {
|
||||
type: 'lllyasviel/control_v11p_sd15_canny',
|
||||
label: 'Canny',
|
||||
defaultProcessor: 'canny_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_inpaint': {
|
||||
type: 'lllyasviel/control_v11p_sd15_inpaint',
|
||||
label: 'Inpaint',
|
||||
defaultProcessor: 'none',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_mlsd': {
|
||||
type: 'lllyasviel/control_v11p_sd15_mlsd',
|
||||
label: 'M-LSD',
|
||||
defaultProcessor: 'mlsd_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11f1p_sd15_depth': {
|
||||
type: 'lllyasviel/control_v11f1p_sd15_depth',
|
||||
label: 'Depth',
|
||||
defaultProcessor: 'midas_depth_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_normalbae': {
|
||||
type: 'lllyasviel/control_v11p_sd15_normalbae',
|
||||
label: 'Normal Map (BAE)',
|
||||
defaultProcessor: 'normalbae_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_seg': {
|
||||
type: 'lllyasviel/control_v11p_sd15_seg',
|
||||
label: 'Segmentation',
|
||||
defaultProcessor: 'none',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_lineart': {
|
||||
type: 'lllyasviel/control_v11p_sd15_lineart',
|
||||
label: 'Lineart',
|
||||
defaultProcessor: 'lineart_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15s2_lineart_anime': {
|
||||
type: 'lllyasviel/control_v11p_sd15s2_lineart_anime',
|
||||
label: 'Lineart Anime',
|
||||
defaultProcessor: 'lineart_anime_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_scribble': {
|
||||
type: 'lllyasviel/control_v11p_sd15_scribble',
|
||||
label: 'Scribble',
|
||||
defaultProcessor: 'none',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_softedge': {
|
||||
type: 'lllyasviel/control_v11p_sd15_softedge',
|
||||
label: 'Soft Edge',
|
||||
defaultProcessor: 'hed_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11e_sd15_shuffle': {
|
||||
type: 'lllyasviel/control_v11e_sd15_shuffle',
|
||||
label: 'Content Shuffle',
|
||||
defaultProcessor: 'content_shuffle_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_openpose': {
|
||||
type: 'lllyasviel/control_v11p_sd15_openpose',
|
||||
label: 'Openpose',
|
||||
defaultProcessor: 'openpose_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11f1e_sd15_tile': {
|
||||
type: 'lllyasviel/control_v11f1e_sd15_tile',
|
||||
label: 'Tile (experimental)',
|
||||
defaultProcessor: 'none',
|
||||
},
|
||||
'lllyasviel/control_v11e_sd15_ip2p': {
|
||||
type: 'lllyasviel/control_v11e_sd15_ip2p',
|
||||
label: 'Pix2Pix (experimental)',
|
||||
defaultProcessor: 'none',
|
||||
},
|
||||
'CrucibleAI/ControlNetMediaPipeFace': {
|
||||
type: 'CrucibleAI/ControlNetMediaPipeFace',
|
||||
label: 'Mediapipe Face',
|
||||
defaultProcessor: 'mediapipe_face_processor',
|
||||
},
|
||||
};
|
||||
|
||||
export type ControlNetModelName = keyof typeof CONTROLNET_MODELS;
|
||||
|
@ -1,22 +1,20 @@
|
||||
import { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
|
||||
import { cloneDeep, forEach } from 'lodash-es';
|
||||
import { imageDeleted } from 'services/api/thunks/image';
|
||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
||||
import { appSocketInvocationError } from 'services/events/actions';
|
||||
import { controlNetImageProcessed } from './actions';
|
||||
import {
|
||||
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
|
||||
CONTROLNET_PROCESSORS,
|
||||
} from './constants';
|
||||
import {
|
||||
ControlNetProcessorType,
|
||||
RequiredCannyImageProcessorInvocation,
|
||||
RequiredControlNetProcessorNode,
|
||||
} from './types';
|
||||
import {
|
||||
CONTROLNET_MODELS,
|
||||
CONTROLNET_PROCESSORS,
|
||||
ControlNetModelName,
|
||||
} from './constants';
|
||||
import { controlNetImageProcessed } from './actions';
|
||||
import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
||||
import { appSocketInvocationError } from 'services/events/actions';
|
||||
|
||||
export type ControlModes =
|
||||
| 'balanced'
|
||||
@ -26,7 +24,7 @@ export type ControlModes =
|
||||
|
||||
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||
isEnabled: true,
|
||||
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
|
||||
model: null,
|
||||
weight: 1,
|
||||
beginStepPct: 0,
|
||||
endStepPct: 1,
|
||||
@ -42,7 +40,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||
export type ControlNetConfig = {
|
||||
controlNetId: string;
|
||||
isEnabled: boolean;
|
||||
model: ControlNetModelName;
|
||||
model: ControlNetModelParam | null;
|
||||
weight: number;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
@ -86,6 +84,19 @@ export const controlNetSlice = createSlice({
|
||||
controlNetId,
|
||||
};
|
||||
},
|
||||
controlNetDuplicated: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
sourceControlNetId: string;
|
||||
newControlNetId: string;
|
||||
}>
|
||||
) => {
|
||||
const { sourceControlNetId, newControlNetId } = action.payload;
|
||||
|
||||
const newControlnet = cloneDeep(state.controlNets[sourceControlNetId]);
|
||||
newControlnet.controlNetId = newControlNetId;
|
||||
state.controlNets[newControlNetId] = newControlnet;
|
||||
},
|
||||
controlNetAddedFromImage: (
|
||||
state,
|
||||
action: PayloadAction<{ controlNetId: string; controlImage: string }>
|
||||
@ -147,7 +158,7 @@ export const controlNetSlice = createSlice({
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
controlNetId: string;
|
||||
model: ControlNetModelName;
|
||||
model: ControlNetModelParam;
|
||||
}>
|
||||
) => {
|
||||
const { controlNetId, model } = action.payload;
|
||||
@ -155,7 +166,15 @@ export const controlNetSlice = createSlice({
|
||||
state.controlNets[controlNetId].processedControlImage = null;
|
||||
|
||||
if (state.controlNets[controlNetId].shouldAutoConfig) {
|
||||
const processorType = CONTROLNET_MODELS[model].defaultProcessor;
|
||||
let processorType: ControlNetProcessorType | undefined = undefined;
|
||||
|
||||
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||
if (model.model_name.includes(modelSubstring)) {
|
||||
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (processorType) {
|
||||
state.controlNets[controlNetId].processorType = processorType;
|
||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
||||
@ -241,9 +260,19 @@ export const controlNetSlice = createSlice({
|
||||
|
||||
if (newShouldAutoConfig) {
|
||||
// manage the processor for the user
|
||||
const processorType =
|
||||
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
|
||||
.defaultProcessor;
|
||||
let processorType: ControlNetProcessorType | undefined = undefined;
|
||||
|
||||
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||
if (
|
||||
state.controlNets[controlNetId].model?.model_name.includes(
|
||||
modelSubstring
|
||||
)
|
||||
) {
|
||||
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (processorType) {
|
||||
state.controlNets[controlNetId].processorType = processorType;
|
||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
||||
@ -272,7 +301,8 @@ export const controlNetSlice = createSlice({
|
||||
});
|
||||
|
||||
builder.addCase(imageDeleted.pending, (state, action) => {
|
||||
// Preemptively remove the image from the gallery
|
||||
// Preemptively remove the image from all controlnets
|
||||
// TODO: doesn't the imageusage stuff do this for us?
|
||||
const { image_name } = action.meta.arg;
|
||||
forEach(state.controlNets, (c) => {
|
||||
if (c.controlImage === image_name) {
|
||||
@ -285,21 +315,6 @@ export const controlNetSlice = createSlice({
|
||||
});
|
||||
});
|
||||
|
||||
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
// const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
// forEach(state.controlNets, (c) => {
|
||||
// if (c.controlImage?.image_name === image_name) {
|
||||
// c.controlImage.image_url = image_url;
|
||||
// c.controlImage.thumbnail_url = thumbnail_url;
|
||||
// }
|
||||
// if (c.processedControlImage?.image_name === image_name) {
|
||||
// c.processedControlImage.image_url = image_url;
|
||||
// c.processedControlImage.thumbnail_url = thumbnail_url;
|
||||
// }
|
||||
// });
|
||||
// });
|
||||
|
||||
builder.addCase(appSocketInvocationError, (state, action) => {
|
||||
state.pendingControlImages = [];
|
||||
});
|
||||
@ -313,6 +328,7 @@ export const controlNetSlice = createSlice({
|
||||
export const {
|
||||
isControlNetEnabledToggled,
|
||||
controlNetAdded,
|
||||
controlNetDuplicated,
|
||||
controlNetAddedFromImage,
|
||||
controlNetRemoved,
|
||||
controlNetImageChanged,
|
||||
|
@ -7,6 +7,7 @@ import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
||||
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
||||
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
||||
import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent';
|
||||
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
||||
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||
@ -174,6 +175,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'controlnet_model' && template.type === 'controlnet_model') {
|
||||
return (
|
||||
<ControlNetModelInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'array' && template.type === 'array') {
|
||||
return (
|
||||
<ArrayInputFieldComponent
|
||||
|
@ -0,0 +1,103 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
ControlNetModelInputFieldTemplate,
|
||||
ControlNetModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const ControlNetModelInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
ControlNetModelInputFieldValue,
|
||||
ControlNetModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const controlNetModel = field.value;
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
controlNetModels?.entities[
|
||||
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
|
||||
] ?? null,
|
||||
[
|
||||
controlNetModel?.base_model,
|
||||
controlNetModel?.model_name,
|
||||
controlNetModels?.entities,
|
||||
]
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!controlNetModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(controlNetModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [controlNetModels]);
|
||||
|
||||
const handleValueChanged = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newControlNetModel = modelIdToControlNetModelParam(v);
|
||||
|
||||
if (!newControlNetModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: newControlNetModel,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={
|
||||
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||
}
|
||||
value={selectedModel?.id ?? null}
|
||||
placeholder="Pick one"
|
||||
error={!selectedModel}
|
||||
data={data}
|
||||
onChange={handleValueChanged}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ControlNetModelInputFieldComponent);
|
@ -1,6 +1,7 @@
|
||||
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
ControlNetModelParam,
|
||||
LoRAModelParam,
|
||||
MainModelParam,
|
||||
VaeModelParam,
|
||||
@ -81,7 +82,8 @@ const nodesSlice = createSlice({
|
||||
| ImageField[]
|
||||
| MainModelParam
|
||||
| VaeModelParam
|
||||
| LoRAModelParam;
|
||||
| LoRAModelParam
|
||||
| ControlNetModelParam;
|
||||
}>
|
||||
) => {
|
||||
const { nodeId, fieldName, value } = action.payload;
|
||||
|
@ -19,6 +19,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
||||
model: 'model',
|
||||
vae_model: 'vae_model',
|
||||
lora_model: 'lora_model',
|
||||
controlnet_model: 'controlnet_model',
|
||||
ControlNetModelField: 'controlnet_model',
|
||||
array: 'array',
|
||||
item: 'item',
|
||||
ColorField: 'color',
|
||||
@ -130,6 +132,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
title: 'LoRA',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
controlnet_model: {
|
||||
color: 'teal',
|
||||
colorCssVar: getColorTokenCssVariable('teal'),
|
||||
title: 'ControlNet',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
array: {
|
||||
color: 'gray',
|
||||
colorCssVar: getColorTokenCssVariable('gray'),
|
||||
|
@ -1,4 +1,5 @@
|
||||
import {
|
||||
ControlNetModelParam,
|
||||
LoRAModelParam,
|
||||
MainModelParam,
|
||||
VaeModelParam,
|
||||
@ -71,6 +72,7 @@ export type FieldType =
|
||||
| 'model'
|
||||
| 'vae_model'
|
||||
| 'lora_model'
|
||||
| 'controlnet_model'
|
||||
| 'array'
|
||||
| 'item'
|
||||
| 'color'
|
||||
@ -100,6 +102,7 @@ export type InputFieldValue =
|
||||
| MainModelInputFieldValue
|
||||
| VaeModelInputFieldValue
|
||||
| LoRAModelInputFieldValue
|
||||
| ControlNetModelInputFieldValue
|
||||
| ArrayInputFieldValue
|
||||
| ItemInputFieldValue
|
||||
| ColorInputFieldValue
|
||||
@ -127,6 +130,7 @@ export type InputFieldTemplate =
|
||||
| ModelInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate
|
||||
| LoRAModelInputFieldTemplate
|
||||
| ControlNetModelInputFieldTemplate
|
||||
| ArrayInputFieldTemplate
|
||||
| ItemInputFieldTemplate
|
||||
| ColorInputFieldTemplate
|
||||
@ -249,6 +253,11 @@ export type LoRAModelInputFieldValue = FieldValueBase & {
|
||||
value?: LoRAModelParam;
|
||||
};
|
||||
|
||||
export type ControlNetModelInputFieldValue = FieldValueBase & {
|
||||
type: 'controlnet_model';
|
||||
value?: ControlNetModelParam;
|
||||
};
|
||||
|
||||
export type ArrayInputFieldValue = FieldValueBase & {
|
||||
type: 'array';
|
||||
value?: (string | number)[];
|
||||
@ -368,6 +377,11 @@ export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'lora_model';
|
||||
};
|
||||
|
||||
export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
type: 'controlnet_model';
|
||||
};
|
||||
|
||||
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: [];
|
||||
type: 'array';
|
||||
|
@ -9,6 +9,7 @@ import {
|
||||
ColorInputFieldTemplate,
|
||||
ConditioningInputFieldTemplate,
|
||||
ControlInputFieldTemplate,
|
||||
ControlNetModelInputFieldTemplate,
|
||||
EnumInputFieldTemplate,
|
||||
FieldType,
|
||||
FloatInputFieldTemplate,
|
||||
@ -207,6 +208,21 @@ const buildLoRAModelInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildControlNetModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ControlNetModelInputFieldTemplate => {
|
||||
const template: ControlNetModelInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'controlnet_model',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'direct',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImageInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -479,6 +495,9 @@ export const buildInputFieldTemplate = (
|
||||
if (['lora_model'].includes(fieldType)) {
|
||||
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['controlnet_model'].includes(fieldType)) {
|
||||
return buildControlNetModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['enum'].includes(fieldType)) {
|
||||
return buildEnumInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
|
@ -83,6 +83,10 @@ export const buildInputFieldValue = (
|
||||
if (template.type === 'lora_model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'controlnet_model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
return fieldValue;
|
||||
|
@ -2,12 +2,13 @@ import { Divider, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAICollapse from 'common/components/IAICollapse';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import ControlNet from 'features/controlNet/components/ControlNet';
|
||||
import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle';
|
||||
import {
|
||||
controlNetAdded,
|
||||
controlNetModelChanged,
|
||||
controlNetSelector,
|
||||
} from 'features/controlNet/store/controlNetSlice';
|
||||
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
|
||||
@ -15,6 +16,8 @@ import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { map } from 'lodash-es';
|
||||
import { Fragment, memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaPlus } from 'react-icons/fa';
|
||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
const selector = createSelector(
|
||||
@ -39,10 +42,23 @@ const ParamControlNetCollapse = () => {
|
||||
const { controlNetsArray, activeLabel } = useAppSelector(selector);
|
||||
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
|
||||
const dispatch = useAppDispatch();
|
||||
const { firstModel } = useGetControlNetModelsQuery(undefined, {
|
||||
selectFromResult: (result) => {
|
||||
const firstModel = result.data?.entities[result.data?.ids[0]];
|
||||
return {
|
||||
firstModel,
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
const handleClickedAddControlNet = useCallback(() => {
|
||||
dispatch(controlNetAdded({ controlNetId: uuidv4() }));
|
||||
}, [dispatch]);
|
||||
if (!firstModel) {
|
||||
return;
|
||||
}
|
||||
const controlNetId = uuidv4();
|
||||
dispatch(controlNetAdded({ controlNetId }));
|
||||
dispatch(controlNetModelChanged({ controlNetId, model: firstModel }));
|
||||
}, [dispatch, firstModel]);
|
||||
|
||||
if (isControlNetDisabled) {
|
||||
return null;
|
||||
@ -51,16 +67,39 @@ const ParamControlNetCollapse = () => {
|
||||
return (
|
||||
<IAICollapse label="ControlNet" activeLabel={activeLabel}>
|
||||
<Flex sx={{ flexDir: 'column', gap: 3 }}>
|
||||
<ParamControlNetFeatureToggle />
|
||||
<Flex gap={2} alignItems="center">
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
w: '100%',
|
||||
gap: 2,
|
||||
px: 4,
|
||||
py: 2,
|
||||
borderRadius: 4,
|
||||
bg: 'base.200',
|
||||
_dark: {
|
||||
bg: 'base.850',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<ParamControlNetFeatureToggle />
|
||||
</Flex>
|
||||
<IAIIconButton
|
||||
tooltip="Add ControlNet"
|
||||
aria-label="Add ControlNet"
|
||||
icon={<FaPlus />}
|
||||
isDisabled={!firstModel}
|
||||
flexGrow={1}
|
||||
size="md"
|
||||
onClick={handleClickedAddControlNet}
|
||||
/>
|
||||
</Flex>
|
||||
{controlNetsArray.map((c, i) => (
|
||||
<Fragment key={c.controlNetId}>
|
||||
{i > 0 && <Divider />}
|
||||
<ControlNet controlNet={c} />
|
||||
<ControlNet controlNetId={c.controlNetId} />
|
||||
</Fragment>
|
||||
))}
|
||||
<IAIButton flexGrow={1} onClick={handleClickedAddControlNet}>
|
||||
Add ControlNet
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
</IAICollapse>
|
||||
);
|
||||
|
@ -37,6 +37,7 @@ const ParamVAEModelSelect = () => {
|
||||
return [];
|
||||
}
|
||||
|
||||
// add a "default" option, this means use the main model's included VAE
|
||||
const data: SelectItem[] = [
|
||||
{
|
||||
value: 'default',
|
||||
|
@ -17,8 +17,10 @@ import { FaPlay } from 'react-icons/fa';
|
||||
const IN_PROGRESS_STYLES: ChakraProps['sx'] = {
|
||||
_disabled: {
|
||||
bg: 'none',
|
||||
color: 'base.600',
|
||||
cursor: 'not-allowed',
|
||||
_hover: {
|
||||
color: 'base.600',
|
||||
bg: 'none',
|
||||
},
|
||||
},
|
||||
|
@ -180,6 +180,23 @@ export type LoRAModelParam = z.infer<typeof zLoRAModel>;
|
||||
*/
|
||||
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
|
||||
zLoRAModel.safeParse(val).success;
|
||||
/**
|
||||
* Zod schema for ControlNet models
|
||||
*/
|
||||
export const zControlNetModel = z.object({
|
||||
model_name: z.string().min(1),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type ControlNetModelParam = z.infer<typeof zLoRAModel>;
|
||||
/**
|
||||
* Validates/type-guards a value as a model parameter
|
||||
*/
|
||||
export const isValidControlNetModel = (
|
||||
val: unknown
|
||||
): val is ControlNetModelParam => zControlNetModel.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for l2l strength parameter
|
||||
|
@ -0,0 +1,30 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { zControlNetModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { ControlNetModelField } from 'services/api/types';
|
||||
|
||||
const moduleLog = log.child({ module: 'models' });
|
||||
|
||||
export const modelIdToControlNetModelParam = (
|
||||
controlNetModelId: string
|
||||
): ControlNetModelField | undefined => {
|
||||
const [base_model, model_type, model_name] = controlNetModelId.split('/');
|
||||
|
||||
const result = zControlNetModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
moduleLog.error(
|
||||
{
|
||||
controlNetModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse ControlNet model id'
|
||||
);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
return result.data;
|
||||
};
|
@ -1,9 +1,12 @@
|
||||
import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
|
||||
const moduleLog = log.child({ module: 'models' });
|
||||
|
||||
export const modelIdToLoRAModelParam = (
|
||||
loraId: string
|
||||
loraModelId: string
|
||||
): LoRAModelParam | undefined => {
|
||||
const [base_model, model_type, model_name] = loraId.split('/');
|
||||
const [base_model, model_type, model_name] = loraModelId.split('/');
|
||||
|
||||
const result = zLoRAModel.safeParse({
|
||||
base_model,
|
||||
@ -11,6 +14,13 @@ export const modelIdToLoRAModelParam = (
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
moduleLog.error(
|
||||
{
|
||||
loraModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse LoRA model id'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2,11 +2,14 @@ import {
|
||||
MainModelParam,
|
||||
zMainModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
|
||||
const moduleLog = log.child({ module: 'models' });
|
||||
|
||||
export const modelIdToMainModelParam = (
|
||||
modelId: string
|
||||
mainModelId: string
|
||||
): MainModelParam | undefined => {
|
||||
const [base_model, model_type, model_name] = modelId.split('/');
|
||||
const [base_model, model_type, model_name] = mainModelId.split('/');
|
||||
|
||||
const result = zMainModel.safeParse({
|
||||
base_model,
|
||||
@ -14,6 +17,13 @@ export const modelIdToMainModelParam = (
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
moduleLog.error(
|
||||
{
|
||||
mainModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse main model id'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1,9 +1,12 @@
|
||||
import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
|
||||
const moduleLog = log.child({ module: 'models' });
|
||||
|
||||
export const modelIdToVAEModelParam = (
|
||||
modelId: string
|
||||
vaeModelId: string
|
||||
): VaeModelParam | undefined => {
|
||||
const [base_model, model_type, model_name] = modelId.split('/');
|
||||
const [base_model, model_type, model_name] = vaeModelId.split('/');
|
||||
|
||||
const result = zVaeModel.safeParse({
|
||||
base_model,
|
||||
@ -11,6 +14,13 @@ export const modelIdToVAEModelParam = (
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
moduleLog.error(
|
||||
{
|
||||
vaeModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse VAE model id'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -19,9 +19,9 @@ const ImageToImageTabParameters = () => {
|
||||
<ParamNegativeConditioning />
|
||||
<ProcessButtons />
|
||||
<ImageToImageTabCoreParameters />
|
||||
<ParamControlNetCollapse />
|
||||
<ParamLoraCollapse />
|
||||
<ParamDynamicPromptsCollapse />
|
||||
<ParamControlNetCollapse />
|
||||
<ParamVariationCollapse />
|
||||
<ParamNoiseCollapse />
|
||||
<ParamSymmetryCollapse />
|
||||
|
@ -20,9 +20,9 @@ const TextToImageTabParameters = () => {
|
||||
<ParamNegativeConditioning />
|
||||
<ProcessButtons />
|
||||
<TextToImageTabCoreParameters />
|
||||
<ParamControlNetCollapse />
|
||||
<ParamLoraCollapse />
|
||||
<ParamDynamicPromptsCollapse />
|
||||
<ParamControlNetCollapse />
|
||||
<ParamVariationCollapse />
|
||||
<ParamNoiseCollapse />
|
||||
<ParamSymmetryCollapse />
|
||||
|
@ -19,9 +19,9 @@ const UnifiedCanvasParameters = () => {
|
||||
<ParamNegativeConditioning />
|
||||
<ProcessButtons />
|
||||
<UnifiedCanvasCoreParameters />
|
||||
<ParamControlNetCollapse />
|
||||
<ParamLoraCollapse />
|
||||
<ParamDynamicPromptsCollapse />
|
||||
<ParamControlNetCollapse />
|
||||
<ParamVariationCollapse />
|
||||
<ParamSymmetryCollapse />
|
||||
<ParamSeamCorrectionCollapse />
|
||||
|
@ -734,7 +734,7 @@ export type components = {
|
||||
* Control Model
|
||||
* @description The ControlNet model to use
|
||||
*/
|
||||
control_model: string;
|
||||
control_model: components["schemas"]["ControlNetModelField"];
|
||||
/**
|
||||
* Control Weight
|
||||
* @description The weight given to the ControlNet
|
||||
@ -792,9 +792,8 @@ export type components = {
|
||||
* Control Model
|
||||
* @description control model used
|
||||
* @default lllyasviel/sd-controlnet-canny
|
||||
* @enum {string}
|
||||
*/
|
||||
control_model?: "lllyasviel/sd-controlnet-canny" | "lllyasviel/sd-controlnet-depth" | "lllyasviel/sd-controlnet-hed" | "lllyasviel/sd-controlnet-seg" | "lllyasviel/sd-controlnet-openpose" | "lllyasviel/sd-controlnet-scribble" | "lllyasviel/sd-controlnet-normal" | "lllyasviel/sd-controlnet-mlsd" | "lllyasviel/control_v11p_sd15_canny" | "lllyasviel/control_v11p_sd15_openpose" | "lllyasviel/control_v11p_sd15_seg" | "lllyasviel/control_v11f1p_sd15_depth" | "lllyasviel/control_v11p_sd15_normalbae" | "lllyasviel/control_v11p_sd15_scribble" | "lllyasviel/control_v11p_sd15_mlsd" | "lllyasviel/control_v11p_sd15_softedge" | "lllyasviel/control_v11p_sd15s2_lineart_anime" | "lllyasviel/control_v11p_sd15_lineart" | "lllyasviel/control_v11p_sd15_inpaint" | "lllyasviel/control_v11e_sd15_shuffle" | "lllyasviel/control_v11e_sd15_ip2p" | "lllyasviel/control_v11f1e_sd15_tile" | "thibaud/controlnet-sd21-openpose-diffusers" | "thibaud/controlnet-sd21-canny-diffusers" | "thibaud/controlnet-sd21-depth-diffusers" | "thibaud/controlnet-sd21-scribble-diffusers" | "thibaud/controlnet-sd21-hed-diffusers" | "thibaud/controlnet-sd21-zoedepth-diffusers" | "thibaud/controlnet-sd21-color-diffusers" | "thibaud/controlnet-sd21-openposev2-diffusers" | "thibaud/controlnet-sd21-lineart-diffusers" | "thibaud/controlnet-sd21-normalbae-diffusers" | "thibaud/controlnet-sd21-ade20k-diffusers" | "CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15" | "CrucibleAI/ControlNetMediaPipeFace";
|
||||
control_model?: components["schemas"]["ControlNetModelField"];
|
||||
/**
|
||||
* Control Weight
|
||||
* @description The weight given to the ControlNet
|
||||
@ -838,6 +837,19 @@ export type components = {
|
||||
model_format: components["schemas"]["ControlNetModelFormat"];
|
||||
error?: components["schemas"]["ModelError"];
|
||||
};
|
||||
/**
|
||||
* ControlNetModelField
|
||||
* @description ControlNet model field
|
||||
*/
|
||||
ControlNetModelField: {
|
||||
/**
|
||||
* Model Name
|
||||
* @description Name of the ControlNet model
|
||||
*/
|
||||
model_name: string;
|
||||
/** @description Base model */
|
||||
base_model: components["schemas"]["BaseModelType"];
|
||||
};
|
||||
/**
|
||||
* ControlNetModelFormat
|
||||
* @description An enumeration.
|
||||
@ -1923,12 +1935,12 @@ export type components = {
|
||||
* Width
|
||||
* @description The width to resize to (px)
|
||||
*/
|
||||
width: number;
|
||||
width?: number;
|
||||
/**
|
||||
* Height
|
||||
* @description The height to resize to (px)
|
||||
*/
|
||||
height: number;
|
||||
height?: number;
|
||||
/**
|
||||
* Resample Mode
|
||||
* @description The resampling mode
|
||||
@ -3911,13 +3923,15 @@ export type components = {
|
||||
/**
|
||||
* Width
|
||||
* @description The width to resize to (px)
|
||||
* @default 512
|
||||
*/
|
||||
width: number;
|
||||
width?: number;
|
||||
/**
|
||||
* Height
|
||||
* @description The height to resize to (px)
|
||||
* @default 512
|
||||
*/
|
||||
height: number;
|
||||
height?: number;
|
||||
/**
|
||||
* Mode
|
||||
* @description The interpolation mode
|
||||
@ -4605,18 +4619,18 @@ export type components = {
|
||||
*/
|
||||
image?: components["schemas"]["ImageField"];
|
||||
};
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
};
|
||||
responses: never;
|
||||
parameters: never;
|
||||
|
@ -32,6 +32,8 @@ export type BaseModelType = components['schemas']['BaseModelType'];
|
||||
export type MainModelField = components['schemas']['MainModelField'];
|
||||
export type VAEModelField = components['schemas']['VAEModelField'];
|
||||
export type LoRAModelField = components['schemas']['LoRAModelField'];
|
||||
export type ControlNetModelField =
|
||||
components['schemas']['ControlNetModelField'];
|
||||
export type ModelsList = components['schemas']['ModelsList'];
|
||||
export type ControlField = components['schemas']['ControlField'];
|
||||
|
||||
|
@ -30,7 +30,7 @@ const invokeAIThumb = defineStyle((props) => {
|
||||
|
||||
const invokeAIMark = defineStyle((props) => {
|
||||
return {
|
||||
fontSize: 'xs',
|
||||
fontSize: '2xs',
|
||||
fontWeight: '500',
|
||||
color: mode('base.700', 'base.400')(props),
|
||||
mt: 2,
|
||||
|
Loading…
Reference in New Issue
Block a user