merge with main

This commit is contained in:
Lincoln Stein 2023-06-06 22:18:41 -04:00
commit 1f9e1eb964
198 changed files with 5781 additions and 1409 deletions

171
docs/features/LOGGING.md Normal file
View File

@ -0,0 +1,171 @@
---
title: Controlling Logging
---
# :material-image-off: Controlling Logging
## Controlling How InvokeAI Logs Status Messages
InvokeAI logs status messages using a configurable logging system. You
can log to the terminal window, to a designated file on the local
machine, to the syslog facility on a Linux or Mac, or to a properly
configured web server. You can configure several logs at the same
time, and control the level of message logged and the logging format
(to a limited extent).
Three command-line options control logging:
### `--log_handlers <handler1> <handler2> ...`
This option activates one or more log handlers. Options are "console",
"file", "syslog" and "http". To specify more than one, separate them
by spaces:
```bash
invokeai-web --log_handlers console syslog=/dev/log file=C:\Users\fred\invokeai.log
```
The format of these options is described below.
### `--log_format {plain|color|legacy|syslog}`
This controls the format of log messages written to the console. Only
the "console" log handler is currently affected by this setting.
* "plain" provides formatted messages like this:
```bash
[2023-05-24 23:18:2[2023-05-24 23:18:50,352]::[InvokeAI]::DEBUG --> this is a debug message
[2023-05-24 23:18:50,352]::[InvokeAI]::INFO --> this is an informational messages
[2023-05-24 23:18:50,352]::[InvokeAI]::WARNING --> this is a warning
[2023-05-24 23:18:50,352]::[InvokeAI]::ERROR --> this is an error
[2023-05-24 23:18:50,352]::[InvokeAI]::CRITICAL --> this is a critical error
```
* "color" produces similar output, but the text will be color coded to
indicate the severity of the message.
* "legacy" produces output similar to InvokeAI versions 2.3 and earlier:
```bash
### this is a critical error
*** this is an error
** this is a warning
>> this is an informational messages
| this is a debug message
```
* "syslog" produces messages suitable for syslog entries:
```bash
InvokeAI [2691178] <CRITICAL> this is a critical error
InvokeAI [2691178] <ERROR> this is an error
InvokeAI [2691178] <WARNING> this is a warning
InvokeAI [2691178] <INFO> this is an informational messages
InvokeAI [2691178] <DEBUG> this is a debug message
```
(note that the date, time and hostname will be added by the syslog
system)
### `--log_level {debug|info|warning|error|critical}`
Providing this command-line option will cause only messages at the
specified level or above to be emitted.
## Console logging
When "console" is provided to `--log_handlers`, messages will be
written to the command line window in which InvokeAI was launched. By
default, the color formatter will be used unless overridden by
`--log_format`.
## File logging
When "file" is provided to `--log_handlers`, entries will be written
to the file indicated in the path argument. By default, the "plain"
format will be used:
```bash
invokeai-web --log_handlers file=/var/log/invokeai.log
```
## Syslog logging
When "syslog" is requested, entries will be sent to the syslog
system. There are a variety of ways to control where the log message
is sent:
* Send to the local machine using the `/dev/log` socket:
```
invokeai-web --log_handlers syslog=/dev/log
```
* Send to the local machine using a UDP message:
```
invokeai-web --log_handlers syslog=localhost
```
* Send to the local machine using a UDP message on a nonstandard
port:
```
invokeai-web --log_handlers syslog=localhost:512
```
* Send to a remote machine named "loghost" on the local LAN using
facility LOG_USER and UDP packets:
```
invokeai-web --log_handlers syslog=loghost,facility=LOG_USER,socktype=SOCK_DGRAM
```
This can be abbreviated `syslog=loghost`, as LOG_USER and SOCK_DGRAM
are defaults.
* Send to a remote machine named "loghost" using the facility LOCAL0
and using a TCP socket:
```
invokeai-web --log_handlers syslog=loghost,facility=LOG_LOCAL0,socktype=SOCK_STREAM
```
If no arguments are specified (just a bare "syslog"), then the logging
system will look for a UNIX socket named `/dev/log`, and if not found
try to send a UDP message to `localhost`. The Macintosh OS used to
support logging to a socket named `/var/run/syslog`, but this feature
has since been disabled.
## Web logging
If you have access to a web server that is configured to log messages
when a particular URL is requested, you can log using the "http"
method:
```
invokeai-web --log_handlers http=http://my.server/path/to/logger,method=POST
```
The optional [,method=] part can be used to specify whether the URL
accepts GET (default) or POST messages.
Currently password authentication and SSL are not supported.
## Using the configuration file
You can set and forget logging options by adding a "Logging" section
to `invokeai.yaml`:
```
InvokeAI:
[... other settings...]
Logging:
log_handlers:
- console
- syslog=/dev/log
log_level: info
log_format: color
```

View File

@ -57,6 +57,9 @@ Personalize models by adding your own style or subjects.
## * [The NSFW Checker](NSFW.md) ## * [The NSFW Checker](NSFW.md)
Prevent InvokeAI from displaying unwanted racy images. Prevent InvokeAI from displaying unwanted racy images.
## * [Controlling Logging](LOGGING.md)
Control how InvokeAI logs status messages.
## * [Miscellaneous](OTHER.md) ## * [Miscellaneous](OTHER.md)
Run InvokeAI on Google Colab, generate images with repeating patterns, Run InvokeAI on Google Colab, generate images with repeating patterns,
batch process a file of prompts, increase the "creativity" of image batch process a file of prompts, increase the "creativity" of image

View File

@ -39,7 +39,8 @@ socket_io = SocketIO(app)
# initialize config # initialize config
# this is a module global # this is a module global
app_config = InvokeAIAppConfig() app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
# Add startup event to load dependencies # Add startup event to load dependencies
@app.on_event("startup") @app.on_event("startup")

View File

@ -39,7 +39,7 @@ from .services.model_manager_service import ModelManagerService
from .services.processor import DefaultInvocationProcessor from .services.processor import DefaultInvocationProcessor
from .services.restoration_services import RestorationServices from .services.restoration_services import RestorationServices
from .services.sqlite import SqliteItemStorage from .services.sqlite import SqliteItemStorage
from .services.config import InvokeAIAppConfig
class CliCommand(BaseModel): class CliCommand(BaseModel):
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
@ -198,7 +198,8 @@ logger = logger.InvokeAILogger.getLogger()
def invoke_cli(): def invoke_cli():
# this gets the basic configuration # this gets the basic configuration
config = get_invokeai_config() config = InvokeAIAppConfig.get_config()
config.parse_args()
# get the optional list of invocations to execute on the command line # get the optional list of invocations to execute on the command line
parser = config.get_parser() parser = config.get_parser()

View File

@ -4,12 +4,10 @@ from contextlib import ExitStack
import re import re
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from .model import ClipField from .model import ClipField
from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.util.devices import torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
from ...backend.model_management import SDModelType from ...backend.model_management import SDModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
@ -18,7 +16,7 @@ from compel.prompt_parser import (
Blend, Blend,
CrossAttentionControlSubstitute, CrossAttentionControlSubstitute,
FlattenedPrompt, FlattenedPrompt,
Fragment, Fragment, Conjunction,
) )
@ -81,7 +79,7 @@ class CompelInvocation(BaseInvocation):
context.services.model_manager.get_model(model_name=name, model_type=SDModelType.TextualInversion) context.services.model_manager.get_model(model_name=name, model_type=SDModelType.TextualInversion)
) )
) )
except Exception as e: except Exception:
#print(e) #print(e)
#import traceback #import traceback
#print(traceback.format_exc()) #print(traceback.format_exc())
@ -98,7 +96,6 @@ class CompelInvocation(BaseInvocation):
truncate_long_prompts=True, # TODO: truncate_long_prompts=True, # TODO:
) )
conjunction = Compel.parse_prompt_string(self.prompt) conjunction = Compel.parse_prompt_string(self.prompt)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
@ -106,16 +103,15 @@ class CompelInvocation(BaseInvocation):
log_tokenization_for_prompt_object(prompt, tokenizer) log_tokenization_for_prompt_object(prompt, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
# TODO: long prompt support # TODO: long prompt support
#if not self.truncate_long_prompts: #if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt), tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None), cross_attention_control_args=options.get("cross_attention_control", None),
) )
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow? # TODO: hacky but works ;D maybe rename latents somehow?
@ -129,14 +125,22 @@ class CompelInvocation(BaseInvocation):
def get_max_token_count( def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
) -> int: ) -> int:
if type(prompt) is Blend: if type(prompt) is Blend:
blend: Blend = prompt blend: Blend = prompt
return max( return max(
[ [
get_max_token_count(tokenizer, c, truncate_if_too_long) get_max_token_count(tokenizer, p, truncate_if_too_long)
for c in blend.prompts for p in blend.prompts
]
)
elif type(prompt) is Conjunction:
conjunction: Conjunction = prompt
return sum(
[
get_max_token_count(tokenizer, p, truncate_if_too_long)
for p in conjunction.prompts
] ]
) )
else: else:
@ -171,6 +175,22 @@ def get_tokens_for_prompt_object(
return tokens return tokens
def log_tokenization_for_conjunction(
c: Conjunction, tokenizer, display_label_prefix=None
):
display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts):
if len(c.prompts)>1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else:
this_display_label_prefix = display_label_prefix
log_tokenization_for_prompt_object(
p,
tokenizer,
display_label_prefix=this_display_label_prefix
)
def log_tokenization_for_prompt_object( def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
): ):

View File

@ -94,13 +94,13 @@ CONTROLNET_DEFAULT_MODELS = [
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(default=None, description="processed image") image: ImageField = Field(default=None, description="The control image")
control_model: Optional[str] = Field(default=None, description="control model used") control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
control_weight: Optional[float] = Field(default=1, description="weight given to controlnet") control_weight: Optional[float] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1, begin_step_percent: float = Field(default=0, ge=0, le=1,
description="% of total steps at which controlnet is first applied") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="% of total steps at which controlnet is last applied") description="When the ControlNet is last applied (% of total steps)")
class Config: class Config:
schema_extra = { schema_extra = {
@ -112,7 +112,7 @@ class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info""" """node output for ControlNet info"""
# fmt: off # fmt: off
type: Literal["control_output"] = "control_output" type: Literal["control_output"] = "control_output"
control: ControlField = Field(default=None, description="The control info dict") control: ControlField = Field(default=None, description="The output control image")
# fmt: on # fmt: on
@ -121,15 +121,15 @@ class ControlNetInvocation(BaseInvocation):
# fmt: off # fmt: off
type: Literal["controlnet"] = "controlnet" type: Literal["controlnet"] = "controlnet"
# Inputs # Inputs
image: ImageField = Field(default=None, description="image to process") image: ImageField = Field(default=None, description="The control image")
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny", control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used") description="The ControlNet model to use")
control_weight: float = Field(default=1.0, ge=0, le=1, description="weight given to controlnet") control_weight: float = Field(default=1.0, ge=0, le=1, description="The weight given to the ControlNet")
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode # TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
begin_step_percent: float = Field(default=0, ge=0, le=1, begin_step_percent: float = Field(default=0, ge=0, le=1,
description="% of total steps at which controlnet is first applied") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="% of total steps at which controlnet is last applied") description="When the ControlNet is last applied (% of total steps)")
# fmt: on # fmt: on
@ -152,7 +152,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
# fmt: off # fmt: off
type: Literal["image_processor"] = "image_processor" type: Literal["image_processor"] = "image_processor"
# Inputs # Inputs
image: ImageField = Field(default=None, description="image to process") image: ImageField = Field(default=None, description="The image to process")
# fmt: on # fmt: on
@ -204,8 +204,8 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
# fmt: off # fmt: off
type: Literal["canny_image_processor"] = "canny_image_processor" type: Literal["canny_image_processor"] = "canny_image_processor"
# Input # Input
low_threshold: float = Field(default=100, ge=0, description="low threshold of Canny pixel gradient") low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)")
high_threshold: float = Field(default=200, ge=0, description="high threshold of Canny pixel gradient") high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)")
# fmt: on # fmt: on
def run_processor(self, image): def run_processor(self, image):
@ -214,16 +214,16 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
return processed_image return processed_image
class HedImageprocessorInvocation(ImageProcessorInvocation, PILInvocationConfig): class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies HED edge detection to image""" """Applies HED edge detection to image"""
# fmt: off # fmt: off
type: Literal["hed_image_processor"] = "hed_image_processor" type: Literal["hed_image_processor"] = "hed_image_processor"
# Inputs # Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# safe not supported in controlnet_aux v0.0.3 # safe not supported in controlnet_aux v0.0.3
# safe: bool = Field(default=False, description="whether to use safe mode") # safe: bool = Field(default=False, description="whether to use safe mode")
scribble: bool = Field(default=False, description="whether to use scribble mode") scribble: bool = Field(default=False, description="Whether to use scribble mode")
# fmt: on # fmt: on
def run_processor(self, image): def run_processor(self, image):
@ -243,9 +243,9 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCon
# fmt: off # fmt: off
type: Literal["lineart_image_processor"] = "lineart_image_processor" type: Literal["lineart_image_processor"] = "lineart_image_processor"
# Inputs # Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
coarse: bool = Field(default=False, description="whether to use coarse mode") coarse: bool = Field(default=False, description="Whether to use coarse mode")
# fmt: on # fmt: on
def run_processor(self, image): def run_processor(self, image):
@ -262,8 +262,8 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati
# fmt: off # fmt: off
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor" type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
# Inputs # Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on # fmt: on
def run_processor(self, image): def run_processor(self, image):
@ -280,9 +280,9 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
# fmt: off # fmt: off
type: Literal["openpose_image_processor"] = "openpose_image_processor" type: Literal["openpose_image_processor"] = "openpose_image_processor"
# Inputs # Inputs
hand_and_face: bool = Field(default=False, description="whether to use hands and face mode") hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode")
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on # fmt: on
def run_processor(self, image): def run_processor(self, image):
@ -300,8 +300,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation
# fmt: off # fmt: off
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor" type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
# Inputs # Inputs
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter a = amult * PI") a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter bg_th") bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`")
# depth_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode") # depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
# fmt: on # fmt: on
@ -322,8 +322,8 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationC
# fmt: off # fmt: off
type: Literal["normalbae_image_processor"] = "normalbae_image_processor" type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
# Inputs # Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on # fmt: on
def run_processor(self, image): def run_processor(self, image):
@ -339,10 +339,10 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
# fmt: off # fmt: off
type: Literal["mlsd_image_processor"] = "mlsd_image_processor" type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
# Inputs # Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter thr_v") thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`")
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter thr_d") thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`")
# fmt: on # fmt: on
def run_processor(self, image): def run_processor(self, image):
@ -360,10 +360,10 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
# fmt: off # fmt: off
type: Literal["pidi_image_processor"] = "pidi_image_processor" type: Literal["pidi_image_processor"] = "pidi_image_processor"
# Inputs # Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
safe: bool = Field(default=False, description="whether to use safe mode") safe: bool = Field(default=False, description="Whether to use safe mode")
scribble: bool = Field(default=False, description="whether to use scribble mode") scribble: bool = Field(default=False, description="Whether to use scribble mode")
# fmt: on # fmt: on
def run_processor(self, image): def run_processor(self, image):
@ -381,11 +381,11 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
# fmt: off # fmt: off
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor" type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
# Inputs # Inputs
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
h: Union[int | None] = Field(default=512, ge=0, description="content shuffle h parameter") h: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
w: Union[int | None] = Field(default=512, ge=0, description="content shuffle w parameter") w: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
f: Union[int | None] = Field(default=256, ge=0, description="cont") f: Union[int, None] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
# fmt: on # fmt: on
def run_processor(self, image): def run_processor(self, image):
@ -418,8 +418,8 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
# fmt: off # fmt: off
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor" type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
# Inputs # Inputs
max_faces: int = Field(default=1, ge=1, description="maximum number of faces to detect") max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect")
min_confidence: float = Field(default=0.5, ge=0, le=1, description="minimum confidence for face detection") min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
# fmt: on # fmt: on
def run_processor(self, image): def run_processor(self, image):

View File

@ -4,11 +4,12 @@ from contextlib import ExitStack
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
import einops import einops
from pydantic import BaseModel, Field, validator
import torch import torch
from diffusers import ControlNetModel from diffusers import ControlNetModel
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback

View File

