Merge branch 'main' into update-textual-inversion-training

This commit is contained in:
Lincoln Stein 2023-07-15 17:44:45 -04:00 committed by GitHub
commit f66ead0819
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
88 changed files with 2840 additions and 1075 deletions

4
.github/CODEOWNERS vendored
View File

@ -6,7 +6,7 @@
/mkdocs.yml @lstein @blessedcoolant /mkdocs.yml @lstein @blessedcoolant
# nodes # nodes
/invokeai/app/ @Kyle0654 @blessedcoolant /invokeai/app/ @Kyle0654 @blessedcoolant @psychedelicious @brandonrising
# installation and configuration # installation and configuration
/pyproject.toml @lstein @blessedcoolant /pyproject.toml @lstein @blessedcoolant
@ -22,7 +22,7 @@
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp /invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp
# generation, model management, postprocessing # generation, model management, postprocessing
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 @StAlKeR7779 /invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising
# front ends # front ends
/invokeai/frontend/CLI @lstein /invokeai/frontend/CLI @lstein

View File

@ -0,0 +1,287 @@
---
title: Configuration
---
# :material-tune-variant: InvokeAI Configuration
## Intro
InvokeAI has numerous runtime settings which can be used to adjust
many aspects of its operations, including the location of files and
directories, memory usage, and performance. These settings can be
viewed and customized in several ways:
1. By editing settings in the `invokeai.yaml` file.
2. By setting environment variables.
3. On the command-line, when InvokeAI is launched.
In addition, the most commonly changed settings are accessible
graphically via the `invokeai-configure` script.
### How the Configuration System Works
When InvokeAI is launched, the very first thing it needs to do is to
find its "root" directory, which contains its configuration files,
installed models, its database of images, and the folder(s) of
generated images themselves. In this document, the root directory will
be referred to as ROOT.
#### Finding the Root Directory
To find its root directory, InvokeAI uses the following recipe:
1. It first looks for the argument `--root <path>` on the command line
it was launched from, and uses the indicated path if present.
2. Next it looks for the environment variable INVOKEAI_ROOT, and uses
the directory path found there if present.
3. If neither of these are present, then InvokeAI looks for the
folder containing the `.venv` Python virtual environment directory for
the currently active environment. This directory is checked for files
expected inside the InvokeAI root before it is used.
4. Finally, InvokeAI looks for a directory in the current user's home
directory named `invokeai`.
#### Reading the InvokeAI Configuration File
Once the root directory has been located, InvokeAI looks for a file
named `ROOT/invokeai.yaml`, and if present reads configuration values
from it. The top of this file looks like this:
```
InvokeAI:
Web Server:
host: localhost
port: 9090
allow_origins: []
allow_credentials: true
allow_methods:
- '*'
allow_headers:
- '*'
Features:
esrgan: true
internet_available: true
log_tokenization: false
nsfw_checker: false
patchmatch: true
restore: true
...
```
This lines in this file are used to establish default values for
Invoke's settings. In the above fragment, the Web Server's listening
port is set to 9090 by the `port` setting.
You can edit this file with a text editor such as "Notepad" (do not
use Word or any other word processor). When editing, be careful to
maintain the indentation, and do not add extraneous text, as syntax
errors will prevent InvokeAI from launching. A basic guide to the
format of YAML files can be found
[here](https://circleci.com/blog/what-is-yaml-a-beginner-s-guide/).
You can fix a broken `invokeai.yaml` by deleting it and running the
configuration script again -- option [7] in the launcher, "Re-run the
configure script".
#### Reading Environment Variables
Next InvokeAI looks for defined environment variables in the format
`INVOKEAI_<setting_name>`, for example `INVOKEAI_port`. Environment
variable values take precedence over configuration file variables. On
a Macintosh system, for example, you could change the port that the
web server listens on by setting the environment variable this way:
```
export INVOKEAI_port=8000
invokeai-web
```
Please check out these
[Macintosh](https://phoenixnap.com/kb/set-environment-variable-mac)
and
[Windows](https://phoenixnap.com/kb/windows-set-environment-variable)
guides for setting temporary and permanent environment variables.
#### Reading the Command Line
Lastly, InvokeAI takes settings from the command line, which override
everything else. The command-line settings have the same name as the
corresponding configuration file settings, preceded by a `--`, for
example `--port 8000`.
If you are using the launcher (`invoke.sh` or `invoke.bat`) to launch
InvokeAI, then just pass the command-line arguments to the launcher:
```
invoke.bat --port 8000 --host 0.0.0.0
```
The arguments will be applied when you select the web server option
(and the other options as well).
If, on the other hand, you prefer to launch InvokeAI directly from the
command line, you would first activate the virtual environment (known
as the "developer's console" in the launcher), and run `invokeai-web`:
```
> C:\Users\Fred\invokeai\.venv\scripts\activate
(.venv) > invokeai-web --port 8000 --host 0.0.0.0
```
You can get a listing and brief instructions for each of the
command-line options by giving the `--help` argument:
```
(.venv) > invokeai-web --help
usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS ...]] [--allow_credentials | --no-allow_credentials]
[--allow_methods [ALLOW_METHODS ...]] [--allow_headers [ALLOW_HEADERS ...]] [--esrgan | --no-esrgan]
[--internet_available | --no-internet_available] [--log_tokenization | --no-log_tokenization]
[--nsfw_checker | --no-nsfw_checker] [--patchmatch | --no-patchmatch] [--restore | --no-restore]
[--always_use_cpu | --no-always_use_cpu] [--free_gpu_mem | --no-free_gpu_mem] [--max_cache_size MAX_CACHE_SIZE]
[--max_vram_cache_size MAX_VRAM_CACHE_SIZE] [--precision {auto,float16,float32,autocast}]
[--sequential_guidance | --no-sequential_guidance] [--xformers_enabled | --no-xformers_enabled]
[--tiled_decode | --no-tiled_decode] [--root ROOT] [--autoimport_dir AUTOIMPORT_DIR] [--lora_dir LORA_DIR]
[--embedding_dir EMBEDDING_DIR] [--controlnet_dir CONTROLNET_DIR] [--conf_path CONF_PATH] [--models_dir MODELS_DIR]
[--legacy_conf_dir LEGACY_CONF_DIR] [--db_dir DB_DIR] [--outdir OUTDIR] [--from_file FROM_FILE]
[--use_memory_db | --no-use_memory_db] [--model MODEL] [--log_handlers [LOG_HANDLERS ...]]
[--log_format {plain,color,syslog,legacy}] [--log_level {debug,info,warning,error,critical}]
...
```
## The Configuration Settings
The configuration settings are divided into several distinct
groups in `invokeia.yaml`:
### Web Server
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
| `port` | `9090` | Network port number that the web server will listen on |
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
| `allow_credentials | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
The documentation for InvokeAI's API can be accessed by browsing to the following URL: [http://localhost:9090/docs].
### Features
These configuration settings allow you to enable and disable various InvokeAI features:
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `esrgan` | `true` | Activate the ESRGAN upscaling options|
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
| `nsfw_checker` | `true` | Activate the NSFW checker to blur out risque images |
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
| `restore` | `true` | Activate the facial restoration features (DEPRECATED; restoration features will be removed in 3.0.0) |
### Memory/Performance
These options tune InvokeAI's memory and performance characteristics.
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `always_use_cpu` | `false` | Use the CPU to generate images, even if a GPU is available |
| `free_gpu_mem` | `false` | Aggressively free up GPU memory after each operation; this will allow you to run in low-VRAM environments with some performance penalties |
| `max_cache_size` | `6` | Amount of CPU RAM (in GB) to reserve for caching models in memory; more cache allows you to keep models in memory and switch among them quickly |
| `max_vram_cache_size` | `2.75` | Amount of GPU VRAM (in GB) to reserve for caching models in VRAM; more cache speeds up generation but reduces the size of the images that can be generated. This can be set to zero to maximize the amount of memory available for generation. |
| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system |
| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss |
| `xformers_enabled` | `true` | If the x-formers memory-efficient attention module is installed, activate it for better memory usage and generation speed|
| `tiled_decode` | `false` | If true, then during the VAE decoding phase the image will be decoded a section at a time, reducing memory consumption at the cost of a performance hit |
### Paths
These options set the paths of various directories and files used by
InvokeAI. Relative paths are interpreted relative to INVOKEAI_ROOT, so
if INVOKEAI_ROOT is `/home/fred/invokeai` and the path is
`autoimport/main`, then the corresponding directory will be located at
`/home/fred/invokeai/autoimport/main`.
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory |
| `lora_dir` | `autoimport/lora` | At startup time, read and import any LoRA/LyCORIS models found in this directory |
| `embedding_dir` | `autoimport/embedding` | At startup time, read and import any textual inversion (embedding) models found in this directory |
| `controlnet_dir` | `autoimport/controlnet` | At startup time, read and import any ControlNet models found in this directory |
| `conf_path` | `configs/models.yaml` | Location of the `models.yaml` model configuration file |
| `models_dir` | `models` | Location of the directory containing models installed by InvokeAI's model manager |
| `legacy_conf_dir` | `configs/stable-diffusion` | Location of the directory containing the .yaml configuration files for legacy checkpoint models |
| `db_dir` | `databases` | Location of the directory containing InvokeAI's image, schema and session database |
| `outdir` | `outputs` | Location of the directory in which the gallery of generated and uploaded images will be stored |
| `use_memory_db` | `false` | Keep database information in memory rather than on disk; this will not preserve image gallery information across restarts |
Note that the autoimport directories will be searched recursively,
allowing you to organize the models into folders and subfolders in any
way you wish. In addition, while we have split up autoimport
directories by the type of model they contain, this isn't
necessary. You can combine different model types in the same folder
and InvokeAI will figure out what they are. So you can easily use just
one autoimport directory by commenting out the unneeded paths:
```
Paths:
autoimport_dir: autoimport
# lora_dir: null
# embedding_dir: null
# controlnet_dir: null
```
### Logging
These settings control the information, warning, and debugging
messages printed to the console log while InvokeAI is running:
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `log_handlers` | `console` | This controls where log messages are sent, and can be a list of one or more destinations. Values include `console`, `file`, `syslog` and `http`. These are described in more detail below |
| `log_format` | `color` | This controls the formatting of the log messages. Values are `plain`, `color`, `legacy` and `syslog` |
| `log_level` | `debug` | This filters messages according to the level of severity and can be one of `debug`, `info`, `warning`, `error` and `critical`. For example, setting to `warning` will display all messages at the warning level or higher, but won't display "debug" or "info" messages |
Several different log handler destinations are available, and multiple destinations are supported by providing a list:
```
log_handlers:
- console
- syslog=localhost
- file=/var/log/invokeai.log
```
* `console` is the default. It prints log messages to the command-line window from which InvokeAI was launched.
* `syslog` is only available on Linux and Macintosh systems. It uses
the operating system's "syslog" facility to write log file entries
locally or to a remote logging machine. `syslog` offers a variety
of configuration options:
```
syslog=/dev/log` - log to the /dev/log device
syslog=localhost` - log to the network logger running on the local machine
syslog=localhost:512` - same as above, but using a non-standard port
syslog=fredserver,facility=LOG_USER,socktype=SOCK_DRAM`
- Log to LAN-connected server "fredserver" using the facility LOG_USER and datagram packets.
```
* `http` can be used to log to a remote web server. The server must be
properly configured to receive and act on log messages. The option
accepts the URL to the web server, and a `method` argument
indicating whether the message should be submitted using the GET or
POST method.
```
http=http://my.server/path/to/logger,method=POST
```
The `log_format` option provides several alternative formats:
* `color` - default format providing time, date and a message, using text colors to distinguish different log severities
* `plain` - same as above, but monochrome text only
* `syslog` - the log level and error message only, allowing the syslog system to attach the time and date
* `legacy` - a format similar to the one used by the legacy 2.3 InvokeAI releases.

View File

@ -153,6 +153,9 @@ This method is recommended for those familiar with running Docker containers
- [Prompt Syntax](features/PROMPTS.md) - [Prompt Syntax](features/PROMPTS.md)
- [Generating Variations](features/VARIATIONS.md) - [Generating Variations](features/VARIATIONS.md)
### InvokeAI Configuration
- [Guide to InvokeAI Runtime Settings](features/CONFIGURATION.md)
## :octicons-log-16: Important Changes Since Version 2.3 ## :octicons-log-16: Important Changes Since Version 2.3
### Nodes ### Nodes

View File

@ -1,6 +1,7 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
import pathlib
from typing import Literal, List, Optional, Union from typing import Literal, List, Optional, Union
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
@ -22,6 +23,7 @@ UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@ -78,7 +80,7 @@ async def update_model(
return model_response return model_response
@models_router.post( @models_router.post(
"/", "/import",
operation_id="import_model", operation_id="import_model",
responses= { responses= {
201: {"description" : "The model imported successfully"}, 201: {"description" : "The model imported successfully"},
@ -94,7 +96,7 @@ async def import_model(
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse: ) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL """ """ Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
items_to_import = {location} items_to_import = {location}
prediction_types = { x.value: x for x in SchedulerPredictionType } prediction_types = { x.value: x for x in SchedulerPredictionType }
@ -126,18 +128,100 @@ async def import_model(
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/add",
operation_id="add_model",
responses= {
201: {"description" : "The model added successfully"},
404: {"description" : "The model could not be found"},
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse
)
async def add_model(
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> ImportModelResponse:
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
logger = ApiDependencies.invoker.services.logger
try:
ApiDependencies.invoker.services.model_manager.add_model(
info.model_name,
info.base_model,
info.model_type,
model_attributes = info.dict()
)
logger.info(f'Successfully added {info.model_name}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.model_name,
base_model=info.base_model,
model_type=info.model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/rename/{base_model}/{model_type}/{model_name}",
operation_id="rename_model",
responses= {
201: {"description" : "The model was renamed successfully"},
404: {"description" : "The model could not be found"},
409: {"description" : "There is already a model corresponding to the new name"},
},
status_code=201,
response_model=ImportModelResponse
)
async def rename_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="current model name"),
new_name: Optional[str] = Query(description="new model name", default=None),
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
) -> ImportModelResponse:
""" Rename a model"""
logger = ApiDependencies.invoker.services.logger
try:
result = ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = new_name,
new_base = new_base,
)
logger.debug(result)
logger.info(f'Successfully renamed {model_name}=>{new_name}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=new_name or model_name,
base_model=new_base or base_model,
model_type=model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete( @models_router.delete(
"/{base_model}/{model_type}/{model_name}", "/{base_model}/{model_type}/{model_name}",
operation_id="del_model", operation_id="del_model",
responses={ responses={
204: { 204: { "description": "Model deleted successfully" },
"description": "Model deleted successfully" 404: { "description": "Model not found" }
},
404: {
"description": "Model not found"
}
}, },
status_code = 204,
response_model = None,
) )
async def delete_model( async def delete_model(
base_model: BaseModelType = Path(description="Base model"), base_model: BaseModelType = Path(description="Base model"),
@ -173,14 +257,17 @@ async def convert_model(
base_model: BaseModelType = Path(description="Base model"), base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"), model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"), model_name: str = Path(description="model name"),
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
) -> ConvertModelResponse: ) -> ConvertModelResponse:
"""Convert a checkpoint model into a diffusers model""" """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
logger.info(f"Converting model: {model_name}") logger.info(f"Converting model: {model_name}")
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
ApiDependencies.invoker.services.model_manager.convert_model(model_name, ApiDependencies.invoker.services.model_manager.convert_model(model_name,
base_model = base_model, base_model = base_model,
model_type = model_type model_type = model_type,
convert_dest_directory = dest,
) )
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name, model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
base_model = base_model, base_model = base_model,
@ -192,6 +279,53 @@ async def convert_model(
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return response return response
@models_router.get(
"/search",
operation_id="search_for_models",
responses={
200: { "description": "Directory searched successfully" },
404: { "description": "Invalid directory path" },
},
status_code = 200,
response_model = List[pathlib.Path]
)
async def search_for_models(
search_path: pathlib.Path = Query(description="Directory path to search for models")
)->List[pathlib.Path]:
if not search_path.is_dir():
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
@models_router.get(
"/ckpt_confs",
operation_id="list_ckpt_configs",
responses={
200: { "description" : "paths retrieved successfully" },
},
status_code = 200,
response_model = List[pathlib.Path]
)
async def list_ckpt_configs(
)->List[pathlib.Path]:
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
@models_router.get(
"/sync",
operation_id="sync_to_config",
responses={
201: { "description": "synchronization successful" },
},
status_code = 201,
response_model = None
)
async def sync_to_config(
)->None:
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
in-memory data structures with disk data structures."""
return ApiDependencies.invoker.services.model_manager.sync_to_config()
@models_router.put( @models_router.put(
"/merge/{base_model}", "/merge/{base_model}",
operation_id="merge_models", operation_id="merge_models",
@ -210,17 +344,21 @@ async def merge_models(
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"), interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False), force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
) -> MergeModelResponse: ) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model""" """Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
logger.info(f"Merging models: {model_names}") logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names, result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
base_model, base_model,
merged_model_name or "+".join(model_names), merged_model_name=merged_model_name or "+".join(model_names),
alpha, alpha=alpha,
interp, interp=interp,
force) force=force,
merge_dest_directory = dest
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name, model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
base_model = base_model, base_model = base_model,
model_type = ModelType.Main, model_type = ModelType.Main,

View File

@ -9,6 +9,7 @@ from typing import Literal, Optional, Union, List, Dict
from PIL import Image from PIL import Image
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from ...backend.model_management import BaseModelType, ModelType
from ..models.image import ImageField, ImageCategory, ResourceOrigin from ..models.image import ImageField, ImageCategory, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
@ -105,9 +106,15 @@ CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])] # CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
model_name: str = Field(description="Name of the ControlNet model")
base_model: BaseModelType = Field(description="Base model")
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image") image: ImageField = Field(default=None, description="The control image")
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use") control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet") # control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1, begin_step_percent: float = Field(default=0, ge=0, le=1,
@ -118,15 +125,15 @@ class ControlField(BaseModel):
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") # resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@validator("control_weight") @validator("control_weight")
def abs_le_one(cls, v): def validate_control_weight(cls, v):
"""validate that all abs(values) are <=1""" """Validate that all control weights in the valid range"""
if isinstance(v, list): if isinstance(v, list):
for i in v: for i in v:
if abs(i) > 1: if i < -1 or i > 2:
raise ValueError('all abs(control_weight) must be <= 1') raise ValueError('Control weights must be within -1 to 2 range')
else: else:
if abs(v) > 1: if v < -1 or v > 2:
raise ValueError('abs(control_weight) must be <= 1') raise ValueError('Control weights must be within -1 to 2 range')
return v return v
class Config: class Config:
schema_extra = { schema_extra = {
@ -134,6 +141,7 @@ class ControlField(BaseModel):
"ui": { "ui": {
"type_hints": { "type_hints": {
"control_weight": "float", "control_weight": "float",
"control_model": "controlnet_model",
# "control_weight": "number", # "control_weight": "number",
} }
} }
@ -154,10 +162,10 @@ class ControlNetInvocation(BaseInvocation):
type: Literal["controlnet"] = "controlnet" type: Literal["controlnet"] = "controlnet"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The control image") image: ImageField = Field(default=None, description="The control image")
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny", control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used") description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1, begin_step_percent: float = Field(default=0, ge=-1, le=2,
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")

View File

@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import ExitStack
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
import einops import einops
@ -11,6 +12,7 @@ from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models.base import ModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
@ -71,16 +73,21 @@ def get_scheduler(
scheduler_name: str, scheduler_name: str,
) -> Scheduler: ) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get( scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
scheduler_name, SCHEDULER_MAP['ddim']) scheduler_name, SCHEDULER_MAP['ddim']
)
orig_scheduler_info = context.services.model_manager.get_model( orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict()) **scheduler_info.dict()
)
with orig_scheduler_info as orig_scheduler: with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
if "_backup" in scheduler_config: if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"] scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, ** scheduler_config = {
scheduler_extra_config, "_backup": scheduler_config} **scheduler_config,
**scheduler_extra_config,
"_backup": scheduler_config,
}
scheduler = scheduler_class.from_config(scheduler_config) scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py # hack copied over from generate.py
@ -137,8 +144,11 @@ class TextToLatentsInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching? # TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress( def dispatch_progress(
self, context: InvocationContext, source_node_id: str, self,
intermediate_state: PipelineIntermediateState) -> None: context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback( stable_diffusion_step_callback(
context=context, context=context,
intermediate_state=intermediate_state, intermediate_state=intermediate_state,
@ -147,11 +157,16 @@ class TextToLatentsInvocation(BaseInvocation):
) )
def get_conditioning_data( def get_conditioning_data(
self, context: InvocationContext, scheduler) -> ConditioningData: self,
context: InvocationContext,
scheduler,
) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get( c, extra_conditioning_info = context.services.latents.get(
self.positive_conditioning.conditioning_name) self.positive_conditioning.conditioning_name
)
uc, _ = context.services.latents.get( uc, _ = context.services.latents.get(
self.negative_conditioning.conditioning_name) self.negative_conditioning.conditioning_name
)
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
unconditioned_embeddings=uc, unconditioned_embeddings=uc,
@ -178,7 +193,10 @@ class TextToLatentsInvocation(BaseInvocation):
return conditioning_data return conditioning_data
def create_pipeline( def create_pipeline(
self, unet, scheduler) -> StableDiffusionGeneratorPipeline: self,
unet,
scheduler,
) -> StableDiffusionGeneratorPipeline:
# TODO: # TODO:
# configure_model_padding( # configure_model_padding(
# unet, # unet,
@ -213,6 +231,7 @@ class TextToLatentsInvocation(BaseInvocation):
model: StableDiffusionGeneratorPipeline, model: StableDiffusionGeneratorPipeline,
control_input: List[ControlField], control_input: List[ControlField],
latents_shape: List[int], latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]: ) -> List[ControlNetData]:
@ -238,25 +257,19 @@ class TextToLatentsInvocation(BaseInvocation):
control_data = [] control_data = []
control_models = [] control_models = []
for control_info in control_list: for control_info in control_list:
# handle control models control_model = exit_stack.enter_context(
if ("," in control_info.control_model): context.services.model_manager.get_model(
control_model_split = control_info.control_model.split(",") model_name=control_info.control_model.model_name,
control_name = control_model_split[0] model_type=ModelType.ControlNet,
control_subfolder = control_model_split[1] base_model=control_info.control_model.base_model,
print("Using HF model subfolders") )
print(" control_name: ", control_name) )
print(" control_subfolder: ", control_subfolder)
control_model = ControlNetModel.from_pretrained(
control_name, subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(
model.device)
else:
control_model = ControlNetModel.from_pretrained(
control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model) control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
input_image = context.services.images.get_pil_image( input_image = context.services.images.get_pil_image(
control_image_field.image_name) control_image_field.image_name
)
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes # FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt? # and add in batch_size, num_images_per_prompt?
@ -278,7 +291,8 @@ class TextToLatentsInvocation(BaseInvocation):
weight=control_info.control_weight, weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent, begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent, end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,) control_mode=control_info.control_mode,
)
control_data.append(control_item) control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData] # MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data return control_data
@ -289,7 +303,8 @@ class TextToLatentsInvocation(BaseInvocation):
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id) context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
@ -298,14 +313,17 @@ class TextToLatentsInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"})
)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()) **self.unet.unet.dict()
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ )
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet: unet_info as unet:
scheduler = get_scheduler( scheduler = get_scheduler(
@ -322,6 +340,7 @@ class TextToLatentsInvocation(BaseInvocation):
latents_shape=noise.shape, latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0)) # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
exit_stack=exit_stack,
) )
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
@ -374,7 +393,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id) context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
@ -383,14 +403,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"})
)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()) **self.unet.unet.dict()
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ )
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet: unet_info as unet:
scheduler = get_scheduler( scheduler = get_scheduler(
@ -407,11 +430,13 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
latents_shape=noise.shape, latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0)) # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
exit_stack=exit_stack,
) )
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like( initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=unet.device, dtype=latent.dtype) latent, device=unet.device, dtype=latent.dtype
)
timesteps, _ = pipeline.get_img2img_timesteps( timesteps, _ = pipeline.get_img2img_timesteps(
self.steps, self.steps,
@ -535,7 +560,8 @@ class ResizeLatentsInvocation(BaseInvocation):
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents, size=(self.height // 8, self.width // 8), latents, size=(self.height // 8, self.width // 8),
mode=self.mode, antialias=self.antialias mode=self.mode, antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,) if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -569,7 +595,8 @@ class ScaleLatentsInvocation(BaseInvocation):
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents, scale_factor=self.scale_factor, mode=self.mode, latents, scale_factor=self.scale_factor, mode=self.mode,
antialias=self.antialias antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,) if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -271,8 +271,8 @@ class InvokeAISettings(BaseSettings):
@classmethod @classmethod
def _excluded(self)->List[str]: def _excluded(self)->List[str]:
# combination of deprecated parameters and internal ones # combination of deprecated parameters and internal ones that shouldn't be exposed
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version'] return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'root']
class Config: class Config:
env_file_encoding = 'utf-8' env_file_encoding = 'utf-8'

View File

@ -19,7 +19,7 @@ from invokeai.backend.model_management import (
ModelMerger, ModelMerger,
MergeInterpolationMethod, MergeInterpolationMethod,
) )
from invokeai.backend.model_management.model_search import FindModels
import torch import torch
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
@ -167,6 +167,27 @@ class ModelManagerServiceBase(ABC):
""" """
pass pass
@abstractmethod
def rename_model(self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str,
):
"""
Rename the indicated model.
"""
pass
@abstractmethod
def list_checkpoint_configs(
self
)->List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
pass
@abstractmethod @abstractmethod
def convert_model( def convert_model(
self, self,
@ -220,6 +241,7 @@ class ModelManagerServiceBase(ABC):
alpha: Optional[float] = 0.5, alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None, interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False, force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None
) -> AddModelResult: ) -> AddModelResult:
""" """
Merge two to three diffusrs pipeline models and save as a new model. Merge two to three diffusrs pipeline models and save as a new model.
@ -228,6 +250,23 @@ class ModelManagerServiceBase(ABC):
:param merged_model_name: Name of destination merged model :param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model :param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default) :param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
pass
@abstractmethod
def search_for_models(self, directory: Path)->List[Path]:
"""
Return list of all models found in the designated directory.
"""
pass
@abstractmethod
def sync_to_config(self):
"""
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
""" """
pass pass
@ -431,16 +470,18 @@ class ModelManagerService(ModelManagerServiceBase):
""" """
Delete the named model from configuration. If delete_files is true, Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk. as well.
""" """
self.logger.debug(f'delete model {model_name}') self.logger.debug(f'delete model {model_name}')
self.mgr.del_model(model_name, base_model, model_type) self.mgr.del_model(model_name, base_model, model_type)
self.mgr.commit()
def convert_model( def convert_model(
self, self,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae], model_type: Union[ModelType.Main,ModelType.Vae],
convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
) -> AddModelResult: ) -> AddModelResult:
""" """
Convert a checkpoint file into a diffusers folder, deleting the cached Convert a checkpoint file into a diffusers folder, deleting the cached
@ -449,13 +490,14 @@ class ModelManagerService(ModelManagerServiceBase):
:param model_name: Name of the model to convert :param model_name: Name of the model to convert
:param base_model: Base model type :param base_model: Base model type
:param model_type: Type of model ['vae' or 'main'] :param model_type: Type of model ['vae' or 'main']
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
This will raise a ValueError unless the model is not a checkpoint. It will This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place. directory already in place.
""" """
self.logger.debug(f'convert model {model_name}') self.logger.debug(f'convert model {model_name}')
return self.mgr.convert_model(model_name, base_model, model_type) return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
def commit(self, conf_file: Optional[Path]=None): def commit(self, conf_file: Optional[Path]=None):
""" """
@ -536,6 +578,7 @@ class ModelManagerService(ModelManagerServiceBase):
alpha: Optional[float] = 0.5, alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None, interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False, force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
) -> AddModelResult: ) -> AddModelResult:
""" """
Merge two to three diffusrs pipeline models and save as a new model. Merge two to three diffusrs pipeline models and save as a new model.
@ -544,6 +587,7 @@ class ModelManagerService(ModelManagerServiceBase):
:param merged_model_name: Name of destination merged model :param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model :param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default) :param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
""" """
merger = ModelMerger(self.mgr) merger = ModelMerger(self.mgr)
try: try:
@ -554,7 +598,55 @@ class ModelManagerService(ModelManagerServiceBase):
alpha = alpha, alpha = alpha,
interp = interp, interp = interp,
force = force, force = force,
merge_dest_directory=merge_dest_directory,
) )
except AssertionError as e: except AssertionError as e:
raise ValueError(e) raise ValueError(e)
return result return result
def search_for_models(self, directory: Path)->List[Path]:
"""
Return list of all models found in the designated directory.
"""
search = FindModels(directory,self.logger)
return search.list_models()
def sync_to_config(self):
"""
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
"""
return self.mgr.sync_to_config()
def list_checkpoint_configs(self)->List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
config = self.mgr.app_config
conf_path = config.legacy_conf_path
root_path = config.root_path
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')]
def rename_model(self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str = None,
new_base: BaseModelType = None,
):
"""
Rename the indicated model. Can provide a new name and/or a new base.
:param model_name: Current name of the model
:param base_model: Current base of the model
:param model_type: Model type (can't be changed)
:param new_name: New name for the model
:param new_base: New base for the model
"""
self.mgr.rename_model(base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = new_name,
new_base = new_base,
)

View File

@ -593,9 +593,12 @@ script, which will perform a full upgrade in place."""
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
config.parse_args(['--root',str(dest_root)]) config.parse_args(['--root',str(dest_root)])
# TODO: revisit # TODO: revisit - don't rely on invokeai.yaml to exist yet!
# assert (dest_root / 'models').is_dir(), f"{dest_root} does not contain a 'models' subdirectory" dest_is_setup = (dest_root / 'models/core').exists() and (dest_root / 'databases').exists()
# assert (dest_root / 'invokeai.yaml').exists(), f"{dest_root} does not contain an InvokeAI init file." if not dest_is_setup:
import invokeai.frontend.install.invokeai_configure
from invokeai.backend.install.invokeai_configure import initialize_rootdir
initialize_rootdir(dest_root, True)
do_migrate(src_root,dest_root) do_migrate(src_root,dest_root)

View File

@ -71,8 +71,6 @@ class ModelInstallList:
class InstallSelections(): class InstallSelections():
install_models: List[str]= field(default_factory=list) install_models: List[str]= field(default_factory=list)
remove_models: List[str]=field(default_factory=list) remove_models: List[str]=field(default_factory=list)
# scan_directory: Path = None
# autoscan_on_startup: bool=False
@dataclass @dataclass
class ModelLoadInfo(): class ModelLoadInfo():

View File

@ -247,6 +247,7 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, Chdir from invokeai.backend.util import CUDA_DEVICE, Chdir
from .model_cache import ModelCache, ModelLocker from .model_cache import ModelCache, ModelLocker
from .model_search import ModelSearch
from .models import ( from .models import (
BaseModelType, ModelType, SubModelType, BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES, ModelError, SchedulerPredictionType, MODEL_CLASSES,
@ -323,15 +324,6 @@ class ModelManager(object):
# TODO: metadata not found # TODO: metadata not found
# TODO: version check # TODO: version check
self.models = dict()
for model_key, model_config in config.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary
self.app_config = InvokeAIAppConfig.get_config() self.app_config = InvokeAIAppConfig.get_config()
self.logger = logger self.logger = logger
self.cache = ModelCache( self.cache = ModelCache(
@ -342,11 +334,41 @@ class ModelManager(object):
sequential_offload = sequential_offload, sequential_offload = sequential_offload,
logger = logger, logger = logger,
) )
self._read_models(config)
def _read_models(self, config: Optional[DictConfig] = None):
if not config:
if self.config_path:
config = OmegaConf.load(self.config_path)
else:
return
self.models = dict()
for model_key, model_config in config.items():
if model_key.startswith('_'):
continue
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary
self.cache_keys = dict() self.cache_keys = dict()
# add controlnet, lora and textual_inversion models from disk # add controlnet, lora and textual_inversion models from disk
self.scan_models_directory() self.scan_models_directory()
def sync_to_config(self):
"""
Call this when `models.yaml` has been changed externally.
This will reinitialize internal data structures
"""
# Reread models directory; note that this will reinitialize the cache,
# causing otherwise unreferenced models to be removed from memory
self._read_models()
def model_exists( def model_exists(
self, self,
model_name: str, model_name: str,
@ -527,7 +549,10 @@ class ModelManager(object):
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold) model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
models = [] models = []
for model_key in model_keys: for model_key in model_keys:
model_config = self.models[model_key] model_config = self.models.get(model_key)
if not model_config:
self.logger.error(f'Unknown model {model_name}')
raise KeyError(f'Unknown model {model_name}')
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
if base_model is not None and cur_base_model != base_model: if base_model is not None and cur_base_model != base_model:
@ -646,11 +671,61 @@ class ModelManager(object):
config = model_config, config = model_config,
) )
def rename_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str = None,
new_base: BaseModelType = None,
):
'''
Rename or rebase a model.
'''
if new_name is None and new_base is None:
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
return
model_key = self.create_key(model_name, base_model, model_type)
model_cfg = self.models.get(model_key, None)
if not model_cfg:
raise KeyError(f"Unknown model: {model_key}")
old_path = self.app_config.root_path / model_cfg.path
new_name = new_name or model_name
new_base = new_base or base_model
new_key = self.create_key(new_name, new_base, model_type)
if new_key in self.models:
raise ValueError(f'Attempt to overwrite existing model definition "{new_key}"')
# if this is a model file/directory that we manage ourselves, we need to move it
if old_path.is_relative_to(self.app_config.models_path):
new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name
move(old_path, new_path)
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
# clean up caches
old_model_cache = self._get_model_cache_path(old_path)
if old_model_cache.exists():
if old_model_cache.is_dir():
rmtree(str(old_model_cache))
else:
old_model_cache.unlink()
cache_ids = self.cache_keys.pop(model_key, [])
for cache_id in cache_ids:
self.cache.uncache_model(cache_id)
self.models.pop(model_key, None) # delete
self.models[new_key] = model_cfg
self.commit()
def convert_model ( def convert_model (
self, self,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae], model_type: Union[ModelType.Main,ModelType.Vae],
dest_directory: Optional[Path]=None,
) -> AddModelResult: ) -> AddModelResult:
''' '''
Convert a checkpoint file into a diffusers folder, deleting the cached Convert a checkpoint file into a diffusers folder, deleting the cached
@ -677,14 +752,14 @@ class ModelManager(object):
) )
checkpoint_path = self.app_config.root_path / info["path"] checkpoint_path = self.app_config.root_path / info["path"]
old_diffusers_path = self.app_config.models_path / model.location old_diffusers_path = self.app_config.models_path / model.location
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name new_diffusers_path = (dest_directory or self.app_config.models_path / base_model.value / model_type.value) / model_name
if new_diffusers_path.exists(): if new_diffusers_path.exists():
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}") raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
try: try:
move(old_diffusers_path,new_diffusers_path) move(old_diffusers_path,new_diffusers_path)
info["model_format"] = "diffusers" info["model_format"] = "diffusers"
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path)) info["path"] = str(new_diffusers_path) if dest_directory else str(new_diffusers_path.relative_to(self.app_config.root_path))
info.pop('config') info.pop('config')
result = self.add_model(model_name, base_model, model_type, result = self.add_model(model_name, base_model, model_type,
@ -824,6 +899,7 @@ class ModelManager(object):
if (new_models_found or imported_models) and self.config_path: if (new_models_found or imported_models) and self.config_path:
self.commit() self.commit()
def autoimport(self)->Dict[str, AddModelResult]: def autoimport(self)->Dict[str, AddModelResult]:
''' '''
Scan the autoimport directory (if defined) and import new models, delete defunct models. Scan the autoimport directory (if defined) and import new models, delete defunct models.
@ -832,62 +908,41 @@ class ModelManager(object):
from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.frontend.install.model_install import ask_user_for_prediction_type from invokeai.frontend.install.model_install import ask_user_for_prediction_type
class ScanAndImport(ModelSearch):
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
super().__init__(directories, logger)
self.installer = installer
self.ignore = ignore
def on_search_started(self):
self.new_models_found = dict()
def on_model_found(self, model: Path):
if model not in self.ignore:
self.new_models_found.update(self.installer.heuristic_import(model))
def on_search_completed(self):
self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models')
def models_found(self):
return self.new_models_found
installer = ModelInstall(config = self.app_config, installer = ModelInstall(config = self.app_config,
model_manager = self, model_manager = self,
prediction_type_helper = ask_user_for_prediction_type, prediction_type_helper = ask_user_for_prediction_type,
) )
scanned_dirs = set()
config = self.app_config config = self.app_config
known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()} known_paths = {config.root_path / x['path'] for x in self.list_models()}
directories = {config.root_path / x for x in [config.autoimport_dir,
for autodir in [config.autoimport_dir,
config.lora_dir, config.lora_dir,
config.embedding_dir, config.embedding_dir,
config.controlnet_dir]: config.controlnet_dir]
if autodir is None: }
continue scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
scanner.search()
self.logger.info(f'Scanning {autodir} for models to import') return scanner.models_found()
installed = dict()
autodir = self.app_config.root_path / autodir
if not autodir.exists():
continue
items_scanned = 0
new_models_found = dict()
for root, dirs, files in os.walk(autodir):
items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in known_paths or path.parent in scanned_dirs:
scanned_dirs.add(path)
continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
try:
new_models_found.update(installer.heuristic_import(path))
scanned_dirs.add(path)
except ValueError as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path in known_paths or path.parent in scanned_dirs:
continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
try:
import_result = installer.heuristic_import(path)
new_models_found.update(import_result)
except ValueError as e:
self.logger.warning(str(e))
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
installed.update(new_models_found)
return installed
def heuristic_import(self, def heuristic_import(self,
items_to_import: Set[str], items_to_import: Set[str],
@ -925,3 +980,4 @@ class ModelManager(object):
successfully_installed.update(installed) successfully_installed.update(installed)
self.commit() self.commit()
return successfully_installed return successfully_installed

View File

@ -11,7 +11,7 @@ from enum import Enum
from pathlib import Path from pathlib import Path
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers import logging as dlogging from diffusers import logging as dlogging
from typing import List, Union from typing import List, Union, Optional
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
@ -74,6 +74,7 @@ class ModelMerger(object):
alpha: float = 0.5, alpha: float = 0.5,
interp: MergeInterpolationMethod = None, interp: MergeInterpolationMethod = None,
force: bool = False, force: bool = False,
merge_dest_directory: Optional[Path] = None,
**kwargs, **kwargs,
) -> AddModelResult: ) -> AddModelResult:
""" """
@ -85,7 +86,7 @@ class ModelMerger(object):
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None. :param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C). Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
**kwargs - the default DiffusionPipeline.get_config_dict kwargs: **kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
""" """
@ -111,7 +112,7 @@ class ModelMerger(object):
merged_pipe = self.merge_diffusion_models( merged_pipe = self.merge_diffusion_models(
model_paths, alpha, merge_method, force, **kwargs model_paths, alpha, merge_method, force, **kwargs
) )
dump_path = config.models_path / base_model.value / ModelType.Main.value dump_path = Path(merge_dest_directory) if merge_dest_directory else config.models_path / base_model.value / ModelType.Main.value
dump_path.mkdir(parents=True, exist_ok=True) dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name dump_path = dump_path / merged_model_name

View File

@ -0,0 +1,103 @@
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
"""
Abstract base class for recursive directory search for models.
"""
import os
from abc import ABC, abstractmethod
from typing import List, Set, types
from pathlib import Path
import invokeai.backend.util.logging as logger
class ModelSearch(ABC):
def __init__(self, directories: List[Path], logger: types.ModuleType=logger):
"""
Initialize a recursive model directory search.
:param directories: List of directory Paths to recurse through
:param logger: Logger to use
"""
self.directories = directories
self.logger = logger
self._items_scanned = 0
self._models_found = 0
self._scanned_dirs = set()
self._scanned_paths = set()
self._pruned_paths = set()
@abstractmethod
def on_search_started(self):
"""
Called before the scan starts.
"""
pass
@abstractmethod
def on_model_found(self, model: Path):
"""
Process a found model. Raise an exception if something goes wrong.
:param model: Model to process - could be a directory or checkpoint.
"""
pass
@abstractmethod
def on_search_completed(self):
"""
Perform some activity when the scan is completed. May use instance
variables, items_scanned and models_found
"""
pass
def search(self):
self.on_search_started()
for dir in self.directories:
self.walk_directory(dir)
self.on_search_completed()
def walk_directory(self, path: Path):
for root, dirs, files in os.walk(path):
if str(Path(root).name).startswith('.'):
self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
continue
self._items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in self._scanned_paths or path.parent in self._scanned_dirs:
self._scanned_dirs.add(path)
continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
try:
self.on_model_found(path)
self._models_found += 1
self._scanned_dirs.add(path)
except Exception as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path.parent in self._scanned_dirs:
continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
try:
self.on_model_found(path)
self._models_found += 1
except Exception as e:
self.logger.warning(str(e))
class FindModels(ModelSearch):
def on_search_started(self):
self.models_found: Set[Path] = set()
def on_model_found(self,model: Path):
self.models_found.add(model)
def on_search_completed(self):
pass
def list_models(self) -> List[Path]:
self.search()
return self.models_found

View File

@ -48,7 +48,9 @@ for base_model, models in MODEL_CLASSES.items():
model_configs.discard(None) model_configs.discard(None)
MODEL_CONFIGS.extend(model_configs) MODEL_CONFIGS.extend(model_configs)
for cfg in model_configs: # LS: sort to get the checkpoint configs first, which makes
# for a better template in the Swagger docs
for cfg in sorted(model_configs, key=lambda x: str(x)):
model_name, cfg_name = cfg.__qualname__.split('.')[-2:] model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
openapi_cfg_name = model_name + cfg_name openapi_cfg_name = model_name + cfg_name
if openapi_cfg_name in vars(): if openapi_cfg_name in vars():

View File

@ -59,7 +59,6 @@ class ModelConfigBase(BaseModel):
path: str # or Path path: str # or Path
description: Optional[str] = Field(None) description: Optional[str] = Field(None)
model_format: Optional[str] = Field(None) model_format: Optional[str] = Field(None)
# do not save to config
error: Optional[ModelError] = Field(None) error: Optional[ModelError] = Field(None)
class Config: class Config:

View File

@ -1,8 +1,7 @@
import os import os
import torch import torch
from enum import Enum from enum import Enum
from pathlib import Path from typing import Optional
from typing import Optional, Union, Literal
from .base import ( from .base import (
ModelBase, ModelBase,
ModelConfigBase, ModelConfigBase,
@ -14,6 +13,7 @@ from .base import (
calc_model_size_by_data, calc_model_size_by_data,
classproperty, classproperty,
InvalidModelException, InvalidModelException,
ModelNotFoundException,
) )
class ControlNetModelFormat(str, Enum): class ControlNetModelFormat(str, Enum):
@ -60,10 +60,20 @@ class ControlNetModel(ModelBase):
if child_type is not None: if child_type is not None:
raise Exception("There is no child models in controlnet model") raise Exception("There is no child models in controlnet model")
model = None
for variant in ['fp16',None]:
try:
model = self.model_class.from_pretrained( model = self.model_class.from_pretrained(
self.model_path, self.model_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
variant=variant,
) )
break
except:
pass
if not model:
raise ModelNotFoundException()
# calc more accurate size # calc more accurate size
self.model_size = calc_model_size_by_data(model) self.model_size = calc_model_size_by_data(model)
return model return model

View File

@ -38,7 +38,6 @@ class StableDiffusion1Model(DiffusersModel):
config: str config: str
variant: ModelVariantType variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1 assert base_model == BaseModelType.StableDiffusion1
assert model_type == ModelType.Main assert model_type == ModelType.Main

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -12,7 +12,7 @@
margin: 0; margin: 0;
} }
</style> </style>
<script type="module" crossorigin src="./assets/index-078526aa.js"></script> <script type="module" crossorigin src="./assets/index-8888b06f.js"></script>
</head> </head>
<body dir="ltr"> <body dir="ltr">

View File

@ -102,7 +102,8 @@
"openInNewTab": "Open in New Tab", "openInNewTab": "Open in New Tab",
"dontAskMeAgain": "Don't ask me again", "dontAskMeAgain": "Don't ask me again",
"areYouSure": "Are you sure?", "areYouSure": "Are you sure?",
"imagePrompt": "Image Prompt" "imagePrompt": "Image Prompt",
"clearNodes": "Are you sure you want to clear all nodes?"
}, },
"gallery": { "gallery": {
"generations": "Generations", "generations": "Generations",
@ -118,7 +119,7 @@
"pinGallery": "Pin Gallery", "pinGallery": "Pin Gallery",
"allImagesLoaded": "All Images Loaded", "allImagesLoaded": "All Images Loaded",
"loadMore": "Load More", "loadMore": "Load More",
"noImagesInGallery": "No Images In Gallery", "noImagesInGallery": "No Images to Display",
"deleteImage": "Delete Image", "deleteImage": "Delete Image",
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.", "deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
"deleteImagePermanent": "Deleted images cannot be restored.", "deleteImagePermanent": "Deleted images cannot be restored.",
@ -342,6 +343,7 @@
"safetensorModels": "SafeTensors", "safetensorModels": "SafeTensors",
"modelAdded": "Model Added", "modelAdded": "Model Added",
"modelUpdated": "Model Updated", "modelUpdated": "Model Updated",
"modelUpdateFailed": "Model Update Failed",
"modelEntryDeleted": "Model Entry Deleted", "modelEntryDeleted": "Model Entry Deleted",
"cannotUseSpaces": "Cannot Use Spaces", "cannotUseSpaces": "Cannot Use Spaces",
"addNew": "Add New", "addNew": "Add New",
@ -396,8 +398,8 @@
"delete": "Delete", "delete": "Delete",
"deleteModel": "Delete Model", "deleteModel": "Delete Model",
"deleteConfig": "Delete Config", "deleteConfig": "Delete Config",
"deleteMsg1": "Are you sure you want to delete this model entry from InvokeAI?", "deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
"deleteMsg2": "This will not delete the model checkpoint file from your disk. You can readd them if you wish to.", "deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
"formMessageDiffusersModelLocation": "Diffusers Model Location", "formMessageDiffusersModelLocation": "Diffusers Model Location",
"formMessageDiffusersModelLocationDesc": "Please enter at least one.", "formMessageDiffusersModelLocationDesc": "Please enter at least one.",
"formMessageDiffusersVAELocation": "VAE Location", "formMessageDiffusersVAELocation": "VAE Location",
@ -408,7 +410,7 @@
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.", "convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
"convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.", "convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.",
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.", "convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 4GB-7GB in size.", "convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?", "convertToDiffusersHelpText6": "Do you wish to convert this model?",
"convertToDiffusersSaveLocation": "Save Location", "convertToDiffusersSaveLocation": "Save Location",
"v1": "v1", "v1": "v1",
@ -419,12 +421,14 @@
"pathToCustomConfig": "Path To Custom Config", "pathToCustomConfig": "Path To Custom Config",
"statusConverting": "Converting", "statusConverting": "Converting",
"modelConverted": "Model Converted", "modelConverted": "Model Converted",
"modelConversionFailed": "Model Conversion Failed",
"sameFolder": "Same folder", "sameFolder": "Same folder",
"invokeRoot": "InvokeAI folder", "invokeRoot": "InvokeAI folder",
"custom": "Custom", "custom": "Custom",
"customSaveLocation": "Custom Save Location", "customSaveLocation": "Custom Save Location",
"merge": "Merge", "merge": "Merge",
"modelsMerged": "Models Merged", "modelsMerged": "Models Merged",
"modelsMergeFailed": "Model Merge Failed",
"mergeModels": "Merge Models", "mergeModels": "Merge Models",
"modelOne": "Model 1", "modelOne": "Model 1",
"modelTwo": "Model 2", "modelTwo": "Model 2",
@ -445,7 +449,8 @@
"weightedSum": "Weighted Sum", "weightedSum": "Weighted Sum",
"none": "none", "none": "none",
"addDifference": "Add Difference", "addDifference": "Add Difference",
"pickModelType": "Pick Model Type" "pickModelType": "Pick Model Type",
"selectModel": "Select Model"
}, },
"parameters": { "parameters": {
"general": "General", "general": "General",
@ -528,7 +533,7 @@
"hidePreview": "Hide Preview", "hidePreview": "Hide Preview",
"showPreview": "Show Preview", "showPreview": "Show Preview",
"controlNetControlMode": "Control Mode", "controlNetControlMode": "Control Mode",
"clipSkip": "Clip Skip", "clipSkip": "CLIP Skip",
"aspectRatio": "Ratio" "aspectRatio": "Ratio"
}, },
"settings": { "settings": {
@ -593,7 +598,11 @@
"metadataLoadFailed": "Failed to load metadata", "metadataLoadFailed": "Failed to load metadata",
"initialImageSet": "Initial Image Set", "initialImageSet": "Initial Image Set",
"initialImageNotSet": "Initial Image Not Set", "initialImageNotSet": "Initial Image Not Set",
"initialImageNotSetDesc": "Could not load initial image" "initialImageNotSetDesc": "Could not load initial image",
"nodesSaved": "Nodes Saved",
"nodesLoaded": "Nodes Loaded",
"nodesLoadedFailed": "Failed To Load Nodes",
"nodesCleared": "Nodes Cleared"
}, },
"tooltip": { "tooltip": {
"feature": { "feature": {
@ -674,5 +683,11 @@
"showProgressImages": "Show Progress Images", "showProgressImages": "Show Progress Images",
"hideProgressImages": "Hide Progress Images", "hideProgressImages": "Hide Progress Images",
"swapSizes": "Swap Sizes" "swapSizes": "Swap Sizes"
},
"nodes": {
"reloadSchema": "Reload Schema",
"saveNodes": "Save Nodes",
"loadNodes": "Load Nodes",
"clearNodes": "Clear Nodes"
} }
} }

View File

@ -13,7 +13,11 @@ import { RootState } from 'app/store/store';
const moduleLog = log.child({ namespace: 'controlNet' }); const moduleLog = log.child({ namespace: 'controlNet' });
const predicate: AnyListenerPredicate<RootState> = (action, state) => { const predicate: AnyListenerPredicate<RootState> = (
action,
state,
prevState
) => {
const isActionMatched = const isActionMatched =
controlNetProcessorParamsChanged.match(action) || controlNetProcessorParamsChanged.match(action) ||
controlNetModelChanged.match(action) || controlNetModelChanged.match(action) ||
@ -25,6 +29,16 @@ const predicate: AnyListenerPredicate<RootState> = (action, state) => {
return false; return false;
} }
if (controlNetAutoConfigToggled.match(action)) {
// do not process if the user just disabled auto-config
if (
prevState.controlNet.controlNets[action.payload.controlNetId]
.shouldAutoConfig === true
) {
return false;
}
}
const { controlImage, processorType, shouldAutoConfig } = const { controlImage, processorType, shouldAutoConfig } =
state.controlNet.controlNets[action.payload.controlNetId]; state.controlNet.controlNets[action.payload.controlNetId];

View File

@ -10,6 +10,7 @@ import { zMainModel } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
const moduleLog = log.child({ module: 'models' }); const moduleLog = log.child({ module: 'models' });
@ -51,7 +52,14 @@ export const addModelSelectedListener = () => {
modelsCleared += 1; modelsCleared += 1;
} }
// TODO: handle incompatible controlnet; pending model manager support const { controlNets } = state.controlNet;
forEach(controlNets, (controlNet, controlNetId) => {
if (controlNet.model?.base_model !== base_model) {
dispatch(controlNetRemoved({ controlNetId }));
modelsCleared += 1;
}
});
if (modelsCleared > 0) { if (modelsCleared > 0) {
dispatch( dispatch(
addToast( addToast(

View File

@ -11,6 +11,7 @@ import {
import { forEach, some } from 'lodash-es'; import { forEach, some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
const moduleLog = log.child({ module: 'models' }); const moduleLog = log.child({ module: 'models' });
@ -127,7 +128,22 @@ export const addModelsLoadedListener = () => {
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled, matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state // ControlNet models loaded - need to remove missing ControlNets from state
// TODO: pending model manager controlnet support const controlNets = getState().controlNet.controlNets;
forEach(controlNets, (controlNet, controlNetId) => {
const isControlNetAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === controlNet?.model?.model_name &&
m?.base_model === controlNet?.model?.base_model
);
if (isControlNetAvailable) {
return;
}
dispatch(controlNetRemoved({ controlNetId }));
});
}, },
}); });
}; };

View File

@ -1,5 +1,5 @@
import { import {
CONTROLNET_MODELS, // CONTROLNET_MODELS,
CONTROLNET_PROCESSORS, CONTROLNET_PROCESSORS,
} from 'features/controlNet/store/constants'; } from 'features/controlNet/store/constants';
import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName } from 'features/ui/store/tabMap';
@ -128,7 +128,7 @@ export type AppConfig = {
canRestoreDeletedImagesFromBin: boolean; canRestoreDeletedImagesFromBin: boolean;
sd: { sd: {
defaultModel?: string; defaultModel?: string;
disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[]; disabledControlNetModels: string[];
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[]; disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
iterations: { iterations: {
initial: number; initial: number;

View File

@ -170,12 +170,14 @@ const IAIDndImage = (props: IAIDndImageProps) => {
</> </>
)} )}
{!imageDTO && isUploadDisabled && noContentFallback} {!imageDTO && isUploadDisabled && noContentFallback}
{!isDropDisabled && (
<IAIDroppable <IAIDroppable
data={droppableData} data={droppableData}
disabled={isDropDisabled} disabled={isDropDisabled}
dropLabel={dropLabel} dropLabel={dropLabel}
/> />
{imageDTO && ( )}
{imageDTO && !isDragDisabled && (
<IAIDraggable <IAIDraggable
data={draggableData} data={draggableData}
disabled={isDragDisabled || !imageDTO} disabled={isDragDisabled || !imageDTO}

View File

@ -1,17 +1,25 @@
import { Tooltip } from '@chakra-ui/react'; import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core'; import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles'; import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles';
import { KeyboardEvent, RefObject, memo, useCallback } from 'react'; import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
type IAIMultiSelectProps = MultiSelectProps & { type IAIMultiSelectProps = Omit<MultiSelectProps, 'label'> & {
tooltip?: string; tooltip?: string;
inputRef?: RefObject<HTMLInputElement>; inputRef?: RefObject<HTMLInputElement>;
label?: string;
}; };
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, inputRef, ...rest } = props; const {
searchable = true,
tooltip,
inputRef,
label,
disabled,
...rest
} = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleKeyDown = useCallback( const handleKeyDown = useCallback(
@ -37,7 +45,15 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
return ( return (
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}> <Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
<MultiSelect <MultiSelect
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
ref={inputRef} ref={inputRef}
disabled={disabled}
onKeyDown={handleKeyDown} onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp} onKeyUp={handleKeyUp}
searchable={searchable} searchable={searchable}

View File

@ -1,4 +1,4 @@
import { Tooltip } from '@chakra-ui/react'; import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core'; import { Select, SelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
@ -11,13 +11,22 @@ export type IAISelectDataType = {
tooltip?: string; tooltip?: string;
}; };
type IAISelectProps = SelectProps & { type IAISelectProps = Omit<SelectProps, 'label'> & {
tooltip?: string; tooltip?: string;
label?: string;
inputRef?: RefObject<HTMLInputElement>; inputRef?: RefObject<HTMLInputElement>;
}; };
const IAIMantineSearchableSelect = (props: IAISelectProps) => { const IAIMantineSearchableSelect = (props: IAISelectProps) => {
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props; const {
searchable = true,
tooltip,
inputRef,
onChange,
label,
disabled,
...rest
} = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [searchValue, setSearchValue] = useState(''); const [searchValue, setSearchValue] = useState('');
@ -61,6 +70,14 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => {
<Tooltip label={tooltip} placement="top" hasArrow> <Tooltip label={tooltip} placement="top" hasArrow>
<Select <Select
ref={inputRef} ref={inputRef}
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
disabled={disabled}
searchValue={searchValue} searchValue={searchValue}
onSearchChange={setSearchValue} onSearchChange={setSearchValue}
onChange={handleChange} onChange={handleChange}

View File

@ -1,4 +1,4 @@
import { Tooltip } from '@chakra-ui/react'; import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core'; import { Select, SelectProps } from '@mantine/core';
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles'; import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
import { RefObject, memo } from 'react'; import { RefObject, memo } from 'react';
@ -9,19 +9,32 @@ export type IAISelectDataType = {
tooltip?: string; tooltip?: string;
}; };
type IAISelectProps = SelectProps & { type IAISelectProps = Omit<SelectProps, 'label'> & {
tooltip?: string; tooltip?: string;
inputRef?: RefObject<HTMLInputElement>; inputRef?: RefObject<HTMLInputElement>;
label?: string;
}; };
const IAIMantineSelect = (props: IAISelectProps) => { const IAIMantineSelect = (props: IAISelectProps) => {
const { tooltip, inputRef, ...rest } = props; const { tooltip, inputRef, label, disabled, ...rest } = props;
const styles = useMantineSelectStyles(); const styles = useMantineSelectStyles();
return ( return (
<Tooltip label={tooltip} placement="top" hasArrow> <Tooltip label={tooltip} placement="top" hasArrow>
<Select ref={inputRef} styles={styles} {...rest} /> <Select
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
disabled={disabled}
ref={inputRef}
styles={styles}
{...rest}
/>
</Tooltip> </Tooltip>
); );
}; };

View File

@ -43,11 +43,6 @@ import { useTranslation } from 'react-i18next';
import { BiReset } from 'react-icons/bi'; import { BiReset } from 'react-icons/bi';
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton'; import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
};
export type IAIFullSliderProps = { export type IAIFullSliderProps = {
label?: string; label?: string;
value: number; value: number;
@ -207,7 +202,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
{...sliderFormControlProps} {...sliderFormControlProps}
> >
{label && ( {label && (
<FormLabel {...sliderFormLabelProps} mb={-1}> <FormLabel sx={withInput ? { mb: -1.5 } : {}} {...sliderFormLabelProps}>
{label} {label}
</FormLabel> </FormLabel>
)} )}
@ -233,7 +228,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{ sx={{
insetInlineStart: '0 !important', insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important', insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >
@ -244,7 +238,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{ sx={{
insetInlineStart: 'unset !important', insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important', insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >
@ -263,7 +256,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{ sx={{
insetInlineStart: '0 !important', insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important', insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >
@ -278,7 +270,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{ sx={{
insetInlineStart: 'unset !important', insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important', insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >
@ -291,7 +282,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
key={m} key={m}
value={m} value={m}
sx={{ sx={{
...SLIDER_MARK_STYLES, transform: 'translateX(-50%)',
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >

View File

@ -5,6 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { validateSeedWeights } from 'common/util/seedWeightPairs'; import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { modelsApi } from '../../services/api/endpoints/models'; import { modelsApi } from '../../services/api/endpoints/models';
import { forEach } from 'lodash-es';
const readinessSelector = createSelector( const readinessSelector = createSelector(
[stateSelector, activeTabNameSelector], [stateSelector, activeTabNameSelector],
@ -52,6 +53,13 @@ const readinessSelector = createSelector(
reasonsWhyNotReady.push('Seed-Weights badly formatted.'); reasonsWhyNotReady.push('Seed-Weights badly formatted.');
} }
forEach(state.controlNet.controlNets, (controlNet, id) => {
if (!controlNet.model) {
isReady = false;
reasonsWhyNotReady.push('ControlNet ${id} has no model selected.');
}
});
// All good // All good
return { isReady, reasonsWhyNotReady }; return { isReady, reasonsWhyNotReady };
}, },

View File

@ -1,10 +1,9 @@
import { Box, ChakraProps, Flex, useColorMode } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { FaCopy, FaTrash } from 'react-icons/fa'; import { FaCopy, FaTrash } from 'react-icons/fa';
import { import {
ControlNetConfig, controlNetDuplicated,
controlNetAdded,
controlNetRemoved, controlNetRemoved,
controlNetToggled, controlNetToggled,
} from '../store/controlNetSlice'; } from '../store/controlNetSlice';
@ -12,6 +11,9 @@ import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight'; import ParamControlNetWeight from './parameters/ParamControlNetWeight';
import { ChevronUpIcon } from '@chakra-ui/icons'; import { ChevronUpIcon } from '@chakra-ui/icons';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { useToggle } from 'react-use'; import { useToggle } from 'react-use';
@ -22,41 +24,41 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd'; import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode'; import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect'; import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import { mode } from 'theme/util/mode';
const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 };
type ControlNetProps = { type ControlNetProps = {
controlNet: ControlNetConfig; controlNetId: string;
}; };
const ControlNet = (props: ControlNetProps) => { const ControlNet = (props: ControlNetProps) => {
const { const { controlNetId } = props;
controlNetId,
isEnabled,
model,
weight,
beginStepPct,
endStepPct,
controlMode,
controlImage,
processedControlImage,
processorNode,
processorType,
shouldAutoConfig,
} = props.controlNet;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selector = createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, shouldAutoConfig } =
controlNet.controlNets[controlNetId];
return { isEnabled, shouldAutoConfig };
},
defaultSelectorOptions
);
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
const [isExpanded, toggleIsExpanded] = useToggle(false); const [isExpanded, toggleIsExpanded] = useToggle(false);
const { colorMode } = useColorMode();
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {
dispatch(controlNetRemoved({ controlNetId })); dispatch(controlNetRemoved({ controlNetId }));
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
const handleDuplicate = useCallback(() => { const handleDuplicate = useCallback(() => {
dispatch( dispatch(
controlNetAdded({ controlNetId: uuidv4(), controlNet: props.controlNet }) controlNetDuplicated({
sourceControlNetId: controlNetId,
newControlNetId: uuidv4(),
})
); );
}, [dispatch, props.controlNet]); }, [controlNetId, dispatch]);
const handleToggleIsEnabled = useCallback(() => { const handleToggleIsEnabled = useCallback(() => {
dispatch(controlNetToggled({ controlNetId })); dispatch(controlNetToggled({ controlNetId }));
@ -68,15 +70,18 @@ const ControlNet = (props: ControlNetProps) => {
flexDir: 'column', flexDir: 'column',
gap: 2, gap: 2,
p: 3, p: 3,
bg: mode('base.200', 'base.850')(colorMode),
borderRadius: 'base', borderRadius: 'base',
position: 'relative', position: 'relative',
bg: 'base.200',
_dark: {
bg: 'base.850',
},
}} }}
> >
<Flex sx={{ gap: 2 }}> <Flex sx={{ gap: 2, alignItems: 'center' }}>
<IAISwitch <IAISwitch
tooltip="Toggle" tooltip={'Toggle this ControlNet'}
aria-label="Toggle" aria-label={'Toggle this ControlNet'}
isChecked={isEnabled} isChecked={isEnabled}
onChange={handleToggleIsEnabled} onChange={handleToggleIsEnabled}
/> />
@ -90,7 +95,7 @@ const ControlNet = (props: ControlNetProps) => {
transitionDuration: '0.1s', transitionDuration: '0.1s',
}} }}
> >
<ParamControlNetModel controlNetId={controlNetId} model={model} /> <ParamControlNetModel controlNetId={controlNetId} />
</Box> </Box>
<IAIIconButton <IAIIconButton
size="sm" size="sm"
@ -109,21 +114,26 @@ const ControlNet = (props: ControlNetProps) => {
/> />
<IAIIconButton <IAIIconButton
size="sm" size="sm"
aria-label="Show All Options" tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
onClick={toggleIsExpanded} onClick={toggleIsExpanded}
variant="link" variant="link"
icon={ icon={
<ChevronUpIcon <ChevronUpIcon
sx={{ sx={{
boxSize: 4, boxSize: 4,
color: mode('base.700', 'base.300')(colorMode), color: 'base.700',
transform: isExpanded ? 'rotate(0deg)' : 'rotate(180deg)', transform: isExpanded ? 'rotate(0deg)' : 'rotate(180deg)',
transitionProperty: 'common', transitionProperty: 'common',
transitionDuration: 'normal', transitionDuration: 'normal',
_dark: {
color: 'base.300',
},
}} }}
/> />
} }
/> />
{!shouldAutoConfig && ( {!shouldAutoConfig && (
<Box <Box
sx={{ sx={{
@ -131,21 +141,23 @@ const ControlNet = (props: ControlNetProps) => {
w: 1.5, w: 1.5,
h: 1.5, h: 1.5,
borderRadius: 'full', borderRadius: 'full',
bg: mode('error.700', 'error.200')(colorMode),
top: 4, top: 4,
insetInlineEnd: 4, insetInlineEnd: 4,
bg: 'accent.700',
_dark: {
bg: 'accent.400',
},
}} }}
/> />
)} )}
</Flex> </Flex>
{isEnabled && (
<>
<Flex sx={{ w: 'full', flexDirection: 'column' }}> <Flex sx={{ w: 'full', flexDirection: 'column' }}>
<Flex sx={{ gap: 4, w: 'full' }}> <Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
<Flex <Flex
sx={{ sx={{
flexDir: 'column', flexDir: 'column',
gap: 3, gap: 3,
h: 28,
w: 'full', w: 'full',
paddingInlineStart: 1, paddingInlineStart: 1,
paddingInlineEnd: isExpanded ? 1 : 0, paddingInlineEnd: isExpanded ? 1 : 0,
@ -153,63 +165,35 @@ const ControlNet = (props: ControlNetProps) => {
justifyContent: 'space-between', justifyContent: 'space-between',
}} }}
> >
<ParamControlNetWeight <ParamControlNetWeight controlNetId={controlNetId} />
controlNetId={controlNetId} <ParamControlNetBeginEnd controlNetId={controlNetId} />
weight={weight}
mini={!isExpanded}
/>
<ParamControlNetBeginEnd
controlNetId={controlNetId}
beginStepPct={beginStepPct}
endStepPct={endStepPct}
mini={!isExpanded}
/>
</Flex> </Flex>
{!isExpanded && ( {!isExpanded && (
<Flex <Flex
sx={{ sx={{
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
h: 24, h: 28,
w: 24, w: 28,
aspectRatio: '1/1', aspectRatio: '1/1',
mt: 3,
}} }}
> >
<ControlNetImagePreview <ControlNetImagePreview controlNetId={controlNetId} height={28} />
controlNet={props.controlNet}
height={24}
/>
</Flex> </Flex>
)} )}
</Flex> </Flex>
<ParamControlNetControlMode <Box mt={2}>
controlNetId={controlNetId} <ParamControlNetControlMode controlNetId={controlNetId} />
controlMode={controlMode} </Box>
/> <ParamControlNetProcessorSelect controlNetId={controlNetId} />
</Flex> </Flex>
{isExpanded && ( {isExpanded && (
<> <>
<Box mt={2}> <ControlNetImagePreview controlNetId={controlNetId} height="392px" />
<ControlNetImagePreview <ParamControlNetShouldAutoConfig controlNetId={controlNetId} />
controlNet={props.controlNet} <ControlNetProcessorComponent controlNetId={controlNetId} />
height={96}
/>
</Box>
<ParamControlNetProcessorSelect
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ControlNetProcessorComponent
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ParamControlNetShouldAutoConfig
controlNetId={controlNetId}
shouldAutoConfig={shouldAutoConfig}
/>
</>
)}
</> </>
)} )}
</Flex> </Flex>

View File

@ -5,42 +5,57 @@ import {
TypesafeDraggableData, TypesafeDraggableData,
TypesafeDroppableData, TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd'; } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import { memo, useCallback, useMemo, useState } from 'react'; import { memo, useCallback, useMemo, useState } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/thunks/image'; import { PostUploadAction } from 'services/api/thunks/image';
import { import { controlNetImageChanged } from '../store/controlNetSlice';
ControlNetConfig,
controlNetImageChanged,
controlNetSelector,
} from '../store/controlNetSlice';
const selector = createSelector(
controlNetSelector,
(controlNet) => {
const { pendingControlImages } = controlNet;
return { pendingControlImages };
},
defaultSelectorOptions
);
type Props = { type Props = {
controlNet: ControlNetConfig; controlNetId: string;
height: SystemStyleObject['h']; height: SystemStyleObject['h'];
}; };
const ControlNetImagePreview = (props: Props) => { const ControlNetImagePreview = (props: Props) => {
const { height } = props; const { height, controlNetId } = props;
const {
controlNetId,
controlImage: controlImageName,
processedControlImage: processedControlImageName,
processorType,
} = props.controlNet;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { pendingControlImages } = useAppSelector(selector);
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { pendingControlImages } = controlNet;
const {
controlImage,
processedControlImage,
processorType,
isEnabled,
} = controlNet.controlNets[controlNetId];
return {
controlImageName: controlImage,
processedControlImageName: processedControlImage,
processorType,
isEnabled,
pendingControlImages,
};
},
defaultSelectorOptions
),
[controlNetId]
);
const {
controlImageName,
processedControlImageName,
processorType,
pendingControlImages,
isEnabled,
} = useAppSelector(selector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false); const [isMouseOverImage, setIsMouseOverImage] = useState(false);
@ -110,13 +125,15 @@ const ControlNetImagePreview = (props: Props) => {
h: height, h: height,
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
pointerEvents: isEnabled ? 'auto' : 'none',
opacity: isEnabled ? 1 : 0.5,
}} }}
> >
<IAIDndImage <IAIDndImage
draggableData={draggableData} draggableData={draggableData}
droppableData={droppableData} droppableData={droppableData}
imageDTO={controlImage} imageDTO={controlImage}
isDropDisabled={shouldShowProcessedImage} isDropDisabled={shouldShowProcessedImage || !isEnabled}
onClickReset={handleResetControlImage} onClickReset={handleResetControlImage}
postUploadAction={postUploadAction} postUploadAction={postUploadAction}
resetTooltip="Reset Control Image" resetTooltip="Reset Control Image"
@ -140,6 +157,7 @@ const ControlNetImagePreview = (props: Props) => {
droppableData={droppableData} droppableData={droppableData}
imageDTO={processedControlImage} imageDTO={processedControlImage}
isUploadDisabled={true} isUploadDisabled={true}
isDropDisabled={!isEnabled}
onClickReset={handleResetControlImage} onClickReset={handleResetControlImage}
resetTooltip="Reset Control Image" resetTooltip="Reset Control Image"
withResetIcon={Boolean(controlImage)} withResetIcon={Boolean(controlImage)}

View File

@ -1,10 +1,13 @@
import { memo } from 'react'; import { createSelector } from '@reduxjs/toolkit';
import { RequiredControlNetProcessorNode } from '../store/types'; import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { memo, useMemo } from 'react';
import CannyProcessor from './processors/CannyProcessor'; import CannyProcessor from './processors/CannyProcessor';
import HedProcessor from './processors/HedProcessor';
import LineartProcessor from './processors/LineartProcessor';
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
import ContentShuffleProcessor from './processors/ContentShuffleProcessor'; import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
import HedProcessor from './processors/HedProcessor';
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
import LineartProcessor from './processors/LineartProcessor';
import MediapipeFaceProcessor from './processors/MediapipeFaceProcessor'; import MediapipeFaceProcessor from './processors/MediapipeFaceProcessor';
import MidasDepthProcessor from './processors/MidasDepthProcessor'; import MidasDepthProcessor from './processors/MidasDepthProcessor';
import MlsdImageProcessor from './processors/MlsdImageProcessor'; import MlsdImageProcessor from './processors/MlsdImageProcessor';
@ -15,23 +18,45 @@ import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
export type ControlNetProcessorProps = { export type ControlNetProcessorProps = {
controlNetId: string; controlNetId: string;
processorNode: RequiredControlNetProcessorNode;
}; };
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => { const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
const { controlNetId, processorNode } = props; const { controlNetId } = props;
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, processorNode } =
controlNet.controlNets[controlNetId];
return { isEnabled, processorNode };
},
defaultSelectorOptions
),
[controlNetId]
);
const { isEnabled, processorNode } = useAppSelector(selector);
if (processorNode.type === 'canny_image_processor') { if (processorNode.type === 'canny_image_processor') {
return ( return (
<CannyProcessor <CannyProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
if (processorNode.type === 'hed_image_processor') { if (processorNode.type === 'hed_image_processor') {
return ( return (
<HedProcessor controlNetId={controlNetId} processorNode={processorNode} /> <HedProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
); );
} }
@ -40,6 +65,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<LineartProcessor <LineartProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -49,6 +75,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<ContentShuffleProcessor <ContentShuffleProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -58,6 +85,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<LineartAnimeProcessor <LineartAnimeProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -67,6 +95,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MediapipeFaceProcessor <MediapipeFaceProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -76,6 +105,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MidasDepthProcessor <MidasDepthProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -85,6 +115,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MlsdImageProcessor <MlsdImageProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -94,6 +125,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<NormalBaeProcessor <NormalBaeProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -103,6 +135,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<OpenposeProcessor <OpenposeProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -112,6 +145,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<PidiProcessor <PidiProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -121,6 +155,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<ZoeDepthProcessor <ZoeDepthProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }

View File

@ -1,18 +1,36 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice'; import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback, useMemo } from 'react';
type Props = { type Props = {
controlNetId: string; controlNetId: string;
shouldAutoConfig: boolean;
}; };
const ParamControlNetShouldAutoConfig = (props: Props) => { const ParamControlNetShouldAutoConfig = (props: Props) => {
const { controlNetId, shouldAutoConfig } = props; const { controlNetId } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke(); const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, shouldAutoConfig } =
controlNet.controlNets[controlNetId];
return { isEnabled, shouldAutoConfig };
},
defaultSelectorOptions
),
[controlNetId]
);
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
const isBusy = useAppSelector(selectIsBusy);
const handleShouldAutoConfigChanged = useCallback(() => { const handleShouldAutoConfigChanged = useCallback(() => {
dispatch(controlNetAutoConfigToggled({ controlNetId })); dispatch(controlNetAutoConfigToggled({ controlNetId }));
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
@ -23,7 +41,7 @@ const ParamControlNetShouldAutoConfig = (props: Props) => {
aria-label="Auto configure processor" aria-label="Auto configure processor"
isChecked={shouldAutoConfig} isChecked={shouldAutoConfig}
onChange={handleShouldAutoConfigChanged} onChange={handleShouldAutoConfigChanged}
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
); );
}; };

View File

@ -1,5 +1,4 @@
import { import {
ChakraProps,
FormControl, FormControl,
FormLabel, FormLabel,
HStack, HStack,
@ -10,34 +9,41 @@ import {
RangeSliderTrack, RangeSliderTrack,
Tooltip, Tooltip,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { import {
controlNetBeginStepPctChanged, controlNetBeginStepPctChanged,
controlNetEndStepPctChanged, controlNetEndStepPctChanged,
} from 'features/controlNet/store/controlNetSlice'; } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
fontWeight: '500',
color: 'base.400',
};
type Props = { type Props = {
controlNetId: string; controlNetId: string;
beginStepPct: number;
endStepPct: number;
mini?: boolean;
}; };
const formatPct = (v: number) => `${Math.round(v * 100)}%`; const formatPct = (v: number) => `${Math.round(v * 100)}%`;
const ParamControlNetBeginEnd = (props: Props) => { const ParamControlNetBeginEnd = (props: Props) => {
const { controlNetId, beginStepPct, mini = false, endStepPct } = props; const { controlNetId } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { beginStepPct, endStepPct, isEnabled } =
controlNet.controlNets[controlNetId];
return { beginStepPct, endStepPct, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { beginStepPct, endStepPct, isEnabled } = useAppSelector(selector);
const handleStepPctChanged = useCallback( const handleStepPctChanged = useCallback(
(v: number[]) => { (v: number[]) => {
@ -55,7 +61,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
return ( return (
<FormControl> <FormControl isDisabled={!isEnabled}>
<FormLabel>Begin / End Step Percentage</FormLabel> <FormLabel>Begin / End Step Percentage</FormLabel>
<HStack w="100%" gap={2} alignItems="center"> <HStack w="100%" gap={2} alignItems="center">
<RangeSlider <RangeSlider
@ -66,6 +72,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
max={1} max={1}
step={0.01} step={0.01}
minStepsBetweenThumbs={5} minStepsBetweenThumbs={5}
isDisabled={!isEnabled}
> >
<RangeSliderTrack> <RangeSliderTrack>
<RangeSliderFilledTrack /> <RangeSliderFilledTrack />
@ -76,14 +83,11 @@ const ParamControlNetBeginEnd = (props: Props) => {
<Tooltip label={formatPct(endStepPct)} placement="top" hasArrow> <Tooltip label={formatPct(endStepPct)} placement="top" hasArrow>
<RangeSliderThumb index={1} /> <RangeSliderThumb index={1} />
</Tooltip> </Tooltip>
{!mini && (
<>
<RangeSliderMark <RangeSliderMark
value={0} value={0}
sx={{ sx={{
insetInlineStart: '0 !important', insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important', insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}} }}
> >
0% 0%
@ -91,7 +95,8 @@ const ParamControlNetBeginEnd = (props: Props) => {
<RangeSliderMark <RangeSliderMark
value={0.5} value={0.5}
sx={{ sx={{
...SLIDER_MARK_STYLES, insetInlineStart: '50% !important',
transform: 'translateX(-50%)',
}} }}
> >
50% 50%
@ -101,13 +106,10 @@ const ParamControlNetBeginEnd = (props: Props) => {
sx={{ sx={{
insetInlineStart: 'unset !important', insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important', insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}} }}
> >
100% 100%
</RangeSliderMark> </RangeSliderMark>
</>
)}
</RangeSlider> </RangeSlider>
</HStack> </HStack>
</FormControl> </FormControl>

View File

@ -1,15 +1,17 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { import {
ControlModes, ControlModes,
controlNetControlModeChanged, controlNetControlModeChanged,
} from 'features/controlNet/store/controlNetSlice'; } from 'features/controlNet/store/controlNetSlice';
import { useCallback } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
type ParamControlNetControlModeProps = { type ParamControlNetControlModeProps = {
controlNetId: string; controlNetId: string;
controlMode: string;
}; };
const CONTROL_MODE_DATA = [ const CONTROL_MODE_DATA = [
@ -22,8 +24,23 @@ const CONTROL_MODE_DATA = [
export default function ParamControlNetControlMode( export default function ParamControlNetControlMode(
props: ParamControlNetControlModeProps props: ParamControlNetControlModeProps
) { ) {
const { controlNetId, controlMode = false } = props; const { controlNetId } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { controlMode, isEnabled } =
controlNet.controlNets[controlNetId];
return { controlMode, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { controlMode, isEnabled } = useAppSelector(selector);
const { t } = useTranslation(); const { t } = useTranslation();
@ -36,7 +53,8 @@ export default function ParamControlNetControlMode(
return ( return (
<IAIMantineSelect <IAIMantineSelect
label={t('parameters.controlNetControlMode')} disabled={!isEnabled}
label="Control Mode"
data={CONTROL_MODE_DATA} data={CONTROL_MODE_DATA}
value={String(controlMode)} value={String(controlMode)}
onChange={handleControlModeChange} onChange={handleControlModeChange}

View File

@ -29,6 +29,9 @@ const ParamControlNetFeatureToggle = () => {
label="Enable ControlNet" label="Enable ControlNet"
isChecked={isEnabled} isChecked={isEnabled}
onChange={handleChange} onChange={handleChange}
formControlProps={{
width: '100%',
}}
/> />
); );
}; };

View File

@ -1,28 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { controlNetToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetIsEnabledProps = {
controlNetId: string;
isEnabled: boolean;
};
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
const { controlNetId, isEnabled } = props;
const dispatch = useAppDispatch();
const handleIsEnabledChanged = useCallback(() => {
dispatch(controlNetToggled({ controlNetId }));
}, [dispatch, controlNetId]);
return (
<IAISwitch
label="Enabled"
isChecked={isEnabled}
onChange={handleIsEnabledChanged}
/>
);
};
export default memo(ParamControlNetIsEnabled);

View File

@ -1,36 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIFullCheckbox from 'common/components/IAIFullCheckbox';
import IAISwitch from 'common/components/IAISwitch';
import {
controlNetToggled,
isControlNetImagePreprocessedToggled,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetIsEnabledProps = {
controlNetId: string;
isControlImageProcessed: boolean;
};
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
const { controlNetId, isControlImageProcessed } = props;
const dispatch = useAppDispatch();
const handleIsControlImageProcessedToggled = useCallback(() => {
dispatch(
isControlNetImagePreprocessedToggled({
controlNetId,
})
);
}, [controlNetId, dispatch]);
return (
<IAISwitch
label="Preprocess"
isChecked={isControlImageProcessed}
onChange={handleIsControlImageProcessedToggled}
/>
);
};
export default memo(ParamControlNetIsEnabled);

View File

@ -1,59 +1,118 @@
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSearchableSelect, { import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
IAISelectDataType, import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
} from 'common/components/IAIMantineSearchableSelect'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import {
CONTROLNET_MODELS,
ControlNetModelName,
} from 'features/controlNet/store/constants';
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import { configSelector } from 'features/system/store/configSelectors'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { map } from 'lodash-es'; import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { memo, useCallback } from 'react'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
type ParamControlNetModelProps = { type ParamControlNetModelProps = {
controlNetId: string; controlNetId: string;
model: ControlNetModelName;
}; };
const selector = createSelector(configSelector, (config) => { const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const controlNetModels: IAISelectDataType[] = map(CONTROLNET_MODELS, (m) => ({ const { controlNetId } = props;
label: m.label, const dispatch = useAppDispatch();
value: m.type, const isBusy = useAppSelector(selectIsBusy);
})).filter(
(d) => const selector = useMemo(
!config.sd.disabledControlNetModels.includes( () =>
d.value as ControlNetModelName createSelector(
) stateSelector,
({ generation, controlNet }) => {
const { model } = generation;
const controlNetModel = controlNet.controlNets[controlNetId]?.model;
const isEnabled = controlNet.controlNets[controlNetId]?.isEnabled;
return { mainModel: model, controlNetModel, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
); );
return controlNetModels; const { mainModel, controlNetModel, isEnabled } = useAppSelector(selector);
});
const ParamControlNetModel = (props: ParamControlNetModelProps) => { const { data: controlNetModels } = useGetControlNetModelsQuery();
const { controlNetId, model } = props;
const controlNetModels = useAppSelector(selector); const data = useMemo(() => {
const dispatch = useAppDispatch(); if (!controlNetModels) {
const isReady = useIsReadyToInvoke(); return [];
}
const data: SelectItem[] = [];
forEach(controlNetModels.entities, (model, id) => {
if (!model) {
return;
}
const disabled = model?.base_model !== mainModel?.base_model;
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
disabled,
tooltip: disabled
? `Incompatible base model: ${model.base_model}`
: undefined,
});
});
return data;
}, [controlNetModels, mainModel?.base_model]);
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
controlNetModels?.entities[
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
] ?? null,
[
controlNetModel?.base_model,
controlNetModel?.model_name,
controlNetModels?.entities,
]
);
const handleModelChanged = useCallback( const handleModelChanged = useCallback(
(val: string | null) => { (v: string | null) => {
// TODO: do not cast if (!v) {
const model = val as ControlNetModelName; return;
dispatch(controlNetModelChanged({ controlNetId, model })); }
const newControlNetModel = modelIdToControlNetModelParam(v);
if (!newControlNetModel) {
return;
}
dispatch(
controlNetModelChanged({ controlNetId, model: newControlNetModel })
);
}, },
[controlNetId, dispatch] [controlNetId, dispatch]
); );
return ( return (
<IAIMantineSearchableSelect <IAIMantineSearchableSelect
data={controlNetModels} itemComponent={IAIMantineSelectItemWithTooltip}
value={model} data={data}
error={
!selectedModel || mainModel?.base_model !== selectedModel.base_model
}
placeholder="Select a model"
value={selectedModel?.id ?? null}
onChange={handleModelChanged} onChange={handleModelChanged}
disabled={!isReady} disabled={isBusy || !isEnabled}
tooltip={model} tooltip={selectedModel?.description}
/> />
); );
}; };

View File

@ -1,24 +1,22 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect, { import IAIMantineSearchableSelect, {
IAISelectDataType, IAISelectDataType,
} from 'common/components/IAIMantineSearchableSelect'; } from 'common/components/IAIMantineSearchableSelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { configSelector } from 'features/system/store/configSelectors'; import { configSelector } from 'features/system/store/configSelectors';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { memo, useCallback } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { CONTROLNET_PROCESSORS } from '../../store/constants'; import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice'; import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
import { import { ControlNetProcessorType } from '../../store/types';
ControlNetProcessorNode, import { FormControl, FormLabel } from '@chakra-ui/react';
ControlNetProcessorType,
} from '../../store/types';
type ParamControlNetProcessorSelectProps = { type ParamControlNetProcessorSelectProps = {
controlNetId: string; controlNetId: string;
processorNode: ControlNetProcessorNode;
}; };
const selector = createSelector( const selector = createSelector(
@ -54,10 +52,24 @@ const selector = createSelector(
const ParamControlNetProcessorSelect = ( const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps props: ParamControlNetProcessorSelectProps
) => { ) => {
const { controlNetId, processorNode } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke(); const { controlNetId } = props;
const processorNodeSelector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, processorNode } =
controlNet.controlNets[controlNetId];
return { isEnabled, processorNode };
},
defaultSelectorOptions
),
[controlNetId]
);
const isBusy = useAppSelector(selectIsBusy);
const controlNetProcessors = useAppSelector(selector); const controlNetProcessors = useAppSelector(selector);
const { isEnabled, processorNode } = useAppSelector(processorNodeSelector);
const handleProcessorTypeChanged = useCallback( const handleProcessorTypeChanged = useCallback(
(v: string | null) => { (v: string | null) => {
@ -77,7 +89,7 @@ const ParamControlNetProcessorSelect = (
value={processorNode.type ?? 'canny_image_processor'} value={processorNode.type ?? 'canny_image_processor'}
data={controlNetProcessors} data={controlNetProcessors}
onChange={handleProcessorTypeChanged} onChange={handleProcessorTypeChanged}
disabled={!isReady} disabled={isBusy || !isEnabled}
/> />
); );
}; };

View File

@ -1,18 +1,32 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback, useMemo } from 'react';
type ParamControlNetWeightProps = { type ParamControlNetWeightProps = {
controlNetId: string; controlNetId: string;
weight: number;
mini?: boolean;
}; };
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => { const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
const { controlNetId, weight, mini = false } = props; const { controlNetId } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { weight, isEnabled } = controlNet.controlNets[controlNetId];
return { weight, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { weight, isEnabled } = useAppSelector(selector);
const handleWeightChanged = useCallback( const handleWeightChanged = useCallback(
(weight: number) => { (weight: number) => {
dispatch(controlNetWeightChanged({ controlNetId, weight })); dispatch(controlNetWeightChanged({ controlNetId, weight }));
@ -22,15 +36,15 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
return ( return (
<IAISlider <IAISlider
isDisabled={!isEnabled}
label={'Weight'} label={'Weight'}
sliderFormLabelProps={{ pb: 2 }}
value={weight} value={weight}
onChange={handleWeightChanged} onChange={handleWeightChanged}
min={-1} min={0}
max={1} max={2}
step={0.01} step={0.01}
withSliderMarks={!mini} withSliderMarks
sliderMarks={[-1, 0, 1]} sliderMarks={[0, 1, 2]}
/> />
); );
}; };

View File

@ -1,22 +1,25 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation;
type CannyProcessorProps = { type CannyProcessorProps = {
controlNetId: string; controlNetId: string;
processorNode: RequiredCannyImageProcessorInvocation; processorNode: RequiredCannyImageProcessorInvocation;
isEnabled: boolean;
}; };
const CannyProcessor = (props: CannyProcessorProps) => { const CannyProcessor = (props: CannyProcessorProps) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { low_threshold, high_threshold } = processorNode; const { low_threshold, high_threshold } = processorNode;
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const handleLowThresholdChanged = useCallback( const handleLowThresholdChanged = useCallback(
@ -48,7 +51,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
return ( return (
<ProcessorWrapper> <ProcessorWrapper>
<IAISlider <IAISlider
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
label="Low Threshold" label="Low Threshold"
value={low_threshold} value={low_threshold}
onChange={handleLowThresholdChanged} onChange={handleLowThresholdChanged}
@ -60,7 +63,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
withSliderMarks withSliderMarks
/> />
<IAISlider <IAISlider
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
label="High Threshold" label="High Threshold"
value={high_threshold} value={high_threshold}
onChange={handleHighThresholdChanged} onChange={handleHighThresholdChanged}

View File

@ -4,20 +4,23 @@ import { RequiredContentShuffleImageProcessorInvocation } from 'features/control
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; import { useAppSelector } from 'app/store/storeHooks';
import { selectIsBusy } from 'features/system/store/systemSelectors';
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor
.default as RequiredContentShuffleImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredContentShuffleImageProcessorInvocation; processorNode: RequiredContentShuffleImageProcessorInvocation;
isEnabled: boolean;
}; };
const ContentShuffleProcessor = (props: Props) => { const ContentShuffleProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, w, h, f } = processorNode; const { image_resolution, detect_resolution, w, h, f } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -96,7 +99,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -108,7 +111,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="W" label="W"
@ -120,7 +123,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="H" label="H"
@ -132,7 +135,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="F" label="F"
@ -144,7 +147,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,25 +1,29 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor
.default as RequiredHedImageProcessorInvocation;
type HedProcessorProps = { type HedProcessorProps = {
controlNetId: string; controlNetId: string;
processorNode: RequiredHedImageProcessorInvocation; processorNode: RequiredHedImageProcessorInvocation;
isEnabled: boolean;
}; };
const HedPreprocessor = (props: HedProcessorProps) => { const HedPreprocessor = (props: HedProcessorProps) => {
const { const {
controlNetId, controlNetId,
processorNode: { detect_resolution, image_resolution, scribble }, processorNode: { detect_resolution, image_resolution, scribble },
isEnabled,
} = props; } = props;
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
@ -67,7 +71,7 @@ const HedPreprocessor = (props: HedProcessorProps) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -79,13 +83,13 @@ const HedPreprocessor = (props: HedProcessorProps) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISwitch <IAISwitch
label="Scribble" label="Scribble"
isChecked={scribble} isChecked={scribble}
onChange={handleScribbleChanged} onChange={handleScribbleChanged}
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor
.default as RequiredLineartAnimeImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredLineartAnimeImageProcessorInvocation; processorNode: RequiredLineartAnimeImageProcessorInvocation;
isEnabled: boolean;
}; };
const LineartAnimeProcessor = (props: Props) => { const LineartAnimeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution } = processorNode; const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -57,7 +60,7 @@ const LineartAnimeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -69,7 +72,7 @@ const LineartAnimeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor
.default as RequiredLineartImageProcessorInvocation;
type LineartProcessorProps = { type LineartProcessorProps = {
controlNetId: string; controlNetId: string;
processorNode: RequiredLineartImageProcessorInvocation; processorNode: RequiredLineartImageProcessorInvocation;
isEnabled: boolean;
}; };
const LineartProcessor = (props: LineartProcessorProps) => { const LineartProcessor = (props: LineartProcessorProps) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, coarse } = processorNode; const { image_resolution, detect_resolution, coarse } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -65,7 +68,7 @@ const LineartProcessor = (props: LineartProcessorProps) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -77,13 +80,13 @@ const LineartProcessor = (props: LineartProcessorProps) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISwitch <IAISwitch
label="Coarse" label="Coarse"
isChecked={coarse} isChecked={coarse}
onChange={handleCoarseChanged} onChange={handleCoarseChanged}
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor
.default as RequiredMediapipeFaceProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredMediapipeFaceProcessorInvocation; processorNode: RequiredMediapipeFaceProcessorInvocation;
isEnabled: boolean;
}; };
const MediapipeFaceProcessor = (props: Props) => { const MediapipeFaceProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { max_faces, min_confidence } = processorNode; const { max_faces, min_confidence } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleMaxFacesChanged = useCallback( const handleMaxFacesChanged = useCallback(
(v: number) => { (v: number) => {
@ -53,7 +56,7 @@ const MediapipeFaceProcessor = (props: Props) => {
max={20} max={20}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Min Confidence" label="Min Confidence"
@ -66,7 +69,7 @@ const MediapipeFaceProcessor = (props: Props) => {
step={0.01} step={0.01}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor
.default as RequiredMidasDepthImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredMidasDepthImageProcessorInvocation; processorNode: RequiredMidasDepthImageProcessorInvocation;
isEnabled: boolean;
}; };
const MidasDepthProcessor = (props: Props) => { const MidasDepthProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { a_mult, bg_th } = processorNode; const { a_mult, bg_th } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleAMultChanged = useCallback( const handleAMultChanged = useCallback(
(v: number) => { (v: number) => {
@ -54,7 +57,7 @@ const MidasDepthProcessor = (props: Props) => {
step={0.01} step={0.01}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="bg_th" label="bg_th"
@ -67,7 +70,7 @@ const MidasDepthProcessor = (props: Props) => {
step={0.01} step={0.01}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor
.default as RequiredMlsdImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredMlsdImageProcessorInvocation; processorNode: RequiredMlsdImageProcessorInvocation;
isEnabled: boolean;
}; };
const MlsdImageProcessor = (props: Props) => { const MlsdImageProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode; const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -79,7 +82,7 @@ const MlsdImageProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -91,7 +94,7 @@ const MlsdImageProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="W" label="W"
@ -104,7 +107,7 @@ const MlsdImageProcessor = (props: Props) => {
step={0.01} step={0.01}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="H" label="H"
@ -117,7 +120,7 @@ const MlsdImageProcessor = (props: Props) => {
step={0.01} step={0.01}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor
.default as RequiredNormalbaeImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredNormalbaeImageProcessorInvocation; processorNode: RequiredNormalbaeImageProcessorInvocation;
isEnabled: boolean;
}; };
const NormalBaeProcessor = (props: Props) => { const NormalBaeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution } = processorNode; const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -57,7 +60,7 @@ const NormalBaeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -69,7 +72,7 @@ const NormalBaeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor
.default as RequiredOpenposeImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredOpenposeImageProcessorInvocation; processorNode: RequiredOpenposeImageProcessorInvocation;
isEnabled: boolean;
}; };
const OpenposeProcessor = (props: Props) => { const OpenposeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, hand_and_face } = processorNode; const { image_resolution, detect_resolution, hand_and_face } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -65,7 +68,7 @@ const OpenposeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -77,13 +80,13 @@ const OpenposeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISwitch <IAISwitch
label="Hand and Face" label="Hand and Face"
isChecked={hand_and_face} isChecked={hand_and_face}
onChange={handleHandAndFaceChanged} onChange={handleHandAndFaceChanged}
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor
.default as RequiredPidiImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredPidiImageProcessorInvocation; processorNode: RequiredPidiImageProcessorInvocation;
isEnabled: boolean;
}; };
const PidiProcessor = (props: Props) => { const PidiProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, scribble, safe } = processorNode; const { image_resolution, detect_resolution, scribble, safe } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -72,7 +75,7 @@ const PidiProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -84,7 +87,7 @@ const PidiProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISwitch <IAISwitch
label="Scribble" label="Scribble"
@ -95,7 +98,7 @@ const PidiProcessor = (props: Props) => {
label="Safe" label="Safe"
isChecked={safe} isChecked={safe}
onChange={handleSafeChanged} onChange={handleSafeChanged}
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -4,6 +4,7 @@ import { memo } from 'react';
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredZoeDepthImageProcessorInvocation; processorNode: RequiredZoeDepthImageProcessorInvocation;
isEnabled: boolean;
}; };
const ZoeDepthProcessor = (props: Props) => { const ZoeDepthProcessor = (props: Props) => {

View File

@ -173,91 +173,17 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
}, },
}; };
type ControlNetModelsDict = Record<string, ControlNetModel>; export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
[key: string]: ControlNetProcessorType;
type ControlNetModel = { } = {
type: string; canny: 'canny_image_processor',
label: string; mlsd: 'mlsd_image_processor',
description?: string; depth: 'midas_depth_image_processor',
defaultProcessor?: ControlNetProcessorType; bae: 'normalbae_image_processor',
lineart: 'lineart_image_processor',
lineart_anime: 'lineart_anime_image_processor',
softedge: 'hed_image_processor',
shuffle: 'content_shuffle_image_processor',
openpose: 'openpose_image_processor',
mediapipe: 'mediapipe_face_processor',
}; };
export const CONTROLNET_MODELS: ControlNetModelsDict = {
'lllyasviel/control_v11p_sd15_canny': {
type: 'lllyasviel/control_v11p_sd15_canny',
label: 'Canny',
defaultProcessor: 'canny_image_processor',
},
'lllyasviel/control_v11p_sd15_inpaint': {
type: 'lllyasviel/control_v11p_sd15_inpaint',
label: 'Inpaint',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_mlsd': {
type: 'lllyasviel/control_v11p_sd15_mlsd',
label: 'M-LSD',
defaultProcessor: 'mlsd_image_processor',
},
'lllyasviel/control_v11f1p_sd15_depth': {
type: 'lllyasviel/control_v11f1p_sd15_depth',
label: 'Depth',
defaultProcessor: 'midas_depth_image_processor',
},
'lllyasviel/control_v11p_sd15_normalbae': {
type: 'lllyasviel/control_v11p_sd15_normalbae',
label: 'Normal Map (BAE)',
defaultProcessor: 'normalbae_image_processor',
},
'lllyasviel/control_v11p_sd15_seg': {
type: 'lllyasviel/control_v11p_sd15_seg',
label: 'Segmentation',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_lineart': {
type: 'lllyasviel/control_v11p_sd15_lineart',
label: 'Lineart',
defaultProcessor: 'lineart_image_processor',
},
'lllyasviel/control_v11p_sd15s2_lineart_anime': {
type: 'lllyasviel/control_v11p_sd15s2_lineart_anime',
label: 'Lineart Anime',
defaultProcessor: 'lineart_anime_image_processor',
},
'lllyasviel/control_v11p_sd15_scribble': {
type: 'lllyasviel/control_v11p_sd15_scribble',
label: 'Scribble',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_softedge': {
type: 'lllyasviel/control_v11p_sd15_softedge',
label: 'Soft Edge',
defaultProcessor: 'hed_image_processor',
},
'lllyasviel/control_v11e_sd15_shuffle': {
type: 'lllyasviel/control_v11e_sd15_shuffle',
label: 'Content Shuffle',
defaultProcessor: 'content_shuffle_image_processor',
},
'lllyasviel/control_v11p_sd15_openpose': {
type: 'lllyasviel/control_v11p_sd15_openpose',
label: 'Openpose',
defaultProcessor: 'openpose_image_processor',
},
'lllyasviel/control_v11f1e_sd15_tile': {
type: 'lllyasviel/control_v11f1e_sd15_tile',
label: 'Tile (experimental)',
defaultProcessor: 'none',
},
'lllyasviel/control_v11e_sd15_ip2p': {
type: 'lllyasviel/control_v11e_sd15_ip2p',
label: 'Pix2Pix (experimental)',
defaultProcessor: 'none',
},
'CrucibleAI/ControlNetMediaPipeFace': {
type: 'CrucibleAI/ControlNetMediaPipeFace',
label: 'Mediapipe Face',
defaultProcessor: 'mediapipe_face_processor',
},
};
export type ControlNetModelName = keyof typeof CONTROLNET_MODELS;

View File

@ -1,22 +1,20 @@
import { PayloadAction } from '@reduxjs/toolkit'; import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { ImageDTO } from 'services/api/types'; import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
import { cloneDeep, forEach } from 'lodash-es';
import { imageDeleted } from 'services/api/thunks/image';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
CONTROLNET_PROCESSORS,
} from './constants';
import { import {
ControlNetProcessorType, ControlNetProcessorType,
RequiredCannyImageProcessorInvocation, RequiredCannyImageProcessorInvocation,
RequiredControlNetProcessorNode, RequiredControlNetProcessorNode,
} from './types'; } from './types';
import {
CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
ControlNetModelName,
} from './constants';
import { controlNetImageProcessed } from './actions';
import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image';
import { forEach } from 'lodash-es';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
export type ControlModes = export type ControlModes =
| 'balanced' | 'balanced'
@ -26,7 +24,7 @@ export type ControlModes =
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = { export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true, isEnabled: true,
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type, model: null,
weight: 1, weight: 1,
beginStepPct: 0, beginStepPct: 0,
endStepPct: 1, endStepPct: 1,
@ -42,7 +40,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
export type ControlNetConfig = { export type ControlNetConfig = {
controlNetId: string; controlNetId: string;
isEnabled: boolean; isEnabled: boolean;
model: ControlNetModelName; model: ControlNetModelParam | null;
weight: number; weight: number;
beginStepPct: number; beginStepPct: number;
endStepPct: number; endStepPct: number;
@ -86,6 +84,19 @@ export const controlNetSlice = createSlice({
controlNetId, controlNetId,
}; };
}, },
controlNetDuplicated: (
state,
action: PayloadAction<{
sourceControlNetId: string;
newControlNetId: string;
}>
) => {
const { sourceControlNetId, newControlNetId } = action.payload;
const newControlnet = cloneDeep(state.controlNets[sourceControlNetId]);
newControlnet.controlNetId = newControlNetId;
state.controlNets[newControlNetId] = newControlnet;
},
controlNetAddedFromImage: ( controlNetAddedFromImage: (
state, state,
action: PayloadAction<{ controlNetId: string; controlImage: string }> action: PayloadAction<{ controlNetId: string; controlImage: string }>
@ -147,7 +158,7 @@ export const controlNetSlice = createSlice({
state, state,
action: PayloadAction<{ action: PayloadAction<{
controlNetId: string; controlNetId: string;
model: ControlNetModelName; model: ControlNetModelParam;
}> }>
) => { ) => {
const { controlNetId, model } = action.payload; const { controlNetId, model } = action.payload;
@ -155,7 +166,15 @@ export const controlNetSlice = createSlice({
state.controlNets[controlNetId].processedControlImage = null; state.controlNets[controlNetId].processedControlImage = null;
if (state.controlNets[controlNetId].shouldAutoConfig) { if (state.controlNets[controlNetId].shouldAutoConfig) {
const processorType = CONTROLNET_MODELS[model].defaultProcessor; let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (model.model_name.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) { if (processorType) {
state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
@ -241,9 +260,19 @@ export const controlNetSlice = createSlice({
if (newShouldAutoConfig) { if (newShouldAutoConfig) {
// manage the processor for the user // manage the processor for the user
const processorType = let processorType: ControlNetProcessorType | undefined = undefined;
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
.defaultProcessor; for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (
state.controlNets[controlNetId].model?.model_name.includes(
modelSubstring
)
) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) { if (processorType) {
state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
@ -272,7 +301,8 @@ export const controlNetSlice = createSlice({
}); });
builder.addCase(imageDeleted.pending, (state, action) => { builder.addCase(imageDeleted.pending, (state, action) => {
// Preemptively remove the image from the gallery // Preemptively remove the image from all controlnets
// TODO: doesn't the imageusage stuff do this for us?
const { image_name } = action.meta.arg; const { image_name } = action.meta.arg;
forEach(state.controlNets, (c) => { forEach(state.controlNets, (c) => {
if (c.controlImage === image_name) { if (c.controlImage === image_name) {
@ -285,21 +315,6 @@ export const controlNetSlice = createSlice({
}); });
}); });
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload;
// forEach(state.controlNets, (c) => {
// if (c.controlImage?.image_name === image_name) {
// c.controlImage.image_url = image_url;
// c.controlImage.thumbnail_url = thumbnail_url;
// }
// if (c.processedControlImage?.image_name === image_name) {
// c.processedControlImage.image_url = image_url;
// c.processedControlImage.thumbnail_url = thumbnail_url;
// }
// });
// });
builder.addCase(appSocketInvocationError, (state, action) => { builder.addCase(appSocketInvocationError, (state, action) => {
state.pendingControlImages = []; state.pendingControlImages = [];
}); });
@ -313,6 +328,7 @@ export const controlNetSlice = createSlice({
export const { export const {
isControlNetEnabledToggled, isControlNetEnabledToggled,
controlNetAdded, controlNetAdded,
controlNetDuplicated,
controlNetAddedFromImage, controlNetAddedFromImage,
controlNetRemoved, controlNetRemoved,
controlNetImageChanged, controlNetImageChanged,

View File

@ -7,6 +7,7 @@ import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
import ColorInputFieldComponent from './fields/ColorInputFieldComponent'; import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent'; import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import ControlInputFieldComponent from './fields/ControlInputFieldComponent'; import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent';
import EnumInputFieldComponent from './fields/EnumInputFieldComponent'; import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent'; import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
@ -174,6 +175,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
); );
} }
if (type === 'controlnet_model' && template.type === 'controlnet_model') {
return (
<ControlNetModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'array' && template.type === 'array') { if (type === 'array' && template.type === 'array') {
return ( return (
<ArrayInputFieldComponent <ArrayInputFieldComponent

View File

@ -0,0 +1,103 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
ControlNetModelInputFieldTemplate,
ControlNetModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const ControlNetModelInputFieldComponent = (
props: FieldComponentProps<
ControlNetModelInputFieldValue,
ControlNetModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const controlNetModel = field.value;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: controlNetModels } = useGetControlNetModelsQuery();
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
controlNetModels?.entities[
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
] ?? null,
[
controlNetModel?.base_model,
controlNetModel?.model_name,
controlNetModels?.entities,
]
);
const data = useMemo(() => {
if (!controlNetModels) {
return [];
}
const data: SelectItem[] = [];
forEach(controlNetModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [controlNetModels]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newControlNetModel = modelIdToControlNetModelParam(v);
if (!newControlNetModel) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: newControlNetModel,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
}
value={selectedModel?.id ?? null}
placeholder="Pick one"
error={!selectedModel}
data={data}
onChange={handleValueChanged}
/>
);
};
export default memo(ControlNetModelInputFieldComponent);

View File

@ -1,6 +1,7 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit'; import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { import {
ControlNetModelParam,
LoRAModelParam, LoRAModelParam,
MainModelParam, MainModelParam,
VaeModelParam, VaeModelParam,
@ -81,7 +82,8 @@ const nodesSlice = createSlice({
| ImageField[] | ImageField[]
| MainModelParam | MainModelParam
| VaeModelParam | VaeModelParam
| LoRAModelParam; | LoRAModelParam
| ControlNetModelParam;
}> }>
) => { ) => {
const { nodeId, fieldName, value } = action.payload; const { nodeId, fieldName, value } = action.payload;

View File

@ -19,6 +19,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
model: 'model', model: 'model',
vae_model: 'vae_model', vae_model: 'vae_model',
lora_model: 'lora_model', lora_model: 'lora_model',
controlnet_model: 'controlnet_model',
ControlNetModelField: 'controlnet_model',
array: 'array', array: 'array',
item: 'item', item: 'item',
ColorField: 'color', ColorField: 'color',
@ -130,6 +132,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'LoRA', title: 'LoRA',
description: 'Models are models.', description: 'Models are models.',
}, },
controlnet_model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
title: 'ControlNet',
description: 'Models are models.',
},
array: { array: {
color: 'gray', color: 'gray',
colorCssVar: getColorTokenCssVariable('gray'), colorCssVar: getColorTokenCssVariable('gray'),

View File

@ -1,4 +1,5 @@
import { import {
ControlNetModelParam,
LoRAModelParam, LoRAModelParam,
MainModelParam, MainModelParam,
VaeModelParam, VaeModelParam,
@ -71,6 +72,7 @@ export type FieldType =
| 'model' | 'model'
| 'vae_model' | 'vae_model'
| 'lora_model' | 'lora_model'
| 'controlnet_model'
| 'array' | 'array'
| 'item' | 'item'
| 'color' | 'color'
@ -100,6 +102,7 @@ export type InputFieldValue =
| MainModelInputFieldValue | MainModelInputFieldValue
| VaeModelInputFieldValue | VaeModelInputFieldValue
| LoRAModelInputFieldValue | LoRAModelInputFieldValue
| ControlNetModelInputFieldValue
| ArrayInputFieldValue | ArrayInputFieldValue
| ItemInputFieldValue | ItemInputFieldValue
| ColorInputFieldValue | ColorInputFieldValue
@ -127,6 +130,7 @@ export type InputFieldTemplate =
| ModelInputFieldTemplate | ModelInputFieldTemplate
| VaeModelInputFieldTemplate | VaeModelInputFieldTemplate
| LoRAModelInputFieldTemplate | LoRAModelInputFieldTemplate
| ControlNetModelInputFieldTemplate
| ArrayInputFieldTemplate | ArrayInputFieldTemplate
| ItemInputFieldTemplate | ItemInputFieldTemplate
| ColorInputFieldTemplate | ColorInputFieldTemplate
@ -249,6 +253,11 @@ export type LoRAModelInputFieldValue = FieldValueBase & {
value?: LoRAModelParam; value?: LoRAModelParam;
}; };
export type ControlNetModelInputFieldValue = FieldValueBase & {
type: 'controlnet_model';
value?: ControlNetModelParam;
};
export type ArrayInputFieldValue = FieldValueBase & { export type ArrayInputFieldValue = FieldValueBase & {
type: 'array'; type: 'array';
value?: (string | number)[]; value?: (string | number)[];
@ -368,6 +377,11 @@ export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'lora_model'; type: 'lora_model';
}; };
export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'controlnet_model';
};
export type ArrayInputFieldTemplate = InputFieldTemplateBase & { export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
default: []; default: [];
type: 'array'; type: 'array';

View File

@ -9,6 +9,7 @@ import {
ColorInputFieldTemplate, ColorInputFieldTemplate,
ConditioningInputFieldTemplate, ConditioningInputFieldTemplate,
ControlInputFieldTemplate, ControlInputFieldTemplate,
ControlNetModelInputFieldTemplate,
EnumInputFieldTemplate, EnumInputFieldTemplate,
FieldType, FieldType,
FloatInputFieldTemplate, FloatInputFieldTemplate,
@ -207,6 +208,21 @@ const buildLoRAModelInputFieldTemplate = ({
return template; return template;
}; };
const buildControlNetModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ControlNetModelInputFieldTemplate => {
const template: ControlNetModelInputFieldTemplate = {
...baseField,
type: 'controlnet_model',
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageInputFieldTemplate = ({ const buildImageInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -479,6 +495,9 @@ export const buildInputFieldTemplate = (
if (['lora_model'].includes(fieldType)) { if (['lora_model'].includes(fieldType)) {
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField }); return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
} }
if (['controlnet_model'].includes(fieldType)) {
return buildControlNetModelInputFieldTemplate({ schemaObject, baseField });
}
if (['enum'].includes(fieldType)) { if (['enum'].includes(fieldType)) {
return buildEnumInputFieldTemplate({ schemaObject, baseField }); return buildEnumInputFieldTemplate({ schemaObject, baseField });
} }

View File

@ -83,6 +83,10 @@ export const buildInputFieldValue = (
if (template.type === 'lora_model') { if (template.type === 'lora_model') {
fieldValue.value = undefined; fieldValue.value = undefined;
} }
if (template.type === 'controlnet_model') {
fieldValue.value = undefined;
}
} }
return fieldValue; return fieldValue;

View File

@ -2,12 +2,13 @@ import { Divider, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
import IAIIconButton from 'common/components/IAIIconButton';
import ControlNet from 'features/controlNet/components/ControlNet'; import ControlNet from 'features/controlNet/components/ControlNet';
import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle'; import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle';
import { import {
controlNetAdded, controlNetAdded,
controlNetModelChanged,
controlNetSelector, controlNetSelector,
} from 'features/controlNet/store/controlNetSlice'; } from 'features/controlNet/store/controlNetSlice';
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets'; import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
@ -15,6 +16,8 @@ import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { Fragment, memo, useCallback } from 'react'; import { Fragment, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaPlus } from 'react-icons/fa';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
const selector = createSelector( const selector = createSelector(
@ -39,10 +42,23 @@ const ParamControlNetCollapse = () => {
const { controlNetsArray, activeLabel } = useAppSelector(selector); const { controlNetsArray, activeLabel } = useAppSelector(selector);
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled; const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { firstModel } = useGetControlNetModelsQuery(undefined, {
selectFromResult: (result) => {
const firstModel = result.data?.entities[result.data?.ids[0]];
return {
firstModel,
};
},
});
const handleClickedAddControlNet = useCallback(() => { const handleClickedAddControlNet = useCallback(() => {
dispatch(controlNetAdded({ controlNetId: uuidv4() })); if (!firstModel) {
}, [dispatch]); return;
}
const controlNetId = uuidv4();
dispatch(controlNetAdded({ controlNetId }));
dispatch(controlNetModelChanged({ controlNetId, model: firstModel }));
}, [dispatch, firstModel]);
if (isControlNetDisabled) { if (isControlNetDisabled) {
return null; return null;
@ -51,16 +67,39 @@ const ParamControlNetCollapse = () => {
return ( return (
<IAICollapse label="ControlNet" activeLabel={activeLabel}> <IAICollapse label="ControlNet" activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 3 }}> <Flex sx={{ flexDir: 'column', gap: 3 }}>
<Flex gap={2} alignItems="center">
<Flex
sx={{
flexDirection: 'column',
w: '100%',
gap: 2,
px: 4,
py: 2,
borderRadius: 4,
bg: 'base.200',
_dark: {
bg: 'base.850',
},
}}
>
<ParamControlNetFeatureToggle /> <ParamControlNetFeatureToggle />
</Flex>
<IAIIconButton
tooltip="Add ControlNet"
aria-label="Add ControlNet"
icon={<FaPlus />}
isDisabled={!firstModel}
flexGrow={1}
size="md"
onClick={handleClickedAddControlNet}
/>
</Flex>
{controlNetsArray.map((c, i) => ( {controlNetsArray.map((c, i) => (
<Fragment key={c.controlNetId}> <Fragment key={c.controlNetId}>
{i > 0 && <Divider />} {i > 0 && <Divider />}
<ControlNet controlNet={c} /> <ControlNet controlNetId={c.controlNetId} />
</Fragment> </Fragment>
))} ))}
<IAIButton flexGrow={1} onClick={handleClickedAddControlNet}>
Add ControlNet
</IAIButton>
</Flex> </Flex>
</IAICollapse> </IAICollapse>
); );

View File

@ -37,6 +37,7 @@ const ParamVAEModelSelect = () => {
return []; return [];
} }
// add a "default" option, this means use the main model's included VAE
const data: SelectItem[] = [ const data: SelectItem[] = [
{ {
value: 'default', value: 'default',

View File

@ -17,8 +17,10 @@ import { FaPlay } from 'react-icons/fa';
const IN_PROGRESS_STYLES: ChakraProps['sx'] = { const IN_PROGRESS_STYLES: ChakraProps['sx'] = {
_disabled: { _disabled: {
bg: 'none', bg: 'none',
color: 'base.600',
cursor: 'not-allowed', cursor: 'not-allowed',
_hover: { _hover: {
color: 'base.600',
bg: 'none', bg: 'none',
}, },
}, },

View File

@ -180,6 +180,23 @@ export type LoRAModelParam = z.infer<typeof zLoRAModel>;
*/ */
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam => export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
zLoRAModel.safeParse(val).success; zLoRAModel.safeParse(val).success;
/**
* Zod schema for ControlNet models
*/
export const zControlNetModel = z.object({
model_name: z.string().min(1),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type ControlNetModelParam = z.infer<typeof zLoRAModel>;
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidControlNetModel = (
val: unknown
): val is ControlNetModelParam => zControlNetModel.safeParse(val).success;
/** /**
* Zod schema for l2l strength parameter * Zod schema for l2l strength parameter

View File

@ -0,0 +1,30 @@
import { log } from 'app/logging/useLogger';
import { zControlNetModel } from 'features/parameters/types/parameterSchemas';
import { ControlNetModelField } from 'services/api/types';
const moduleLog = log.child({ module: 'models' });
export const modelIdToControlNetModelParam = (
controlNetModelId: string
): ControlNetModelField | undefined => {
const [base_model, model_type, model_name] = controlNetModelId.split('/');
const result = zControlNetModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
moduleLog.error(
{
controlNetModelId,
errors: result.error.format(),
},
'Failed to parse ControlNet model id'
);
return;
}
return result.data;
};

View File

@ -1,9 +1,12 @@
import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas'; import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToLoRAModelParam = ( export const modelIdToLoRAModelParam = (
loraId: string loraModelId: string
): LoRAModelParam | undefined => { ): LoRAModelParam | undefined => {
const [base_model, model_type, model_name] = loraId.split('/'); const [base_model, model_type, model_name] = loraModelId.split('/');
const result = zLoRAModel.safeParse({ const result = zLoRAModel.safeParse({
base_model, base_model,
@ -11,6 +14,13 @@ export const modelIdToLoRAModelParam = (
}); });
if (!result.success) { if (!result.success) {
moduleLog.error(
{
loraModelId,
errors: result.error.format(),
},
'Failed to parse LoRA model id'
);
return; return;
} }

View File

@ -2,11 +2,14 @@ import {
MainModelParam, MainModelParam,
zMainModel, zMainModel,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToMainModelParam = ( export const modelIdToMainModelParam = (
modelId: string mainModelId: string
): MainModelParam | undefined => { ): MainModelParam | undefined => {
const [base_model, model_type, model_name] = modelId.split('/'); const [base_model, model_type, model_name] = mainModelId.split('/');
const result = zMainModel.safeParse({ const result = zMainModel.safeParse({
base_model, base_model,
@ -14,6 +17,13 @@ export const modelIdToMainModelParam = (
}); });
if (!result.success) { if (!result.success) {
moduleLog.error(
{
mainModelId,
errors: result.error.format(),
},
'Failed to parse main model id'
);
return; return;
} }

View File

@ -1,9 +1,12 @@
import { VaeModelParam, zVaeModel } from '../types/parameterSchemas'; import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToVAEModelParam = ( export const modelIdToVAEModelParam = (
modelId: string vaeModelId: string
): VaeModelParam | undefined => { ): VaeModelParam | undefined => {
const [base_model, model_type, model_name] = modelId.split('/'); const [base_model, model_type, model_name] = vaeModelId.split('/');
const result = zVaeModel.safeParse({ const result = zVaeModel.safeParse({
base_model, base_model,
@ -11,6 +14,13 @@ export const modelIdToVAEModelParam = (
}); });
if (!result.success) { if (!result.success) {
moduleLog.error(
{
vaeModelId,
errors: result.error.format(),
},
'Failed to parse VAE model id'
);
return; return;
} }

View File

@ -19,9 +19,9 @@ const ImageToImageTabParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<ImageToImageTabCoreParameters /> <ImageToImageTabCoreParameters />
<ParamControlNetCollapse />
<ParamLoraCollapse /> <ParamLoraCollapse />
<ParamDynamicPromptsCollapse /> <ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamNoiseCollapse /> <ParamNoiseCollapse />
<ParamSymmetryCollapse /> <ParamSymmetryCollapse />

View File

@ -20,9 +20,9 @@ const TextToImageTabParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<TextToImageTabCoreParameters /> <TextToImageTabCoreParameters />
<ParamControlNetCollapse />
<ParamLoraCollapse /> <ParamLoraCollapse />
<ParamDynamicPromptsCollapse /> <ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamNoiseCollapse /> <ParamNoiseCollapse />
<ParamSymmetryCollapse /> <ParamSymmetryCollapse />

View File

@ -19,9 +19,9 @@ const UnifiedCanvasParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<UnifiedCanvasCoreParameters /> <UnifiedCanvasCoreParameters />
<ParamControlNetCollapse />
<ParamLoraCollapse /> <ParamLoraCollapse />
<ParamDynamicPromptsCollapse /> <ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamSymmetryCollapse /> <ParamSymmetryCollapse />
<ParamSeamCorrectionCollapse /> <ParamSeamCorrectionCollapse />

View File

@ -734,7 +734,7 @@ export type components = {
* Control Model * Control Model
* @description The ControlNet model to use * @description The ControlNet model to use
*/ */
control_model: string; control_model: components["schemas"]["ControlNetModelField"];
/** /**
* Control Weight * Control Weight
* @description The weight given to the ControlNet * @description The weight given to the ControlNet
@ -792,9 +792,8 @@ export type components = {
* Control Model * Control Model
* @description control model used * @description control model used
* @default lllyasviel/sd-controlnet-canny * @default lllyasviel/sd-controlnet-canny
* @enum {string}
*/ */
control_model?: "lllyasviel/sd-controlnet-canny" | "lllyasviel/sd-controlnet-depth" | "lllyasviel/sd-controlnet-hed" | "lllyasviel/sd-controlnet-seg" | "lllyasviel/sd-controlnet-openpose" | "lllyasviel/sd-controlnet-scribble" | "lllyasviel/sd-controlnet-normal" | "lllyasviel/sd-controlnet-mlsd" | "lllyasviel/control_v11p_sd15_canny" | "lllyasviel/control_v11p_sd15_openpose" | "lllyasviel/control_v11p_sd15_seg" | "lllyasviel/control_v11f1p_sd15_depth" | "lllyasviel/control_v11p_sd15_normalbae" | "lllyasviel/control_v11p_sd15_scribble" | "lllyasviel/control_v11p_sd15_mlsd" | "lllyasviel/control_v11p_sd15_softedge" | "lllyasviel/control_v11p_sd15s2_lineart_anime" | "lllyasviel/control_v11p_sd15_lineart" | "lllyasviel/control_v11p_sd15_inpaint" | "lllyasviel/control_v11e_sd15_shuffle" | "lllyasviel/control_v11e_sd15_ip2p" | "lllyasviel/control_v11f1e_sd15_tile" | "thibaud/controlnet-sd21-openpose-diffusers" | "thibaud/controlnet-sd21-canny-diffusers" | "thibaud/controlnet-sd21-depth-diffusers" | "thibaud/controlnet-sd21-scribble-diffusers" | "thibaud/controlnet-sd21-hed-diffusers" | "thibaud/controlnet-sd21-zoedepth-diffusers" | "thibaud/controlnet-sd21-color-diffusers" | "thibaud/controlnet-sd21-openposev2-diffusers" | "thibaud/controlnet-sd21-lineart-diffusers" | "thibaud/controlnet-sd21-normalbae-diffusers" | "thibaud/controlnet-sd21-ade20k-diffusers" | "CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15" | "CrucibleAI/ControlNetMediaPipeFace"; control_model?: components["schemas"]["ControlNetModelField"];
/** /**
* Control Weight * Control Weight
* @description The weight given to the ControlNet * @description The weight given to the ControlNet
@ -838,6 +837,19 @@ export type components = {
model_format: components["schemas"]["ControlNetModelFormat"]; model_format: components["schemas"]["ControlNetModelFormat"];
error?: components["schemas"]["ModelError"]; error?: components["schemas"]["ModelError"];
}; };
/**
* ControlNetModelField
* @description ControlNet model field
*/
ControlNetModelField: {
/**
* Model Name
* @description Name of the ControlNet model
*/
model_name: string;
/** @description Base model */
base_model: components["schemas"]["BaseModelType"];
};
/** /**
* ControlNetModelFormat * ControlNetModelFormat
* @description An enumeration. * @description An enumeration.
@ -1923,12 +1935,12 @@ export type components = {
* Width * Width
* @description The width to resize to (px) * @description The width to resize to (px)
*/ */
width: number; width?: number;
/** /**
* Height * Height
* @description The height to resize to (px) * @description The height to resize to (px)
*/ */
height: number; height?: number;
/** /**
* Resample Mode * Resample Mode
* @description The resampling mode * @description The resampling mode
@ -3911,13 +3923,15 @@ export type components = {
/** /**
* Width * Width
* @description The width to resize to (px) * @description The width to resize to (px)
* @default 512
*/ */
width: number; width?: number;
/** /**
* Height * Height
* @description The height to resize to (px) * @description The height to resize to (px)
* @default 512
*/ */
height: number; height?: number;
/** /**
* Mode * Mode
* @description The interpolation mode * @description The interpolation mode
@ -4605,18 +4619,18 @@ export type components = {
*/ */
image?: components["schemas"]["ImageField"]; image?: components["schemas"]["ImageField"];
}; };
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusion1ModelFormat * StableDiffusion1ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;

View File

@ -32,6 +32,8 @@ export type BaseModelType = components['schemas']['BaseModelType'];
export type MainModelField = components['schemas']['MainModelField']; export type MainModelField = components['schemas']['MainModelField'];
export type VAEModelField = components['schemas']['VAEModelField']; export type VAEModelField = components['schemas']['VAEModelField'];
export type LoRAModelField = components['schemas']['LoRAModelField']; export type LoRAModelField = components['schemas']['LoRAModelField'];
export type ControlNetModelField =
components['schemas']['ControlNetModelField'];
export type ModelsList = components['schemas']['ModelsList']; export type ModelsList = components['schemas']['ModelsList'];
export type ControlField = components['schemas']['ControlField']; export type ControlField = components['schemas']['ControlField'];

View File

@ -30,7 +30,7 @@ const invokeAIThumb = defineStyle((props) => {
const invokeAIMark = defineStyle((props) => { const invokeAIMark = defineStyle((props) => {
return { return {
fontSize: 'xs', fontSize: '2xs',
fontWeight: '500', fontWeight: '500',
color: mode('base.700', 'base.400')(props), color: mode('base.700', 'base.400')(props),
mt: 2, mt: 2,