Merge branch 'main' into update-textual-inversion-training

This commit is contained in:
Lincoln Stein 2023-07-15 17:44:45 -04:00 committed by GitHub
commit f66ead0819
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
88 changed files with 2840 additions and 1075 deletions

4
.github/CODEOWNERS vendored
View File

@ -6,7 +6,7 @@
/mkdocs.yml @lstein @blessedcoolant
# nodes
/invokeai/app/ @Kyle0654 @blessedcoolant
/invokeai/app/ @Kyle0654 @blessedcoolant @psychedelicious @brandonrising
# installation and configuration
/pyproject.toml @lstein @blessedcoolant
@ -22,7 +22,7 @@
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp
# generation, model management, postprocessing
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 @StAlKeR7779
/invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising
# front ends
/invokeai/frontend/CLI @lstein

View File

@ -0,0 +1,287 @@
---
title: Configuration
---
# :material-tune-variant: InvokeAI Configuration
## Intro
InvokeAI has numerous runtime settings which can be used to adjust
many aspects of its operations, including the location of files and
directories, memory usage, and performance. These settings can be
viewed and customized in several ways:
1. By editing settings in the `invokeai.yaml` file.
2. By setting environment variables.
3. On the command-line, when InvokeAI is launched.
In addition, the most commonly changed settings are accessible
graphically via the `invokeai-configure` script.
### How the Configuration System Works
When InvokeAI is launched, the very first thing it needs to do is to
find its "root" directory, which contains its configuration files,
installed models, its database of images, and the folder(s) of
generated images themselves. In this document, the root directory will
be referred to as ROOT.
#### Finding the Root Directory
To find its root directory, InvokeAI uses the following recipe:
1. It first looks for the argument `--root <path>` on the command line
it was launched from, and uses the indicated path if present.
2. Next it looks for the environment variable INVOKEAI_ROOT, and uses
the directory path found there if present.
3. If neither of these are present, then InvokeAI looks for the
folder containing the `.venv` Python virtual environment directory for
the currently active environment. This directory is checked for files
expected inside the InvokeAI root before it is used.
4. Finally, InvokeAI looks for a directory in the current user's home
directory named `invokeai`.
#### Reading the InvokeAI Configuration File
Once the root directory has been located, InvokeAI looks for a file
named `ROOT/invokeai.yaml`, and if present reads configuration values
from it. The top of this file looks like this:
```
InvokeAI:
Web Server:
host: localhost
port: 9090
allow_origins: []
allow_credentials: true
allow_methods:
- '*'
allow_headers:
- '*'
Features:
esrgan: true
internet_available: true
log_tokenization: false
nsfw_checker: false
patchmatch: true
restore: true
...
```
This lines in this file are used to establish default values for
Invoke's settings. In the above fragment, the Web Server's listening
port is set to 9090 by the `port` setting.
You can edit this file with a text editor such as "Notepad" (do not
use Word or any other word processor). When editing, be careful to
maintain the indentation, and do not add extraneous text, as syntax
errors will prevent InvokeAI from launching. A basic guide to the
format of YAML files can be found
[here](https://circleci.com/blog/what-is-yaml-a-beginner-s-guide/).
You can fix a broken `invokeai.yaml` by deleting it and running the
configuration script again -- option [7] in the launcher, "Re-run the
configure script".
#### Reading Environment Variables
Next InvokeAI looks for defined environment variables in the format
`INVOKEAI_<setting_name>`, for example `INVOKEAI_port`. Environment
variable values take precedence over configuration file variables. On
a Macintosh system, for example, you could change the port that the
web server listens on by setting the environment variable this way:
```
export INVOKEAI_port=8000
invokeai-web
```
Please check out these
[Macintosh](https://phoenixnap.com/kb/set-environment-variable-mac)
and
[Windows](https://phoenixnap.com/kb/windows-set-environment-variable)
guides for setting temporary and permanent environment variables.
#### Reading the Command Line
Lastly, InvokeAI takes settings from the command line, which override
everything else. The command-line settings have the same name as the
corresponding configuration file settings, preceded by a `--`, for
example `--port 8000`.
If you are using the launcher (`invoke.sh` or `invoke.bat`) to launch
InvokeAI, then just pass the command-line arguments to the launcher:
```
invoke.bat --port 8000 --host 0.0.0.0
```
The arguments will be applied when you select the web server option
(and the other options as well).
If, on the other hand, you prefer to launch InvokeAI directly from the
command line, you would first activate the virtual environment (known
as the "developer's console" in the launcher), and run `invokeai-web`:
```
> C:\Users\Fred\invokeai\.venv\scripts\activate
(.venv) > invokeai-web --port 8000 --host 0.0.0.0
```
You can get a listing and brief instructions for each of the
command-line options by giving the `--help` argument:
```
(.venv) > invokeai-web --help
usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS ...]] [--allow_credentials | --no-allow_credentials]
[--allow_methods [ALLOW_METHODS ...]] [--allow_headers [ALLOW_HEADERS ...]] [--esrgan | --no-esrgan]
[--internet_available | --no-internet_available] [--log_tokenization | --no-log_tokenization]
[--nsfw_checker | --no-nsfw_checker] [--patchmatch | --no-patchmatch] [--restore | --no-restore]
[--always_use_cpu | --no-always_use_cpu] [--free_gpu_mem | --no-free_gpu_mem] [--max_cache_size MAX_CACHE_SIZE]
[--max_vram_cache_size MAX_VRAM_CACHE_SIZE] [--precision {auto,float16,float32,autocast}]
[--sequential_guidance | --no-sequential_guidance] [--xformers_enabled | --no-xformers_enabled]
[--tiled_decode | --no-tiled_decode] [--root ROOT] [--autoimport_dir AUTOIMPORT_DIR] [--lora_dir LORA_DIR]
[--embedding_dir EMBEDDING_DIR] [--controlnet_dir CONTROLNET_DIR] [--conf_path CONF_PATH] [--models_dir MODELS_DIR]
[--legacy_conf_dir LEGACY_CONF_DIR] [--db_dir DB_DIR] [--outdir OUTDIR] [--from_file FROM_FILE]
[--use_memory_db | --no-use_memory_db] [--model MODEL] [--log_handlers [LOG_HANDLERS ...]]
[--log_format {plain,color,syslog,legacy}] [--log_level {debug,info,warning,error,critical}]
...
```
## The Configuration Settings
The configuration settings are divided into several distinct
groups in `invokeia.yaml`:
### Web Server
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
| `port` | `9090` | Network port number that the web server will listen on |
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
| `allow_credentials | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
The documentation for InvokeAI's API can be accessed by browsing to the following URL: [http://localhost:9090/docs].
### Features
These configuration settings allow you to enable and disable various InvokeAI features:
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `esrgan` | `true` | Activate the ESRGAN upscaling options|
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
| `nsfw_checker` | `true` | Activate the NSFW checker to blur out risque images |
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
| `restore` | `true` | Activate the facial restoration features (DEPRECATED; restoration features will be removed in 3.0.0) |
### Memory/Performance
These options tune InvokeAI's memory and performance characteristics.
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `always_use_cpu` | `false` | Use the CPU to generate images, even if a GPU is available |
| `free_gpu_mem` | `false` | Aggressively free up GPU memory after each operation; this will allow you to run in low-VRAM environments with some performance penalties |
| `max_cache_size` | `6` | Amount of CPU RAM (in GB) to reserve for caching models in memory; more cache allows you to keep models in memory and switch among them quickly |
| `max_vram_cache_size` | `2.75` | Amount of GPU VRAM (in GB) to reserve for caching models in VRAM; more cache speeds up generation but reduces the size of the images that can be generated. This can be set to zero to maximize the amount of memory available for generation. |
| `precision` | `auto` | Floating point precision. One of `auto`, `float16` or `float32`. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system |
| `sequential_guidance` | `false` | Calculate guidance in serial rather than in parallel, lowering memory requirements at the cost of some performance loss |
| `xformers_enabled` | `true` | If the x-formers memory-efficient attention module is installed, activate it for better memory usage and generation speed|
| `tiled_decode` | `false` | If true, then during the VAE decoding phase the image will be decoded a section at a time, reducing memory consumption at the cost of a performance hit |
### Paths
These options set the paths of various directories and files used by
InvokeAI. Relative paths are interpreted relative to INVOKEAI_ROOT, so
if INVOKEAI_ROOT is `/home/fred/invokeai` and the path is
`autoimport/main`, then the corresponding directory will be located at
`/home/fred/invokeai/autoimport/main`.
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `autoimport_dir` | `autoimport/main` | At startup time, read and import any main model files found in this directory |
| `lora_dir` | `autoimport/lora` | At startup time, read and import any LoRA/LyCORIS models found in this directory |
| `embedding_dir` | `autoimport/embedding` | At startup time, read and import any textual inversion (embedding) models found in this directory |
| `controlnet_dir` | `autoimport/controlnet` | At startup time, read and import any ControlNet models found in this directory |
| `conf_path` | `configs/models.yaml` | Location of the `models.yaml` model configuration file |
| `models_dir` | `models` | Location of the directory containing models installed by InvokeAI's model manager |
| `legacy_conf_dir` | `configs/stable-diffusion` | Location of the directory containing the .yaml configuration files for legacy checkpoint models |
| `db_dir` | `databases` | Location of the directory containing InvokeAI's image, schema and session database |
| `outdir` | `outputs` | Location of the directory in which the gallery of generated and uploaded images will be stored |
| `use_memory_db` | `false` | Keep database information in memory rather than on disk; this will not preserve image gallery information across restarts |
Note that the autoimport directories will be searched recursively,
allowing you to organize the models into folders and subfolders in any
way you wish. In addition, while we have split up autoimport
directories by the type of model they contain, this isn't
necessary. You can combine different model types in the same folder
and InvokeAI will figure out what they are. So you can easily use just
one autoimport directory by commenting out the unneeded paths:
```
Paths:
autoimport_dir: autoimport
# lora_dir: null
# embedding_dir: null
# controlnet_dir: null
```
### Logging
These settings control the information, warning, and debugging
messages printed to the console log while InvokeAI is running:
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `log_handlers` | `console` | This controls where log messages are sent, and can be a list of one or more destinations. Values include `console`, `file`, `syslog` and `http`. These are described in more detail below |
| `log_format` | `color` | This controls the formatting of the log messages. Values are `plain`, `color`, `legacy` and `syslog` |
| `log_level` | `debug` | This filters messages according to the level of severity and can be one of `debug`, `info`, `warning`, `error` and `critical`. For example, setting to `warning` will display all messages at the warning level or higher, but won't display "debug" or "info" messages |
Several different log handler destinations are available, and multiple destinations are supported by providing a list:
```
log_handlers:
- console
- syslog=localhost
- file=/var/log/invokeai.log
```
* `console` is the default. It prints log messages to the command-line window from which InvokeAI was launched.
* `syslog` is only available on Linux and Macintosh systems. It uses
the operating system's "syslog" facility to write log file entries
locally or to a remote logging machine. `syslog` offers a variety
of configuration options:
```
syslog=/dev/log` - log to the /dev/log device
syslog=localhost` - log to the network logger running on the local machine
syslog=localhost:512` - same as above, but using a non-standard port
syslog=fredserver,facility=LOG_USER,socktype=SOCK_DRAM`
- Log to LAN-connected server "fredserver" using the facility LOG_USER and datagram packets.
```
* `http` can be used to log to a remote web server. The server must be
properly configured to receive and act on log messages. The option
accepts the URL to the web server, and a `method` argument
indicating whether the message should be submitted using the GET or
POST method.
```
http=http://my.server/path/to/logger,method=POST
```
The `log_format` option provides several alternative formats:
* `color` - default format providing time, date and a message, using text colors to distinguish different log severities
* `plain` - same as above, but monochrome text only
* `syslog` - the log level and error message only, allowing the syslog system to attach the time and date
* `legacy` - a format similar to the one used by the legacy 2.3 InvokeAI releases.

View File

@ -153,6 +153,9 @@ This method is recommended for those familiar with running Docker containers
- [Prompt Syntax](features/PROMPTS.md)
- [Generating Variations](features/VARIATIONS.md)
### InvokeAI Configuration
- [Guide to InvokeAI Runtime Settings](features/CONFIGURATION.md)
## :octicons-log-16: Important Changes Since Version 2.3
### Nodes

View File

@ -1,6 +1,7 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
import pathlib
from typing import Literal, List, Optional, Union
from fastapi import Body, Path, Query, Response
@ -22,6 +23,7 @@ UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@ -78,7 +80,7 @@ async def update_model(
return model_response
@models_router.post(
"/",
"/import",
operation_id="import_model",
responses= {
201: {"description" : "The model imported successfully"},
@ -94,7 +96,7 @@ async def import_model(
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL """
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
items_to_import = {location}
prediction_types = { x.value: x for x in SchedulerPredictionType }
@ -126,18 +128,100 @@ async def import_model(
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/add",
operation_id="add_model",
responses= {
201: {"description" : "The model added successfully"},
404: {"description" : "The model could not be found"},
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse
)
async def add_model(
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> ImportModelResponse:
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
logger = ApiDependencies.invoker.services.logger
try:
ApiDependencies.invoker.services.model_manager.add_model(
info.model_name,
info.base_model,
info.model_type,
model_attributes = info.dict()
)
logger.info(f'Successfully added {info.model_name}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.model_name,
base_model=info.base_model,
model_type=info.model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/rename/{base_model}/{model_type}/{model_name}",
operation_id="rename_model",
responses= {
201: {"description" : "The model was renamed successfully"},
404: {"description" : "The model could not be found"},
409: {"description" : "There is already a model corresponding to the new name"},
},
status_code=201,
response_model=ImportModelResponse
)
async def rename_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="current model name"),
new_name: Optional[str] = Query(description="new model name", default=None),
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
) -> ImportModelResponse:
""" Rename a model"""
logger = ApiDependencies.invoker.services.logger
try:
result = ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = new_name,
new_base = new_base,
)
logger.debug(result)
logger.info(f'Successfully renamed {model_name}=>{new_name}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=new_name or model_name,
base_model=new_base or base_model,
model_type=model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete(
"/{base_model}/{model_type}/{model_name}",
operation_id="del_model",
responses={
204: {
"description": "Model deleted successfully"
},
404: {
"description": "Model not found"
}
204: { "description": "Model deleted successfully" },
404: { "description": "Model not found" }
},
status_code = 204,
response_model = None,
)
async def delete_model(
base_model: BaseModelType = Path(description="Base model"),
@ -173,14 +257,17 @@ async def convert_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
) -> ConvertModelResponse:
"""Convert a checkpoint model into a diffusers model"""
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Converting model: {model_name}")
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
base_model = base_model,
model_type = model_type
model_type = model_type,
convert_dest_directory = dest,
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
base_model = base_model,
@ -192,6 +279,53 @@ async def convert_model(
raise HTTPException(status_code=400, detail=str(e))
return response
@models_router.get(
"/search",
operation_id="search_for_models",
responses={
200: { "description": "Directory searched successfully" },
404: { "description": "Invalid directory path" },
},
status_code = 200,
response_model = List[pathlib.Path]
)
async def search_for_models(
search_path: pathlib.Path = Query(description="Directory path to search for models")
)->List[pathlib.Path]:
if not search_path.is_dir():
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
@models_router.get(
"/ckpt_confs",
operation_id="list_ckpt_configs",
responses={
200: { "description" : "paths retrieved successfully" },
},
status_code = 200,
response_model = List[pathlib.Path]
)
async def list_ckpt_configs(
)->List[pathlib.Path]:
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
@models_router.get(
"/sync",
operation_id="sync_to_config",
responses={
201: { "description": "synchronization successful" },
},
status_code = 201,
response_model = None
)
async def sync_to_config(
)->None:
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
in-memory data structures with disk data structures."""
return ApiDependencies.invoker.services.model_manager.sync_to_config()
@models_router.put(
"/merge/{base_model}",
operation_id="merge_models",
@ -210,17 +344,21 @@ async def merge_models(
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Merging models: {model_names}")
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
base_model,
merged_model_name or "+".join(model_names),
alpha,
interp,
force)
merged_model_name=merged_model_name or "+".join(model_names),
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory = dest
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
base_model = base_model,
model_type = ModelType.Main,

View File

@ -9,6 +9,7 @@ from typing import Literal, Optional, Union, List, Dict
from PIL import Image
from pydantic import BaseModel, Field, validator
from ...backend.model_management import BaseModelType, ModelType
from ..models.image import ImageField, ImageCategory, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
@ -105,9 +106,15 @@ CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
model_name: str = Field(description="Name of the ControlNet model")
base_model: BaseModelType = Field(description="Base model")
class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image")
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1,
@ -118,15 +125,15 @@ class ControlField(BaseModel):
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@validator("control_weight")
def abs_le_one(cls, v):
"""validate that all abs(values) are <=1"""
def validate_control_weight(cls, v):
"""Validate that all control weights in the valid range"""
if isinstance(v, list):
for i in v:
if abs(i) > 1:
raise ValueError('all abs(control_weight) must be <= 1')
if i < -1 or i > 2:
raise ValueError('Control weights must be within -1 to 2 range')
else:
if abs(v) > 1:
raise ValueError('abs(control_weight) must be <= 1')
if v < -1 or v > 2:
raise ValueError('Control weights must be within -1 to 2 range')
return v
class Config:
schema_extra = {
@ -134,6 +141,7 @@ class ControlField(BaseModel):
"ui": {
"type_hints": {
"control_weight": "float",
"control_model": "controlnet_model",
# "control_weight": "number",
}
}
@ -154,10 +162,10 @@ class ControlNetInvocation(BaseInvocation):
type: Literal["controlnet"] = "controlnet"
# Inputs
image: ImageField = Field(default=None, description="The control image")
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1,
begin_step_percent: float = Field(default=0, ge=-1, le=2,
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")

View File

@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import einops
@ -11,6 +12,7 @@ from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models.base import ModelType
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
@ -71,16 +73,21 @@ def get_scheduler(
scheduler_name: str,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_name, SCHEDULER_MAP['ddim']
)
orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict())
**scheduler_info.dict()
)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **
scheduler_extra_config, "_backup": scheduler_config}
scheduler_config = {
**scheduler_config,
**scheduler_extra_config,
"_backup": scheduler_config,
}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
@ -137,8 +144,11 @@ class TextToLatentsInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, source_node_id: str,
intermediate_state: PipelineIntermediateState) -> None:
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
@ -147,11 +157,16 @@ class TextToLatentsInvocation(BaseInvocation):
)
def get_conditioning_data(
self, context: InvocationContext, scheduler) -> ConditioningData:
self,
context: InvocationContext,
scheduler,
) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(
self.positive_conditioning.conditioning_name)
self.positive_conditioning.conditioning_name
)
uc, _ = context.services.latents.get(
self.negative_conditioning.conditioning_name)
self.negative_conditioning.conditioning_name
)
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
@ -178,7 +193,10 @@ class TextToLatentsInvocation(BaseInvocation):
return conditioning_data
def create_pipeline(
self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
self,
unet,
scheduler,
) -> StableDiffusionGeneratorPipeline:
# TODO:
# configure_model_padding(
# unet,
@ -213,6 +231,7 @@ class TextToLatentsInvocation(BaseInvocation):
model: StableDiffusionGeneratorPipeline,
control_input: List[ControlField],
latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:
@ -238,25 +257,19 @@ class TextToLatentsInvocation(BaseInvocation):
control_data = []
control_models = []
for control_info in control_list:
# handle control models
if ("," in control_info.control_model):
control_model_split = control_info.control_model.split(",")
control_name = control_model_split[0]
control_subfolder = control_model_split[1]
print("Using HF model subfolders")
print(" control_name: ", control_name)
print(" control_subfolder: ", control_subfolder)
control_model = ControlNetModel.from_pretrained(
control_name, subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(
model.device)
else:
control_model = ControlNetModel.from_pretrained(
control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
control_model = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
)
)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(
control_image_field.image_name)
control_image_field.image_name
)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
@ -278,7 +291,8 @@ class TextToLatentsInvocation(BaseInvocation):
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,)
control_mode=control_info.control_mode,
)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
@ -289,7 +303,8 @@ class TextToLatentsInvocation(BaseInvocation):
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
@ -298,14 +313,17 @@ class TextToLatentsInvocation(BaseInvocation):
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
**lora.dict(exclude={"weight"})
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
**self.unet.unet.dict()
)
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
scheduler = get_scheduler(
@ -322,6 +340,7 @@ class TextToLatentsInvocation(BaseInvocation):
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# TODO: Verify the noise is the right size
@ -374,7 +393,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
@ -383,14 +403,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
**lora.dict(exclude={"weight"})
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
**self.unet.unet.dict()
)
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
scheduler = get_scheduler(
@ -407,11 +430,13 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=unet.device, dtype=latent.dtype)
latent, device=unet.device, dtype=latent.dtype
)
timesteps, _ = pipeline.get_img2img_timesteps(
self.steps,
@ -535,7 +560,8 @@ class ResizeLatentsInvocation(BaseInvocation):
resized_latents = torch.nn.functional.interpolate(
latents, size=(self.height // 8, self.width // 8),
mode=self.mode, antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,)
if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@ -569,7 +595,8 @@ class ScaleLatentsInvocation(BaseInvocation):
resized_latents = torch.nn.functional.interpolate(
latents, scale_factor=self.scale_factor, mode=self.mode,
antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,)
if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()

View File

@ -271,8 +271,8 @@ class InvokeAISettings(BaseSettings):
@classmethod
def _excluded(self)->List[str]:
# combination of deprecated parameters and internal ones
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version']
# combination of deprecated parameters and internal ones that shouldn't be exposed
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'root']
class Config:
env_file_encoding = 'utf-8'

View File

@ -19,7 +19,7 @@ from invokeai.backend.model_management import (
ModelMerger,
MergeInterpolationMethod,
)
from invokeai.backend.model_management.model_search import FindModels
import torch
from invokeai.app.models.exceptions import CanceledException
@ -167,6 +167,27 @@ class ModelManagerServiceBase(ABC):
"""
pass
@abstractmethod
def rename_model(self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str,
):
"""
Rename the indicated model.
"""
pass
@abstractmethod
def list_checkpoint_configs(
self
)->List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
pass
@abstractmethod
def convert_model(
self,
@ -220,6 +241,7 @@ class ModelManagerServiceBase(ABC):
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
@ -228,6 +250,23 @@ class ModelManagerServiceBase(ABC):
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
pass
@abstractmethod
def search_for_models(self, directory: Path)->List[Path]:
"""
Return list of all models found in the designated directory.
"""
pass
@abstractmethod
def sync_to_config(self):
"""
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
"""
pass
@ -431,16 +470,18 @@ class ModelManagerService(ModelManagerServiceBase):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk.
as well.
"""
self.logger.debug(f'delete model {model_name}')
self.mgr.del_model(model_name, base_model, model_type)
self.mgr.commit()
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
@ -449,13 +490,14 @@ class ModelManagerService(ModelManagerServiceBase):
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
self.logger.debug(f'convert model {model_name}')
return self.mgr.convert_model(model_name, base_model, model_type)
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
def commit(self, conf_file: Optional[Path]=None):
"""
@ -536,6 +578,7 @@ class ModelManagerService(ModelManagerServiceBase):
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
@ -544,6 +587,7 @@ class ModelManagerService(ModelManagerServiceBase):
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
merger = ModelMerger(self.mgr)
try:
@ -554,7 +598,55 @@ class ModelManagerService(ModelManagerServiceBase):
alpha = alpha,
interp = interp,
force = force,
merge_dest_directory=merge_dest_directory,
)
except AssertionError as e:
raise ValueError(e)
return result
def search_for_models(self, directory: Path)->List[Path]:
"""
Return list of all models found in the designated directory.
"""
search = FindModels(directory,self.logger)
return search.list_models()
def sync_to_config(self):
"""
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
"""
return self.mgr.sync_to_config()
def list_checkpoint_configs(self)->List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
config = self.mgr.app_config
conf_path = config.legacy_conf_path
root_path = config.root_path
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')]
def rename_model(self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str = None,
new_base: BaseModelType = None,
):
"""
Rename the indicated model. Can provide a new name and/or a new base.
:param model_name: Current name of the model
:param base_model: Current base of the model
:param model_type: Model type (can't be changed)
:param new_name: New name for the model
:param new_base: New base for the model
"""
self.mgr.rename_model(base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = new_name,
new_base = new_base,
)

View File

@ -593,9 +593,12 @@ script, which will perform a full upgrade in place."""
config = InvokeAIAppConfig.get_config()
config.parse_args(['--root',str(dest_root)])
# TODO: revisit
# assert (dest_root / 'models').is_dir(), f"{dest_root} does not contain a 'models' subdirectory"
# assert (dest_root / 'invokeai.yaml').exists(), f"{dest_root} does not contain an InvokeAI init file."
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
dest_is_setup = (dest_root / 'models/core').exists() and (dest_root / 'databases').exists()
if not dest_is_setup:
import invokeai.frontend.install.invokeai_configure
from invokeai.backend.install.invokeai_configure import initialize_rootdir
initialize_rootdir(dest_root, True)
do_migrate(src_root,dest_root)

View File

@ -71,8 +71,6 @@ class ModelInstallList:
class InstallSelections():
install_models: List[str]= field(default_factory=list)
remove_models: List[str]=field(default_factory=list)
# scan_directory: Path = None
# autoscan_on_startup: bool=False
@dataclass
class ModelLoadInfo():

View File

@ -247,6 +247,7 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, Chdir
from .model_cache import ModelCache, ModelLocker
from .model_search import ModelSearch
from .models import (
BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES,
@ -323,15 +324,6 @@ class ModelManager(object):
# TODO: metadata not found
# TODO: version check
self.models = dict()
for model_key, model_config in config.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary
self.app_config = InvokeAIAppConfig.get_config()
self.logger = logger
self.cache = ModelCache(
@ -342,11 +334,41 @@ class ModelManager(object):
sequential_offload = sequential_offload,
logger = logger,
)
self._read_models(config)
def _read_models(self, config: Optional[DictConfig] = None):
if not config:
if self.config_path:
config = OmegaConf.load(self.config_path)
else:
return
self.models = dict()
for model_key, model_config in config.items():
if model_key.startswith('_'):
continue
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary
self.cache_keys = dict()
# add controlnet, lora and textual_inversion models from disk
self.scan_models_directory()
def sync_to_config(self):
"""
Call this when `models.yaml` has been changed externally.
This will reinitialize internal data structures
"""
# Reread models directory; note that this will reinitialize the cache,
# causing otherwise unreferenced models to be removed from memory
self._read_models()
def model_exists(
self,
model_name: str,
@ -527,7 +549,10 @@ class ModelManager(object):
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
models = []
for model_key in model_keys:
model_config = self.models[model_key]
model_config = self.models.get(model_key)
if not model_config:
self.logger.error(f'Unknown model {model_name}')
raise KeyError(f'Unknown model {model_name}')
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
if base_model is not None and cur_base_model != base_model:
@ -646,11 +671,61 @@ class ModelManager(object):
config = model_config,
)
def rename_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str = None,
new_base: BaseModelType = None,
):
'''
Rename or rebase a model.
'''
if new_name is None and new_base is None:
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
return
model_key = self.create_key(model_name, base_model, model_type)
model_cfg = self.models.get(model_key, None)
if not model_cfg:
raise KeyError(f"Unknown model: {model_key}")
old_path = self.app_config.root_path / model_cfg.path
new_name = new_name or model_name
new_base = new_base or base_model
new_key = self.create_key(new_name, new_base, model_type)
if new_key in self.models:
raise ValueError(f'Attempt to overwrite existing model definition "{new_key}"')
# if this is a model file/directory that we manage ourselves, we need to move it
if old_path.is_relative_to(self.app_config.models_path):
new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name
move(old_path, new_path)
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
# clean up caches
old_model_cache = self._get_model_cache_path(old_path)
if old_model_cache.exists():
if old_model_cache.is_dir():
rmtree(str(old_model_cache))
else:
old_model_cache.unlink()
cache_ids = self.cache_keys.pop(model_key, [])
for cache_id in cache_ids:
self.cache.uncache_model(cache_id)
self.models.pop(model_key, None) # delete
self.models[new_key] = model_cfg
self.commit()
def convert_model (
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
dest_directory: Optional[Path]=None,
) -> AddModelResult:
'''
Convert a checkpoint file into a diffusers folder, deleting the cached
@ -677,14 +752,14 @@ class ModelManager(object):
)
checkpoint_path = self.app_config.root_path / info["path"]
old_diffusers_path = self.app_config.models_path / model.location
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
new_diffusers_path = (dest_directory or self.app_config.models_path / base_model.value / model_type.value) / model_name
if new_diffusers_path.exists():
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
try:
move(old_diffusers_path,new_diffusers_path)
info["model_format"] = "diffusers"
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path))
info["path"] = str(new_diffusers_path) if dest_directory else str(new_diffusers_path.relative_to(self.app_config.root_path))
info.pop('config')
result = self.add_model(model_name, base_model, model_type,
@ -824,6 +899,7 @@ class ModelManager(object):
if (new_models_found or imported_models) and self.config_path:
self.commit()
def autoimport(self)->Dict[str, AddModelResult]:
'''
Scan the autoimport directory (if defined) and import new models, delete defunct models.
@ -832,62 +908,41 @@ class ModelManager(object):
from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
class ScanAndImport(ModelSearch):
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
super().__init__(directories, logger)
self.installer = installer
self.ignore = ignore
def on_search_started(self):
self.new_models_found = dict()
def on_model_found(self, model: Path):
if model not in self.ignore:
self.new_models_found.update(self.installer.heuristic_import(model))
def on_search_completed(self):
self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models')
def models_found(self):
return self.new_models_found
installer = ModelInstall(config = self.app_config,
model_manager = self,
prediction_type_helper = ask_user_for_prediction_type,
)
scanned_dirs = set()
config = self.app_config
known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()}
for autodir in [config.autoimport_dir,
config.lora_dir,
config.embedding_dir,
config.controlnet_dir]:
if autodir is None:
continue
self.logger.info(f'Scanning {autodir} for models to import')
installed = dict()
autodir = self.app_config.root_path / autodir
if not autodir.exists():
continue
items_scanned = 0
new_models_found = dict()
for root, dirs, files in os.walk(autodir):
items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in known_paths or path.parent in scanned_dirs:
scanned_dirs.add(path)
continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
try:
new_models_found.update(installer.heuristic_import(path))
scanned_dirs.add(path)
except ValueError as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path in known_paths or path.parent in scanned_dirs:
continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
try:
import_result = installer.heuristic_import(path)
new_models_found.update(import_result)
except ValueError as e:
self.logger.warning(str(e))
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
installed.update(new_models_found)
return installed
known_paths = {config.root_path / x['path'] for x in self.list_models()}
directories = {config.root_path / x for x in [config.autoimport_dir,
config.lora_dir,
config.embedding_dir,
config.controlnet_dir]
}
scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
scanner.search()
return scanner.models_found()
def heuristic_import(self,
items_to_import: Set[str],
@ -925,3 +980,4 @@ class ModelManager(object):
successfully_installed.update(installed)
self.commit()
return successfully_installed

View File

@ -11,7 +11,7 @@ from enum import Enum
from pathlib import Path
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
from typing import List, Union
from typing import List, Union, Optional
import invokeai.backend.util.logging as logger
@ -74,6 +74,7 @@ class ModelMerger(object):
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
force: bool = False,
merge_dest_directory: Optional[Path] = None,
**kwargs,
) -> AddModelResult:
"""
@ -85,7 +86,7 @@ class ModelMerger(object):
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
@ -111,7 +112,7 @@ class ModelMerger(object):
merged_pipe = self.merge_diffusion_models(
model_paths, alpha, merge_method, force, **kwargs
)
dump_path = config.models_path / base_model.value / ModelType.Main.value
dump_path = Path(merge_dest_directory) if merge_dest_directory else config.models_path / base_model.value / ModelType.Main.value
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name

View File

@ -0,0 +1,103 @@
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
"""
Abstract base class for recursive directory search for models.
"""
import os
from abc import ABC, abstractmethod
from typing import List, Set, types
from pathlib import Path
import invokeai.backend.util.logging as logger
class ModelSearch(ABC):
def __init__(self, directories: List[Path], logger: types.ModuleType=logger):
"""
Initialize a recursive model directory search.
:param directories: List of directory Paths to recurse through
:param logger: Logger to use
"""
self.directories = directories
self.logger = logger
self._items_scanned = 0
self._models_found = 0
self._scanned_dirs = set()
self._scanned_paths = set()
self._pruned_paths = set()
@abstractmethod
def on_search_started(self):
"""
Called before the scan starts.
"""
pass
@abstractmethod
def on_model_found(self, model: Path):
"""
Process a found model. Raise an exception if something goes wrong.
:param model: Model to process - could be a directory or checkpoint.
"""
pass
@abstractmethod
def on_search_completed(self):
"""
Perform some activity when the scan is completed. May use instance
variables, items_scanned and models_found
"""
pass
def search(self):
self.on_search_started()
for dir in self.directories:
self.walk_directory(dir)
self.on_search_completed()
def walk_directory(self, path: Path):
for root, dirs, files in os.walk(path):
if str(Path(root).name).startswith('.'):
self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
continue
self._items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in self._scanned_paths or path.parent in self._scanned_dirs:
self._scanned_dirs.add(path)
continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
try:
self.on_model_found(path)
self._models_found += 1
self._scanned_dirs.add(path)
except Exception as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path.parent in self._scanned_dirs:
continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
try:
self.on_model_found(path)
self._models_found += 1
except Exception as e:
self.logger.warning(str(e))
class FindModels(ModelSearch):
def on_search_started(self):
self.models_found: Set[Path] = set()
def on_model_found(self,model: Path):
self.models_found.add(model)
def on_search_completed(self):
pass
def list_models(self) -> List[Path]:
self.search()
return self.models_found

View File

@ -48,7 +48,9 @@ for base_model, models in MODEL_CLASSES.items():
model_configs.discard(None)
MODEL_CONFIGS.extend(model_configs)
for cfg in model_configs:
# LS: sort to get the checkpoint configs first, which makes
# for a better template in the Swagger docs
for cfg in sorted(model_configs, key=lambda x: str(x)):
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
openapi_cfg_name = model_name + cfg_name
if openapi_cfg_name in vars():

View File

@ -59,7 +59,6 @@ class ModelConfigBase(BaseModel):
path: str # or Path
description: Optional[str] = Field(None)
model_format: Optional[str] = Field(None)
# do not save to config
error: Optional[ModelError] = Field(None)
class Config:

View File

@ -1,8 +1,7 @@
import os
import torch
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from typing import Optional
from .base import (
ModelBase,
ModelConfigBase,
@ -14,6 +13,7 @@ from .base import (
calc_model_size_by_data,
classproperty,
InvalidModelException,
ModelNotFoundException,
)
class ControlNetModelFormat(str, Enum):
@ -60,10 +60,20 @@ class ControlNetModel(ModelBase):
if child_type is not None:
raise Exception("There is no child models in controlnet model")
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
)
model = None
for variant in ['fp16',None]:
try:
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
variant=variant,
)
break
except:
pass
if not model:
raise ModelNotFoundException()
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model

View File

@ -38,7 +38,6 @@ class StableDiffusion1Model(DiffusersModel):
config: str
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1
assert model_type == ModelType.Main

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-078526aa.js"></script>
<script type="module" crossorigin src="./assets/index-8888b06f.js"></script>
</head>
<body dir="ltr">

View File

@ -102,7 +102,8 @@
"openInNewTab": "Open in New Tab",
"dontAskMeAgain": "Don't ask me again",
"areYouSure": "Are you sure?",
"imagePrompt": "Image Prompt"
"imagePrompt": "Image Prompt",
"clearNodes": "Are you sure you want to clear all nodes?"
},
"gallery": {
"generations": "Generations",
@ -118,7 +119,7 @@
"pinGallery": "Pin Gallery",
"allImagesLoaded": "All Images Loaded",
"loadMore": "Load More",
"noImagesInGallery": "No Images In Gallery",
"noImagesInGallery": "No Images to Display",
"deleteImage": "Delete Image",
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
"deleteImagePermanent": "Deleted images cannot be restored.",
@ -342,6 +343,7 @@
"safetensorModels": "SafeTensors",
"modelAdded": "Model Added",
"modelUpdated": "Model Updated",
"modelUpdateFailed": "Model Update Failed",
"modelEntryDeleted": "Model Entry Deleted",
"cannotUseSpaces": "Cannot Use Spaces",
"addNew": "Add New",
@ -396,8 +398,8 @@
"delete": "Delete",
"deleteModel": "Delete Model",
"deleteConfig": "Delete Config",
"deleteMsg1": "Are you sure you want to delete this model entry from InvokeAI?",
"deleteMsg2": "This will not delete the model checkpoint file from your disk. You can readd them if you wish to.",
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
"formMessageDiffusersModelLocation": "Diffusers Model Location",
"formMessageDiffusersModelLocationDesc": "Please enter at least one.",
"formMessageDiffusersVAELocation": "VAE Location",
@ -408,7 +410,7 @@
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
"convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.",
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 4GB-7GB in size.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
"convertToDiffusersSaveLocation": "Save Location",
"v1": "v1",
@ -419,12 +421,14 @@
"pathToCustomConfig": "Path To Custom Config",
"statusConverting": "Converting",
"modelConverted": "Model Converted",
"modelConversionFailed": "Model Conversion Failed",
"sameFolder": "Same folder",
"invokeRoot": "InvokeAI folder",
"custom": "Custom",
"customSaveLocation": "Custom Save Location",
"merge": "Merge",
"modelsMerged": "Models Merged",
"modelsMergeFailed": "Model Merge Failed",
"mergeModels": "Merge Models",
"modelOne": "Model 1",
"modelTwo": "Model 2",
@ -445,7 +449,8 @@
"weightedSum": "Weighted Sum",
"none": "none",
"addDifference": "Add Difference",
"pickModelType": "Pick Model Type"
"pickModelType": "Pick Model Type",
"selectModel": "Select Model"
},
"parameters": {
"general": "General",
@ -528,7 +533,7 @@
"hidePreview": "Hide Preview",
"showPreview": "Show Preview",
"controlNetControlMode": "Control Mode",
"clipSkip": "Clip Skip",
"clipSkip": "CLIP Skip",
"aspectRatio": "Ratio"
},
"settings": {
@ -593,7 +598,11 @@
"metadataLoadFailed": "Failed to load metadata",
"initialImageSet": "Initial Image Set",
"initialImageNotSet": "Initial Image Not Set",
"initialImageNotSetDesc": "Could not load initial image"
"initialImageNotSetDesc": "Could not load initial image",
"nodesSaved": "Nodes Saved",
"nodesLoaded": "Nodes Loaded",
"nodesLoadedFailed": "Failed To Load Nodes",
"nodesCleared": "Nodes Cleared"
},
"tooltip": {
"feature": {
@ -674,5 +683,11 @@
"showProgressImages": "Show Progress Images",
"hideProgressImages": "Hide Progress Images",
"swapSizes": "Swap Sizes"
},
"nodes": {
"reloadSchema": "Reload Schema",
"saveNodes": "Save Nodes",
"loadNodes": "Load Nodes",
"clearNodes": "Clear Nodes"
}
}

View File

@ -13,7 +13,11 @@ import { RootState } from 'app/store/store';
const moduleLog = log.child({ namespace: 'controlNet' });
const predicate: AnyListenerPredicate<RootState> = (action, state) => {
const predicate: AnyListenerPredicate<RootState> = (
action,
state,
prevState
) => {
const isActionMatched =
controlNetProcessorParamsChanged.match(action) ||
controlNetModelChanged.match(action) ||
@ -25,6 +29,16 @@ const predicate: AnyListenerPredicate<RootState> = (action, state) => {
return false;
}
if (controlNetAutoConfigToggled.match(action)) {
// do not process if the user just disabled auto-config
if (
prevState.controlNet.controlNets[action.payload.controlNetId]
.shouldAutoConfig === true
) {
return false;
}
}
const { controlImage, processorType, shouldAutoConfig } =
state.controlNet.controlNets[action.payload.controlNetId];

View File

@ -10,6 +10,7 @@ import { zMainModel } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { forEach } from 'lodash-es';
import { startAppListening } from '..';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
const moduleLog = log.child({ module: 'models' });
@ -51,7 +52,14 @@ export const addModelSelectedListener = () => {
modelsCleared += 1;
}
// TODO: handle incompatible controlnet; pending model manager support
const { controlNets } = state.controlNet;
forEach(controlNets, (controlNet, controlNetId) => {
if (controlNet.model?.base_model !== base_model) {
dispatch(controlNetRemoved({ controlNetId }));
modelsCleared += 1;
}
});
if (modelsCleared > 0) {
dispatch(
addToast(

View File

@ -11,6 +11,7 @@ import {
import { forEach, some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
const moduleLog = log.child({ module: 'models' });
@ -127,7 +128,22 @@ export const addModelsLoadedListener = () => {
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
// TODO: pending model manager controlnet support
const controlNets = getState().controlNet.controlNets;
forEach(controlNets, (controlNet, controlNetId) => {
const isControlNetAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === controlNet?.model?.model_name &&
m?.base_model === controlNet?.model?.base_model
);
if (isControlNetAvailable) {
return;
}
dispatch(controlNetRemoved({ controlNetId }));
});
},
});
};

View File

@ -1,5 +1,5 @@
import {
CONTROLNET_MODELS,
// CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
} from 'features/controlNet/store/constants';
import { InvokeTabName } from 'features/ui/store/tabMap';
@ -128,7 +128,7 @@ export type AppConfig = {
canRestoreDeletedImagesFromBin: boolean;
sd: {
defaultModel?: string;
disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[];
disabledControlNetModels: string[];
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
iterations: {
initial: number;

View File

@ -170,12 +170,14 @@ const IAIDndImage = (props: IAIDndImageProps) => {
</>
)}
{!imageDTO && isUploadDisabled && noContentFallback}
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
{imageDTO && (
{!isDropDisabled && (
<IAIDroppable
data={droppableData}
disabled={isDropDisabled}
dropLabel={dropLabel}
/>
)}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}

View File

@ -1,17 +1,25 @@
import { Tooltip } from '@chakra-ui/react';
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles';
import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
type IAIMultiSelectProps = MultiSelectProps & {
type IAIMultiSelectProps = Omit<MultiSelectProps, 'label'> & {
tooltip?: string;
inputRef?: RefObject<HTMLInputElement>;
label?: string;
};
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, inputRef, ...rest } = props;
const {
searchable = true,
tooltip,
inputRef,
label,
disabled,
...rest
} = props;
const dispatch = useAppDispatch();
const handleKeyDown = useCallback(
@ -37,7 +45,15 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
return (
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
<MultiSelect
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
ref={inputRef}
disabled={disabled}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
searchable={searchable}

View File

@ -1,4 +1,4 @@
import { Tooltip } from '@chakra-ui/react';
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
@ -11,13 +11,22 @@ export type IAISelectDataType = {
tooltip?: string;
};
type IAISelectProps = SelectProps & {
type IAISelectProps = Omit<SelectProps, 'label'> & {
tooltip?: string;
label?: string;
inputRef?: RefObject<HTMLInputElement>;
};
const IAIMantineSearchableSelect = (props: IAISelectProps) => {
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props;
const {
searchable = true,
tooltip,
inputRef,
onChange,
label,
disabled,
...rest
} = props;
const dispatch = useAppDispatch();
const [searchValue, setSearchValue] = useState('');
@ -61,6 +70,14 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => {
<Tooltip label={tooltip} placement="top" hasArrow>
<Select
ref={inputRef}
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
disabled={disabled}
searchValue={searchValue}
onSearchChange={setSearchValue}
onChange={handleChange}

View File

@ -1,4 +1,4 @@
import { Tooltip } from '@chakra-ui/react';
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core';
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
import { RefObject, memo } from 'react';
@ -9,19 +9,32 @@ export type IAISelectDataType = {
tooltip?: string;
};
type IAISelectProps = SelectProps & {
type IAISelectProps = Omit<SelectProps, 'label'> & {
tooltip?: string;
inputRef?: RefObject<HTMLInputElement>;
label?: string;
};
const IAIMantineSelect = (props: IAISelectProps) => {
const { tooltip, inputRef, ...rest } = props;
const { tooltip, inputRef, label, disabled, ...rest } = props;
const styles = useMantineSelectStyles();
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<Select ref={inputRef} styles={styles} {...rest} />
<Select
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
disabled={disabled}
ref={inputRef}
styles={styles}
{...rest}
/>
</Tooltip>
);
};

View File

@ -43,11 +43,6 @@ import { useTranslation } from 'react-i18next';
import { BiReset } from 'react-icons/bi';
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
};
export type IAIFullSliderProps = {
label?: string;
value: number;
@ -207,7 +202,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
{...sliderFormControlProps}
>
{label && (
<FormLabel {...sliderFormLabelProps} mb={-1}>
<FormLabel sx={withInput ? { mb: -1.5 } : {}} {...sliderFormLabelProps}>
{label}
</FormLabel>
)}
@ -233,7 +228,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
@ -244,7 +238,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
@ -263,7 +256,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
@ -278,7 +270,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
@ -291,7 +282,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
key={m}
value={m}
sx={{
...SLIDER_MARK_STYLES,
transform: 'translateX(-50%)',
}}
{...sliderMarkProps}
>

View File

@ -5,6 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { modelsApi } from '../../services/api/endpoints/models';
import { forEach } from 'lodash-es';
const readinessSelector = createSelector(
[stateSelector, activeTabNameSelector],
@ -52,6 +53,13 @@ const readinessSelector = createSelector(
reasonsWhyNotReady.push('Seed-Weights badly formatted.');
}
forEach(state.controlNet.controlNets, (controlNet, id) => {
if (!controlNet.model) {
isReady = false;
reasonsWhyNotReady.push('ControlNet ${id} has no model selected.');
}
});
// All good
return { isReady, reasonsWhyNotReady };
},

View File

@ -1,10 +1,9 @@
import { Box, ChakraProps, Flex, useColorMode } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { Box, Flex } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { memo, useCallback } from 'react';
import { FaCopy, FaTrash } from 'react-icons/fa';
import {
ControlNetConfig,
controlNetAdded,
controlNetDuplicated,
controlNetRemoved,
controlNetToggled,
} from '../store/controlNetSlice';
@ -12,6 +11,9 @@ import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
import { ChevronUpIcon } from '@chakra-ui/icons';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { useToggle } from 'react-use';
@ -22,41 +24,41 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import { mode } from 'theme/util/mode';
const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 };
type ControlNetProps = {
controlNet: ControlNetConfig;
controlNetId: string;
};
const ControlNet = (props: ControlNetProps) => {
const {
controlNetId,
isEnabled,
model,
weight,
beginStepPct,
endStepPct,
controlMode,
controlImage,
processedControlImage,
processorNode,
processorType,
shouldAutoConfig,
} = props.controlNet;
const { controlNetId } = props;
const dispatch = useAppDispatch();
const selector = createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, shouldAutoConfig } =
controlNet.controlNets[controlNetId];
return { isEnabled, shouldAutoConfig };
},
defaultSelectorOptions
);
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
const [isExpanded, toggleIsExpanded] = useToggle(false);
const { colorMode } = useColorMode();
const handleDelete = useCallback(() => {
dispatch(controlNetRemoved({ controlNetId }));
}, [controlNetId, dispatch]);
const handleDuplicate = useCallback(() => {
dispatch(
controlNetAdded({ controlNetId: uuidv4(), controlNet: props.controlNet })
controlNetDuplicated({
sourceControlNetId: controlNetId,
newControlNetId: uuidv4(),
})
);
}, [dispatch, props.controlNet]);
}, [controlNetId, dispatch]);
const handleToggleIsEnabled = useCallback(() => {
dispatch(controlNetToggled({ controlNetId }));
@ -68,15 +70,18 @@ const ControlNet = (props: ControlNetProps) => {
flexDir: 'column',
gap: 2,
p: 3,
bg: mode('base.200', 'base.850')(colorMode),
borderRadius: 'base',
position: 'relative',
bg: 'base.200',
_dark: {
bg: 'base.850',
},
}}
>
<Flex sx={{ gap: 2 }}>
<Flex sx={{ gap: 2, alignItems: 'center' }}>
<IAISwitch
tooltip="Toggle"
aria-label="Toggle"
tooltip={'Toggle this ControlNet'}
aria-label={'Toggle this ControlNet'}
isChecked={isEnabled}
onChange={handleToggleIsEnabled}
/>
@ -90,7 +95,7 @@ const ControlNet = (props: ControlNetProps) => {
transitionDuration: '0.1s',
}}
>
<ParamControlNetModel controlNetId={controlNetId} model={model} />
<ParamControlNetModel controlNetId={controlNetId} />
</Box>
<IAIIconButton
size="sm"
@ -109,21 +114,26 @@ const ControlNet = (props: ControlNetProps) => {
/>
<IAIIconButton
size="sm"
aria-label="Show All Options"
tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
onClick={toggleIsExpanded}
variant="link"
icon={
<ChevronUpIcon
sx={{
boxSize: 4,
color: mode('base.700', 'base.300')(colorMode),
color: 'base.700',
transform: isExpanded ? 'rotate(0deg)' : 'rotate(180deg)',
transitionProperty: 'common',
transitionDuration: 'normal',
_dark: {
color: 'base.300',
},
}}
/>
}
/>
{!shouldAutoConfig && (
<Box
sx={{
@ -131,85 +141,59 @@ const ControlNet = (props: ControlNetProps) => {
w: 1.5,
h: 1.5,
borderRadius: 'full',
bg: mode('error.700', 'error.200')(colorMode),
top: 4,
insetInlineEnd: 4,
bg: 'accent.700',
_dark: {
bg: 'accent.400',
},
}}
/>
)}
</Flex>
{isEnabled && (
<>
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
<Flex sx={{ gap: 4, w: 'full' }}>
<Flex
sx={{
flexDir: 'column',
gap: 3,
w: 'full',
paddingInlineStart: 1,
paddingInlineEnd: isExpanded ? 1 : 0,
pb: 2,
justifyContent: 'space-between',
}}
>
<ParamControlNetWeight
controlNetId={controlNetId}
weight={weight}
mini={!isExpanded}
/>
<ParamControlNetBeginEnd
controlNetId={controlNetId}
beginStepPct={beginStepPct}
endStepPct={endStepPct}
mini={!isExpanded}
/>
</Flex>
{!isExpanded && (
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
h: 24,
w: 24,
aspectRatio: '1/1',
}}
>
<ControlNetImagePreview
controlNet={props.controlNet}
height={24}
/>
</Flex>
)}
</Flex>
<ParamControlNetControlMode
controlNetId={controlNetId}
controlMode={controlMode}
/>
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
<Flex
sx={{
flexDir: 'column',
gap: 3,
h: 28,
w: 'full',
paddingInlineStart: 1,
paddingInlineEnd: isExpanded ? 1 : 0,
pb: 2,
justifyContent: 'space-between',
}}
>
<ParamControlNetWeight controlNetId={controlNetId} />
<ParamControlNetBeginEnd controlNetId={controlNetId} />
</Flex>
{isExpanded && (
<>
<Box mt={2}>
<ControlNetImagePreview
controlNet={props.controlNet}
height={96}
/>
</Box>
<ParamControlNetProcessorSelect
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ControlNetProcessorComponent
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ParamControlNetShouldAutoConfig
controlNetId={controlNetId}
shouldAutoConfig={shouldAutoConfig}
/>
</>
{!isExpanded && (
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
h: 28,
w: 28,
aspectRatio: '1/1',
mt: 3,
}}
>
<ControlNetImagePreview controlNetId={controlNetId} height={28} />
</Flex>
)}
</Flex>
<Box mt={2}>
<ParamControlNetControlMode controlNetId={controlNetId} />
</Box>
<ParamControlNetProcessorSelect controlNetId={controlNetId} />
</Flex>
{isExpanded && (
<>
<ControlNetImagePreview controlNetId={controlNetId} height="392px" />
<ParamControlNetShouldAutoConfig controlNetId={controlNetId} />
<ControlNetProcessorComponent controlNetId={controlNetId} />
</>
)}
</Flex>

View File

@ -5,42 +5,57 @@ import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import { memo, useCallback, useMemo, useState } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/thunks/image';
import {
ControlNetConfig,
controlNetImageChanged,
controlNetSelector,
} from '../store/controlNetSlice';
const selector = createSelector(
controlNetSelector,
(controlNet) => {
const { pendingControlImages } = controlNet;
return { pendingControlImages };
},
defaultSelectorOptions
);
import { controlNetImageChanged } from '../store/controlNetSlice';
type Props = {
controlNet: ControlNetConfig;
controlNetId: string;
height: SystemStyleObject['h'];
};
const ControlNetImagePreview = (props: Props) => {
const { height } = props;
const {
controlNetId,
controlImage: controlImageName,
processedControlImage: processedControlImageName,
processorType,
} = props.controlNet;
const { height, controlNetId } = props;
const dispatch = useAppDispatch();
const { pendingControlImages } = useAppSelector(selector);
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { pendingControlImages } = controlNet;
const {
controlImage,
processedControlImage,
processorType,
isEnabled,
} = controlNet.controlNets[controlNetId];
return {
controlImageName: controlImage,
processedControlImageName: processedControlImage,
processorType,
isEnabled,
pendingControlImages,
};
},
defaultSelectorOptions
),
[controlNetId]
);
const {
controlImageName,
processedControlImageName,
processorType,
pendingControlImages,
isEnabled,
} = useAppSelector(selector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
@ -110,13 +125,15 @@ const ControlNetImagePreview = (props: Props) => {
h: height,
alignItems: 'center',
justifyContent: 'center',
pointerEvents: isEnabled ? 'auto' : 'none',
opacity: isEnabled ? 1 : 0.5,
}}
>
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={controlImage}
isDropDisabled={shouldShowProcessedImage}
isDropDisabled={shouldShowProcessedImage || !isEnabled}
onClickReset={handleResetControlImage}
postUploadAction={postUploadAction}
resetTooltip="Reset Control Image"
@ -140,6 +157,7 @@ const ControlNetImagePreview = (props: Props) => {
droppableData={droppableData}
imageDTO={processedControlImage}
isUploadDisabled={true}
isDropDisabled={!isEnabled}
onClickReset={handleResetControlImage}
resetTooltip="Reset Control Image"
withResetIcon={Boolean(controlImage)}

View File

@ -1,10 +1,13 @@
import { memo } from 'react';
import { RequiredControlNetProcessorNode } from '../store/types';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { memo, useMemo } from 'react';
import CannyProcessor from './processors/CannyProcessor';
import HedProcessor from './processors/HedProcessor';
import LineartProcessor from './processors/LineartProcessor';
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
import HedProcessor from './processors/HedProcessor';
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
import LineartProcessor from './processors/LineartProcessor';
import MediapipeFaceProcessor from './processors/MediapipeFaceProcessor';
import MidasDepthProcessor from './processors/MidasDepthProcessor';
import MlsdImageProcessor from './processors/MlsdImageProcessor';
@ -15,23 +18,45 @@ import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
export type ControlNetProcessorProps = {
controlNetId: string;
processorNode: RequiredControlNetProcessorNode;
};
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
const { controlNetId, processorNode } = props;
const { controlNetId } = props;
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, processorNode } =
controlNet.controlNets[controlNetId];
return { isEnabled, processorNode };
},
defaultSelectorOptions
),
[controlNetId]
);
const { isEnabled, processorNode } = useAppSelector(selector);
if (processorNode.type === 'canny_image_processor') {
return (
<CannyProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
if (processorNode.type === 'hed_image_processor') {
return (
<HedProcessor controlNetId={controlNetId} processorNode={processorNode} />
<HedProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -40,6 +65,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<LineartProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -49,6 +75,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<ContentShuffleProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -58,6 +85,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<LineartAnimeProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -67,6 +95,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MediapipeFaceProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -76,6 +105,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MidasDepthProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -85,6 +115,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MlsdImageProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -94,6 +125,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<NormalBaeProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -103,6 +135,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<OpenposeProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -112,6 +145,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<PidiProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}
@ -121,6 +155,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<ZoeDepthProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
);
}

View File

@ -1,18 +1,36 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback, useMemo } from 'react';
type Props = {
controlNetId: string;
shouldAutoConfig: boolean;
};
const ParamControlNetShouldAutoConfig = (props: Props) => {
const { controlNetId, shouldAutoConfig } = props;
const { controlNetId } = props;
const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, shouldAutoConfig } =
controlNet.controlNets[controlNetId];
return { isEnabled, shouldAutoConfig };
},
defaultSelectorOptions
),
[controlNetId]
);
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
const isBusy = useAppSelector(selectIsBusy);
const handleShouldAutoConfigChanged = useCallback(() => {
dispatch(controlNetAutoConfigToggled({ controlNetId }));
}, [controlNetId, dispatch]);
@ -23,7 +41,7 @@ const ParamControlNetShouldAutoConfig = (props: Props) => {
aria-label="Auto configure processor"
isChecked={shouldAutoConfig}
onChange={handleShouldAutoConfigChanged}
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
);
};

View File

@ -1,5 +1,4 @@
import {
ChakraProps,
FormControl,
FormLabel,
HStack,
@ -10,34 +9,41 @@ import {
RangeSliderTrack,
Tooltip,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import {
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
fontWeight: '500',
color: 'base.400',
};
import { memo, useCallback, useMemo } from 'react';
type Props = {
controlNetId: string;
beginStepPct: number;
endStepPct: number;
mini?: boolean;
};
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
const ParamControlNetBeginEnd = (props: Props) => {
const { controlNetId, beginStepPct, mini = false, endStepPct } = props;
const { controlNetId } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { beginStepPct, endStepPct, isEnabled } =
controlNet.controlNets[controlNetId];
return { beginStepPct, endStepPct, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { beginStepPct, endStepPct, isEnabled } = useAppSelector(selector);
const handleStepPctChanged = useCallback(
(v: number[]) => {
@ -55,7 +61,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
}, [controlNetId, dispatch]);
return (
<FormControl>
<FormControl isDisabled={!isEnabled}>
<FormLabel>Begin / End Step Percentage</FormLabel>
<HStack w="100%" gap={2} alignItems="center">
<RangeSlider
@ -66,6 +72,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
max={1}
step={0.01}
minStepsBetweenThumbs={5}
isDisabled={!isEnabled}
>
<RangeSliderTrack>
<RangeSliderFilledTrack />
@ -76,38 +83,33 @@ const ParamControlNetBeginEnd = (props: Props) => {
<Tooltip label={formatPct(endStepPct)} placement="top" hasArrow>
<RangeSliderThumb index={1} />
</Tooltip>
{!mini && (
<>
<RangeSliderMark
value={0}
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}}
>
0%
</RangeSliderMark>
<RangeSliderMark
value={0.5}
sx={{
...SLIDER_MARK_STYLES,
}}
>
50%
</RangeSliderMark>
<RangeSliderMark
value={1}
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}}
>
100%
</RangeSliderMark>
</>
)}
<RangeSliderMark
value={0}
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
}}
>
0%
</RangeSliderMark>
<RangeSliderMark
value={0.5}
sx={{
insetInlineStart: '50% !important',
transform: 'translateX(-50%)',
}}
>
50%
</RangeSliderMark>
<RangeSliderMark
value={1}
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
}}
>
100%
</RangeSliderMark>
</RangeSlider>
</HStack>
</FormControl>

View File

@ -1,15 +1,17 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
ControlModes,
controlNetControlModeChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useCallback } from 'react';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
type ParamControlNetControlModeProps = {
controlNetId: string;
controlMode: string;
};
const CONTROL_MODE_DATA = [
@ -22,8 +24,23 @@ const CONTROL_MODE_DATA = [
export default function ParamControlNetControlMode(
props: ParamControlNetControlModeProps
) {
const { controlNetId, controlMode = false } = props;
const { controlNetId } = props;
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { controlMode, isEnabled } =
controlNet.controlNets[controlNetId];
return { controlMode, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { controlMode, isEnabled } = useAppSelector(selector);
const { t } = useTranslation();
@ -36,7 +53,8 @@ export default function ParamControlNetControlMode(
return (
<IAIMantineSelect
label={t('parameters.controlNetControlMode')}
disabled={!isEnabled}
label="Control Mode"
data={CONTROL_MODE_DATA}
value={String(controlMode)}
onChange={handleControlModeChange}

View File

@ -29,6 +29,9 @@ const ParamControlNetFeatureToggle = () => {
label="Enable ControlNet"
isChecked={isEnabled}
onChange={handleChange}
formControlProps={{
width: '100%',
}}
/>
);
};

View File

@ -1,28 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { controlNetToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetIsEnabledProps = {
controlNetId: string;
isEnabled: boolean;
};
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
const { controlNetId, isEnabled } = props;
const dispatch = useAppDispatch();
const handleIsEnabledChanged = useCallback(() => {
dispatch(controlNetToggled({ controlNetId }));
}, [dispatch, controlNetId]);
return (
<IAISwitch
label="Enabled"
isChecked={isEnabled}
onChange={handleIsEnabledChanged}
/>
);
};
export default memo(ParamControlNetIsEnabled);

View File

@ -1,36 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIFullCheckbox from 'common/components/IAIFullCheckbox';
import IAISwitch from 'common/components/IAISwitch';
import {
controlNetToggled,
isControlNetImagePreprocessedToggled,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetIsEnabledProps = {
controlNetId: string;
isControlImageProcessed: boolean;
};
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
const { controlNetId, isControlImageProcessed } = props;
const dispatch = useAppDispatch();
const handleIsControlImageProcessedToggled = useCallback(() => {
dispatch(
isControlNetImagePreprocessedToggled({
controlNetId,
})
);
}, [controlNetId, dispatch]);
return (
<IAISwitch
label="Preprocess"
isChecked={isControlImageProcessed}
onChange={handleIsControlImageProcessedToggled}
/>
);
};
export default memo(ParamControlNetIsEnabled);

View File

@ -1,59 +1,118 @@
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSearchableSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSearchableSelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import {
CONTROLNET_MODELS,
ControlNetModelName,
} from 'features/controlNet/store/constants';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import { configSelector } from 'features/system/store/configSelectors';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
type ParamControlNetModelProps = {
controlNetId: string;
model: ControlNetModelName;
};
const selector = createSelector(configSelector, (config) => {
const controlNetModels: IAISelectDataType[] = map(CONTROLNET_MODELS, (m) => ({
label: m.label,
value: m.type,
})).filter(
(d) =>
!config.sd.disabledControlNetModels.includes(
d.value as ControlNetModelName
)
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId } = props;
const dispatch = useAppDispatch();
const isBusy = useAppSelector(selectIsBusy);
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ generation, controlNet }) => {
const { model } = generation;
const controlNetModel = controlNet.controlNets[controlNetId]?.model;
const isEnabled = controlNet.controlNets[controlNetId]?.isEnabled;
return { mainModel: model, controlNetModel, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
return controlNetModels;
});
const { mainModel, controlNetModel, isEnabled } = useAppSelector(selector);
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId, model } = props;
const controlNetModels = useAppSelector(selector);
const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const { data: controlNetModels } = useGetControlNetModelsQuery();
const data = useMemo(() => {
if (!controlNetModels) {
return [];
}
const data: SelectItem[] = [];
forEach(controlNetModels.entities, (model, id) => {
if (!model) {
return;
}
const disabled = model?.base_model !== mainModel?.base_model;
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
disabled,
tooltip: disabled
? `Incompatible base model: ${model.base_model}`
: undefined,
});
});
return data;
}, [controlNetModels, mainModel?.base_model]);
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
controlNetModels?.entities[
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
] ?? null,
[
controlNetModel?.base_model,
controlNetModel?.model_name,
controlNetModels?.entities,
]
);
const handleModelChanged = useCallback(
(val: string | null) => {
// TODO: do not cast
const model = val as ControlNetModelName;
dispatch(controlNetModelChanged({ controlNetId, model }));
(v: string | null) => {
if (!v) {
return;
}
const newControlNetModel = modelIdToControlNetModelParam(v);
if (!newControlNetModel) {
return;
}
dispatch(
controlNetModelChanged({ controlNetId, model: newControlNetModel })
);
},
[controlNetId, dispatch]
);
return (
<IAIMantineSearchableSelect
data={controlNetModels}
value={model}
itemComponent={IAIMantineSelectItemWithTooltip}
data={data}
error={
!selectedModel || mainModel?.base_model !== selectedModel.base_model
}
placeholder="Select a model"
value={selectedModel?.id ?? null}
onChange={handleModelChanged}
disabled={!isReady}
tooltip={model}
disabled={isBusy || !isEnabled}
tooltip={selectedModel?.description}
/>
);
};

View File

@ -1,24 +1,22 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSearchableSelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { configSelector } from 'features/system/store/configSelectors';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
import {
ControlNetProcessorNode,
ControlNetProcessorType,
} from '../../store/types';
import { ControlNetProcessorType } from '../../store/types';
import { FormControl, FormLabel } from '@chakra-ui/react';
type ParamControlNetProcessorSelectProps = {
controlNetId: string;
processorNode: ControlNetProcessorNode;
};
const selector = createSelector(
@ -54,10 +52,24 @@ const selector = createSelector(
const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps
) => {
const { controlNetId, processorNode } = props;
const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const { controlNetId } = props;
const processorNodeSelector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, processorNode } =
controlNet.controlNets[controlNetId];
return { isEnabled, processorNode };
},
defaultSelectorOptions
),
[controlNetId]
);
const isBusy = useAppSelector(selectIsBusy);
const controlNetProcessors = useAppSelector(selector);
const { isEnabled, processorNode } = useAppSelector(processorNodeSelector);
const handleProcessorTypeChanged = useCallback(
(v: string | null) => {
@ -77,7 +89,7 @@ const ParamControlNetProcessorSelect = (
value={processorNode.type ?? 'canny_image_processor'}
data={controlNetProcessors}
onChange={handleProcessorTypeChanged}
disabled={!isReady}
disabled={isBusy || !isEnabled}
/>
);
};

View File

@ -1,18 +1,32 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { memo, useCallback, useMemo } from 'react';
type ParamControlNetWeightProps = {
controlNetId: string;
weight: number;
mini?: boolean;
};
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
const { controlNetId, weight, mini = false } = props;
const { controlNetId } = props;
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { weight, isEnabled } = controlNet.controlNets[controlNetId];
return { weight, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { weight, isEnabled } = useAppSelector(selector);
const handleWeightChanged = useCallback(
(weight: number) => {
dispatch(controlNetWeightChanged({ controlNetId, weight }));
@ -22,15 +36,15 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
return (
<IAISlider
isDisabled={!isEnabled}
label={'Weight'}
sliderFormLabelProps={{ pb: 2 }}
value={weight}
onChange={handleWeightChanged}
min={-1}
max={1}
min={0}
max={2}
step={0.01}
withSliderMarks={!mini}
sliderMarks={[-1, 0, 1]}
withSliderMarks
sliderMarks={[0, 1, 2]}
/>
);
};

View File

@ -1,22 +1,25 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation;
type CannyProcessorProps = {
controlNetId: string;
processorNode: RequiredCannyImageProcessorInvocation;
isEnabled: boolean;
};
const CannyProcessor = (props: CannyProcessorProps) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { low_threshold, high_threshold } = processorNode;
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const processorChanged = useProcessorNodeChanged();
const handleLowThresholdChanged = useCallback(
@ -48,7 +51,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
return (
<ProcessorWrapper>
<IAISlider
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
label="Low Threshold"
value={low_threshold}
onChange={handleLowThresholdChanged}
@ -60,7 +63,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
withSliderMarks
/>
<IAISlider
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
label="High Threshold"
value={high_threshold}
onChange={handleHighThresholdChanged}

View File

@ -4,20 +4,23 @@ import { RequiredContentShuffleImageProcessorInvocation } from 'features/control
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsBusy } from 'features/system/store/systemSelectors';
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor
.default as RequiredContentShuffleImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredContentShuffleImageProcessorInvocation;
isEnabled: boolean;
};
const ContentShuffleProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, w, h, f } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -96,7 +99,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@ -108,7 +111,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="W"
@ -120,7 +123,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="H"
@ -132,7 +135,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="F"
@ -144,7 +147,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,25 +1,29 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor
.default as RequiredHedImageProcessorInvocation;
type HedProcessorProps = {
controlNetId: string;
processorNode: RequiredHedImageProcessorInvocation;
isEnabled: boolean;
};
const HedPreprocessor = (props: HedProcessorProps) => {
const {
controlNetId,
processorNode: { detect_resolution, image_resolution, scribble },
isEnabled,
} = props;
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
@ -67,7 +71,7 @@ const HedPreprocessor = (props: HedProcessorProps) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@ -79,13 +83,13 @@ const HedPreprocessor = (props: HedProcessorProps) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISwitch
label="Scribble"
isChecked={scribble}
onChange={handleScribbleChanged}
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor
.default as RequiredLineartAnimeImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredLineartAnimeImageProcessorInvocation;
isEnabled: boolean;
};
const LineartAnimeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -57,7 +60,7 @@ const LineartAnimeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@ -69,7 +72,7 @@ const LineartAnimeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor
.default as RequiredLineartImageProcessorInvocation;
type LineartProcessorProps = {
controlNetId: string;
processorNode: RequiredLineartImageProcessorInvocation;
isEnabled: boolean;
};
const LineartProcessor = (props: LineartProcessorProps) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, coarse } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -65,7 +68,7 @@ const LineartProcessor = (props: LineartProcessorProps) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@ -77,13 +80,13 @@ const LineartProcessor = (props: LineartProcessorProps) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISwitch
label="Coarse"
isChecked={coarse}
onChange={handleCoarseChanged}
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor
.default as RequiredMediapipeFaceProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredMediapipeFaceProcessorInvocation;
isEnabled: boolean;
};
const MediapipeFaceProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { max_faces, min_confidence } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleMaxFacesChanged = useCallback(
(v: number) => {
@ -53,7 +56,7 @@ const MediapipeFaceProcessor = (props: Props) => {
max={20}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Min Confidence"
@ -66,7 +69,7 @@ const MediapipeFaceProcessor = (props: Props) => {
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor
.default as RequiredMidasDepthImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredMidasDepthImageProcessorInvocation;
isEnabled: boolean;
};
const MidasDepthProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { a_mult, bg_th } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleAMultChanged = useCallback(
(v: number) => {
@ -54,7 +57,7 @@ const MidasDepthProcessor = (props: Props) => {
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="bg_th"
@ -67,7 +70,7 @@ const MidasDepthProcessor = (props: Props) => {
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor
.default as RequiredMlsdImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredMlsdImageProcessorInvocation;
isEnabled: boolean;
};
const MlsdImageProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -79,7 +82,7 @@ const MlsdImageProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@ -91,7 +94,7 @@ const MlsdImageProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="W"
@ -104,7 +107,7 @@ const MlsdImageProcessor = (props: Props) => {
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="H"
@ -117,7 +120,7 @@ const MlsdImageProcessor = (props: Props) => {
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor
.default as RequiredNormalbaeImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredNormalbaeImageProcessorInvocation;
isEnabled: boolean;
};
const NormalBaeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -57,7 +60,7 @@ const NormalBaeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@ -69,7 +72,7 @@ const NormalBaeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor
.default as RequiredOpenposeImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredOpenposeImageProcessorInvocation;
isEnabled: boolean;
};
const OpenposeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, hand_and_face } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -65,7 +68,7 @@ const OpenposeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@ -77,13 +80,13 @@ const OpenposeProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISwitch
label="Hand and Face"
isChecked={hand_and_face}
onChange={handleHandAndFaceChanged}
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default;
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor
.default as RequiredPidiImageProcessorInvocation;
type Props = {
controlNetId: string;
processorNode: RequiredPidiImageProcessorInvocation;
isEnabled: boolean;
};
const PidiProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, scribble, safe } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -72,7 +75,7 @@ const PidiProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISlider
label="Image Resolution"
@ -84,7 +87,7 @@ const PidiProcessor = (props: Props) => {
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
<IAISwitch
label="Scribble"
@ -95,7 +98,7 @@ const PidiProcessor = (props: Props) => {
label="Safe"
isChecked={safe}
onChange={handleSafeChanged}
isDisabled={!isReady}
isDisabled={isBusy || !isEnabled}
/>
</ProcessorWrapper>
);

View File

@ -4,6 +4,7 @@ import { memo } from 'react';
type Props = {
controlNetId: string;
processorNode: RequiredZoeDepthImageProcessorInvocation;
isEnabled: boolean;
};
const ZoeDepthProcessor = (props: Props) => {

View File

@ -173,91 +173,17 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
},
};
type ControlNetModelsDict = Record<string, ControlNetModel>;
type ControlNetModel = {
type: string;
label: string;
description?: string;
defaultProcessor?: ControlNetProcessorType;
export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
[key: string]: ControlNetProcessorType;
} = {
canny: 'canny_image_processor',
mlsd: 'mlsd_image_processor',
depth: 'midas_depth_image_processor',
bae: 'normalbae_image_processor',
lineart: 'lineart_image_processor',
lineart_anime: 'lineart_anime_image_processor',
softedge: 'hed_image_processor',
shuffle: 'content_shuffle_image_processor',
openpose: 'openpose_image_processor',
mediapipe: 'mediapipe_face_processor',
};
export const CONTROLNET_MODELS: ControlNetModelsDict = {
'lllyasviel/control_v11p_sd15_canny': {
type: 'lllyasviel/control_v11p_sd15_canny',
label: 'Canny',
defaultProcessor: 'canny_image_processor',
},
'lllyasviel/control_v11p_sd15_inpaint': {
type: 'lllyasviel/control_v11p_sd15_inpaint',
label: 'Inpaint',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_mlsd': {
type: 'lllyasviel/control_v11p_sd15_mlsd',
label: 'M-LSD',
defaultProcessor: 'mlsd_image_processor',
},
'lllyasviel/control_v11f1p_sd15_depth': {
type: 'lllyasviel/control_v11f1p_sd15_depth',
label: 'Depth',
defaultProcessor: 'midas_depth_image_processor',
},
'lllyasviel/control_v11p_sd15_normalbae': {
type: 'lllyasviel/control_v11p_sd15_normalbae',
label: 'Normal Map (BAE)',
defaultProcessor: 'normalbae_image_processor',
},
'lllyasviel/control_v11p_sd15_seg': {
type: 'lllyasviel/control_v11p_sd15_seg',
label: 'Segmentation',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_lineart': {
type: 'lllyasviel/control_v11p_sd15_lineart',
label: 'Lineart',
defaultProcessor: 'lineart_image_processor',
},
'lllyasviel/control_v11p_sd15s2_lineart_anime': {
type: 'lllyasviel/control_v11p_sd15s2_lineart_anime',
label: 'Lineart Anime',
defaultProcessor: 'lineart_anime_image_processor',
},
'lllyasviel/control_v11p_sd15_scribble': {
type: 'lllyasviel/control_v11p_sd15_scribble',
label: 'Scribble',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_softedge': {
type: 'lllyasviel/control_v11p_sd15_softedge',
label: 'Soft Edge',
defaultProcessor: 'hed_image_processor',
},
'lllyasviel/control_v11e_sd15_shuffle': {
type: 'lllyasviel/control_v11e_sd15_shuffle',
label: 'Content Shuffle',
defaultProcessor: 'content_shuffle_image_processor',
},
'lllyasviel/control_v11p_sd15_openpose': {
type: 'lllyasviel/control_v11p_sd15_openpose',
label: 'Openpose',
defaultProcessor: 'openpose_image_processor',
},
'lllyasviel/control_v11f1e_sd15_tile': {
type: 'lllyasviel/control_v11f1e_sd15_tile',
label: 'Tile (experimental)',
defaultProcessor: 'none',
},
'lllyasviel/control_v11e_sd15_ip2p': {
type: 'lllyasviel/control_v11e_sd15_ip2p',
label: 'Pix2Pix (experimental)',
defaultProcessor: 'none',
},
'CrucibleAI/ControlNetMediaPipeFace': {
type: 'CrucibleAI/ControlNetMediaPipeFace',
label: 'Mediapipe Face',
defaultProcessor: 'mediapipe_face_processor',
},
};
export type ControlNetModelName = keyof typeof CONTROLNET_MODELS;

View File

@ -1,22 +1,20 @@
import { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { ImageDTO } from 'services/api/types';
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
import { cloneDeep, forEach } from 'lodash-es';
import { imageDeleted } from 'services/api/thunks/image';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
CONTROLNET_PROCESSORS,
} from './constants';
import {
ControlNetProcessorType,
RequiredCannyImageProcessorInvocation,
RequiredControlNetProcessorNode,
} from './types';
import {
CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
ControlNetModelName,
} from './constants';
import { controlNetImageProcessed } from './actions';
import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image';
import { forEach } from 'lodash-es';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
export type ControlModes =
| 'balanced'
@ -26,7 +24,7 @@ export type ControlModes =
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true,
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
@ -42,7 +40,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
export type ControlNetConfig = {
controlNetId: string;
isEnabled: boolean;
model: ControlNetModelName;
model: ControlNetModelParam | null;
weight: number;
beginStepPct: number;
endStepPct: number;
@ -86,6 +84,19 @@ export const controlNetSlice = createSlice({
controlNetId,
};
},
controlNetDuplicated: (
state,
action: PayloadAction<{
sourceControlNetId: string;
newControlNetId: string;
}>
) => {
const { sourceControlNetId, newControlNetId } = action.payload;
const newControlnet = cloneDeep(state.controlNets[sourceControlNetId]);
newControlnet.controlNetId = newControlNetId;
state.controlNets[newControlNetId] = newControlnet;
},
controlNetAddedFromImage: (
state,
action: PayloadAction<{ controlNetId: string; controlImage: string }>
@ -147,7 +158,7 @@ export const controlNetSlice = createSlice({
state,
action: PayloadAction<{
controlNetId: string;
model: ControlNetModelName;
model: ControlNetModelParam;
}>
) => {
const { controlNetId, model } = action.payload;
@ -155,7 +166,15 @@ export const controlNetSlice = createSlice({
state.controlNets[controlNetId].processedControlImage = null;
if (state.controlNets[controlNetId].shouldAutoConfig) {
const processorType = CONTROLNET_MODELS[model].defaultProcessor;
let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (model.model_name.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
@ -241,9 +260,19 @@ export const controlNetSlice = createSlice({
if (newShouldAutoConfig) {
// manage the processor for the user
const processorType =
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
.defaultProcessor;
let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (
state.controlNets[controlNetId].model?.model_name.includes(
modelSubstring
)
) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
@ -272,7 +301,8 @@ export const controlNetSlice = createSlice({
});
builder.addCase(imageDeleted.pending, (state, action) => {
// Preemptively remove the image from the gallery
// Preemptively remove the image from all controlnets
// TODO: doesn't the imageusage stuff do this for us?
const { image_name } = action.meta.arg;
forEach(state.controlNets, (c) => {
if (c.controlImage === image_name) {
@ -285,21 +315,6 @@ export const controlNetSlice = createSlice({
});
});
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload;
// forEach(state.controlNets, (c) => {
// if (c.controlImage?.image_name === image_name) {
// c.controlImage.image_url = image_url;
// c.controlImage.thumbnail_url = thumbnail_url;
// }
// if (c.processedControlImage?.image_name === image_name) {
// c.processedControlImage.image_url = image_url;
// c.processedControlImage.thumbnail_url = thumbnail_url;
// }
// });
// });
builder.addCase(appSocketInvocationError, (state, action) => {
state.pendingControlImages = [];
});
@ -313,6 +328,7 @@ export const controlNetSlice = createSlice({
export const {
isControlNetEnabledToggled,
controlNetAdded,
controlNetDuplicated,
controlNetAddedFromImage,
controlNetRemoved,
controlNetImageChanged,

View File

@ -7,6 +7,7 @@ import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent';
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
@ -174,6 +175,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
);
}
if (type === 'controlnet_model' && template.type === 'controlnet_model') {
return (
<ControlNetModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'array' && template.type === 'array') {
return (
<ArrayInputFieldComponent

View File

@ -0,0 +1,103 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
ControlNetModelInputFieldTemplate,
ControlNetModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const ControlNetModelInputFieldComponent = (
props: FieldComponentProps<
ControlNetModelInputFieldValue,
ControlNetModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const controlNetModel = field.value;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: controlNetModels } = useGetControlNetModelsQuery();
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
controlNetModels?.entities[
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
] ?? null,
[
controlNetModel?.base_model,
controlNetModel?.model_name,
controlNetModels?.entities,
]
);
const data = useMemo(() => {
if (!controlNetModels) {
return [];
}
const data: SelectItem[] = [];
forEach(controlNetModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [controlNetModels]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newControlNetModel = modelIdToControlNetModelParam(v);
if (!newControlNetModel) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: newControlNetModel,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
}
value={selectedModel?.id ?? null}
placeholder="Pick one"
error={!selectedModel}
data={data}
onChange={handleValueChanged}
/>
);
};
export default memo(ControlNetModelInputFieldComponent);

View File

@ -1,6 +1,7 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import {
ControlNetModelParam,
LoRAModelParam,
MainModelParam,
VaeModelParam,
@ -81,7 +82,8 @@ const nodesSlice = createSlice({
| ImageField[]
| MainModelParam
| VaeModelParam
| LoRAModelParam;
| LoRAModelParam
| ControlNetModelParam;
}>
) => {
const { nodeId, fieldName, value } = action.payload;

View File

@ -19,6 +19,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
model: 'model',
vae_model: 'vae_model',
lora_model: 'lora_model',
controlnet_model: 'controlnet_model',
ControlNetModelField: 'controlnet_model',
array: 'array',
item: 'item',
ColorField: 'color',
@ -130,6 +132,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'LoRA',
description: 'Models are models.',
},
controlnet_model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
title: 'ControlNet',
description: 'Models are models.',
},
array: {
color: 'gray',
colorCssVar: getColorTokenCssVariable('gray'),

View File

@ -1,4 +1,5 @@
import {
ControlNetModelParam,
LoRAModelParam,
MainModelParam,
VaeModelParam,
@ -71,6 +72,7 @@ export type FieldType =
| 'model'
| 'vae_model'
| 'lora_model'
| 'controlnet_model'
| 'array'
| 'item'
| 'color'
@ -100,6 +102,7 @@ export type InputFieldValue =
| MainModelInputFieldValue
| VaeModelInputFieldValue
| LoRAModelInputFieldValue
| ControlNetModelInputFieldValue
| ArrayInputFieldValue
| ItemInputFieldValue
| ColorInputFieldValue
@ -127,6 +130,7 @@ export type InputFieldTemplate =
| ModelInputFieldTemplate
| VaeModelInputFieldTemplate
| LoRAModelInputFieldTemplate
| ControlNetModelInputFieldTemplate
| ArrayInputFieldTemplate
| ItemInputFieldTemplate
| ColorInputFieldTemplate
@ -249,6 +253,11 @@ export type LoRAModelInputFieldValue = FieldValueBase & {
value?: LoRAModelParam;
};
export type ControlNetModelInputFieldValue = FieldValueBase & {
type: 'controlnet_model';
value?: ControlNetModelParam;
};
export type ArrayInputFieldValue = FieldValueBase & {
type: 'array';
value?: (string | number)[];
@ -368,6 +377,11 @@ export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'lora_model';
};
export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'controlnet_model';
};
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
default: [];
type: 'array';

View File

@ -9,6 +9,7 @@ import {
ColorInputFieldTemplate,
ConditioningInputFieldTemplate,
ControlInputFieldTemplate,
ControlNetModelInputFieldTemplate,
EnumInputFieldTemplate,
FieldType,
FloatInputFieldTemplate,
@ -207,6 +208,21 @@ const buildLoRAModelInputFieldTemplate = ({
return template;
};
const buildControlNetModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ControlNetModelInputFieldTemplate => {
const template: ControlNetModelInputFieldTemplate = {
...baseField,
type: 'controlnet_model',
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageInputFieldTemplate = ({
schemaObject,
baseField,
@ -479,6 +495,9 @@ export const buildInputFieldTemplate = (
if (['lora_model'].includes(fieldType)) {
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
}
if (['controlnet_model'].includes(fieldType)) {
return buildControlNetModelInputFieldTemplate({ schemaObject, baseField });
}
if (['enum'].includes(fieldType)) {
return buildEnumInputFieldTemplate({ schemaObject, baseField });
}

View File

@ -83,6 +83,10 @@ export const buildInputFieldValue = (
if (template.type === 'lora_model') {
fieldValue.value = undefined;
}
if (template.type === 'controlnet_model') {
fieldValue.value = undefined;
}
}
return fieldValue;

View File

@ -2,12 +2,13 @@ import { Divider, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAICollapse from 'common/components/IAICollapse';
import IAIIconButton from 'common/components/IAIIconButton';
import ControlNet from 'features/controlNet/components/ControlNet';
import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle';
import {
controlNetAdded,
controlNetModelChanged,
controlNetSelector,
} from 'features/controlNet/store/controlNetSlice';
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
@ -15,6 +16,8 @@ import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { map } from 'lodash-es';
import { Fragment, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaPlus } from 'react-icons/fa';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import { v4 as uuidv4 } from 'uuid';
const selector = createSelector(
@ -39,10 +42,23 @@ const ParamControlNetCollapse = () => {
const { controlNetsArray, activeLabel } = useAppSelector(selector);
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
const dispatch = useAppDispatch();
const { firstModel } = useGetControlNetModelsQuery(undefined, {
selectFromResult: (result) => {
const firstModel = result.data?.entities[result.data?.ids[0]];
return {
firstModel,
};
},
});
const handleClickedAddControlNet = useCallback(() => {
dispatch(controlNetAdded({ controlNetId: uuidv4() }));
}, [dispatch]);
if (!firstModel) {
return;
}
const controlNetId = uuidv4();
dispatch(controlNetAdded({ controlNetId }));
dispatch(controlNetModelChanged({ controlNetId, model: firstModel }));
}, [dispatch, firstModel]);
if (isControlNetDisabled) {
return null;
@ -51,16 +67,39 @@ const ParamControlNetCollapse = () => {
return (
<IAICollapse label="ControlNet" activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 3 }}>
<ParamControlNetFeatureToggle />
<Flex gap={2} alignItems="center">
<Flex
sx={{
flexDirection: 'column',
w: '100%',
gap: 2,
px: 4,
py: 2,
borderRadius: 4,
bg: 'base.200',
_dark: {
bg: 'base.850',
},
}}
>
<ParamControlNetFeatureToggle />
</Flex>
<IAIIconButton
tooltip="Add ControlNet"
aria-label="Add ControlNet"
icon={<FaPlus />}
isDisabled={!firstModel}
flexGrow={1}
size="md"
onClick={handleClickedAddControlNet}
/>
</Flex>
{controlNetsArray.map((c, i) => (
<Fragment key={c.controlNetId}>
{i > 0 && <Divider />}
<ControlNet controlNet={c} />
<ControlNet controlNetId={c.controlNetId} />
</Fragment>
))}
<IAIButton flexGrow={1} onClick={handleClickedAddControlNet}>
Add ControlNet
</IAIButton>
</Flex>
</IAICollapse>
);

View File

@ -37,6 +37,7 @@ const ParamVAEModelSelect = () => {
return [];
}
// add a "default" option, this means use the main model's included VAE
const data: SelectItem[] = [
{
value: 'default',

View File

@ -17,8 +17,10 @@ import { FaPlay } from 'react-icons/fa';
const IN_PROGRESS_STYLES: ChakraProps['sx'] = {
_disabled: {
bg: 'none',
color: 'base.600',
cursor: 'not-allowed',
_hover: {
color: 'base.600',
bg: 'none',
},
},

View File

@ -180,6 +180,23 @@ export type LoRAModelParam = z.infer<typeof zLoRAModel>;
*/
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
zLoRAModel.safeParse(val).success;
/**
* Zod schema for ControlNet models
*/
export const zControlNetModel = z.object({
model_name: z.string().min(1),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type ControlNetModelParam = z.infer<typeof zLoRAModel>;
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidControlNetModel = (
val: unknown
): val is ControlNetModelParam => zControlNetModel.safeParse(val).success;
/**
* Zod schema for l2l strength parameter

View File

@ -0,0 +1,30 @@
import { log } from 'app/logging/useLogger';
import { zControlNetModel } from 'features/parameters/types/parameterSchemas';
import { ControlNetModelField } from 'services/api/types';
const moduleLog = log.child({ module: 'models' });
export const modelIdToControlNetModelParam = (
controlNetModelId: string
): ControlNetModelField | undefined => {
const [base_model, model_type, model_name] = controlNetModelId.split('/');
const result = zControlNetModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
moduleLog.error(
{
controlNetModelId,
errors: result.error.format(),
},
'Failed to parse ControlNet model id'
);
return;
}
return result.data;
};

View File

@ -1,9 +1,12 @@
import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToLoRAModelParam = (
loraId: string
loraModelId: string
): LoRAModelParam | undefined => {
const [base_model, model_type, model_name] = loraId.split('/');
const [base_model, model_type, model_name] = loraModelId.split('/');
const result = zLoRAModel.safeParse({
base_model,
@ -11,6 +14,13 @@ export const modelIdToLoRAModelParam = (
});
if (!result.success) {
moduleLog.error(
{
loraModelId,
errors: result.error.format(),
},
'Failed to parse LoRA model id'
);
return;
}

View File

@ -2,11 +2,14 @@ import {
MainModelParam,
zMainModel,
} from 'features/parameters/types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToMainModelParam = (
modelId: string
mainModelId: string
): MainModelParam | undefined => {
const [base_model, model_type, model_name] = modelId.split('/');
const [base_model, model_type, model_name] = mainModelId.split('/');
const result = zMainModel.safeParse({
base_model,
@ -14,6 +17,13 @@ export const modelIdToMainModelParam = (
});
if (!result.success) {
moduleLog.error(
{
mainModelId,
errors: result.error.format(),
},
'Failed to parse main model id'
);
return;
}

View File

@ -1,9 +1,12 @@
import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToVAEModelParam = (
modelId: string
vaeModelId: string
): VaeModelParam | undefined => {
const [base_model, model_type, model_name] = modelId.split('/');
const [base_model, model_type, model_name] = vaeModelId.split('/');
const result = zVaeModel.safeParse({
base_model,
@ -11,6 +14,13 @@ export const modelIdToVAEModelParam = (
});
if (!result.success) {
moduleLog.error(
{
vaeModelId,
errors: result.error.format(),
},
'Failed to parse VAE model id'
);
return;
}

View File

@ -19,9 +19,9 @@ const ImageToImageTabParameters = () => {
<ParamNegativeConditioning />
<ProcessButtons />
<ImageToImageTabCoreParameters />
<ParamControlNetCollapse />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse />
<ParamNoiseCollapse />
<ParamSymmetryCollapse />

View File

@ -20,9 +20,9 @@ const TextToImageTabParameters = () => {
<ParamNegativeConditioning />
<ProcessButtons />
<TextToImageTabCoreParameters />
<ParamControlNetCollapse />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse />
<ParamNoiseCollapse />
<ParamSymmetryCollapse />

View File

@ -19,9 +19,9 @@ const UnifiedCanvasParameters = () => {
<ParamNegativeConditioning />
<ProcessButtons />
<UnifiedCanvasCoreParameters />
<ParamControlNetCollapse />
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse />
<ParamSymmetryCollapse />
<ParamSeamCorrectionCollapse />

View File

@ -734,7 +734,7 @@ export type components = {
* Control Model
* @description The ControlNet model to use
*/
control_model: string;
control_model: components["schemas"]["ControlNetModelField"];
/**
* Control Weight
* @description The weight given to the ControlNet
@ -792,9 +792,8 @@ export type components = {
* Control Model
* @description control model used
* @default lllyasviel/sd-controlnet-canny
* @enum {string}
*/
control_model?: "lllyasviel/sd-controlnet-canny" | "lllyasviel/sd-controlnet-depth" | "lllyasviel/sd-controlnet-hed" | "lllyasviel/sd-controlnet-seg" | "lllyasviel/sd-controlnet-openpose" | "lllyasviel/sd-controlnet-scribble" | "lllyasviel/sd-controlnet-normal" | "lllyasviel/sd-controlnet-mlsd" | "lllyasviel/control_v11p_sd15_canny" | "lllyasviel/control_v11p_sd15_openpose" | "lllyasviel/control_v11p_sd15_seg" | "lllyasviel/control_v11f1p_sd15_depth" | "lllyasviel/control_v11p_sd15_normalbae" | "lllyasviel/control_v11p_sd15_scribble" | "lllyasviel/control_v11p_sd15_mlsd" | "lllyasviel/control_v11p_sd15_softedge" | "lllyasviel/control_v11p_sd15s2_lineart_anime" | "lllyasviel/control_v11p_sd15_lineart" | "lllyasviel/control_v11p_sd15_inpaint" | "lllyasviel/control_v11e_sd15_shuffle" | "lllyasviel/control_v11e_sd15_ip2p" | "lllyasviel/control_v11f1e_sd15_tile" | "thibaud/controlnet-sd21-openpose-diffusers" | "thibaud/controlnet-sd21-canny-diffusers" | "thibaud/controlnet-sd21-depth-diffusers" | "thibaud/controlnet-sd21-scribble-diffusers" | "thibaud/controlnet-sd21-hed-diffusers" | "thibaud/controlnet-sd21-zoedepth-diffusers" | "thibaud/controlnet-sd21-color-diffusers" | "thibaud/controlnet-sd21-openposev2-diffusers" | "thibaud/controlnet-sd21-lineart-diffusers" | "thibaud/controlnet-sd21-normalbae-diffusers" | "thibaud/controlnet-sd21-ade20k-diffusers" | "CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15" | "CrucibleAI/ControlNetMediaPipeFace";
control_model?: components["schemas"]["ControlNetModelField"];
/**
* Control Weight
* @description The weight given to the ControlNet
@ -838,6 +837,19 @@ export type components = {
model_format: components["schemas"]["ControlNetModelFormat"];
error?: components["schemas"]["ModelError"];
};
/**
* ControlNetModelField
* @description ControlNet model field
*/
ControlNetModelField: {
/**
* Model Name
* @description Name of the ControlNet model
*/
model_name: string;
/** @description Base model */
base_model: components["schemas"]["BaseModelType"];
};
/**
* ControlNetModelFormat
* @description An enumeration.
@ -1923,12 +1935,12 @@ export type components = {
* Width
* @description The width to resize to (px)
*/
width: number;
width?: number;
/**
* Height
* @description The height to resize to (px)
*/
height: number;
height?: number;
/**
* Resample Mode
* @description The resampling mode
@ -3911,13 +3923,15 @@ export type components = {
/**
* Width
* @description The width to resize to (px)
* @default 512
*/
width: number;
width?: number;
/**
* Height
* @description The height to resize to (px)
* @default 512
*/
height: number;
height?: number;
/**
* Mode
* @description The interpolation mode
@ -4605,18 +4619,18 @@ export type components = {
*/
image?: components["schemas"]["ImageField"];
};
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;

View File

@ -32,6 +32,8 @@ export type BaseModelType = components['schemas']['BaseModelType'];
export type MainModelField = components['schemas']['MainModelField'];
export type VAEModelField = components['schemas']['VAEModelField'];
export type LoRAModelField = components['schemas']['LoRAModelField'];
export type ControlNetModelField =
components['schemas']['ControlNetModelField'];
export type ModelsList = components['schemas']['ModelsList'];
export type ControlField = components['schemas']['ControlField'];

View File

@ -30,7 +30,7 @@ const invokeAIThumb = defineStyle((props) => {
const invokeAIMark = defineStyle((props) => {
return {
fontSize: 'xs',
fontSize: '2xs',
fontWeight: '500',
color: mode('base.700', 'base.400')(props),
mt: 2,