@ -51,18 +51,32 @@ in INVOKEAI_ROOT. You can replace supersede this by providing any
OmegaConf dictionary object initialization time: OmegaConf dictionary object initialization time:
omegaconf = OmegaConf.load('/tmp/init.yaml') omegaconf = OmegaConf.load('/tmp/init.yaml')
conf = InvokeAIAppConfig(conf=omegaconf) conf = InvokeAIAppConfig()
conf.parse_args(conf=omegaconf)
By default, InvokeAIAppConfig will parse the contents of `sys.argv` at InvokeAIAppConfig.parse_args() will parse the contents of `sys.argv`
initialization time. You may pass a list of strings in the optional at initialization time. You may pass a list of strings in the optional
`argv` argument to use instead of the system argv: `argv` argument to use instead of the system argv:
conf = InvokeAIAppConfig(arg=['--xformers_enabled']) conf.parse_args(argv=['--xformers_enabled'])
It is also possible to set a value at initialization time. This value It is also possible to set a value at initialization time. However, if
has highest priority. you call parse_args() it may be overwritten.
conf = InvokeAIAppConfig(xformers_enabled=True) conf = InvokeAIAppConfig(xformers_enabled=True)
conf.parse_args(argv=['--no-xformers'])
conf.xformers_enabled
# False
To avoid this, use `get_config()` to retrieve the application-wide
configuration object. This will retain any properties set at object
creation time:
conf = InvokeAIAppConfig.get_config(xformers_enabled=True)
conf.parse_args(argv=['--no-xformers'])
conf.xformers_enabled
# True
Any setting can be overwritten by setting an environment variable of Any setting can be overwritten by setting an environment variable of
form: "INVOKEAI_<setting>", as in: form: "INVOKEAI_<setting>", as in:
@ -76,18 +90,23 @@ Order of precedence (from highest):
4) config file options 4) config file options
5) pydantic defaults 5) pydantic defaults
Typical usage: Typical usage at the top level file:
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.invocations.generate import TextToImageInvocation
# get global configuration and print its nsfw_checker value # get global configuration and print its nsfw_checker value
conf = InvokeAIAppConfig() conf = InvokeAIAppConfig.get_config()
conf.parse_args()
print(conf.nsfw_checker)
Typical usage in a backend module:
from invokeai.app.services.config import InvokeAIAppConfig
# get global configuration and print its nsfw_checker value
conf = InvokeAIAppConfig.get_config()
print(conf.nsfw_checker) print(conf.nsfw_checker)
# get the text2image invocation and print its step value
text2image = TextToImageInvocation()
print(text2image.steps)
Computed properties: Computed properties:
@ -103,10 +122,11 @@ a Path object:
lora_path - path to the LoRA directory lora_path - path to the LoRA directory
In most cases, you will want to create a single InvokeAIAppConfig In most cases, you will want to create a single InvokeAIAppConfig
object for the entire application. The get_invokeai_config() function object for the entire application. The InvokeAIAppConfig.get_config() function
does this: does this:
config = get_invokeai_config() config = InvokeAIAppConfig.get_config()
config.parse_args() # read values from the command line/config file
print(config.root) print(config.root)
# Subclassing # Subclassing
@ -140,24 +160,22 @@ two configs are kept in separate sections of the config file:
legacy_conf_dir: configs/stable-diffusion legacy_conf_dir: configs/stable-diffusion
outdir: outputs outdir: outputs
... ...
''' '''
from __future__ import annotations
import argparse import argparse
import pydoc import pydoc
import typing
import os import os
import sys import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig from omegaconf import OmegaConf, DictConfig
from pathlib import Path from pathlib import Path
from pydantic import BaseSettings, Field, parse_obj_as from pydantic import BaseSettings, Field, parse_obj_as
from typing import Any, ClassVar, Dict, List, Literal, Type, Union, get_origin, get_type_hints, get_args from typing import ClassVar, Dict, List, Literal, Type, Union, get_origin, get_type_hints, get_args
INIT_FILE = Path('invokeai.yaml') INIT_FILE = Path('invokeai.yaml')
LEGACY_INIT_FILE = Path('invokeai.init') LEGACY_INIT_FILE = Path('invokeai.init')
# This global stores a singleton InvokeAIAppConfig configuration object
global_config = None
class InvokeAISettings(BaseSettings): class InvokeAISettings(BaseSettings):
''' '''
Runtime configuration settings in which default values are Runtime configuration settings in which default values are
@ -168,7 +186,7 @@ class InvokeAISettings(BaseSettings):
def parse_args(self, argv: list=sys.argv[1:]): def parse_args(self, argv: list=sys.argv[1:]):
parser = self.get_parser() parser = self.get_parser()
opt, _ = parser.parse_known_args(argv) opt = parser.parse_args(argv)
for name in self.__fields__: for name in self.__fields__:
if name not in self._excluded(): if name not in self._excluded():
setattr(self, name, getattr(opt,name)) setattr(self, name, getattr(opt,name))
@ -330,6 +348,9 @@ the command-line client (recommended for experts only), or
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
setting environment variables INVOKEAI_<setting>. setting environment variables INVOKEAI_<setting>.
''' '''
singleton_config: ClassVar[InvokeAIAppConfig] = None
singleton_init: ClassVar[Dict] = None
#fmt: off #fmt: off
type: Literal["InvokeAI"] = "InvokeAI" type: Literal["InvokeAI"] = "InvokeAI"
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server') host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
@ -367,35 +388,51 @@ setting environment variables INVOKEAI_<setting>.
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models') model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models') embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="debug", description="Emit logging messages at this level or higher", category="Logging")
#fmt: on #fmt: on
def __init__(self, conf: DictConfig = None, argv: List[str]=None, **kwargs): def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
''' '''
Initialize InvokeAIAppconfig. Update settings with contents of init file, environment, and
command-line settings.
:param conf: alternate Omegaconf dictionary object :param conf: alternate Omegaconf dictionary object
:param argv: aternate sys.argv list :param argv: aternate sys.argv list
:param **kwargs: attributes to initialize with :param clobber: ovewrite any initialization parameters passed during initialization
''' '''
super().__init__(**kwargs)
# Set the runtime root directory. We parse command-line switches here # Set the runtime root directory. We parse command-line switches here
# in order to pick up the --root_dir option. # in order to pick up the --root_dir option.
self.parse_args(argv) super().parse_args(argv)
if conf is None: if conf is None:
try: try:
conf = OmegaConf.load(self.root_dir / INIT_FILE) conf = OmegaConf.load(self.root_dir / INIT_FILE)
except: except:
pass pass
InvokeAISettings.initconf = conf InvokeAISettings.initconf = conf
# parse args again in order to pick up settings in configuration file # parse args again in order to pick up settings in configuration file
self.parse_args(argv) super().parse_args(argv)
# restore initialization values if self.singleton_init and not clobber:
hints = get_type_hints(self) hints = get_type_hints(self.__class__)
for k in kwargs: for k in self.singleton_init:
setattr(self,k,parse_obj_as(hints[k],kwargs[k])) setattr(self,k,parse_obj_as(hints[k],self.singleton_init[k]))
@classmethod
def get_config(cls,**kwargs)->InvokeAIAppConfig:
'''
This returns a singleton InvokeAIAppConfig configuration object.
'''
if cls.singleton_config is None \
or type(cls.singleton_config)!=cls \
or (kwargs and cls.singleton_init != kwargs):
cls.singleton_config = cls(**kwargs)
cls.singleton_init = kwargs
return cls.singleton_config
@property @property
def root_path(self)->Path: def root_path(self)->Path:
''' '''
@ -520,11 +557,8 @@ class PagingArgumentParser(argparse.ArgumentParser):
text = self.format_help() text = self.format_help()
pydoc.pager(text) pydoc.pager(text)
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAIAppConfig: def get_invokeai_config(**kwargs)->InvokeAIAppConfig:
''' '''
This returns a singleton InvokeAIAppConfig configuration object. Legacy function which returns InvokeAIAppConfig.get_config()
''' '''
global global_config return InvokeAIAppConfig.get_config(**kwargs)
if global_config is None or type(global_config)!=cls:
global_config = cls(**kwargs)
return global_config

View File

@ -26,7 +26,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._table_name = table_name self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock() self._lock = Lock()
self._conn = sqlite3.connect( self._conn = sqlite3.connect(
self._filename, check_same_thread=False self._filename, check_same_thread=False
) # TODO: figure out a better threading solution ) # TODO: figure out a better threading solution

View File

@ -35,15 +35,19 @@ from transformers import (
CLIPTextModel, CLIPTextModel,
CLIPTokenizer, CLIPTokenizer,
) )
import invokeai.configs as configs import invokeai.configs as configs
from invokeai.app.services.config import (
get_invokeai_config,
InvokeAIAppConfig,
)
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
from invokeai.frontend.install.widgets import ( from invokeai.frontend.install.widgets import (
CenteredButtonPress, CenteredButtonPress,
IntTitleSlider, IntTitleSlider,
set_min_terminal_size, set_min_terminal_size,
) )
from invokeai.backend.config.legacy_arg_parsing import legacy_parser from invokeai.backend.config.legacy_arg_parsing import legacy_parser
from invokeai.backend.config.model_install_backend import ( from invokeai.backend.config.model_install_backend import (
default_dataset, default_dataset,
@ -51,10 +55,8 @@ from invokeai.backend.config.model_install_backend import (
hf_download_with_resume, hf_download_with_resume,
recommended_datasets, recommended_datasets,
) )
from invokeai.app.services.config import (
get_invokeai_config, from invokeai.app.services.config import InvokeAIAppConfig
InvokeAIAppConfig,
)
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -62,7 +64,8 @@ transformers.logging.set_verbosity_error()
# --------------------------globals----------------------- # --------------------------globals-----------------------
config = get_invokeai_config()
config = InvokeAIAppConfig.get_config()
Model_dir = "models" Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/" Weights_dir = "ldm/stable-diffusion-v1/"
@ -634,7 +637,7 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
def default_startup_options(init_file: Path) -> Namespace: def default_startup_options(init_file: Path) -> Namespace:
opts = InvokeAIAppConfig(argv=[]) opts = InvokeAIAppConfig.get_config()
outdir = Path(opts.outdir) outdir = Path(opts.outdir)
if not outdir.is_absolute(): if not outdir.is_absolute():
opts.outdir = str(config.root / opts.outdir) opts.outdir = str(config.root / opts.outdir)
@ -699,7 +702,7 @@ def write_opts(opts: Namespace, init_file: Path):
""" """
# this will load current settings # this will load current settings
config = InvokeAIAppConfig() config = InvokeAIAppConfig.get_config()
for key,value in opts.__dict__.items(): for key,value in opts.__dict__.items():
if hasattr(config,key): if hasattr(config,key):
setattr(config,key,value) setattr(config,key,value)
@ -731,7 +734,7 @@ def write_default_options(program_opts: Namespace, initfile: Path):
# yaml format. # yaml format.
def migrate_init_file(legacy_format:Path): def migrate_init_file(legacy_format:Path):
old = legacy_parser.parse_args([f'@{str(legacy_format)}']) old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
new = InvokeAIAppConfig(conf={}) new = InvokeAIAppConfig.get_config()
fields = list(get_type_hints(InvokeAIAppConfig).keys()) fields = list(get_type_hints(InvokeAIAppConfig).keys())
for attr in fields: for attr in fields:
@ -820,8 +823,9 @@ def main():
if old_init_file.exists() and not new_init_file.exists(): if old_init_file.exists() and not new_init_file.exists():
print('** Migrating invokeai.init to invokeai.yaml') print('** Migrating invokeai.init to invokeai.yaml')
migrate_init_file(old_init_file) migrate_init_file(old_init_file)
config = get_invokeai_config() # reread defaults
# Load new init file into config
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
if not config.model_conf_path.exists(): if not config.model_conf_path.exists():
initialize_rootdir(config.root, opt.yes_to_all) initialize_rootdir(config.root, opt.yes_to_all)

View File

@ -19,7 +19,7 @@ from tqdm import tqdm
import invokeai.configs as configs import invokeai.configs as configs
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
from ..model_management import ModelManager from ..model_management import ModelManager
from ..stable_diffusion import StableDiffusionGeneratorPipeline from ..stable_diffusion import StableDiffusionGeneratorPipeline
@ -27,7 +27,8 @@ from ..stable_diffusion import StableDiffusionGeneratorPipeline
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
# --------------------------globals----------------------- # --------------------------globals-----------------------
config = get_invokeai_config() config = InvokeAIAppConfig.get_config()
Model_dir = "models" Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/" Weights_dir = "ldm/stable-diffusion-v1/"

View File

@ -6,7 +6,8 @@ be suppressed or deferred
""" """
import numpy as np import numpy as np
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
class PatchMatch: class PatchMatch:
""" """
@ -21,7 +22,6 @@ class PatchMatch:
@classmethod @classmethod
def _load_patch_match(self): def _load_patch_match(self):
config = get_invokeai_config()
if self.tried_load: if self.tried_load:
return return
if config.try_patchmatch: if config.try_patchmatch:

View File

@ -33,10 +33,11 @@ from PIL import Image, ImageOps
from transformers import AutoProcessor, CLIPSegForImageSegmentation from transformers import AutoProcessor, CLIPSegForImageSegmentation
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined" CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
CLIPSEG_SIZE = 352 CLIPSEG_SIZE = 352
config = InvokeAIAppConfig.get_config()
class SegmentedGrayscale(object): class SegmentedGrayscale(object):
def __init__(self, image: Image, heatmap: torch.Tensor): def __init__(self, image: Image, heatmap: torch.Tensor):
@ -83,7 +84,6 @@ class Txt2Mask(object):
def __init__(self, device="cpu", refined=False): def __init__(self, device="cpu", refined=False):
logger.info("Initializing clipseg model for text to mask inference") logger.info("Initializing clipseg model for text to mask inference")
config = get_invokeai_config()
# BUG: we are not doing anything with the device option at this time # BUG: we are not doing anything with the device option at this time
self.device = device self.device = device

View File

@ -26,7 +26,7 @@ import torch
from safetensors.torch import load_file from safetensors.torch import load_file
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
from .model_manager import ModelManager, SDLegacyType from .model_manager import ModelManager, SDLegacyType
from .model_cache import ModelCache from .model_cache import ModelCache
@ -857,7 +857,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
def convert_ldm_clip_checkpoint(checkpoint): def convert_ldm_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained( text_model = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=get_invokeai_config().cache_dir "openai/clip-vit-large-patch14", cache_dir=InvokeAIAppConfig.get_config().cache_dir
) )
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
@ -912,7 +912,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
def convert_paint_by_example_checkpoint(checkpoint): def convert_paint_by_example_checkpoint(checkpoint):
cache_dir = get_invokeai_config().cache_dir cache_dir = InvokeAIAppConfig.get_config().cache_dir
config = CLIPVisionConfig.from_pretrained( config = CLIPVisionConfig.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=cache_dir "openai/clip-vit-large-patch14", cache_dir=cache_dir
) )
@ -984,7 +984,7 @@ def convert_paint_by_example_checkpoint(checkpoint):
def convert_open_clip_checkpoint(checkpoint): def convert_open_clip_checkpoint(checkpoint):
cache_dir = get_invokeai_config().cache_dir cache_dir = InvokeAIAppConfig.get_config().cache_dir
text_model = CLIPTextModel.from_pretrained( text_model = CLIPTextModel.from_pretrained(
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir "stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
) )
@ -1120,7 +1120,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
:param vae: A diffusers VAE to load into the pipeline. :param vae: A diffusers VAE to load into the pipeline.
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline. :param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
""" """
config = get_invokeai_config() config = InvokeAIAppConfig.get_config()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity() verbosity = dlogging.get_verbosity()

View File

@ -149,7 +149,7 @@ from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import download_with_resume from invokeai.backend.util import download_with_resume
from ..util import CUDA_DEVICE from ..util import CUDA_DEVICE
@ -224,7 +224,7 @@ class ModelManager(object):
raise ValueError('config argument must be an OmegaConf object, a Path or a string') raise ValueError('config argument must be an OmegaConf object, a Path or a string')
# check config version number and update on disk/RAM if necessary # check config version number and update on disk/RAM if necessary
self.globals = get_invokeai_config() self.globals = InvokeAIAppConfig.get_config()
self._update_config_file_version() self._update_config_file_version()
self.logger = logger self.logger = logger
self.cache = ModelCache( self.cache = ModelCache(
@ -1149,13 +1149,17 @@ class ModelManager(object):
"""\ """\
# This file describes the alternative machine learning models # This file describes the alternative machine learning models
# available to InvokeAI script. # available to InvokeAI script.
""" #
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
"""
) )
@classmethod @classmethod
def _delete_model_from_cache(cls,repo_id): def _delete_model_from_cache(cls,repo_id):
cache_info = scan_cache_dir(get_invokeai_config().cache_dir) cache_info = scan_cache_dir(InvokeAIAppConfig.get_config().cache_dir)
# I'm sure there is a way to do this with comprehensions # I'm sure there is a way to do this with comprehensions
# but the code quickly became incomprehensible! # but the code quickly became incomprehensible!
@ -1172,7 +1176,7 @@ class ModelManager(object):
@staticmethod @staticmethod
def _abs_path(path: str | Path) -> Path: def _abs_path(path: str | Path) -> Path:
globals = get_invokeai_config() globals = InvokeAIAppConfig.get_config()
if path is None or Path(path).is_absolute(): if path is None or Path(path).is_absolute():
return path return path
return Path(globals.root_dir, path).resolve() return Path(globals.root_dir, path).resolve()

View File

@ -22,10 +22,12 @@ from compel.prompt_parser import (
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
from ..stable_diffusion import InvokeAIDiffuserComponent from ..stable_diffusion import InvokeAIDiffuserComponent
from ..util import torch_dtype from ..util import torch_dtype
config = InvokeAIAppConfig.get_config()
def get_uc_and_c_and_ec(prompt_string, def get_uc_and_c_and_ec(prompt_string,
model: InvokeAIDiffuserComponent, model: InvokeAIDiffuserComponent,
log_tokens=False, skip_normalize_legacy_blend=False): log_tokens=False, skip_normalize_legacy_blend=False):
@ -38,9 +40,7 @@ def get_uc_and_c_and_ec(prompt_string,
textual_inversion_manager=model.textual_inversion_manager, textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype, dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False, truncate_long_prompts=False,
) )
config = get_invokeai_config()
# get rid of any newline characters # get rid of any newline characters
prompt_string = prompt_string.replace("\n", " ") prompt_string = prompt_string.replace("\n", " ")
@ -283,6 +283,8 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1)) (match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
for match in re.finditer(prompt_parser, text) for match in re.finditer(prompt_parser, text)
] ]
if len(parsed_prompts) == 0:
return []
if skip_normalize: if skip_normalize:
return parsed_prompts return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts)) weight_sum = sum(map(lambda x: x[1], parsed_prompts))

View File

@ -6,7 +6,7 @@ import numpy as np
import torch import torch
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
pretrained_model_url = ( pretrained_model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth" "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
@ -18,7 +18,7 @@ class CodeFormerRestoration:
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth" self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
) -> None: ) -> None:
self.globals = get_invokeai_config() self.globals = InvokeAIAppConfig.get_config()
codeformer_dir = self.globals.root_dir / codeformer_dir codeformer_dir = self.globals.root_dir / codeformer_dir
self.model_path = codeformer_dir / codeformer_model_path self.model_path = codeformer_dir / codeformer_model_path
self.codeformer_model_exists = self.model_path.exists() self.codeformer_model_exists = self.model_path.exists()

View File

@ -7,11 +7,11 @@ import torch
from PIL import Image from PIL import Image
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
class GFPGAN: class GFPGAN:
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None: def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
self.globals = get_invokeai_config() self.globals = InvokeAIAppConfig.get_config()
if not os.path.isabs(gfpgan_model_path): if not os.path.isabs(gfpgan_model_path):
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
self.model_path = gfpgan_model_path self.model_path = gfpgan_model_path

View File

@ -6,8 +6,8 @@ from PIL import Image
from PIL.Image import Image as ImageType from PIL.Image import Image as ImageType
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
config = get_invokeai_config() config = InvokeAIAppConfig.get_config()
class ESRGAN: class ESRGAN:
def __init__(self, bg_tile_size=400) -> None: def __init__(self, bg_tile_size=400) -> None:

View File

@ -15,9 +15,11 @@ from transformers import AutoFeatureExtractor
import invokeai.assets.web as web_assets import invokeai.assets.web as web_assets
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
from .util import CPU_DEVICE from .util import CPU_DEVICE
config = InvokeAIAppConfig.get_config()
class SafetyChecker(object): class SafetyChecker(object):
CAUTION_IMG = "caution.png" CAUTION_IMG = "caution.png"
@ -26,7 +28,6 @@ class SafetyChecker(object):
caution = Image.open(path) caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height // 2)) self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
self.device = device self.device = device
config = get_invokeai_config()
try: try:
safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_model_id = "CompVis/stable-diffusion-safety-checker"

View File

@ -17,15 +17,16 @@ from huggingface_hub import (
hf_hub_url, hf_hub_url,
) )
import invokeai.backend.util.logging as logger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
logger = InvokeAILogger.getLogger()
class HuggingFaceConceptsLibrary(object): class HuggingFaceConceptsLibrary(object):
def __init__(self, root=None): def __init__(self, root=None):
""" """
Initialize the Concepts object. May optionally pass a root directory. Initialize the Concepts object. May optionally pass a root directory.
""" """
self.config = get_invokeai_config() self.config = InvokeAIAppConfig.get_config()
self.root = root or self.config.root self.root = root or self.config.root
self.hf_api = HfApi() self.hf_api = HfApi()
self.local_concepts = dict() self.local_concepts = dict()

View File

@ -40,7 +40,7 @@ from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
from ..util import CPU_DEVICE, normalize_device from ..util import CPU_DEVICE, normalize_device
from .diffusion import ( from .diffusion import (
AttentionMapSaver, AttentionMapSaver,
@ -364,7 +364,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
""" """
if xformers is available, use it, otherwise use sliced attention. if xformers is available, use it, otherwise use sliced attention.
""" """
config = get_invokeai_config() config = InvokeAIAppConfig.get_config()
if ( if (
torch.cuda.is_available() torch.cuda.is_available()
and is_xformers_available() and is_xformers_available()

View File

@ -10,7 +10,7 @@ from diffusers.models.attention_processor import AttentionProcessor
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
from .cross_attention_control import ( from .cross_attention_control import (
Arguments, Arguments,
@ -72,7 +72,7 @@ class InvokeAIDiffuserComponent:
:param model: the unet model to pass through to cross attention control :param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
""" """
config = get_invokeai_config() config = InvokeAIAppConfig.get_config()
self.conditioning = None self.conditioning = None
self.model = model self.model = model
self.is_running_diffusers = is_running_diffusers self.is_running_diffusers = is_running_diffusers
@ -112,23 +112,25 @@ class InvokeAIDiffuserComponent:
# TODO resuscitate attention map saving # TODO resuscitate attention map saving
# self.remove_attention_map_saving() # self.remove_attention_map_saving()
def override_cross_attention( # apparently unused code
self, conditioning: ExtraConditioningInfo, step_count: int # TODO: delete
) -> Dict[str, AttentionProcessor]: # def override_cross_attention(
""" # self, conditioning: ExtraConditioningInfo, step_count: int
setup cross attention .swap control. for diffusers this replaces the attention processor, so # ) -> Dict[str, AttentionProcessor]:
the previous attention processor is returned so that the caller can restore it later. # """
""" # setup cross attention .swap control. for diffusers this replaces the attention processor, so
self.conditioning = conditioning # the previous attention processor is returned so that the caller can restore it later.
self.cross_attention_control_context = Context( # """
arguments=self.conditioning.cross_attention_control_args, # self.conditioning = conditioning
step_count=step_count, # self.cross_attention_control_context = Context(
) # arguments=self.conditioning.cross_attention_control_args,
return override_cross_attention( # step_count=step_count,
self.model, # )
self.cross_attention_control_context, # return override_cross_attention(
is_running_diffusers=self.is_running_diffusers, # self.model,
) # self.cross_attention_control_context,
# is_running_diffusers=self.is_running_diffusers,
# )
def restore_default_cross_attention( def restore_default_cross_attention(
self, restore_attention_processor: Optional["AttentionProcessor"] = None self, restore_attention_processor: Optional["AttentionProcessor"] = None

View File

@ -88,7 +88,7 @@ def save_progress(
def parse_args(): def parse_args():
config = InvokeAIAppConfig(argv=[]) config = InvokeAIAppConfig.get_config()
parser = PagingArgumentParser( parser = PagingArgumentParser(
description="Textual inversion training" description="Textual inversion training"
) )

View File

@ -17,3 +17,5 @@ from .util import (
instantiate_from_config, instantiate_from_config,
url_attachment_name, url_attachment_name,
) )

View File

@ -4,15 +4,15 @@ from contextlib import nullcontext
import torch import torch
from torch import autocast from torch import autocast
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda") CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps") MPS_DEVICE = torch.device("mps")
config = InvokeAIAppConfig.get_config()
def choose_torch_device() -> torch.device: def choose_torch_device() -> torch.device:
"""Convenience routine for guessing which GPU device to run model on""" """Convenience routine for guessing which GPU device to run model on"""
config = get_invokeai_config()
if config.always_use_cpu: if config.always_use_cpu:
return CPU_DEVICE return CPU_DEVICE
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -32,7 +32,6 @@ def choose_precision(device: torch.device) -> str:
def torch_dtype(device: torch.device) -> torch.dtype: def torch_dtype(device: torch.device) -> torch.dtype:
config = get_invokeai_config()
if config.full_precision: if config.full_precision:
return torch.float32 return torch.float32
if choose_precision(device) == "float16": if choose_precision(device) == "float16":

View File

@ -31,7 +31,20 @@ IAILogger.debug('this is a debugging message')
""" """
import logging import logging
import logging.handlers
import socket
import urllib.parse
from abc import abstractmethod
from pathlib import Path
from invokeai.app.services.config import InvokeAIAppConfig, get_invokeai_config
try:
import syslog
SYSLOG_AVAILABLE = True
except:
SYSLOG_AVAILABLE = False
# module level functions # module level functions
def debug(msg, *args, **kwargs): def debug(msg, *args, **kwargs):
@ -62,11 +75,77 @@ def getLogger(name: str = None) -> logging.Logger:
return InvokeAILogger.getLogger(name) return InvokeAILogger.getLogger(name)
class InvokeAILogFormatter(logging.Formatter): _FACILITY_MAP = dict(
LOG_KERN = syslog.LOG_KERN,
LOG_USER = syslog.LOG_USER,
LOG_MAIL = syslog.LOG_MAIL,
LOG_DAEMON = syslog.LOG_DAEMON,
LOG_AUTH = syslog.LOG_AUTH,
LOG_LPR = syslog.LOG_LPR,
LOG_NEWS = syslog.LOG_NEWS,
LOG_UUCP = syslog.LOG_UUCP,
LOG_CRON = syslog.LOG_CRON,
LOG_SYSLOG = syslog.LOG_SYSLOG,
LOG_LOCAL0 = syslog.LOG_LOCAL0,
LOG_LOCAL1 = syslog.LOG_LOCAL1,
LOG_LOCAL2 = syslog.LOG_LOCAL2,
LOG_LOCAL3 = syslog.LOG_LOCAL3,
LOG_LOCAL4 = syslog.LOG_LOCAL4,
LOG_LOCAL5 = syslog.LOG_LOCAL5,
LOG_LOCAL6 = syslog.LOG_LOCAL6,
LOG_LOCAL7 = syslog.LOG_LOCAL7,
) if SYSLOG_AVAILABLE else dict()
_SOCK_MAP = dict(
SOCK_STREAM = socket.SOCK_STREAM,
SOCK_DGRAM = socket.SOCK_DGRAM,
)
class InvokeAIFormatter(logging.Formatter):
'''
Base class for logging formatter
'''
def format(self, record):
formatter = logging.Formatter(self.log_fmt(record.levelno))
return formatter.format(record)
@abstractmethod
def log_fmt(self, levelno: int)->str:
pass
class InvokeAISyslogFormatter(InvokeAIFormatter):
'''
Formatting for syslog
'''
def log_fmt(self, levelno: int)->str:
return '%(name)s [%(process)d] <%(levelname)s> %(message)s'
class InvokeAILegacyLogFormatter(InvokeAIFormatter):
'''
Formatting for the InvokeAI Logger (legacy version)
'''
FORMATS = {
logging.DEBUG: " | %(message)s",
logging.INFO: ">> %(message)s",
logging.WARNING: "** %(message)s",
logging.ERROR: "*** %(message)s",
logging.CRITICAL: "### %(message)s",
}
def log_fmt(self,levelno:int)->str:
return self.FORMATS.get(levelno)
class InvokeAIPlainLogFormatter(InvokeAIFormatter):
'''
Custom Formatting for the InvokeAI Logger (plain version)
'''
def log_fmt(self, levelno: int)->str:
return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
class InvokeAIColorLogFormatter(InvokeAIFormatter):
''' '''
Custom Formatting for the InvokeAI Logger Custom Formatting for the InvokeAI Logger
''' '''
# Color Codes # Color Codes
grey = "\x1b[38;20m" grey = "\x1b[38;20m"
yellow = "\x1b[33;20m" yellow = "\x1b[33;20m"
@ -88,23 +167,109 @@ class InvokeAILogFormatter(logging.Formatter):
logging.CRITICAL: bold_red + log_format + reset logging.CRITICAL: bold_red + log_format + reset
} }
def format(self, record): def log_fmt(self, levelno: int)->str:
log_fmt = self.FORMATS.get(record.levelno) return self.FORMATS.get(levelno)
formatter = logging.Formatter(log_fmt, datefmt="%d-%m-%Y %H:%M:%S")
return formatter.format(record)
LOG_FORMATTERS = {
'plain': InvokeAIPlainLogFormatter,
'color': InvokeAIColorLogFormatter,
'syslog': InvokeAISyslogFormatter,
'legacy': InvokeAILegacyLogFormatter,
}
class InvokeAILogger(object): class InvokeAILogger(object):
loggers = dict() loggers = dict()
@classmethod @classmethod
def getLogger(cls, name: str = 'InvokeAI') -> logging.Logger: def getLogger(cls, name: str = 'InvokeAI') -> logging.Logger:
config = get_invokeai_config()
if name not in cls.loggers: if name not in cls.loggers:
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG) logger.setLevel(config.log_level.upper()) # yes, strings work here
ch = logging.StreamHandler() for ch in cls.getLoggers(config):
fmt = InvokeAILogFormatter() logger.addHandler(ch)
ch.setFormatter(fmt)
logger.addHandler(ch)
cls.loggers[name] = logger cls.loggers[name] = logger
return cls.loggers[name] return cls.loggers[name]
@classmethod
def getLoggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
handler_strs = config.log_handlers
print(f'handler_strs={handler_strs}')
handlers = list()
for handler in handler_strs:
handler_name,*args = handler.split('=',2)
args = args[0] if len(args) > 0 else None
# console is the only handler that gets a custom formatter
if handler_name=='console':
formatter = LOG_FORMATTERS[config.log_format]
ch = logging.StreamHandler()
ch.setFormatter(formatter())
handlers.append(ch)
elif handler_name=='syslog':
ch = cls._parse_syslog_args(args)
ch.setFormatter(InvokeAISyslogFormatter())
handlers.append(ch)
elif handler_name=='file':
handlers.append(cls._parse_file_args(args))
elif handler_name=='http':
handlers.append(cls._parse_http_args(args))
return handlers
@staticmethod
def _parse_syslog_args(
args: str=None
)-> logging.Handler:
if not SYSLOG_AVAILABLE:
raise ValueError("syslog is not available on this system")
if not args:
args='/dev/log' if Path('/dev/log').exists() else 'address:localhost:514'
syslog_args = dict()
try:
for a in args.split(','):
arg_name,*arg_value = a.split(':',2)
if arg_name=='address':
host,*port = arg_value
port = 514 if len(port)==0 else int(port[0])
syslog_args['address'] = (host,port)
elif arg_name=='facility':
syslog_args['facility'] = _FACILITY_MAP[arg_value[0]]
elif arg_name=='socktype':
syslog_args['socktype'] = _SOCK_MAP[arg_value[0]]
else:
syslog_args['address'] = arg_name
except:
raise ValueError(f"{args} is not a value argument list for syslog logging")
return logging.handlers.SysLogHandler(**syslog_args)
@staticmethod
def _parse_file_args(args: str=None)-> logging.Handler:
if not args:
raise ValueError("please provide filename for file logging using format 'file=/path/to/logfile.txt'")
return logging.FileHandler(args)
@staticmethod
def _parse_http_args(args: str=None)-> logging.Handler:
if not args:
raise ValueError("please provide destination for http logging using format 'http=url'")
arg_list = args.split(',')
url = urllib.parse.urlparse(arg_list.pop(0))
if url.scheme != 'http':
raise ValueError(f"the http logging module can only log to HTTP URLs, but {url.scheme} was specified")
host = url.hostname
path = url.path
port = url.port or 80
syslog_args = dict()
for a in arg_list:
arg_name, *arg_value = a.split(':',2)
if arg_name=='method':
arg_value = arg_value[0] if len(arg_value)>0 else 'GET'
syslog_args[arg_name] = arg_value
else: # TODO: Provide support for SSL context and credentials
pass
return logging.handlers.HTTPHandler(f'{host}:{port}',path,**syslog_args)

View File

@ -40,13 +40,13 @@ from .widgets import (
TextBox, TextBox,
set_min_terminal_size, set_min_terminal_size,
) )
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
# minimum size for the UI # minimum size for the UI
MIN_COLS = 120 MIN_COLS = 120
MIN_LINES = 45 MIN_LINES = 45
config = get_invokeai_config() config = InvokeAIAppConfig.get_config()
class addModelsForm(npyscreen.FormMultiPage): class addModelsForm(npyscreen.FormMultiPage):
# for responsive resizing - disabled # for responsive resizing - disabled

View File

@ -20,12 +20,12 @@ from npyscreen import widget
from omegaconf import OmegaConf from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.services.config import get_invokeai_config from invokeai.services.config import InvokeAIAppConfig
from ...backend.model_management import ModelManager from ...backend.model_management import ModelManager
from ...frontend.install.widgets import FloatTitleSlider from ...frontend.install.widgets import FloatTitleSlider
DEST_MERGED_MODEL_DIR = "merged_models" DEST_MERGED_MODEL_DIR = "merged_models"
config = get_invokeai_config() config = InvokeAIAppConfig.get_config()
def merge_diffusion_models( def merge_diffusion_models(
model_ids_or_paths: List[Union[str, Path]], model_ids_or_paths: List[Union[str, Path]],

View File

@ -22,7 +22,7 @@ from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
from ...backend.training import ( from ...backend.training import (
do_textual_inversion_training, do_textual_inversion_training,
parse_args parse_args
@ -423,7 +423,7 @@ def do_front_end(args: Namespace):
save_args(args) save_args(args)
try: try:
do_textual_inversion_training(get_invokeai_config(),**args) do_textual_inversion_training(InvokeAIAppConfig.get_config(),**args)
copy_to_embeddings_folder(args) copy_to_embeddings_folder(args)
except Exception as e: except Exception as e:
logger.error("An exception occurred during training. The exception was:") logger.error("An exception occurred during training. The exception was:")
@ -436,7 +436,7 @@ def main():
global config global config
args = parse_args() args = parse_args()
config = get_invokeai_config(argv=[]) config = InvokeAIAppConfig.get_config()
# change root if needed # change root if needed
if args.root_dir: if args.root_dir:

View File

@ -26,10 +26,10 @@ We need to start the nodes web server, which serves the OpenAPI schema to the ge
```bash ```bash
# from the repo root # from the repo root
python scripts/invoke-new.py --web python scripts/invokeai-web.py
``` ```
2. Generate the API client. 2. Generate the API client.
```bash ```bash
# from invokeai/frontend/web/ # from invokeai/frontend/web/

View File

@ -12,7 +12,14 @@ Code in `invokeai/frontend/web/` if you want to have a look.
## Stack ## Stack
State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a custom redux middleware to help). State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). We lean heavily on RTK:
- `createAsyncThunk` for HTTP requests
- `createEntityAdapter` for fetching images and models
- `createListenerMiddleware` for workflows
The API client and associated types are generated from the OpenAPI schema. See API_CLIENT.md.
Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a simple socket.io redux middleware to help).
[Chakra-UI](https://github.com/chakra-ui/chakra-ui) for components and styling. [Chakra-UI](https://github.com/chakra-ui/chakra-ui) for components and styling.
@ -37,9 +44,15 @@ From `invokeai/frontend/web/` run `yarn install` to get everything set up.
Start everything in dev mode: Start everything in dev mode:
1. Start the dev server: `yarn dev` 1. Start the dev server: `yarn dev`
2. Start the InvokeAI Nodes backend: `python scripts/invokeai-new.py --web # run from the repo root` 2. Start the InvokeAI Nodes backend: `python scripts/invokeai-web.py # run from the repo root`
3. Point your browser to the dev server address e.g. <http://localhost:5173/> 3. Point your browser to the dev server address e.g. <http://localhost:5173/>
#### VSCode Remote Dev
We've noticed an intermittent issue with the VSCode Remote Dev port forwarding. If you use this feature of VSCode, you may intermittently click the Invoke button and then get nothing until the request times out. Suggest disabling the IDE's port forwarding feature and doing it manually via SSH:
`ssh -L 9090:localhost:9090 -L 5173:localhost:5173 user@host`
### Production builds ### Production builds
For a number of technical and logistical reasons, we need to commit UI build artefacts to the repo. For a number of technical and logistical reasons, we need to commit UI build artefacts to the repo.

View File

@ -60,6 +60,8 @@
"@chakra-ui/styled-system": "^2.9.0", "@chakra-ui/styled-system": "^2.9.0",
"@chakra-ui/theme-tools": "^2.0.16", "@chakra-ui/theme-tools": "^2.0.16",
"@dagrejs/graphlib": "^2.1.12", "@dagrejs/graphlib": "^2.1.12",
"@dnd-kit/core": "^6.0.8",
"@dnd-kit/modifiers": "^6.0.1",
"@emotion/react": "^11.10.6", "@emotion/react": "^11.10.6",
"@emotion/styled": "^11.10.6", "@emotion/styled": "^11.10.6",
"@floating-ui/react-dom": "^2.0.0", "@floating-ui/react-dom": "^2.0.0",
@ -87,7 +89,7 @@
"react-dropzone": "^14.2.3", "react-dropzone": "^14.2.3",
"react-hotkeys-hook": "4.4.0", "react-hotkeys-hook": "4.4.0",
"react-i18next": "^12.2.2", "react-i18next": "^12.2.2",
"react-icons": "^4.7.1", "react-icons": "^4.9.0",
"react-konva": "^18.2.7", "react-konva": "^18.2.7",
"react-redux": "^8.0.5", "react-redux": "^8.0.5",
"react-resizable-panels": "^0.0.42", "react-resizable-panels": "^0.0.42",

View File

@ -21,6 +21,7 @@ import { ReactNode, memo, useCallback, useEffect, useState } from 'react';
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants'; import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import GlobalHotkeys from './GlobalHotkeys'; import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster'; import Toaster from './Toaster';
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
const DEFAULT_CONFIG = {}; const DEFAULT_CONFIG = {};
@ -76,18 +77,21 @@ const App = ({
{isLightboxEnabled && <Lightbox />} {isLightboxEnabled && <Lightbox />}
<ImageUploader> <ImageUploader>
<Grid <Grid
gap={4} sx={{
p={4} gap: 4,
gridAutoRows="min-content auto" p: 4,
w={APP_WIDTH} gridAutoRows: 'min-content auto',
h={APP_HEIGHT} w: 'full',
h: 'full',
}}
> >
{headerComponent || <SiteHeader />} {headerComponent || <SiteHeader />}
<Flex <Flex
gap={4} sx={{
w={{ base: '100vw', xl: 'full' }} gap: 4,
h="full" w: 'full',
flexDir={{ base: 'column', xl: 'row' }} h: 'full',
}}
> >
<InvokeTabs /> <InvokeTabs />
</Flex> </Flex>
@ -130,6 +134,7 @@ const App = ({
<FloatingGalleryButton /> <FloatingGalleryButton />
</Portal> </Portal>
</Grid> </Grid>
<DeleteImageModal />
<Toaster /> <Toaster />
<GlobalHotkeys /> <GlobalHotkeys />
</> </>

View File

@ -0,0 +1,71 @@
import {
DndContext,
DragEndEvent,
DragOverlay,
DragStartEvent,
KeyboardSensor,
MouseSensor,
TouchSensor,
pointerWithin,
useSensor,
useSensors,
} from '@dnd-kit/core';
import { PropsWithChildren, memo, useCallback, useState } from 'react';
import OverlayDragImage from './OverlayDragImage';
import { ImageDTO } from 'services/api';
import { isImageDTO } from 'services/types/guards';
import { snapCenterToCursor } from '@dnd-kit/modifiers';
type ImageDndContextProps = PropsWithChildren;
const ImageDndContext = (props: ImageDndContextProps) => {
const [draggedImage, setDraggedImage] = useState<ImageDTO | null>(null);
const handleDragStart = useCallback((event: DragStartEvent) => {
const dragData = event.active.data.current;
if (dragData && 'image' in dragData && isImageDTO(dragData.image)) {
setDraggedImage(dragData.image);
}
}, []);
const handleDragEnd = useCallback(
(event: DragEndEvent) => {
const handleDrop = event.over?.data.current?.handleDrop;
if (handleDrop && typeof handleDrop === 'function' && draggedImage) {
handleDrop(draggedImage);
}
setDraggedImage(null);
},
[draggedImage]
);
const mouseSensor = useSensor(MouseSensor, {
activationConstraint: { distance: 15 },
});
const touchSensor = useSensor(TouchSensor, {
activationConstraint: { distance: 15 },
});
// TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos
// Alternatively, fix `rectIntersection` collection detection to work with the drag overlay
// (currently the drag element collision rect is not correctly calculated)
// const keyboardSensor = useSensor(KeyboardSensor);
const sensors = useSensors(mouseSensor, touchSensor);
return (
<DndContext
onDragStart={handleDragStart}
onDragEnd={handleDragEnd}
sensors={sensors}
collisionDetection={pointerWithin}
>
{props.children}
<DragOverlay dropAnimation={null} modifiers={[snapCenterToCursor]}>
{draggedImage && <OverlayDragImage image={draggedImage} />}
</DragOverlay>
</DndContext>
);
};
export default memo(ImageDndContext);

View File

@ -0,0 +1,36 @@
import { Box, Image } from '@chakra-ui/react';
import { memo } from 'react';
import { ImageDTO } from 'services/api';
type OverlayDragImageProps = {
image: ImageDTO;
};
const OverlayDragImage = (props: OverlayDragImageProps) => {
return (
<Box
style={{
width: '100%',
height: '100%',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
userSelect: 'none',
cursor: 'grabbing',
opacity: 0.5,
}}
>
<Image
sx={{
maxW: 36,
maxH: 36,
borderRadius: 'base',
shadow: 'dark-lg',
}}
src={props.image.thumbnail_url}
/>
</Box>
);
};
export default memo(OverlayDragImage);

View File

@ -16,6 +16,11 @@ import { PartialAppConfig } from 'app/types/invokeai';
import '../../i18n'; import '../../i18n';
import { socketMiddleware } from 'services/events/middleware'; import { socketMiddleware } from 'services/events/middleware';
import { Middleware } from '@reduxjs/toolkit'; import { Middleware } from '@reduxjs/toolkit';
import ImageDndContext from './ImageDnd/ImageDndContext';
import {
DeleteImageContext,
DeleteImageContextProvider,
} from 'app/contexts/DeleteImageContext';
const App = lazy(() => import('./App')); const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider')); const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
@ -69,11 +74,15 @@ const InvokeAIUI = ({
<Provider store={store}> <Provider store={store}>
<React.Suspense fallback={<Loading />}> <React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider> <ThemeLocaleProvider>
<App <ImageDndContext>
config={config} <DeleteImageContextProvider>
headerComponent={headerComponent} <App
setIsReady={setIsReady} config={config}
/> headerComponent={headerComponent}
setIsReady={setIsReady}
/>
</DeleteImageContextProvider>
</ImageDndContext>
</ThemeLocaleProvider> </ThemeLocaleProvider>
</React.Suspense> </React.Suspense>
</Provider> </Provider>

View File

@ -0,0 +1,203 @@
import { useDisclosure } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { requestedImageDeletion } from 'features/gallery/store/actions';
import { systemSelector } from 'features/system/store/systemSelectors';
import {
PropsWithChildren,
createContext,
useCallback,
useEffect,
useState,
} from 'react';
import { ImageDTO } from 'services/api';
import { RootState } from 'app/store/store';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
import { nodesSelecter } from 'features/nodes/store/nodesSlice';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { some } from 'lodash-es';
export type ImageUsage = {
isInitialImage: boolean;
isCanvasImage: boolean;
isNodesImage: boolean;
isControlNetImage: boolean;
};
export const selectImageUsage = createSelector(
[
generationSelector,
canvasSelector,
nodesSelecter,
controlNetSelector,
(state: RootState, image_name?: string) => image_name,
],
(generation, canvas, nodes, controlNet, image_name) => {
const isInitialImage = generation.initialImage?.image_name === image_name;
const isCanvasImage = canvas.layerState.objects.some(
(obj) => obj.kind === 'image' && obj.image.image_name === image_name
);
const isNodesImage = nodes.nodes.some((node) => {
return some(
node.data.inputs,
(input) =>
input.type === 'image' && input.value?.image_name === image_name
);
});
const isControlNetImage = some(
controlNet.controlNets,
(c) =>
c.controlImage?.image_name === image_name ||
c.processedControlImage?.image_name === image_name
);
const imageUsage: ImageUsage = {
isInitialImage,
isCanvasImage,
isNodesImage,
isControlNetImage,
};
return imageUsage;
},
defaultSelectorOptions
);
type DeleteImageContextValue = {
/**
* Whether the delete image dialog is open.
*/
isOpen: boolean;
/**
* Closes the delete image dialog.
*/
onClose: () => void;
/**
* Opens the delete image dialog and handles all deletion-related checks.
*/
onDelete: (image?: ImageDTO) => void;
/**
* The image pending deletion
*/
image?: ImageDTO;
/**
* The features in which this image is used
*/
imageUsage?: ImageUsage;
/**
* Immediately deletes an image.
*
* You probably don't want to use this - use `onDelete` instead.
*/
onImmediatelyDelete: () => void;
};
export const DeleteImageContext = createContext<DeleteImageContextValue>({
isOpen: false,
onClose: () => undefined,
onImmediatelyDelete: () => undefined,
onDelete: () => undefined,
});
const selector = createSelector(
[systemSelector],
(system) => {
const { isProcessing, isConnected, shouldConfirmOnDelete } = system;
return {
canDeleteImage: isConnected && !isProcessing,
shouldConfirmOnDelete,
};
},
defaultSelectorOptions
);
type Props = PropsWithChildren;
export const DeleteImageContextProvider = (props: Props) => {
const { canDeleteImage, shouldConfirmOnDelete } = useAppSelector(selector);
const [imageToDelete, setImageToDelete] = useState<ImageDTO>();
const dispatch = useAppDispatch();
const { isOpen, onOpen, onClose } = useDisclosure();
// Check where the image to be deleted is used (eg init image, controlnet, etc.)
const imageUsage = useAppSelector((state) =>
selectImageUsage(state, imageToDelete?.image_name)
);
// Clean up after deleting or dismissing the modal
const closeAndClearImageToDelete = useCallback(() => {
setImageToDelete(undefined);
onClose();
}, [onClose]);
// Dispatch the actual deletion action, to be handled by listener middleware
const handleActualDeletion = useCallback(
(image: ImageDTO) => {
dispatch(requestedImageDeletion({ image, imageUsage }));
closeAndClearImageToDelete();
},
[closeAndClearImageToDelete, dispatch, imageUsage]
);
// This is intended to be called by the delete button in the dialog
const onImmediatelyDelete = useCallback(() => {
if (canDeleteImage && imageToDelete) {
handleActualDeletion(imageToDelete);
}
closeAndClearImageToDelete();
}, [
canDeleteImage,
imageToDelete,
closeAndClearImageToDelete,
handleActualDeletion,
]);
const handleGatedDeletion = useCallback(
(image: ImageDTO) => {
if (shouldConfirmOnDelete || some(imageUsage)) {
// If we should confirm on delete, or if the image is in use, open the dialog
onOpen();
} else {
handleActualDeletion(image);
}
},
[imageUsage, shouldConfirmOnDelete, onOpen, handleActualDeletion]
);
// Consumers of the context call this to delete an image
const onDelete = useCallback((image?: ImageDTO) => {
if (!image) {
return;
}
// Set the image to delete, then let the effect call the actual deletion
setImageToDelete(image);
}, []);
useEffect(() => {
// We need to use an effect here to trigger the image usage selector, else we get a stale value
if (imageToDelete) {
handleGatedDeletion(imageToDelete);
}
}, [handleGatedDeletion, imageToDelete]);
return (
<DeleteImageContext.Provider
value={{
isOpen,
image: imageToDelete,
onClose: closeAndClearImageToDelete,
onDelete,
onImmediatelyDelete,
imageUsage,
}}
>
{props.children}
</DeleteImageContext.Provider>
);
};

View File

@ -1,4 +1,5 @@
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist'; import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
import { controlNetDenylist } from 'features/controlNet/store/controlNetDenylist';
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist'; import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist'; import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist'; import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
@ -23,6 +24,7 @@ const serializationDenylist: {
system: systemPersistDenylist, system: systemPersistDenylist,
// config: configPersistDenyList, // config: configPersistDenyList,
ui: uiPersistDenylist, ui: uiPersistDenylist,
controlNet: controlNetDenylist,
// hotkeys: hotkeysPersistDenylist, // hotkeys: hotkeysPersistDenylist,
}; };

View File

@ -1,4 +1,5 @@
import { initialCanvasState } from 'features/canvas/store/canvasSlice'; import { initialCanvasState } from 'features/canvas/store/canvasSlice';
import { initialControlNetState } from 'features/controlNet/store/controlNetSlice';
import { initialGalleryState } from 'features/gallery/store/gallerySlice'; import { initialGalleryState } from 'features/gallery/store/gallerySlice';
import { initialImagesState } from 'features/gallery/store/imagesSlice'; import { initialImagesState } from 'features/gallery/store/imagesSlice';
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice'; import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
@ -28,6 +29,7 @@ const initialStates: {
ui: initialUIState, ui: initialUIState,
hotkeys: initialHotkeysState, hotkeys: initialHotkeysState,
images: initialImagesState, images: initialImagesState,
controlNet: initialControlNetState,
}; };
export const unserialize: UnserializeFunction = (data, key) => { export const unserialize: UnserializeFunction = (data, key) => {

View File

@ -70,6 +70,9 @@ import {
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved'; import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener'; import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged'; import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged';
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
import { addUpdateImageUrlsOnConnectListener } from './listeners/updateImageUrlsOnConnect';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -173,3 +176,10 @@ addReceivedPageOfImagesRejectedListener();
// Gallery // Gallery
addImageCategoriesChangedListener(); addImageCategoriesChangedListener();
// ControlNet
addControlNetImageProcessedListener();
addControlNetAutoProcessListener();
// Update image URLs on connect
addUpdateImageUrlsOnConnectListener();

View File

@ -28,6 +28,13 @@ export const addCanvasCopiedToClipboardListener = () => {
} }
copyBlobToClipboard(blob); copyBlobToClipboard(blob);
dispatch(
addToast({
title: 'Canvas Copied to Clipboard',
status: 'success',
})
);
}, },
}); });
}; };

View File

@ -27,7 +27,8 @@ export const addCanvasDownloadedAsImageListener = () => {
return; return;
} }
downloadBlob(blob, 'mergedCanvas.png'); downloadBlob(blob, 'canvas.png');
dispatch(addToast({ title: 'Canvas Downloaded', status: 'success' }));
}, },
}); });
}; };

View File

@ -1,22 +1,20 @@
import { canvasMerged } from 'features/canvas/store/actions'; import { canvasMerged } from 'features/canvas/store/actions';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { imageUploaded } from 'services/thunks/image'; import { imageUploaded } from 'services/thunks/image';
import { v4 as uuidv4 } from 'uuid';
import { setMergedCanvas } from 'features/canvas/store/canvasSlice'; import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider'; import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
import { getFullBaseLayerBlob } from 'features/canvas/util/getFullBaseLayerBlob';
const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' }); const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' });
export const MERGED_CANVAS_FILENAME = 'mergedCanvas.png';
export const addCanvasMergedListener = () => { export const addCanvasMergedListener = () => {
startAppListening({ startAppListening({
actionCreator: canvasMerged, actionCreator: canvasMerged,
effect: async (action, { dispatch, getState, take }) => { effect: async (action, { dispatch, getState, take }) => {
const state = getState(); const blob = await getFullBaseLayerBlob();
const blob = await getBaseLayerBlob(state, true);
if (!blob) { if (!blob) {
moduleLog.error('Problem getting base layer blob'); moduleLog.error('Problem getting base layer blob');
@ -48,12 +46,12 @@ export const addCanvasMergedListener = () => {
relativeTo: canvasBaseLayer.getParent(), relativeTo: canvasBaseLayer.getParent(),
}); });
const filename = `mergedCanvas_${uuidv4()}.png`; const imageUploadedRequest = dispatch(
dispatch(
imageUploaded({ imageUploaded({
formData: { formData: {
file: new File([blob], filename, { type: 'image/png' }), file: new File([blob], MERGED_CANVAS_FILENAME, {
type: 'image/png',
}),
}, },
imageCategory: 'general', imageCategory: 'general',
isIntermediate: true, isIntermediate: true,
@ -61,9 +59,11 @@ export const addCanvasMergedListener = () => {
); );
const [{ payload }] = await take( const [{ payload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> => (
imageUploaded.fulfilled.match(action) && uploadedImageAction
action.meta.arg.formData.file.name === filename ): uploadedImageAction is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(uploadedImageAction) &&
uploadedImageAction.meta.requestId === imageUploadedRequest.requestId
); );
const mergedCanvasImage = payload; const mergedCanvasImage = payload;

View File

@ -4,9 +4,10 @@ import { log } from 'app/logging/useLogger';
import { imageUploaded } from 'services/thunks/image'; import { imageUploaded } from 'services/thunks/image';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { v4 as uuidv4 } from 'uuid';
import { imageUpserted } from 'features/gallery/store/imagesSlice'; import { imageUpserted } from 'features/gallery/store/imagesSlice';
export const SAVED_CANVAS_FILENAME = 'savedCanvas.png';
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' }); const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
export const addCanvasSavedToGalleryListener = () => { export const addCanvasSavedToGalleryListener = () => {
@ -15,7 +16,7 @@ export const addCanvasSavedToGalleryListener = () => {
effect: async (action, { dispatch, getState, take }) => { effect: async (action, { dispatch, getState, take }) => {
const state = getState(); const state = getState();
const blob = await getBaseLayerBlob(state, true); const blob = await getBaseLayerBlob(state);
if (!blob) { if (!blob) {
moduleLog.error('Problem getting base layer blob'); moduleLog.error('Problem getting base layer blob');
@ -29,12 +30,12 @@ export const addCanvasSavedToGalleryListener = () => {
return; return;
} }
const filename = `mergedCanvas_${uuidv4()}.png`; const imageUploadedRequest = dispatch(
dispatch(
imageUploaded({ imageUploaded({
formData: { formData: {
file: new File([blob], filename, { type: 'image/png' }), file: new File([blob], SAVED_CANVAS_FILENAME, {
type: 'image/png',
}),
}, },
imageCategory: 'general', imageCategory: 'general',
isIntermediate: false, isIntermediate: false,
@ -42,9 +43,11 @@ export const addCanvasSavedToGalleryListener = () => {
); );
const [{ payload: uploadedImageDTO }] = await take( const [{ payload: uploadedImageDTO }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> => (
imageUploaded.fulfilled.match(action) && uploadedImageAction
action.meta.arg.formData.file.name === filename ): uploadedImageAction is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(uploadedImageAction) &&
uploadedImageAction.meta.requestId === imageUploadedRequest.requestId
); );
dispatch(imageUpserted(uploadedImageDTO)); dispatch(imageUpserted(uploadedImageDTO));

View File

@ -0,0 +1,59 @@
import { AnyAction } from '@reduxjs/toolkit';
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
import {
controlNetImageChanged,
controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged,
} from 'features/controlNet/store/controlNetSlice';
import { RootState } from 'app/store/store';
const moduleLog = log.child({ namespace: 'controlNet' });
const predicate = (action: AnyAction, state: RootState) => {
const isActionMatched =
controlNetProcessorParamsChanged.match(action) ||
controlNetImageChanged.match(action) ||
controlNetProcessorTypeChanged.match(action);
if (!isActionMatched) {
return false;
}
const { controlImage, processorType } =
state.controlNet.controlNets[action.payload.controlNetId];
const isProcessorSelected = processorType !== 'none';
const isBusy = state.system.isProcessing;
const hasControlImage = Boolean(controlImage);
return isProcessorSelected && !isBusy && hasControlImage;
};
/**
* Listener that automatically processes a ControlNet image when its processor parameters are changed.
*
* The network request is debounced by 1 second.
*/
export const addControlNetAutoProcessListener = () => {
startAppListening({
predicate,
effect: async (
action,
{ dispatch, getState, cancelActiveListeners, delay }
) => {
const { controlNetId } = action.payload;
// Cancel any in-progress instances of this listener
cancelActiveListeners();
// Delay before starting actual work
await delay(300);
dispatch(controlNetImageProcessed({ controlNetId }));
},
});
};

View File

@ -0,0 +1,93 @@
import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/thunks/image';
import { log } from 'app/logging/useLogger';
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
import { Graph } from 'services/api';
import { sessionCreated } from 'services/thunks/session';
import { sessionReadyToInvoke } from 'features/system/store/actions';
import { socketInvocationComplete } from 'services/events/actions';
import { isImageOutput } from 'services/types/guards';
import { controlNetProcessedImageChanged } from 'features/controlNet/store/controlNetSlice';
import { pick } from 'lodash-es';
const moduleLog = log.child({ namespace: 'controlNet' });
export const addControlNetImageProcessedListener = () => {
startAppListening({
actionCreator: controlNetImageProcessed,
effect: async (
action,
{ dispatch, getState, take, unsubscribe, subscribe }
) => {
const { controlNetId } = action.payload;
const controlNet = getState().controlNet.controlNets[controlNetId];
if (!controlNet.controlImage) {
moduleLog.error('Unable to process ControlNet image');
return;
}
// ControlNet one-off procressing graph is just the processor node, no edges.
// Also we need to grab the image.
const graph: Graph = {
nodes: {
[controlNet.processorNode.id]: {
...controlNet.processorNode,
is_intermediate: true,
image: pick(controlNet.controlImage, [
'image_name',
'image_origin',
]),
},
},
};
// Create a session to run the graph & wait til it's ready to invoke
const sessionCreatedAction = dispatch(sessionCreated({ graph }));
const [sessionCreatedFulfilledAction] = await take(
(action): action is ReturnType<typeof sessionCreated.fulfilled> =>
sessionCreated.fulfilled.match(action) &&
action.meta.requestId === sessionCreatedAction.requestId
);
const sessionId = sessionCreatedFulfilledAction.payload.id;
// Invoke the session & wait til it's complete
dispatch(sessionReadyToInvoke());
const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) &&
action.payload.data.graph_execution_state_id === sessionId
);
// We still have to check the output type
if (isImageOutput(invocationCompleteAction.payload.data.result)) {
const { image_name } =
invocationCompleteAction.payload.data.result.image;
// Wait for the ImageDTO to be received
const [imageMetadataReceivedAction] = await take(
(
action
): action is ReturnType<typeof imageMetadataReceived.fulfilled> =>
imageMetadataReceived.fulfilled.match(action) &&
action.payload.image_name === image_name
);
const processedControlImage = imageMetadataReceivedAction.payload;
moduleLog.debug(
{ data: { arg: action.payload, processedControlImage } },
'ControlNet image processed'
);
// Update the processed image in the store
dispatch(
controlNetProcessedImageChanged({
controlNetId,
processedControlImage,
})
);
}
},
});
};

View File

@ -6,10 +6,13 @@ import { clamp } from 'lodash-es';
import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageSelected } from 'features/gallery/store/gallerySlice';
import { import {
imageRemoved, imageRemoved,
imagesAdapter,
selectImagesEntities, selectImagesEntities,
selectImagesIds, selectImagesIds,
} from 'features/gallery/store/imagesSlice'; } from 'features/gallery/store/imagesSlice';
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' }); const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
@ -20,11 +23,7 @@ export const addRequestedImageDeletionListener = () => {
startAppListening({ startAppListening({
actionCreator: requestedImageDeletion, actionCreator: requestedImageDeletion,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
const image = action.payload; const { image, imageUsage } = action.payload;
if (!image) {
moduleLog.warn('No image provided');
return;
}
const { image_name, image_origin } = image; const { image_name, image_origin } = image;
@ -58,8 +57,28 @@ export const addRequestedImageDeletionListener = () => {
} }
} }
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
if (imageUsage.isCanvasImage) {
dispatch(resetCanvas());
}
if (imageUsage.isControlNetImage) {
dispatch(controlNetReset());
}
if (imageUsage.isInitialImage) {
dispatch(clearInitialImage());
}
if (imageUsage.isNodesImage) {
dispatch(nodeEditorReset());
}
// Preemptively remove from gallery
dispatch(imageRemoved(image_name)); dispatch(imageRemoved(image_name));
// Delete from server
dispatch( dispatch(
imageDeleted({ imageName: image_name, imageOrigin: image_origin }) imageDeleted({ imageName: image_name, imageOrigin: image_origin })
); );
@ -74,9 +93,7 @@ export const addImageDeletedPendingListener = () => {
startAppListening({ startAppListening({
actionCreator: imageDeleted.pending, actionCreator: imageDeleted.pending,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
const { imageName, imageOrigin } = action.meta.arg; //
// Preemptively remove the image from the gallery
imagesAdapter.removeOne(getState().images, imageName);
}, },
}); });
}; };

View File

@ -1,6 +1,6 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/thunks/image'; import { imageMetadataReceived, imageUpdated } from 'services/thunks/image';
import { imageUpserted } from 'features/gallery/store/imagesSlice'; import { imageUpserted } from 'features/gallery/store/imagesSlice';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });
@ -10,10 +10,29 @@ export const addImageMetadataReceivedFulfilledListener = () => {
actionCreator: imageMetadataReceived.fulfilled, actionCreator: imageMetadataReceived.fulfilled,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
const image = action.payload; const image = action.payload;
if (image.is_intermediate) {
const state = getState();
if (
image.session_id === state.canvas.layerState.stagingArea.sessionId &&
state.canvas.shouldAutoSave
) {
dispatch(
imageUpdated({
imageName: image.image_name,
imageOrigin: image.image_origin,
requestBody: { is_intermediate: false },
})
);
} else if (image.is_intermediate) {
// No further actions needed for intermediate images // No further actions needed for intermediate images
moduleLog.trace(
{ data: { image } },
'Image metadata received (intermediate), skipping'
);
return; return;
} }
moduleLog.debug({ data: { image } }, 'Image metadata received'); moduleLog.debug({ data: { image } }, 'Image metadata received');
dispatch(imageUpserted(image)); dispatch(imageUpserted(image));
}, },

View File

@ -3,6 +3,8 @@ import { imageUploaded } from 'services/thunks/image';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { imageUpserted } from 'features/gallery/store/imagesSlice'; import { imageUpserted } from 'features/gallery/store/imagesSlice';
import { SAVED_CANVAS_FILENAME } from './canvasSavedToGallery';
import { MERGED_CANVAS_FILENAME } from './canvasMerged';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });
@ -19,9 +21,22 @@ export const addImageUploadedFulfilledListener = () => {
return; return;
} }
const state = getState(); const originalFileName = action.meta.arg.formData.file.name;
dispatch(imageUpserted(image)); dispatch(imageUpserted(image));
if (originalFileName === SAVED_CANVAS_FILENAME) {
dispatch(
addToast({ title: 'Canvas Saved to Gallery', status: 'success' })
);
return;
}
if (originalFileName === MERGED_CANVAS_FILENAME) {
dispatch(addToast({ title: 'Canvas Merged', status: 'success' }));
return;
}
dispatch(addToast({ title: 'Image Uploaded', status: 'success' })); dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
}, },
}); });

View File

@ -1,7 +1,7 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { imageUrlsReceived } from 'services/thunks/image'; import { imageUrlsReceived } from 'services/thunks/image';
import { imagesAdapter } from 'features/gallery/store/imagesSlice'; import { imageUpdatedOne } from 'features/gallery/store/imagesSlice';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });
@ -14,13 +14,12 @@ export const addImageUrlsReceivedFulfilledListener = () => {
const { image_name, image_url, thumbnail_url } = image; const { image_name, image_url, thumbnail_url } = image;
imagesAdapter.updateOne(getState().images, { dispatch(
id: image_name, imageUpdatedOne({
changes: { id: image_name,
image_url, changes: { image_url, thumbnail_url },
thumbnail_url, })
}, );
});
}, },
}); });
}; };

View File

@ -2,12 +2,10 @@ import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { t } from 'i18next'; import { t } from 'i18next';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { import { initialImageSelected } from 'features/parameters/store/actions';
initialImageSelected,
isImageDTO,
} from 'features/parameters/store/actions';
import { makeToast } from 'app/components/Toaster'; import { makeToast } from 'app/components/Toaster';
import { selectImagesById } from 'features/gallery/store/imagesSlice'; import { selectImagesById } from 'features/gallery/store/imagesSlice';
import { isImageDTO } from 'services/types/guards';
export const addInitialImageSelectedListener = () => { export const addInitialImageSelectedListener = () => {
startAppListening({ startAppListening({

View File

@ -0,0 +1,93 @@
import { socketConnected } from 'services/events/actions';
import { startAppListening } from '..';
import { createSelector } from '@reduxjs/toolkit';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { nodesSelecter } from 'features/nodes/store/nodesSlice';
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
import { ImageDTO } from 'services/api';
import { forEach, uniqBy } from 'lodash-es';
import { imageUrlsReceived } from 'services/thunks/image';
import { log } from 'app/logging/useLogger';
import { selectImagesEntities } from 'features/gallery/store/imagesSlice';
const moduleLog = log.child({ namespace: 'images' });
const selectAllUsedImages = createSelector(
[
generationSelector,
canvasSelector,
nodesSelecter,
controlNetSelector,
selectImagesEntities,
],
(generation, canvas, nodes, controlNet, imageEntities) => {
const allUsedImages: ImageDTO[] = [];
if (generation.initialImage) {
allUsedImages.push(generation.initialImage);
}
canvas.layerState.objects.forEach((obj) => {
if (obj.kind === 'image') {
allUsedImages.push(obj.image);
}
});
nodes.nodes.forEach((node) => {
forEach(node.data.inputs, (input) => {
if (input.type === 'image' && input.value) {
allUsedImages.push(input.value);
}
});
});
forEach(controlNet.controlNets, (c) => {
if (c.controlImage) {
allUsedImages.push(c.controlImage);
}
if (c.processedControlImage) {
allUsedImages.push(c.processedControlImage);
}
});
forEach(imageEntities, (image) => {
if (image) {
allUsedImages.push(image);
}
});
const uniqueImages = uniqBy(allUsedImages, 'image_name');
return uniqueImages;
}
);
export const addUpdateImageUrlsOnConnectListener = () => {
startAppListening({
actionCreator: socketConnected,
effect: async (action, { dispatch, getState, take }) => {
const state = getState();
if (!state.config.shouldUpdateImagesOnConnect) {
return;
}
const allUsedImages = selectAllUsedImages(state);
moduleLog.trace(
{ data: allUsedImages },
`Fetching new image URLs for ${allUsedImages.length} images`
);
allUsedImages.forEach(({ image_name, image_origin }) => {
dispatch(
imageUrlsReceived({
imageName: image_name,
imageOrigin: image_origin,
})
);
});
},
});
};

View File

@ -13,6 +13,7 @@ import galleryReducer from 'features/gallery/store/gallerySlice';
import imagesReducer from 'features/gallery/store/imagesSlice'; import imagesReducer from 'features/gallery/store/imagesSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice'; import generationReducer from 'features/parameters/store/generationSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice'; import systemReducer from 'features/system/store/systemSlice';
// import sessionReducer from 'features/system/store/sessionSlice'; // import sessionReducer from 'features/system/store/sessionSlice';
@ -45,6 +46,7 @@ const allReducers = {
ui: uiReducer, ui: uiReducer,
hotkeys: hotkeysReducer, hotkeys: hotkeysReducer,
images: imagesReducer, images: imagesReducer,
controlNet: controlNetReducer,
// session: sessionReducer, // session: sessionReducer,
}; };
@ -62,6 +64,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'postprocessing', 'postprocessing',
'system', 'system',
'ui', 'ui',
'controlNet',
// 'hotkeys', // 'hotkeys',
// 'config', // 'config',
]; ];

View File

@ -95,6 +95,7 @@ export type AppFeature =
* A disable-able Stable Diffusion feature * A disable-able Stable Diffusion feature
*/ */
export type SDFeature = export type SDFeature =
| 'controlNet'
| 'noise' | 'noise'
| 'variation' | 'variation'
| 'symmetry' | 'symmetry'
@ -107,12 +108,9 @@ export type SDFeature =
*/ */
export type AppConfig = { export type AppConfig = {
/** /**
* Whether or not URLs should be transformed to use a different host * Whether or not we should update image urls when image loading errors
*/
shouldTransformUrls: boolean;
/**
* Whether or not we need to re-fetch images
*/ */
shouldUpdateImagesOnConnect: boolean;
disabledTabs: InvokeTabName[]; disabledTabs: InvokeTabName[];
disabledFeatures: AppFeature[]; disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[]; disabledSDFeatures: SDFeature[];

View File

@ -1,17 +0,0 @@
import { Checkbox, CheckboxProps } from '@chakra-ui/react';
import { memo, ReactNode } from 'react';
type IAICheckboxProps = CheckboxProps & {
label: string | ReactNode;
};
const IAICheckbox = (props: IAICheckboxProps) => {
const { label, ...rest } = props;
return (
<Checkbox colorScheme="accent" {...rest}>
{label}
</Checkbox>
);
};
export default memo(IAICheckbox);

View File

@ -49,7 +49,7 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
/> />
)} )}
</Flex> </Flex>
<Collapse in={isOpen} animateOpacity> <Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}>
<Box sx={{ p: 4, borderBottomRadius: 'base', bg: 'base.800' }}> <Box sx={{ p: 4, borderBottomRadius: 'base', bg: 'base.800' }}>
{children} {children}
</Box> </Box>

View File

@ -1,4 +1,4 @@
import { CheckIcon } from '@chakra-ui/icons'; import { CheckIcon, ChevronUpIcon } from '@chakra-ui/icons';
import { import {
Box, Box,
Flex, Flex,
@ -10,7 +10,6 @@ import {
GridItem, GridItem,
List, List,
ListItem, ListItem,
Select,
Text, Text,
Tooltip, Tooltip,
TooltipProps, TooltipProps,
@ -19,7 +18,8 @@ import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom';
import { useSelect } from 'downshift'; import { useSelect } from 'downshift';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo } from 'react'; import { memo, useMemo } from 'react';
import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles';
export type ItemTooltips = { [key: string]: string }; export type ItemTooltips = { [key: string]: string };
@ -34,6 +34,7 @@ type IAICustomSelectProps = {
buttonProps?: FlexProps; buttonProps?: FlexProps;
tooltip?: string; tooltip?: string;
tooltipProps?: Omit<TooltipProps, 'children'>; tooltipProps?: Omit<TooltipProps, 'children'>;
ellipsisPosition?: 'start' | 'end';
}; };
const IAICustomSelect = (props: IAICustomSelectProps) => { const IAICustomSelect = (props: IAICustomSelectProps) => {
@ -48,6 +49,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
tooltip, tooltip,
buttonProps, buttonProps,
tooltipProps, tooltipProps,
ellipsisPosition = 'end',
} = props; } = props;
const { const {
@ -69,6 +71,14 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
middleware: [offset(4), shift({ crossAxis: true, padding: 8 })], middleware: [offset(4), shift({ crossAxis: true, padding: 8 })],
}); });
const labelTextDirection = useMemo(() => {
if (ellipsisPosition === 'start') {
return document.dir === 'rtl' ? 'ltr' : 'rtl';
}
return document.dir;
}, [ellipsisPosition]);
return ( return (
<FormControl sx={{ w: 'full' }} {...formControlProps}> <FormControl sx={{ w: 'full' }} {...formControlProps}>
{label && ( {label && (
@ -82,20 +92,44 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
</FormLabel> </FormLabel>
)} )}
<Tooltip label={tooltip} {...tooltipProps}> <Tooltip label={tooltip} {...tooltipProps}>
<Select <Flex
{...getToggleButtonProps({ ref: refs.setReference })} {...getToggleButtonProps({ ref: refs.setReference })}
{...buttonProps} {...buttonProps}
as={Flex}
sx={{ sx={{
alignItems: 'center', alignItems: 'center',
userSelect: 'none', userSelect: 'none',
cursor: 'pointer', cursor: 'pointer',
overflow: 'hidden',
width: 'full',
py: 1,
px: 2,
gap: 2,
justifyContent: 'space-between',
...getInputOutlineStyles(),
}} }}
> >
<Text sx={{ fontSize: 'sm', fontWeight: 500, color: 'base.100' }}> <Text
sx={{
fontSize: 'sm',
fontWeight: 500,
color: 'base.100',
whiteSpace: 'nowrap',
overflow: 'hidden',
textOverflow: 'ellipsis',
direction: labelTextDirection,
}}
>
{selectedItem} {selectedItem}
</Text> </Text>
</Select> <ChevronUpIcon
sx={{
color: 'base.300',
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
transitionProperty: 'common',
transitionDuration: 'normal',
}}
/>
</Flex>
</Tooltip> </Tooltip>
<Box {...getMenuProps()}> <Box {...getMenuProps()}>
{isOpen && ( {isOpen && (
@ -104,11 +138,10 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
ref={refs.setFloating} ref={refs.setFloating}
sx={{ sx={{
...floatingStyles, ...floatingStyles,
width: 'max-content',
top: 0, top: 0,
left: 0, insetInlineStart: 0,
flexDirection: 'column', flexDirection: 'column',
zIndex: 1, zIndex: 2,
bg: 'base.800', bg: 'base.800',
borderRadius: 'base', borderRadius: 'base',
border: '1px', border: '1px',
@ -118,61 +151,72 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
px: 0, px: 0,
h: 'fit-content', h: 'fit-content',
maxH: 64, maxH: 64,
minW: 48,
}} }}
> >
<OverlayScrollbarsComponent> <OverlayScrollbarsComponent>
{items.map((item, index) => ( {items.map((item, index) => {
<Tooltip const isSelected = selectedItem === item;
isDisabled={!itemTooltips} const isHighlighted = highlightedIndex === index;
key={`${item}${index}`} const fontWeight = isSelected ? 700 : 500;
label={itemTooltips?.[item]} const bg = isHighlighted
hasArrow ? 'base.700'
placement="right" : isSelected
> ? 'base.750'
<ListItem : undefined;
sx={{ return (
bg: highlightedIndex === index ? 'base.700' : undefined, <Tooltip
py: 1, isDisabled={!itemTooltips}
paddingInlineStart: 3,
paddingInlineEnd: 6,
cursor: 'pointer',
transitionProperty: 'common',
transitionDuration: '0.15s',
}}
key={`${item}${index}`} key={`${item}${index}`}
{...getItemProps({ item, index })} label={itemTooltips?.[item]}
hasArrow
placement="right"
> >
{withCheckIcon ? ( <ListItem
<Grid gridTemplateColumns="1.25rem auto"> sx={{
<GridItem> bg,
{selectedItem === item && <CheckIcon boxSize={2} />} py: 1,
</GridItem> paddingInlineStart: 3,
<GridItem> paddingInlineEnd: 6,
<Text cursor: 'pointer',
sx={{ transitionProperty: 'common',
fontSize: 'sm', transitionDuration: '0.15s',
color: 'base.100', }}
fontWeight: 500, key={`${item}${index}`}
}} {...getItemProps({ item, index })}
> >
{item} {withCheckIcon ? (
</Text> <Grid gridTemplateColumns="1.25rem auto">
</GridItem> <GridItem>
</Grid> {isSelected && <CheckIcon boxSize={2} />}
) : ( </GridItem>
<Text <GridItem>
sx={{ <Text
fontSize: 'sm', sx={{
color: 'base.100', fontSize: 'sm',
fontWeight: 500, color: 'base.100',
}} fontWeight,
> }}
{item} >
</Text> {item}
)} </Text>
</ListItem> </GridItem>
</Tooltip> </Grid>
))} ) : (
<Text
sx={{
fontSize: 'sm',
color: 'base.50',
fontWeight,
}}
>
{item}
</Text>
)}
</ListItem>
</Tooltip>
);
})}
</OverlayScrollbarsComponent> </OverlayScrollbarsComponent>
</List> </List>
)} )}

View File

@ -0,0 +1,167 @@
import { Box, Flex, Icon, IconButtonProps, Image } from '@chakra-ui/react';
import { useDraggable, useDroppable } from '@dnd-kit/core';
import { useCombinedRefs } from '@dnd-kit/utilities';
import IAIIconButton from 'common/components/IAIIconButton';
import { IAIImageFallback } from 'common/components/IAIImageFallback';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { AnimatePresence } from 'framer-motion';
import { ReactElement, SyntheticEvent } from 'react';
import { memo, useRef } from 'react';
import { FaImage, FaTimes } from 'react-icons/fa';
import { ImageDTO } from 'services/api';
import { v4 as uuidv4 } from 'uuid';
import IAIDropOverlay from './IAIDropOverlay';
type IAIDndImageProps = {
image: ImageDTO | null | undefined;
onDrop: (droppedImage: ImageDTO) => void;
onReset?: () => void;
onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
onLoad?: (event: SyntheticEvent<HTMLImageElement>) => void;
resetIconSize?: IconButtonProps['size'];
withResetIcon?: boolean;
withMetadataOverlay?: boolean;
isDragDisabled?: boolean;
isDropDisabled?: boolean;
fallback?: ReactElement;
payloadImage?: ImageDTO | null | undefined;
minSize?: number;
};
const IAIDndImage = (props: IAIDndImageProps) => {
const {
image,
onDrop,
onReset,
onError,
resetIconSize = 'md',
withResetIcon = false,
withMetadataOverlay = false,
isDropDisabled = false,
isDragDisabled = false,
fallback = <IAIImageFallback />,
payloadImage,
minSize = 24,
} = props;
const dndId = useRef(uuidv4());
const {
isOver,
setNodeRef: setDroppableRef,
active,
} = useDroppable({
id: dndId.current,
disabled: isDropDisabled,
data: {
handleDrop: onDrop,
},
});
const {
attributes,
listeners,
setNodeRef: setDraggableRef,
} = useDraggable({
id: dndId.current,
data: {
image: payloadImage ? payloadImage : image,
},
disabled: isDragDisabled,
});
const setNodeRef = useCombinedRefs(setDroppableRef, setDraggableRef);
return (
<Flex
sx={{
width: 'full',
height: 'full',
alignItems: 'center',
justifyContent: 'center',
position: 'relative',
minW: minSize,
minH: minSize,
userSelect: 'none',
cursor: 'grab',
}}
{...attributes}
{...listeners}
ref={setNodeRef}
>
{image && (
<Flex
sx={{
w: 'full',
h: 'full',
position: 'relative',
alignItems: 'center',
justifyContent: 'center',
}}
>
<Image
src={image.image_url}
fallbackStrategy="beforeLoadOrError"
fallback={fallback}
onError={onError}
objectFit="contain"
draggable={false}
sx={{
maxW: 'full',
maxH: 'full',
borderRadius: 'base',
}}
/>
{withMetadataOverlay && <ImageMetadataOverlay image={image} />}
{onReset && withResetIcon && (
<Box
sx={{
position: 'absolute',
top: 0,
right: 0,
p: 2,
}}
>
<IAIIconButton
size={resetIconSize}
tooltip="Reset Image"
aria-label="Reset Image"
icon={<FaTimes />}
onClick={onReset}
/>
</Box>
)}
<AnimatePresence>
{active && <IAIDropOverlay isOver={isOver} />}
</AnimatePresence>
</Flex>
)}
{!image && (
<>
<Flex
sx={{
minH: minSize,
bg: 'base.850',
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
}}
>
<Icon
as={FaImage}
sx={{
boxSize: 24,
color: 'base.500',
}}
/>
</Flex>
<AnimatePresence>
{active && <IAIDropOverlay isOver={isOver} />}
</AnimatePresence>
</>
)}
</Flex>
);
};
export default memo(IAIDndImage);

View File

@ -0,0 +1,91 @@
import { Flex, Text } from '@chakra-ui/react';
import { motion } from 'framer-motion';
import { memo, useRef } from 'react';
import { v4 as uuidv4 } from 'uuid';
type Props = {
isOver: boolean;
label?: string;
};
export const IAIDropOverlay = (props: Props) => {
const { isOver, label = 'Drop' } = props;
const motionId = useRef(uuidv4());
return (
<motion.div
key={motionId.current}
initial={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
exit={{
opacity: 0,
transition: { duration: 0.1 },
}}
>
<Flex
sx={{
position: 'absolute',
top: 0,
left: 0,
w: 'full',
h: 'full',
}}
>
<Flex
sx={{
position: 'absolute',
top: 0,
left: 0,
w: 'full',
h: 'full',
bg: 'base.900',
opacity: 0.7,
borderRadius: 'base',
alignItems: 'center',
justifyContent: 'center',
transitionProperty: 'common',
transitionDuration: '0.1s',
}}
/>
<Flex
sx={{
position: 'absolute',
top: 0,
left: 0,
w: 'full',
h: 'full',
opacity: 1,
borderWidth: 2,
borderColor: isOver ? 'base.200' : 'base.500',
borderRadius: 'base',
borderStyle: 'dashed',
transitionProperty: 'common',
transitionDuration: '0.1s',
alignItems: 'center',
justifyContent: 'center',
}}
>
<Text
sx={{
fontSize: '2xl',
fontWeight: 600,
transform: isOver ? 'scale(1.1)' : 'scale(1)',
color: isOver ? 'base.100' : 'base.500',
transitionProperty: 'common',
transitionDuration: '0.1s',
}}
>
{label}
</Text>
</Flex>
</Flex>
</motion.div>
);
};
export default memo(IAIDropOverlay);

View File

@ -0,0 +1,25 @@
import {
Checkbox,
CheckboxProps,
FormControl,
FormControlProps,
FormLabel,
} from '@chakra-ui/react';
import { memo, ReactNode } from 'react';
type IAIFullCheckboxProps = CheckboxProps & {
label: string | ReactNode;
formControlProps?: FormControlProps;
};
const IAIFullCheckbox = (props: IAIFullCheckboxProps) => {
const { label, formControlProps, ...rest } = props;
return (
<FormControl {...formControlProps}>
<FormLabel>{label}</FormLabel>
<Checkbox colorScheme="accent" {...rest} />
</FormControl>
);
};
export default memo(IAIFullCheckbox);

View File

@ -0,0 +1,27 @@
import { Flex, FlexProps, Spinner, SpinnerProps } from '@chakra-ui/react';
type Props = FlexProps & {
spinnerProps?: SpinnerProps;
};
export const IAIImageFallback = (props: Props) => {
const { spinnerProps, ...rest } = props;
const { sx, ...restFlexProps } = rest;
return (
<Flex
sx={{
bg: 'base.900',
opacity: 0.7,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
...sx,
}}
{...restFlexProps}
>
<Spinner size="xl" {...spinnerProps} />
</Flex>
);
};

View File

@ -0,0 +1,19 @@
import { Checkbox, CheckboxProps, Text } from '@chakra-ui/react';
import { memo, ReactElement } from 'react';
type IAISimpleCheckboxProps = CheckboxProps & {
label: string | ReactElement;
};
const IAISimpleCheckbox = (props: IAISimpleCheckboxProps) => {
const { label, ...rest } = props;
return (
<Checkbox colorScheme="accent" {...rest}>
<Text color="base.200" fontSize="md">
{label}
</Text>
</Checkbox>
);
};
export default memo(IAISimpleCheckbox);

View File

@ -40,7 +40,7 @@ import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
import { roundDownToMultiple } from 'common/util/roundDownToMultiple'; import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
export type IAIFullSliderProps = { export type IAIFullSliderProps = {
label: string; label?: string;
value: number; value: number;
min?: number; min?: number;
max?: number; max?: number;
@ -178,9 +178,11 @@ const IAISlider = (props: IAIFullSliderProps) => {
isDisabled={isDisabled} isDisabled={isDisabled}
{...sliderFormControlProps} {...sliderFormControlProps}
> >
<FormLabel {...sliderFormLabelProps} mb={-1}> {label && (
{label} <FormLabel {...sliderFormLabelProps} mb={-1}>
</FormLabel> {label}
</FormLabel>
)}
<HStack w="100%" gap={2} alignItems="center"> <HStack w="100%" gap={2} alignItems="center">
<Slider <Slider
@ -203,6 +205,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{ sx={{
insetInlineStart: '0 !important', insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important', insetInlineEnd: 'unset !important',
mt: 1.5,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >
@ -213,6 +216,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{ sx={{
insetInlineStart: 'unset !important', insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important', insetInlineEnd: '0 !important',
mt: 1.5,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >

View File

@ -5,6 +5,7 @@ import {
FormLabelProps, FormLabelProps,
Switch, Switch,
SwitchProps, SwitchProps,
Tooltip,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { memo } from 'react'; import { memo } from 'react';
@ -13,6 +14,7 @@ interface Props extends SwitchProps {
width?: string | number; width?: string | number;
formControlProps?: FormControlProps; formControlProps?: FormControlProps;
formLabelProps?: FormLabelProps; formLabelProps?: FormLabelProps;
tooltip?: string;
} }
/** /**
@ -25,22 +27,27 @@ const IAISwitch = (props: Props) => {
width = 'auto', width = 'auto',
formControlProps, formControlProps,
formLabelProps, formLabelProps,
tooltip,
...rest ...rest
} = props; } = props;
return ( return (
<FormControl <Tooltip label={tooltip} hasArrow placement="top" isDisabled={!tooltip}>
isDisabled={isDisabled} <FormControl
width={width} isDisabled={isDisabled}
display="flex" width={width}
gap={4} display="flex"
alignItems="center" gap={4}
{...formControlProps} alignItems="center"
> {...formControlProps}
<FormLabel my={1} flexGrow={1} {...formLabelProps}> >
{label} {label && (
</FormLabel> <FormLabel my={1} flexGrow={1} {...formLabelProps}>
<Switch {...rest} /> {label}
</FormControl> </FormLabel>
)}
<Switch {...rest} />
</FormControl>
</Tooltip>
); );
}; };

View File

@ -1,5 +1,5 @@
import { Badge, Flex } from '@chakra-ui/react'; import { Badge, Flex } from '@chakra-ui/react';
import { isNumber, isString } from 'lodash-es'; import { isString } from 'lodash-es';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
@ -8,14 +8,6 @@ type ImageMetadataOverlayProps = {
}; };
const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => { const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => {
const dimensions = useMemo(() => {
if (!isNumber(image.metadata?.width) || isNumber(!image.metadata?.height)) {
return;
}
return `${image.metadata?.width} × ${image.metadata?.height}`;
}, [image.metadata]);
const model = useMemo(() => { const model = useMemo(() => {
if (!isString(image.metadata?.model)) { if (!isString(image.metadata?.model)) {
return; return;
@ -31,17 +23,15 @@ const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => {
flexDirection: 'column', flexDirection: 'column',
position: 'absolute', position: 'absolute',
top: 0, top: 0,
right: 0, insetInlineStart: 0,
p: 2, p: 2,
alignItems: 'flex-end', alignItems: 'flex-start',
gap: 2, gap: 2,
}} }}
> >
{dimensions && ( <Badge variant="solid" colorScheme="base">
<Badge variant="solid" colorScheme="base"> {image.width} × {image.height}
{dimensions} </Badge>
</Badge>
)}
{model && ( {model && (
<Badge variant="solid" colorScheme="base"> <Badge variant="solid" colorScheme="base">
{model} {model}

View File

@ -1,42 +0,0 @@
import { ButtonGroup, Flex, Spacer, Text } from '@chakra-ui/react';
import IAIIconButton from 'common/components/IAIIconButton';
import { useTranslation } from 'react-i18next';
import { FaUndo, FaUpload } from 'react-icons/fa';
import { useAppDispatch } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import useImageUploader from 'common/hooks/useImageUploader';
const InitialImageButtons = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { openUploader } = useImageUploader();
const handleResetInitialImage = useCallback(() => {
dispatch(clearInitialImage());
}, [dispatch]);
return (
<Flex w="full" alignItems="center">
<Text size="sm" fontWeight={500} color="base.300">
{t('parameters.initialImage')}
</Text>
<Spacer />
<ButtonGroup>
<IAIIconButton
icon={<FaUndo />}
aria-label={t('accessibility.reset')}
onClick={handleResetInitialImage}
/>
<IAIIconButton
icon={<FaUpload />}
onClick={openUploader}
aria-label={t('common.upload')}
/>
</ButtonGroup>
</Flex>
);
};
export default InitialImageButtons;

View File

@ -1,12 +1,12 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { validateSeedWeights } from 'common/util/seedWeightPairs'; import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { systemSelector } from 'features/system/store/systemSelectors'; import { systemSelector } from 'features/system/store/systemSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es';
export const readinessSelector = createSelector( const readinessSelector = createSelector(
[generationSelector, systemSelector, activeTabNameSelector], [generationSelector, systemSelector, activeTabNameSelector],
(generation, system, activeTabName) => { (generation, system, activeTabName) => {
const { const {
@ -60,3 +60,8 @@ export const readinessSelector = createSelector(
}, },
defaultSelectorOptions defaultSelectorOptions
); );
export const useIsReadyToInvoke = () => {
const { isReady } = useAppSelector(readinessSelector);
return isReady;
};

View File

@ -1,34 +0,0 @@
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { OpenAPI } from 'services/api';
export const getUrlAlt = (url: string, shouldTransformUrls: boolean) => {
if (OpenAPI.BASE && shouldTransformUrls) {
return [OpenAPI.BASE, url].join('/');
}
return url;
};
export const useGetUrl = () => {
const shouldTransformUrls = useAppSelector(
(state: RootState) => state.config.shouldTransformUrls
);
const getUrl = useCallback(
(url?: string) => {
if (OpenAPI.BASE && shouldTransformUrls) {
return [OpenAPI.BASE, url].join('/');
}
return url;
},
[shouldTransformUrls]
);
return {
shouldTransformUrls,
getUrl,
};
};

View File

@ -1,6 +1,5 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { rgbaColorToString } from 'features/canvas/util/colorToString'; import { rgbaColorToString } from 'features/canvas/util/colorToString';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
@ -33,7 +32,6 @@ const selector = createSelector(
const IAICanvasObjectRenderer = () => { const IAICanvasObjectRenderer = () => {
const { objects } = useAppSelector(selector); const { objects } = useAppSelector(selector);
const { getUrl } = useGetUrl();
if (!objects) return null; if (!objects) return null;
@ -46,7 +44,7 @@ const IAICanvasObjectRenderer = () => {
key={i} key={i}
x={obj.x} x={obj.x}
y={obj.y} y={obj.y}
url={getUrl(obj.image.image_url)} url={obj.image.image_url}
/> />
); );
} else if (isCanvasBaseLine(obj)) { } else if (isCanvasBaseLine(obj)) {

View File

@ -1,6 +1,5 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { GroupConfig } from 'konva/lib/Group'; import { GroupConfig } from 'konva/lib/Group';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
@ -56,13 +55,12 @@ const IAICanvasStagingArea = (props: Props) => {
width, width,
height, height,
} = useAppSelector(selector); } = useAppSelector(selector);
const { getUrl } = useGetUrl();
return ( return (
<Group {...rest}> <Group {...rest}>
{shouldShowStagingImage && currentStagingAreaImage && ( {shouldShowStagingImage && currentStagingAreaImage && (
<IAICanvasImage <IAICanvasImage
url={getUrl(currentStagingAreaImage.image.image_url) ?? ''} url={currentStagingAreaImage.image.image_url}
x={x} x={x}
y={y} y={y}
/> />

View File

@ -2,7 +2,7 @@ import { ButtonGroup, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAICheckbox from 'common/components/IAICheckbox'; import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAIColorPicker from 'common/components/IAIColorPicker'; import IAIColorPicker from 'common/components/IAIColorPicker';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover'; import IAIPopover from 'common/components/IAIPopover';
@ -117,12 +117,12 @@ const IAICanvasMaskOptions = () => {
} }
> >
<Flex direction="column" gap={2}> <Flex direction="column" gap={2}>
<IAICheckbox <IAISimpleCheckbox
label={`${t('unifiedCanvas.enableMask')} (H)`} label={`${t('unifiedCanvas.enableMask')} (H)`}
isChecked={isMaskEnabled} isChecked={isMaskEnabled}
onChange={handleToggleEnableMask} onChange={handleToggleEnableMask}
/> />
<IAICheckbox <IAISimpleCheckbox
label={t('unifiedCanvas.preserveMaskedArea')} label={t('unifiedCanvas.preserveMaskedArea')}
isChecked={shouldPreserveMaskedArea} isChecked={shouldPreserveMaskedArea}
onChange={(e) => onChange={(e) =>

View File

@ -1,7 +1,7 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAICheckbox from 'common/components/IAICheckbox'; import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover'; import IAIPopover from 'common/components/IAIPopover';
import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
@ -102,50 +102,50 @@ const IAICanvasSettingsButtonPopover = () => {
} }
> >
<Flex direction="column" gap={2}> <Flex direction="column" gap={2}>
<IAICheckbox <IAISimpleCheckbox
label={t('unifiedCanvas.showIntermediates')} label={t('unifiedCanvas.showIntermediates')}
isChecked={shouldShowIntermediates} isChecked={shouldShowIntermediates}
onChange={(e) => onChange={(e) =>
dispatch(setShouldShowIntermediates(e.target.checked)) dispatch(setShouldShowIntermediates(e.target.checked))
} }
/> />
<IAICheckbox <IAISimpleCheckbox
label={t('unifiedCanvas.showGrid')} label={t('unifiedCanvas.showGrid')}
isChecked={shouldShowGrid} isChecked={shouldShowGrid}
onChange={(e) => dispatch(setShouldShowGrid(e.target.checked))} onChange={(e) => dispatch(setShouldShowGrid(e.target.checked))}
/> />
<IAICheckbox <IAISimpleCheckbox
label={t('unifiedCanvas.snapToGrid')} label={t('unifiedCanvas.snapToGrid')}
isChecked={shouldSnapToGrid} isChecked={shouldSnapToGrid}
onChange={handleChangeShouldSnapToGrid} onChange={handleChangeShouldSnapToGrid}
/> />
<IAICheckbox <IAISimpleCheckbox
label={t('unifiedCanvas.darkenOutsideSelection')} label={t('unifiedCanvas.darkenOutsideSelection')}
isChecked={shouldDarkenOutsideBoundingBox} isChecked={shouldDarkenOutsideBoundingBox}
onChange={(e) => onChange={(e) =>
dispatch(setShouldDarkenOutsideBoundingBox(e.target.checked)) dispatch(setShouldDarkenOutsideBoundingBox(e.target.checked))
} }
/> />
<IAICheckbox <IAISimpleCheckbox
label={t('unifiedCanvas.autoSaveToGallery')} label={t('unifiedCanvas.autoSaveToGallery')}
isChecked={shouldAutoSave} isChecked={shouldAutoSave}
onChange={(e) => dispatch(setShouldAutoSave(e.target.checked))} onChange={(e) => dispatch(setShouldAutoSave(e.target.checked))}
/> />
<IAICheckbox <IAISimpleCheckbox
label={t('unifiedCanvas.saveBoxRegionOnly')} label={t('unifiedCanvas.saveBoxRegionOnly')}
isChecked={shouldCropToBoundingBoxOnSave} isChecked={shouldCropToBoundingBoxOnSave}
onChange={(e) => onChange={(e) =>
dispatch(setShouldCropToBoundingBoxOnSave(e.target.checked)) dispatch(setShouldCropToBoundingBoxOnSave(e.target.checked))
} }
/> />
<IAICheckbox <IAISimpleCheckbox
label={t('unifiedCanvas.limitStrokesToBox')} label={t('unifiedCanvas.limitStrokesToBox')}
isChecked={shouldRestrictStrokesToBox} isChecked={shouldRestrictStrokesToBox}
onChange={(e) => onChange={(e) =>
dispatch(setShouldRestrictStrokesToBox(e.target.checked)) dispatch(setShouldRestrictStrokesToBox(e.target.checked))
} }
/> />
<IAICheckbox <IAISimpleCheckbox
label={t('unifiedCanvas.showCanvasDebugInfo')} label={t('unifiedCanvas.showCanvasDebugInfo')}
isChecked={shouldShowCanvasDebugInfo} isChecked={shouldShowCanvasDebugInfo}
onChange={(e) => onChange={(e) =>
@ -153,7 +153,7 @@ const IAICanvasSettingsButtonPopover = () => {
} }
/> />
<IAICheckbox <IAISimpleCheckbox
label={t('unifiedCanvas.antialiasing')} label={t('unifiedCanvas.antialiasing')}
isChecked={shouldAntialias} isChecked={shouldAntialias}
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))} onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}

View File

@ -1,4 +1,4 @@
import { ButtonGroup, Flex } from '@chakra-ui/react'; import { Box, ButtonGroup, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
@ -210,16 +210,19 @@ const IAICanvasToolbar = () => {
sx={{ sx={{
alignItems: 'center', alignItems: 'center',
gap: 2, gap: 2,
flexWrap: 'wrap',
}} }}
> >
<IAISelect <Box w={24}>
tooltip={`${t('unifiedCanvas.layer')} (Q)`} <IAISelect
tooltipProps={{ hasArrow: true, placement: 'top' }} tooltip={`${t('unifiedCanvas.layer')} (Q)`}
value={layer} tooltipProps={{ hasArrow: true, placement: 'top' }}
validValues={LAYER_NAMES_DICT} value={layer}
onChange={handleChangeLayer} validValues={LAYER_NAMES_DICT}
isDisabled={isStaging} onChange={handleChangeLayer}
/> isDisabled={isStaging}
/>
</Box>
<IAICanvasMaskOptions /> <IAICanvasMaskOptions />
<IAICanvasToolChooserOptions /> <IAICanvasToolChooserOptions />

View File

@ -30,6 +30,8 @@ import {
} from './canvasTypes'; } from './canvasTypes';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { sessionCanceled } from 'services/thunks/session'; import { sessionCanceled } from 'services/thunks/session';
import { setShouldUseCanvasBetaLayout } from 'features/ui/store/uiSlice';
import { imageUrlsReceived } from 'services/thunks/image';
export const initialLayerState: CanvasLayerState = { export const initialLayerState: CanvasLayerState = {
objects: [], objects: [],
@ -851,6 +853,30 @@ export const canvasSlice = createSlice({
state.layerState.stagingArea = initialLayerState.stagingArea; state.layerState.stagingArea = initialLayerState.stagingArea;
} }
}); });
builder.addCase(setShouldUseCanvasBetaLayout, (state, action) => {
state.doesCanvasNeedScaling = true;
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } =
action.payload;
state.layerState.objects.forEach((object) => {
if (object.kind === 'image') {
if (object.image.image_name === image_name) {
object.image.image_url = image_url;
object.image.thumbnail_url = thumbnail_url;
}
}
});
state.layerState.stagingArea.images.forEach((stagedImage) => {
if (stagedImage.image.image_name === image_name) {
stagedImage.image.image_url = image_url;
stagedImage.image.thumbnail_url = thumbnail_url;
}
});
});
}, },
}); });

View File

@ -2,10 +2,10 @@ import { getCanvasBaseLayer } from './konvaInstanceProvider';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { konvaNodeToBlob } from './konvaNodeToBlob'; import { konvaNodeToBlob } from './konvaNodeToBlob';
export const getBaseLayerBlob = async ( /**
state: RootState, * Get the canvas base layer blob, with or without bounding box according to `shouldCropToBoundingBoxOnSave`
withoutBoundingBox?: boolean */
) => { export const getBaseLayerBlob = async (state: RootState) => {
const canvasBaseLayer = getCanvasBaseLayer(); const canvasBaseLayer = getCanvasBaseLayer();
if (!canvasBaseLayer) { if (!canvasBaseLayer) {
@ -24,15 +24,14 @@ export const getBaseLayerBlob = async (
const absPos = clonedBaseLayer.getAbsolutePosition(); const absPos = clonedBaseLayer.getAbsolutePosition();
const boundingBox = const boundingBox = shouldCropToBoundingBoxOnSave
shouldCropToBoundingBoxOnSave && !withoutBoundingBox ? {
? { x: boundingBoxCoordinates.x + absPos.x,
x: boundingBoxCoordinates.x + absPos.x, y: boundingBoxCoordinates.y + absPos.y,
y: boundingBoxCoordinates.y + absPos.y, width: boundingBoxDimensions.width,
width: boundingBoxDimensions.width, height: boundingBoxDimensions.height,
height: boundingBoxDimensions.height, }
} : clonedBaseLayer.getClientRect();
: clonedBaseLayer.getClientRect();
return konvaNodeToBlob(clonedBaseLayer, boundingBox); return konvaNodeToBlob(clonedBaseLayer, boundingBox);
}; };

View File

@ -0,0 +1,19 @@
import { getCanvasBaseLayer } from './konvaInstanceProvider';
import { konvaNodeToBlob } from './konvaNodeToBlob';
/**
* Gets the canvas base layer blob, without bounding box
*/
export const getFullBaseLayerBlob = async () => {
const canvasBaseLayer = getCanvasBaseLayer();
if (!canvasBaseLayer) {
return;
}
const clonedBaseLayer = canvasBaseLayer.clone();
clonedBaseLayer.scale({ x: 1, y: 1 });
return konvaNodeToBlob(clonedBaseLayer, clonedBaseLayer.getClientRect());
};

View File

@ -0,0 +1,258 @@
import { memo, useCallback } from 'react';
import {
ControlNetConfig,
controlNetAdded,
controlNetRemoved,
controlNetToggled,
} from '../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks';
import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
import {
Checkbox,
Flex,
FormControl,
FormLabel,
HStack,
TabList,
TabPanels,
Tabs,
Tab,
TabPanel,
Box,
} from '@chakra-ui/react';
import { FaCopy, FaPlus, FaTrash, FaWrench } from 'react-icons/fa';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ControlNetImagePreview from './ControlNetImagePreview';
import IAIIconButton from 'common/components/IAIIconButton';
import { v4 as uuidv4 } from 'uuid';
import { useToggle } from 'react-use';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import ControlNetPreprocessButton from './ControlNetPreprocessButton';
import IAIButton from 'common/components/IAIButton';
import IAISwitch from 'common/components/IAISwitch';
import { ChevronDownIcon, ChevronUpIcon } from '@chakra-ui/icons';
type ControlNetProps = {
controlNet: ControlNetConfig;
};
const ControlNet = (props: ControlNetProps) => {
const {
controlNetId,
isEnabled,
model,
weight,
beginStepPct,
endStepPct,
controlImage,
processedControlImage,
processorNode,
processorType,
} = props.controlNet;
const dispatch = useAppDispatch();
const [shouldShowAdvanced, onToggleAdvanced] = useToggle(false);
const handleDelete = useCallback(() => {
dispatch(controlNetRemoved({ controlNetId }));
}, [controlNetId, dispatch]);
const handleDuplicate = useCallback(() => {
dispatch(
controlNetAdded({ controlNetId: uuidv4(), controlNet: props.controlNet })
);
}, [dispatch, props.controlNet]);
const handleToggleIsEnabled = useCallback(() => {
dispatch(controlNetToggled({ controlNetId }));
}, [controlNetId, dispatch]);
return (
<Flex
sx={{
flexDir: 'column',
gap: 2,
p: 3,
bg: 'base.850',
borderRadius: 'base',
}}
>
<Flex sx={{ gap: 2 }}>
<IAISwitch
tooltip="Toggle"
aria-label="Toggle"
isChecked={isEnabled}
onChange={handleToggleIsEnabled}
/>
<Box
sx={{
w: 'full',
minW: 0,
opacity: isEnabled ? 1 : 0.5,
pointerEvents: isEnabled ? 'auto' : 'none',
transitionProperty: 'common',
transitionDuration: '0.1s',
}}
>
<ParamControlNetModel controlNetId={controlNetId} model={model} />
</Box>
<IAIIconButton
size="sm"
tooltip="Duplicate"
aria-label="Duplicate"
onClick={handleDuplicate}
icon={<FaCopy />}
/>
<IAIIconButton
size="sm"
tooltip="Delete"
aria-label="Delete"
colorScheme="error"
onClick={handleDelete}
icon={<FaTrash />}
/>
<IAIIconButton
size="sm"
aria-label="Expand"
onClick={onToggleAdvanced}
variant="link"
icon={
<ChevronUpIcon
sx={{
boxSize: 4,
color: 'base.300',
transform: shouldShowAdvanced
? 'rotate(0deg)'
: 'rotate(180deg)',
transitionProperty: 'common',
transitionDuration: 'normal',
}}
/>
}
/>
</Flex>
{isEnabled && (
<>
<Flex sx={{ gap: 4 }}>
<Flex
sx={{
flexDir: 'column',
gap: 2,
w: 'full',
h: 24,
paddingInlineStart: 1,
paddingInlineEnd: shouldShowAdvanced ? 1 : 0,
pb: 2,
justifyContent: 'space-between',
}}
>
<ParamControlNetWeight
controlNetId={controlNetId}
weight={weight}
mini
/>
<ParamControlNetBeginEnd
controlNetId={controlNetId}
beginStepPct={beginStepPct}
endStepPct={endStepPct}
mini
/>
</Flex>
{!shouldShowAdvanced && (
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
h: 24,
w: 24,
aspectRatio: '1/1',
}}
>
<ControlNetImagePreview controlNet={props.controlNet} />
</Flex>
)}
</Flex>
{shouldShowAdvanced && (
<>
<Box pt={2}>
<ControlNetImagePreview controlNet={props.controlNet} />
</Box>
<ParamControlNetProcessorSelect
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ControlNetProcessorComponent
controlNetId={controlNetId}
processorNode={processorNode}
/>
</>
)}
</>
)}
</Flex>
);
return (
<Flex sx={{ flexDir: 'column', gap: 3 }}>
<ControlNetImagePreview controlNet={props.controlNet} />
<ParamControlNetModel controlNetId={controlNetId} model={model} />
<Tabs
isFitted
orientation="horizontal"
variant="enclosed"
size="sm"
colorScheme="accent"
>
<TabList>
<Tab
sx={{ 'button&': { _selected: { borderBottomColor: 'base.800' } } }}
>
Model Config
</Tab>
<Tab
sx={{ 'button&': { _selected: { borderBottomColor: 'base.800' } } }}
>
Preprocess
</Tab>
</TabList>
<TabPanels sx={{ pt: 2 }}>
<TabPanel sx={{ p: 0 }}>
<ParamControlNetWeight
controlNetId={controlNetId}
weight={weight}
/>
<ParamControlNetBeginEnd
controlNetId={controlNetId}
beginStepPct={beginStepPct}
endStepPct={endStepPct}
/>
</TabPanel>
<TabPanel sx={{ p: 0 }}>
<ParamControlNetProcessorSelect
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ControlNetProcessorComponent
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ControlNetPreprocessButton controlNet={props.controlNet} />
{/* <IAIButton
size="sm"
leftIcon={<FaUndo />}
onClick={handleReset}
isDisabled={Boolean(!processedControlImage)}
>
Reset Processing
</IAIButton> */}
</TabPanel>
</TabPanels>
</Tabs>
<IAIButton onClick={handleDelete}>Remove ControlNet</IAIButton>
</Flex>
);
};
export default memo(ControlNet);

View File

@ -0,0 +1,144 @@
import { memo, useCallback, useRef, useState } from 'react';
import { ImageDTO } from 'services/api';
import {
ControlNetConfig,
controlNetImageChanged,
controlNetSelector,
} from '../store/controlNetSlice';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Box } from '@chakra-ui/react';
import IAIDndImage from 'common/components/IAIDndImage';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { AnimatePresence, motion } from 'framer-motion';
import { IAIImageFallback } from 'common/components/IAIImageFallback';
import { useHoverDirty } from 'react-use';
const selector = createSelector(
controlNetSelector,
(controlNet) => {
const { isProcessingControlImage } = controlNet;
return { isProcessingControlImage };
},
defaultSelectorOptions
);
type Props = {
controlNet: ControlNetConfig;
};
const ControlNetImagePreview = (props: Props) => {
const { controlNetId, controlImage, processedControlImage, processorType } =
props.controlNet;
const dispatch = useAppDispatch();
const { isProcessingControlImage } = useAppSelector(selector);
const containerRef = useRef<HTMLDivElement>(null);
const isMouseOverImage = useHoverDirty(containerRef);
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
if (controlImage?.image_name === droppedImage.image_name) {
return;
}
dispatch(
controlNetImageChanged({ controlNetId, controlImage: droppedImage })
);
},
[controlImage, controlNetId, dispatch]
);
const shouldShowProcessedImageBackdrop =
Number(controlImage?.width) > Number(processedControlImage?.width) ||
Number(controlImage?.height) > Number(processedControlImage?.height);
const shouldShowProcessedImage =
controlImage &&
processedControlImage &&
!isMouseOverImage &&
!isProcessingControlImage &&
processorType !== 'none';
return (
<Box
ref={containerRef}
sx={{ position: 'relative', w: 'full', h: 'full', aspectRatio: '1/1' }}
>
<IAIDndImage
image={controlImage}
onDrop={handleDrop}
isDropDisabled={Boolean(
processedControlImage && processorType !== 'none'
)}
/>
<AnimatePresence>
{shouldShowProcessedImage && (
<motion.div
initial={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
exit={{
opacity: 0,
transition: { duration: 0.1 },
}}
>
<Box
sx={{
position: 'absolute',
w: 'full',
h: 'full',
top: 0,
insetInlineStart: 0,
}}
>
{shouldShowProcessedImageBackdrop && (
<Box
sx={{
w: 'full',
h: 'full',
bg: 'base.900',
opacity: 0.7,
}}
/>
)}
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineStart: 0,
w: 'full',
h: 'full',
}}
>
<IAIDndImage
image={processedControlImage}
onDrop={handleDrop}
payloadImage={controlImage}
/>
</Box>
</Box>
</motion.div>
)}
</AnimatePresence>
{isProcessingControlImage && (
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineStart: 0,
w: 'full',
h: 'full',
}}
>
<IAIImageFallback />
</Box>
)}
</Box>
);
};
export default memo(ControlNetImagePreview);

View File

@ -0,0 +1,36 @@
import IAIButton from 'common/components/IAIButton';
import { memo, useCallback } from 'react';
import { ControlNetConfig } from '../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks';
import { controlNetImageProcessed } from '../store/actions';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
type Props = {
controlNet: ControlNetConfig;
};
const ControlNetPreprocessButton = (props: Props) => {
const { controlNetId, controlImage } = props.controlNet;
const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const handleProcess = useCallback(() => {
dispatch(
controlNetImageProcessed({
controlNetId,
})
);
}, [controlNetId, dispatch]);
return (
<IAIButton
size="sm"
onClick={handleProcess}
isDisabled={Boolean(!controlImage) || !isReady}
>
Preprocess
</IAIButton>
);
};
export default memo(ControlNetPreprocessButton);

View File

@ -0,0 +1,131 @@
import { memo } from 'react';
import { RequiredControlNetProcessorNode } from '../store/types';
import CannyProcessor from './processors/CannyProcessor';
import HedProcessor from './processors/HedProcessor';
import LineartProcessor from './processors/LineartProcessor';
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
import MediapipeFaceProcessor from './processors/MediapipeFaceProcessor';
import MidasDepthProcessor from './processors/MidasDepthProcessor';
import MlsdImageProcessor from './processors/MlsdImageProcessor';
import NormalBaeProcessor from './processors/NormalBaeProcessor';
import OpenposeProcessor from './processors/OpenposeProcessor';
import PidiProcessor from './processors/PidiProcessor';
import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
export type ControlNetProcessorProps = {
controlNetId: string;
processorNode: RequiredControlNetProcessorNode;
};
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
const { controlNetId, processorNode } = props;
if (processorNode.type === 'canny_image_processor') {
return (
<CannyProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'hed_image_processor') {
return (
<HedProcessor controlNetId={controlNetId} processorNode={processorNode} />
);
}
if (processorNode.type === 'lineart_image_processor') {
return (
<LineartProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'content_shuffle_image_processor') {
return (
<ContentShuffleProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'lineart_anime_image_processor') {
return (
<LineartAnimeProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'mediapipe_face_processor') {
return (
<MediapipeFaceProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'midas_depth_image_processor') {
return (
<MidasDepthProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'mlsd_image_processor') {
return (
<MlsdImageProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'normalbae_image_processor') {
return (
<NormalBaeProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'openpose_image_processor') {
return (
<OpenposeProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'pidi_image_processor') {
return (
<PidiProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
if (processorNode.type === 'zoe_depth_image_processor') {
return (
<ZoeDepthProcessor
controlNetId={controlNetId}
processorNode={processorNode}
/>
);
}
return null;
};
export default memo(ControlNetProcessorComponent);

View File

@ -0,0 +1,20 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { controlNetProcessorParamsChanged } from 'features/controlNet/store/controlNetSlice';
import { ControlNetProcessorNode } from 'features/controlNet/store/types';
import { useCallback } from 'react';
export const useProcessorNodeChanged = () => {
const dispatch = useAppDispatch();
const handleProcessorNodeChanged = useCallback(
(controlNetId: string, changes: Partial<ControlNetProcessorNode>) => {
dispatch(
controlNetProcessorParamsChanged({
controlNetId,
changes,
})
);
},
[dispatch]
);
return handleProcessorNodeChanged;
};

View File

@ -0,0 +1,130 @@
import {
FormControl,
FormLabel,
HStack,
RangeSlider,
RangeSliderFilledTrack,
RangeSliderMark,
RangeSliderThumb,
RangeSliderTrack,
Tooltip,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import {
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { BiReset } from 'react-icons/bi';
type Props = {
controlNetId: string;
beginStepPct: number;
endStepPct: number;
mini?: boolean;
};
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
const ParamControlNetBeginEnd = (props: Props) => {
const { controlNetId, beginStepPct, endStepPct, mini = false } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleStepPctChanged = useCallback(
(v: number[]) => {
dispatch(
controlNetBeginStepPctChanged({ controlNetId, beginStepPct: v[0] })
);
dispatch(controlNetEndStepPctChanged({ controlNetId, endStepPct: v[1] }));
},
[controlNetId, dispatch]
);
const handleStepPctReset = useCallback(() => {
dispatch(controlNetBeginStepPctChanged({ controlNetId, beginStepPct: 0 }));
dispatch(controlNetEndStepPctChanged({ controlNetId, endStepPct: 1 }));
}, [controlNetId, dispatch]);
return (
<FormControl>
<FormLabel>Begin / End Step Percentage</FormLabel>
<HStack w="100%" gap={2} alignItems="center">
<RangeSlider
aria-label={['Begin Step %', 'End Step %']}
value={[beginStepPct, endStepPct]}
onChange={handleStepPctChanged}
min={0}
max={1}
step={0.01}
minStepsBetweenThumbs={5}
>
<RangeSliderTrack>
<RangeSliderFilledTrack />
</RangeSliderTrack>
<Tooltip label={formatPct(beginStepPct)} placement="top" hasArrow>
<RangeSliderThumb index={0} />
</Tooltip>
<Tooltip label={formatPct(endStepPct)} placement="top" hasArrow>
<RangeSliderThumb index={1} />
</Tooltip>
{!mini && (
<>
<RangeSliderMark
value={0}
sx={{
fontSize: 'xs',
fontWeight: '500',
color: 'base.200',
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
mt: 1.5,
}}
>
0%
</RangeSliderMark>
<RangeSliderMark
value={0.5}
sx={{
fontSize: 'xs',
fontWeight: '500',
color: 'base.200',
mt: 1.5,
}}
>
50%
</RangeSliderMark>
<RangeSliderMark
value={1}
sx={{
fontSize: 'xs',
fontWeight: '500',
color: 'base.200',
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
mt: 1.5,
}}
>
100%
</RangeSliderMark>
</>
)}
</RangeSlider>
{!mini && (
<IAIIconButton
size="sm"
aria-label={t('accessibility.reset')}
tooltip={t('accessibility.reset')}
icon={<BiReset />}
onClick={handleStepPctReset}
/>
)}
</HStack>
</FormControl>
);
};
export default memo(ParamControlNetBeginEnd);

View File

@ -0,0 +1,28 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { controlNetToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetIsEnabledProps = {
controlNetId: string;
isEnabled: boolean;
};
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
const { controlNetId, isEnabled } = props;
const dispatch = useAppDispatch();
const handleIsEnabledChanged = useCallback(() => {
dispatch(controlNetToggled({ controlNetId }));
}, [dispatch, controlNetId]);
return (
<IAISwitch
label="Enabled"
isChecked={isEnabled}
onChange={handleIsEnabledChanged}
/>
);
};
export default memo(ParamControlNetIsEnabled);

View File

@ -0,0 +1,36 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIFullCheckbox from 'common/components/IAIFullCheckbox';
import IAISwitch from 'common/components/IAISwitch';
import {
controlNetToggled,
isControlNetImagePreprocessedToggled,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetIsEnabledProps = {
controlNetId: string;
isControlImageProcessed: boolean;
};
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
const { controlNetId, isControlImageProcessed } = props;
const dispatch = useAppDispatch();
const handleIsControlImageProcessedToggled = useCallback(() => {
dispatch(
isControlNetImagePreprocessedToggled({
controlNetId,
})
);
}, [controlNetId, dispatch]);
return (
<IAISwitch
label="Preprocess"
isChecked={isControlImageProcessed}
onChange={handleIsControlImageProcessedToggled}
/>
);
};
export default memo(ParamControlNetIsEnabled);

View File

@ -0,0 +1,41 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAICustomSelect from 'common/components/IAICustomSelect';
import {
CONTROLNET_MODELS,
ControlNetModel,
} from 'features/controlNet/store/constants';
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamIsControlNetModelProps = {
controlNetId: string;
model: ControlNetModel;
};
const ParamIsControlNetModel = (props: ParamIsControlNetModelProps) => {
const { controlNetId, model } = props;
const dispatch = useAppDispatch();
const handleModelChanged = useCallback(
(val: string | null | undefined) => {
// TODO: do not cast
const model = val as ControlNetModel;
dispatch(controlNetModelChanged({ controlNetId, model }));
},
[controlNetId, dispatch]
);
return (
<IAICustomSelect
tooltip={model}
tooltipProps={{ placement: 'top', hasArrow: true }}
items={CONTROLNET_MODELS}
selectedItem={model}
setSelectedItem={handleModelChanged}
ellipsisPosition="start"
withCheckIcon
/>
);
};
export default memo(ParamIsControlNetModel);

View File

@ -0,0 +1,47 @@
import IAICustomSelect from 'common/components/IAICustomSelect';
import { memo, useCallback } from 'react';
import {
ControlNetProcessorNode,
ControlNetProcessorType,
} from '../../store/types';
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
type ParamControlNetProcessorSelectProps = {
controlNetId: string;
processorNode: ControlNetProcessorNode;
};
const CONTROLNET_PROCESSOR_TYPES = Object.keys(
CONTROLNET_PROCESSORS
) as ControlNetProcessorType[];
const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps
) => {
const { controlNetId, processorNode } = props;
const dispatch = useAppDispatch();
const handleProcessorTypeChanged = useCallback(
(v: string | null | undefined) => {
dispatch(
controlNetProcessorTypeChanged({
controlNetId,
processorType: v as ControlNetProcessorType,
})
);
},
[controlNetId, dispatch]
);
return (
<IAICustomSelect
label="Processor"
items={CONTROLNET_PROCESSOR_TYPES}
selectedItem={processorNode.type ?? 'canny_image_processor'}
setSelectedItem={handleProcessorTypeChanged}
withCheckIcon
/>
);
};
export default memo(ParamControlNetProcessorSelect);

View File

@ -0,0 +1,57 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetWeightProps = {
controlNetId: string;
weight: number;
mini?: boolean;
};
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
const { controlNetId, weight, mini = false } = props;
const dispatch = useAppDispatch();
const handleWeightChanged = useCallback(
(weight: number) => {
dispatch(controlNetWeightChanged({ controlNetId, weight }));
},
[controlNetId, dispatch]
);
const handleWeightReset = () => {
dispatch(controlNetWeightChanged({ controlNetId, weight: 1 }));
};
if (mini) {
return (
<IAISlider
label={'Weight'}
sliderFormLabelProps={{ pb: 1 }}
value={weight}
onChange={handleWeightChanged}
min={0}
max={1}
step={0.01}
/>
);
}
return (
<IAISlider
label="Weight"
value={weight}
onChange={handleWeightChanged}
withInput
withReset
handleReset={handleWeightReset}
withSliderMarks
min={0}
max={1}
step={0.01}
/>
);
};
export default memo(ParamControlNetWeight);

View File

@ -0,0 +1,72 @@
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default;
type CannyProcessorProps = {
controlNetId: string;
processorNode: RequiredCannyImageProcessorInvocation;
};
const CannyProcessor = (props: CannyProcessorProps) => {
const { controlNetId, processorNode } = props;
const { low_threshold, high_threshold } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleLowThresholdChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { low_threshold: v });
},
[controlNetId, processorChanged]
);
const handleLowThresholdReset = useCallback(() => {
processorChanged(controlNetId, {
low_threshold: DEFAULTS.low_threshold,
});
}, [controlNetId, processorChanged]);
const handleHighThresholdChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { high_threshold: v });
},
[controlNetId, processorChanged]
);
const handleHighThresholdReset = useCallback(() => {
processorChanged(controlNetId, {
high_threshold: DEFAULTS.high_threshold,
});
}, [controlNetId, processorChanged]);
return (
<ProcessorWrapper>
<IAISlider
label="Low Threshold"
value={low_threshold}
onChange={handleLowThresholdChanged}
handleReset={handleLowThresholdReset}
withReset
min={0}
max={255}
withInput
/>
<IAISlider
label="High Threshold"
value={high_threshold}
onChange={handleHighThresholdChanged}
handleReset={handleHighThresholdReset}
withReset
min={0}
max={255}
withInput
/>
</ProcessorWrapper>
);
};
export default memo(CannyProcessor);

View File

@ -0,0 +1,141 @@
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredContentShuffleImageProcessorInvocation } from 'features/controlNet/store/types';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default;
type Props = {
controlNetId: string;
processorNode: RequiredContentShuffleImageProcessorInvocation;
};
const ContentShuffleProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, w, h, f } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { detect_resolution: v });
},
[controlNetId, processorChanged]
);
const handleDetectResolutionReset = useCallback(() => {
processorChanged(controlNetId, {
detect_resolution: DEFAULTS.detect_resolution,
});
}, [controlNetId, processorChanged]);
const handleImageResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { image_resolution: v });
},
[controlNetId, processorChanged]
);
const handleImageResolutionReset = useCallback(() => {
processorChanged(controlNetId, {
image_resolution: DEFAULTS.image_resolution,
});
}, [controlNetId, processorChanged]);
const handleWChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { w: v });
},
[controlNetId, processorChanged]
);
const handleWReset = useCallback(() => {
processorChanged(controlNetId, {
w: DEFAULTS.w,
});
}, [controlNetId, processorChanged]);
const handleHChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { h: v });
},
[controlNetId, processorChanged]
);
const handleHReset = useCallback(() => {
processorChanged(controlNetId, {
h: DEFAULTS.h,
});
}, [controlNetId, processorChanged]);
const handleFChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { f: v });
},
[controlNetId, processorChanged]
);
const handleFReset = useCallback(() => {
processorChanged(controlNetId, {
f: DEFAULTS.f,
});
}, [controlNetId, processorChanged]);
return (
<ProcessorWrapper>
<IAISlider
label="Detect Resolution"
value={detect_resolution}
onChange={handleDetectResolutionChanged}
handleReset={handleDetectResolutionReset}
withReset
min={0}
max={4096}
withInput
/>
<IAISlider
label="Image Resolution"
value={image_resolution}
onChange={handleImageResolutionChanged}
handleReset={handleImageResolutionReset}
withReset
min={0}
max={4096}
withInput
/>
<IAISlider
label="W"
value={w}
onChange={handleWChanged}
handleReset={handleWReset}
withReset
min={0}
max={4096}
withInput
/>
<IAISlider
label="H"
value={h}
onChange={handleHChanged}
handleReset={handleHReset}
withReset
min={0}
max={4096}
withInput
/>
<IAISlider
label="F"
value={f}
onChange={handleFChanged}
handleReset={handleFReset}
withReset
min={0}
max={4096}
withInput
/>
</ProcessorWrapper>
);
};
export default memo(ContentShuffleProcessor);

View File

@ -0,0 +1,88 @@
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default;
type HedProcessorProps = {
controlNetId: string;
processorNode: RequiredHedImageProcessorInvocation;
};
const HedPreprocessor = (props: HedProcessorProps) => {
const {
controlNetId,
processorNode: { detect_resolution, image_resolution, scribble },
} = props;
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { detect_resolution: v });
},
[controlNetId, processorChanged]
);
const handleImageResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { image_resolution: v });
},
[controlNetId, processorChanged]
);
const handleScribbleChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
processorChanged(controlNetId, { scribble: e.target.checked });
},
[controlNetId, processorChanged]
);
const handleDetectResolutionReset = useCallback(() => {
processorChanged(controlNetId, {
detect_resolution: DEFAULTS.detect_resolution,
});
}, [controlNetId, processorChanged]);
const handleImageResolutionReset = useCallback(() => {
processorChanged(controlNetId, {
image_resolution: DEFAULTS.image_resolution,
});
}, [controlNetId, processorChanged]);
return (
<ProcessorWrapper>
<IAISlider
label="Detect Resolution"
value={detect_resolution}
onChange={handleDetectResolutionChanged}
handleReset={handleDetectResolutionReset}
withReset
min={0}
max={4096}
withInput
/>
<IAISlider
label="Image Resolution"
value={image_resolution}
onChange={handleImageResolutionChanged}
handleReset={handleImageResolutionReset}
withReset
min={0}
max={4096}
withInput
/>
<IAISwitch
label="Scribble"
isChecked={scribble}
onChange={handleScribbleChanged}
/>
</ProcessorWrapper>
);
};
export default memo(HedPreprocessor);

View File

@ -0,0 +1,72 @@
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default;
type Props = {
controlNetId: string;
processorNode: RequiredLineartAnimeImageProcessorInvocation;
};
const LineartAnimeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { detect_resolution: v });
},
[controlNetId, processorChanged]
);
const handleImageResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { image_resolution: v });
},
[controlNetId, processorChanged]
);
const handleDetectResolutionReset = useCallback(() => {
processorChanged(controlNetId, {
detect_resolution: DEFAULTS.detect_resolution,
});
}, [controlNetId, processorChanged]);
const handleImageResolutionReset = useCallback(() => {
processorChanged(controlNetId, {
image_resolution: DEFAULTS.image_resolution,
});
}, [controlNetId, processorChanged]);
return (
<ProcessorWrapper>
<IAISlider
label="Detect Resolution"
value={detect_resolution}
onChange={handleDetectResolutionChanged}
handleReset={handleDetectResolutionReset}
withReset
min={0}
max={4096}
withInput
/>
<IAISlider
label="Image Resolution"
value={image_resolution}
onChange={handleImageResolutionChanged}
handleReset={handleImageResolutionReset}
withReset
min={0}
max={4096}
withInput
/>
</ProcessorWrapper>
);
};
export default memo(LineartAnimeProcessor);

View File

@ -0,0 +1,85 @@
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default;
type LineartProcessorProps = {
controlNetId: string;
processorNode: RequiredLineartImageProcessorInvocation;
};
const LineartProcessor = (props: LineartProcessorProps) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, coarse } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { detect_resolution: v });
},
[controlNetId, processorChanged]
);
const handleImageResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { image_resolution: v });
},
[controlNetId, processorChanged]
);
const handleDetectResolutionReset = useCallback(() => {
processorChanged(controlNetId, {
detect_resolution: DEFAULTS.detect_resolution,
});
}, [controlNetId, processorChanged]);
const handleImageResolutionReset = useCallback(() => {
processorChanged(controlNetId, {
image_resolution: DEFAULTS.image_resolution,
});
}, [controlNetId, processorChanged]);
const handleCoarseChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
processorChanged(controlNetId, { coarse: e.target.checked });
},
[controlNetId, processorChanged]
);
return (
<ProcessorWrapper>
<IAISlider
label="Detect Resolution"
value={detect_resolution}
onChange={handleDetectResolutionChanged}
handleReset={handleDetectResolutionReset}
withReset
min={0}
max={4096}
withInput
/>
<IAISlider
label="Image Resolution"
value={image_resolution}
onChange={handleImageResolutionChanged}
handleReset={handleImageResolutionReset}
withReset
min={0}
max={4096}
withInput
/>
<IAISwitch
label="Coarse"
isChecked={coarse}
onChange={handleCoarseChanged}
/>
</ProcessorWrapper>
);
};
export default memo(LineartProcessor);

View File

@ -0,0 +1,69 @@
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/store/types';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default;
type Props = {
controlNetId: string;
processorNode: RequiredMediapipeFaceProcessorInvocation;
};
const MediapipeFaceProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { max_faces, min_confidence } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleMaxFacesChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { max_faces: v });
},
[controlNetId, processorChanged]
);
const handleMinConfidenceChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { min_confidence: v });
},
[controlNetId, processorChanged]
);
const handleMaxFacesReset = useCallback(() => {
processorChanged(controlNetId, { max_faces: DEFAULTS.max_faces });
}, [controlNetId, processorChanged]);
const handleMinConfidenceReset = useCallback(() => {
processorChanged(controlNetId, { min_confidence: DEFAULTS.min_confidence });
}, [controlNetId, processorChanged]);
return (
<ProcessorWrapper>
<IAISlider
label="Max Faces"
value={max_faces}
onChange={handleMaxFacesChanged}
handleReset={handleMaxFacesReset}
withReset
min={1}
max={20}
withInput
/>
<IAISlider
label="Min Confidence"
value={min_confidence}
onChange={handleMinConfidenceChanged}
handleReset={handleMinConfidenceReset}
withReset
min={0}
max={1}
step={0.01}
withInput
/>
</ProcessorWrapper>
);
};
export default memo(MediapipeFaceProcessor);

View File

@ -0,0 +1,70 @@
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default;
type Props = {
controlNetId: string;
processorNode: RequiredMidasDepthImageProcessorInvocation;
};
const MidasDepthProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { a_mult, bg_th } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleAMultChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { a_mult: v });
},
[controlNetId, processorChanged]
);
const handleBgThChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { bg_th: v });
},
[controlNetId, processorChanged]
);
const handleAMultReset = useCallback(() => {
processorChanged(controlNetId, { a_mult: DEFAULTS.a_mult });
}, [controlNetId, processorChanged]);
const handleBgThReset = useCallback(() => {
processorChanged(controlNetId, { bg_th: DEFAULTS.bg_th });
}, [controlNetId, processorChanged]);
return (
<ProcessorWrapper>
<IAISlider
label="a_mult"
value={a_mult}
onChange={handleAMultChanged}
handleReset={handleAMultReset}
withReset
min={0}
max={20}
step={0.01}
withInput
/>
<IAISlider
label="bg_th"
value={bg_th}
onChange={handleBgThChanged}
handleReset={handleBgThReset}
withReset
min={0}
max={20}
step={0.01}
withInput
/>
</ProcessorWrapper>
);
};
export default memo(MidasDepthProcessor);

View File

@ -0,0 +1,116 @@
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default;
type Props = {
controlNetId: string;
processorNode: RequiredMlsdImageProcessorInvocation;
};
const MlsdImageProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { detect_resolution: v });
},
[controlNetId, processorChanged]
);
const handleImageResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { image_resolution: v });
},
[controlNetId, processorChanged]
);
const handleThrDChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { thr_d: v });
},
[controlNetId, processorChanged]
);
const handleThrVChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { thr_v: v });
},
[controlNetId, processorChanged]
);
const handleDetectResolutionReset = useCallback(() => {
processorChanged(controlNetId, {
detect_resolution: DEFAULTS.detect_resolution,
});
}, [controlNetId, processorChanged]);
const handleImageResolutionReset = useCallback(() => {
processorChanged(controlNetId, {
image_resolution: DEFAULTS.image_resolution,
});
}, [controlNetId, processorChanged]);
const handleThrDReset = useCallback(() => {
processorChanged(controlNetId, { thr_d: DEFAULTS.thr_d });
}, [controlNetId, processorChanged]);
const handleThrVReset = useCallback(() => {
processorChanged(controlNetId, { thr_v: DEFAULTS.thr_v });
}, [controlNetId, processorChanged]);
return (
<ProcessorWrapper>
<IAISlider
label="Detect Resolution"
value={detect_resolution}
onChange={handleDetectResolutionChanged}
handleReset={handleDetectResolutionReset}
withReset
min={0}
max={4096}
withInput
/>
<IAISlider
label="Image Resolution"
value={image_resolution}
onChange={handleImageResolutionChanged}
handleReset={handleImageResolutionReset}
withReset
min={0}
max={4096}
withInput
/>
<IAISlider
label="W"
value={thr_d}
onChange={handleThrDChanged}
handleReset={handleThrDReset}
withReset
min={0}
max={1}
step={0.01}
withInput
/>
<IAISlider
label="H"
value={thr_v}
onChange={handleThrVChanged}
handleReset={handleThrVReset}
withReset
min={0}
max={1}
step={0.01}
withInput
/>
</ProcessorWrapper>
);
};
export default memo(MlsdImageProcessor);

View File

@ -0,0 +1,72 @@
import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types';
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default;
type Props = {
controlNetId: string;
processorNode: RequiredNormalbaeImageProcessorInvocation;
};
const NormalBaeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { detect_resolution: v });
},
[controlNetId, processorChanged]
);
const handleImageResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { image_resolution: v });
},
[controlNetId, processorChanged]
);
const handleDetectResolutionReset = useCallback(() => {
processorChanged(controlNetId, {
detect_resolution: DEFAULTS.detect_resolution,
});
}, [controlNetId, processorChanged]);
const handleImageResolutionReset = useCallback(() => {
processorChanged(controlNetId, {
image_resolution: DEFAULTS.image_resolution,
});
}, [controlNetId, processorChanged]);
return (
<ProcessorWrapper>
<IAISlider
label="Detect Resolution"
value={detect_resolution}
onChange={handleDetectResolutionChanged}
handleReset={handleDetectResolutionReset}
withReset
min={0}
max={4096}
withInput
/>
<IAISlider
label="Image Resolution"
value={image_resolution}
onChange={handleImageResolutionChanged}
handleReset={handleImageResolutionReset}
withReset
min={0}
max={4096}
withInput
/>
</ProcessorWrapper>
);
};
export default memo(NormalBaeProcessor);

View File

@ -0,0 +1,85 @@
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default;
type Props = {
controlNetId: string;
processorNode: RequiredOpenposeImageProcessorInvocation;
};
const OpenposeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, hand_and_face } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { detect_resolution: v });
},
[controlNetId, processorChanged]
);
const handleImageResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { image_resolution: v });
},
[controlNetId, processorChanged]
);
const handleDetectResolutionReset = useCallback(() => {
processorChanged(controlNetId, {
detect_resolution: DEFAULTS.detect_resolution,
});
}, [controlNetId, processorChanged]);
const handleImageResolutionReset = useCallback(() => {
processorChanged(controlNetId, {
image_resolution: DEFAULTS.image_resolution,
});
}, [controlNetId, processorChanged]);
const handleHandAndFaceChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
processorChanged(controlNetId, { hand_and_face: e.target.checked });
},
[controlNetId, processorChanged]
);
return (
<ProcessorWrapper>
<IAISlider
label="Detect Resolution"
value={detect_resolution}
onChange={handleDetectResolutionChanged}
handleReset={handleDetectResolutionReset}
withReset
min={0}
max={4096}
withInput
/>
<IAISlider
label="Image Resolution"
value={image_resolution}
onChange={handleImageResolutionChanged}
handleReset={handleImageResolutionReset}
withReset
min={0}
max={4096}
withInput
/>
<IAISwitch
label="Hand and Face"
isChecked={hand_and_face}
onChange={handleHandAndFaceChanged}
/>
</ProcessorWrapper>
);
};
export default memo(OpenposeProcessor);

View File

@ -0,0 +1,93 @@
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types';
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default;
type Props = {
controlNetId: string;
processorNode: RequiredPidiImageProcessorInvocation;
};
const PidiProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, scribble, safe } = processorNode;
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { detect_resolution: v });
},
[controlNetId, processorChanged]
);
const handleImageResolutionChanged = useCallback(
(v: number) => {
processorChanged(controlNetId, { image_resolution: v });
},
[controlNetId, processorChanged]
);
const handleDetectResolutionReset = useCallback(() => {
processorChanged(controlNetId, {
detect_resolution: DEFAULTS.detect_resolution,
});
}, [controlNetId, processorChanged]);
const handleImageResolutionReset = useCallback(() => {
processorChanged(controlNetId, {
image_resolution: DEFAULTS.image_resolution,
});
}, [controlNetId, processorChanged]);
const handleScribbleChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
processorChanged(controlNetId, { scribble: e.target.checked });
},
[controlNetId, processorChanged]
);
const handleSafeChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
processorChanged(controlNetId, { safe: e.target.checked });
},
[controlNetId, processorChanged]
);
return (
<ProcessorWrapper>
<IAISlider
label="Detect Resolution"
value={detect_resolution}
onChange={handleDetectResolutionChanged}
handleReset={handleDetectResolutionReset}
withReset
min={0}
max={4096}
withInput
/>
<IAISlider
label="Image Resolution"
value={image_resolution}
onChange={handleImageResolutionChanged}
handleReset={handleImageResolutionReset}
withReset
min={0}
max={4096}
withInput
/>
<IAISwitch
label="Scribble"
isChecked={scribble}
onChange={handleScribbleChanged}
/>
<IAISwitch label="Safe" isChecked={safe} onChange={handleSafeChanged} />
</ProcessorWrapper>
);
};
export default memo(PidiProcessor);

View File

@ -0,0 +1,14 @@
import { RequiredZoeDepthImageProcessorInvocation } from 'features/controlNet/store/types';
import { memo } from 'react';
type Props = {
controlNetId: string;
processorNode: RequiredZoeDepthImageProcessorInvocation;
};
const ZoeDepthProcessor = (props: Props) => {
// Has no parameters?
return null;
};
export default memo(ZoeDepthProcessor);

Some files were not shown because too many files have changed in this diff Show More