mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
merge with main
This commit is contained in:
commit
1f9e1eb964
171
docs/features/LOGGING.md
Normal file
171
docs/features/LOGGING.md
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
---
|
||||||
|
title: Controlling Logging
|
||||||
|
---
|
||||||
|
|
||||||
|
# :material-image-off: Controlling Logging
|
||||||
|
|
||||||
|
## Controlling How InvokeAI Logs Status Messages
|
||||||
|
|
||||||
|
InvokeAI logs status messages using a configurable logging system. You
|
||||||
|
can log to the terminal window, to a designated file on the local
|
||||||
|
machine, to the syslog facility on a Linux or Mac, or to a properly
|
||||||
|
configured web server. You can configure several logs at the same
|
||||||
|
time, and control the level of message logged and the logging format
|
||||||
|
(to a limited extent).
|
||||||
|
|
||||||
|
Three command-line options control logging:
|
||||||
|
|
||||||
|
### `--log_handlers <handler1> <handler2> ...`
|
||||||
|
|
||||||
|
This option activates one or more log handlers. Options are "console",
|
||||||
|
"file", "syslog" and "http". To specify more than one, separate them
|
||||||
|
by spaces:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
invokeai-web --log_handlers console syslog=/dev/log file=C:\Users\fred\invokeai.log
|
||||||
|
```
|
||||||
|
|
||||||
|
The format of these options is described below.
|
||||||
|
|
||||||
|
### `--log_format {plain|color|legacy|syslog}`
|
||||||
|
|
||||||
|
This controls the format of log messages written to the console. Only
|
||||||
|
the "console" log handler is currently affected by this setting.
|
||||||
|
|
||||||
|
* "plain" provides formatted messages like this:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
|
||||||
|
[2023-05-24 23:18:2[2023-05-24 23:18:50,352]::[InvokeAI]::DEBUG --> this is a debug message
|
||||||
|
[2023-05-24 23:18:50,352]::[InvokeAI]::INFO --> this is an informational messages
|
||||||
|
[2023-05-24 23:18:50,352]::[InvokeAI]::WARNING --> this is a warning
|
||||||
|
[2023-05-24 23:18:50,352]::[InvokeAI]::ERROR --> this is an error
|
||||||
|
[2023-05-24 23:18:50,352]::[InvokeAI]::CRITICAL --> this is a critical error
|
||||||
|
```
|
||||||
|
|
||||||
|
* "color" produces similar output, but the text will be color coded to
|
||||||
|
indicate the severity of the message.
|
||||||
|
|
||||||
|
* "legacy" produces output similar to InvokeAI versions 2.3 and earlier:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
### this is a critical error
|
||||||
|
*** this is an error
|
||||||
|
** this is a warning
|
||||||
|
>> this is an informational messages
|
||||||
|
| this is a debug message
|
||||||
|
```
|
||||||
|
|
||||||
|
* "syslog" produces messages suitable for syslog entries:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
InvokeAI [2691178] <CRITICAL> this is a critical error
|
||||||
|
InvokeAI [2691178] <ERROR> this is an error
|
||||||
|
InvokeAI [2691178] <WARNING> this is a warning
|
||||||
|
InvokeAI [2691178] <INFO> this is an informational messages
|
||||||
|
InvokeAI [2691178] <DEBUG> this is a debug message
|
||||||
|
```
|
||||||
|
|
||||||
|
(note that the date, time and hostname will be added by the syslog
|
||||||
|
system)
|
||||||
|
|
||||||
|
### `--log_level {debug|info|warning|error|critical}`
|
||||||
|
|
||||||
|
Providing this command-line option will cause only messages at the
|
||||||
|
specified level or above to be emitted.
|
||||||
|
|
||||||
|
## Console logging
|
||||||
|
|
||||||
|
When "console" is provided to `--log_handlers`, messages will be
|
||||||
|
written to the command line window in which InvokeAI was launched. By
|
||||||
|
default, the color formatter will be used unless overridden by
|
||||||
|
`--log_format`.
|
||||||
|
|
||||||
|
## File logging
|
||||||
|
|
||||||
|
When "file" is provided to `--log_handlers`, entries will be written
|
||||||
|
to the file indicated in the path argument. By default, the "plain"
|
||||||
|
format will be used:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
invokeai-web --log_handlers file=/var/log/invokeai.log
|
||||||
|
```
|
||||||
|
|
||||||
|
## Syslog logging
|
||||||
|
|
||||||
|
When "syslog" is requested, entries will be sent to the syslog
|
||||||
|
system. There are a variety of ways to control where the log message
|
||||||
|
is sent:
|
||||||
|
|
||||||
|
* Send to the local machine using the `/dev/log` socket:
|
||||||
|
|
||||||
|
```
|
||||||
|
invokeai-web --log_handlers syslog=/dev/log
|
||||||
|
```
|
||||||
|
|
||||||
|
* Send to the local machine using a UDP message:
|
||||||
|
|
||||||
|
```
|
||||||
|
invokeai-web --log_handlers syslog=localhost
|
||||||
|
```
|
||||||
|
|
||||||
|
* Send to the local machine using a UDP message on a nonstandard
|
||||||
|
port:
|
||||||
|
|
||||||
|
```
|
||||||
|
invokeai-web --log_handlers syslog=localhost:512
|
||||||
|
```
|
||||||
|
|
||||||
|
* Send to a remote machine named "loghost" on the local LAN using
|
||||||
|
facility LOG_USER and UDP packets:
|
||||||
|
|
||||||
|
```
|
||||||
|
invokeai-web --log_handlers syslog=loghost,facility=LOG_USER,socktype=SOCK_DGRAM
|
||||||
|
```
|
||||||
|
|
||||||
|
This can be abbreviated `syslog=loghost`, as LOG_USER and SOCK_DGRAM
|
||||||
|
are defaults.
|
||||||
|
|
||||||
|
* Send to a remote machine named "loghost" using the facility LOCAL0
|
||||||
|
and using a TCP socket:
|
||||||
|
|
||||||
|
```
|
||||||
|
invokeai-web --log_handlers syslog=loghost,facility=LOG_LOCAL0,socktype=SOCK_STREAM
|
||||||
|
```
|
||||||
|
|
||||||
|
If no arguments are specified (just a bare "syslog"), then the logging
|
||||||
|
system will look for a UNIX socket named `/dev/log`, and if not found
|
||||||
|
try to send a UDP message to `localhost`. The Macintosh OS used to
|
||||||
|
support logging to a socket named `/var/run/syslog`, but this feature
|
||||||
|
has since been disabled.
|
||||||
|
|
||||||
|
## Web logging
|
||||||
|
|
||||||
|
If you have access to a web server that is configured to log messages
|
||||||
|
when a particular URL is requested, you can log using the "http"
|
||||||
|
method:
|
||||||
|
|
||||||
|
```
|
||||||
|
invokeai-web --log_handlers http=http://my.server/path/to/logger,method=POST
|
||||||
|
```
|
||||||
|
|
||||||
|
The optional [,method=] part can be used to specify whether the URL
|
||||||
|
accepts GET (default) or POST messages.
|
||||||
|
|
||||||
|
Currently password authentication and SSL are not supported.
|
||||||
|
|
||||||
|
## Using the configuration file
|
||||||
|
|
||||||
|
You can set and forget logging options by adding a "Logging" section
|
||||||
|
to `invokeai.yaml`:
|
||||||
|
|
||||||
|
```
|
||||||
|
InvokeAI:
|
||||||
|
[... other settings...]
|
||||||
|
Logging:
|
||||||
|
log_handlers:
|
||||||
|
- console
|
||||||
|
- syslog=/dev/log
|
||||||
|
log_level: info
|
||||||
|
log_format: color
|
||||||
|
```
|
@ -57,6 +57,9 @@ Personalize models by adding your own style or subjects.
|
|||||||
## * [The NSFW Checker](NSFW.md)
|
## * [The NSFW Checker](NSFW.md)
|
||||||
Prevent InvokeAI from displaying unwanted racy images.
|
Prevent InvokeAI from displaying unwanted racy images.
|
||||||
|
|
||||||
|
## * [Controlling Logging](LOGGING.md)
|
||||||
|
Control how InvokeAI logs status messages.
|
||||||
|
|
||||||
## * [Miscellaneous](OTHER.md)
|
## * [Miscellaneous](OTHER.md)
|
||||||
Run InvokeAI on Google Colab, generate images with repeating patterns,
|
Run InvokeAI on Google Colab, generate images with repeating patterns,
|
||||||
batch process a file of prompts, increase the "creativity" of image
|
batch process a file of prompts, increase the "creativity" of image
|
||||||
|
@ -39,7 +39,8 @@ socket_io = SocketIO(app)
|
|||||||
|
|
||||||
# initialize config
|
# initialize config
|
||||||
# this is a module global
|
# this is a module global
|
||||||
app_config = InvokeAIAppConfig()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
app_config.parse_args()
|
||||||
|
|
||||||
# Add startup event to load dependencies
|
# Add startup event to load dependencies
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
|
@ -39,7 +39,7 @@ from .services.model_manager_service import ModelManagerService
|
|||||||
from .services.processor import DefaultInvocationProcessor
|
from .services.processor import DefaultInvocationProcessor
|
||||||
from .services.restoration_services import RestorationServices
|
from .services.restoration_services import RestorationServices
|
||||||
from .services.sqlite import SqliteItemStorage
|
from .services.sqlite import SqliteItemStorage
|
||||||
|
from .services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
class CliCommand(BaseModel):
|
class CliCommand(BaseModel):
|
||||||
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||||
@ -198,7 +198,8 @@ logger = logger.InvokeAILogger.getLogger()
|
|||||||
|
|
||||||
def invoke_cli():
|
def invoke_cli():
|
||||||
# this gets the basic configuration
|
# this gets the basic configuration
|
||||||
config = get_invokeai_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
config.parse_args()
|
||||||
|
|
||||||
# get the optional list of invocations to execute on the command line
|
# get the optional list of invocations to execute on the command line
|
||||||
parser = config.get_parser()
|
parser = config.get_parser()
|
||||||
|
@ -4,12 +4,10 @@ from contextlib import ExitStack
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
|
|
||||||
from .model import ClipField
|
from .model import ClipField
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import torch_dtype
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||||
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
|
|
||||||
from ...backend.model_management import SDModelType
|
from ...backend.model_management import SDModelType
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
|
|
||||||
@ -18,7 +16,7 @@ from compel.prompt_parser import (
|
|||||||
Blend,
|
Blend,
|
||||||
CrossAttentionControlSubstitute,
|
CrossAttentionControlSubstitute,
|
||||||
FlattenedPrompt,
|
FlattenedPrompt,
|
||||||
Fragment,
|
Fragment, Conjunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -81,7 +79,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
context.services.model_manager.get_model(model_name=name, model_type=SDModelType.TextualInversion)
|
context.services.model_manager.get_model(model_name=name, model_type=SDModelType.TextualInversion)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
#print(e)
|
#print(e)
|
||||||
#import traceback
|
#import traceback
|
||||||
#print(traceback.format_exc())
|
#print(traceback.format_exc())
|
||||||
@ -98,7 +96,6 @@ class CompelInvocation(BaseInvocation):
|
|||||||
truncate_long_prompts=True, # TODO:
|
truncate_long_prompts=True, # TODO:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||||
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
|
||||||
|
|
||||||
@ -106,16 +103,15 @@ class CompelInvocation(BaseInvocation):
|
|||||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||||
|
|
||||||
# TODO: long prompt support
|
# TODO: long prompt support
|
||||||
#if not self.truncate_long_prompts:
|
#if not self.truncate_long_prompts:
|
||||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||||
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
|
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||||
|
|
||||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||||
@ -129,14 +125,22 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
def get_max_token_count(
|
def get_max_token_count(
|
||||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
|
||||||
) -> int:
|
) -> int:
|
||||||
if type(prompt) is Blend:
|
if type(prompt) is Blend:
|
||||||
blend: Blend = prompt
|
blend: Blend = prompt
|
||||||
return max(
|
return max(
|
||||||
[
|
[
|
||||||
get_max_token_count(tokenizer, c, truncate_if_too_long)
|
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||||
for c in blend.prompts
|
for p in blend.prompts
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif type(prompt) is Conjunction:
|
||||||
|
conjunction: Conjunction = prompt
|
||||||
|
return sum(
|
||||||
|
[
|
||||||
|
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||||
|
for p in conjunction.prompts
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -171,6 +175,22 @@ def get_tokens_for_prompt_object(
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def log_tokenization_for_conjunction(
|
||||||
|
c: Conjunction, tokenizer, display_label_prefix=None
|
||||||
|
):
|
||||||
|
display_label_prefix = display_label_prefix or ""
|
||||||
|
for i, p in enumerate(c.prompts):
|
||||||
|
if len(c.prompts)>1:
|
||||||
|
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
||||||
|
else:
|
||||||
|
this_display_label_prefix = display_label_prefix
|
||||||
|
log_tokenization_for_prompt_object(
|
||||||
|
p,
|
||||||
|
tokenizer,
|
||||||
|
display_label_prefix=this_display_label_prefix
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization_for_prompt_object(
|
def log_tokenization_for_prompt_object(
|
||||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||||
):
|
):
|
||||||
|
@ -94,13 +94,13 @@ CONTROLNET_DEFAULT_MODELS = [
|
|||||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||||
|
|
||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
image: ImageField = Field(default=None, description="processed image")
|
image: ImageField = Field(default=None, description="The control image")
|
||||||
control_model: Optional[str] = Field(default=None, description="control model used")
|
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
|
||||||
control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
control_weight: Optional[float] = Field(default=1, description="The weight given to the ControlNet")
|
||||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||||
description="% of total steps at which controlnet is first applied")
|
description="When the ControlNet is first applied (% of total steps)")
|
||||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||||
description="% of total steps at which controlnet is last applied")
|
description="When the ControlNet is last applied (% of total steps)")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
@ -112,7 +112,7 @@ class ControlOutput(BaseInvocationOutput):
|
|||||||
"""node output for ControlNet info"""
|
"""node output for ControlNet info"""
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["control_output"] = "control_output"
|
type: Literal["control_output"] = "control_output"
|
||||||
control: ControlField = Field(default=None, description="The control info dict")
|
control: ControlField = Field(default=None, description="The output control image")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
@ -121,15 +121,15 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["controlnet"] = "controlnet"
|
type: Literal["controlnet"] = "controlnet"
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="image to process")
|
image: ImageField = Field(default=None, description="The control image")
|
||||||
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
||||||
description="control model used")
|
description="The ControlNet model to use")
|
||||||
control_weight: float = Field(default=1.0, ge=0, le=1, description="weight given to controlnet")
|
control_weight: float = Field(default=1.0, ge=0, le=1, description="The weight given to the ControlNet")
|
||||||
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
|
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
|
||||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||||
description="% of total steps at which controlnet is first applied")
|
description="When the ControlNet is first applied (% of total steps)")
|
||||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||||
description="% of total steps at which controlnet is last applied")
|
description="When the ControlNet is last applied (% of total steps)")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
@ -152,7 +152,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["image_processor"] = "image_processor"
|
type: Literal["image_processor"] = "image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="image to process")
|
image: ImageField = Field(default=None, description="The image to process")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
@ -204,8 +204,8 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["canny_image_processor"] = "canny_image_processor"
|
type: Literal["canny_image_processor"] = "canny_image_processor"
|
||||||
# Input
|
# Input
|
||||||
low_threshold: float = Field(default=100, ge=0, description="low threshold of Canny pixel gradient")
|
low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)")
|
||||||
high_threshold: float = Field(default=200, ge=0, description="high threshold of Canny pixel gradient")
|
high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
@ -214,16 +214,16 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
|||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class HedImageprocessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
"""Applies HED edge detection to image"""
|
"""Applies HED edge detection to image"""
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["hed_image_processor"] = "hed_image_processor"
|
type: Literal["hed_image_processor"] = "hed_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||||
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||||
# safe not supported in controlnet_aux v0.0.3
|
# safe not supported in controlnet_aux v0.0.3
|
||||||
# safe: bool = Field(default=False, description="whether to use safe mode")
|
# safe: bool = Field(default=False, description="whether to use safe mode")
|
||||||
scribble: bool = Field(default=False, description="whether to use scribble mode")
|
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
@ -243,9 +243,9 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCon
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||||
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||||
coarse: bool = Field(default=False, description="whether to use coarse mode")
|
coarse: bool = Field(default=False, description="Whether to use coarse mode")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
@ -262,8 +262,8 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||||
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
@ -280,9 +280,9 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
hand_and_face: bool = Field(default=False, description="whether to use hands and face mode")
|
hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode")
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||||
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
@ -300,8 +300,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter a = amult * PI")
|
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
|
||||||
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter bg_th")
|
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`")
|
||||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||||
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
|
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
@ -322,8 +322,8 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationC
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||||
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
@ -339,10 +339,10 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||||
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||||
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter thr_v")
|
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||||
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter thr_d")
|
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
@ -360,10 +360,10 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||||
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||||
safe: bool = Field(default=False, description="whether to use safe mode")
|
safe: bool = Field(default=False, description="Whether to use safe mode")
|
||||||
scribble: bool = Field(default=False, description="whether to use scribble mode")
|
scribble: bool = Field(default=False, description="Whether to use scribble mode")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
@ -381,11 +381,11 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||||
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||||
h: Union[int | None] = Field(default=512, ge=0, description="content shuffle h parameter")
|
h: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
|
||||||
w: Union[int | None] = Field(default=512, ge=0, description="content shuffle w parameter")
|
w: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||||
f: Union[int | None] = Field(default=256, ge=0, description="cont")
|
f: Union[int, None] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
@ -418,8 +418,8 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
max_faces: int = Field(default=1, ge=1, description="maximum number of faces to detect")
|
max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect")
|
||||||
min_confidence: float = Field(default=0.5, ge=0, le=1, description="minimum confidence for face detection")
|
min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
|
@ -4,11 +4,12 @@ from contextlib import ExitStack
|
|||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
import torch
|
import torch
|
||||||
from diffusers import ControlNetModel
|
from diffusers import ControlNetModel
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
from pydantic import BaseModel, Field, validator
|
|
||||||
|
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
|
@ -51,18 +51,32 @@ in INVOKEAI_ROOT. You can replace supersede this by providing any
|
|||||||
OmegaConf dictionary object initialization time:
|
OmegaConf dictionary object initialization time:
|
||||||
|
|
||||||
omegaconf = OmegaConf.load('/tmp/init.yaml')
|
omegaconf = OmegaConf.load('/tmp/init.yaml')
|
||||||
conf = InvokeAIAppConfig(conf=omegaconf)
|
conf = InvokeAIAppConfig()
|
||||||
|
conf.parse_args(conf=omegaconf)
|
||||||
|
|
||||||
By default, InvokeAIAppConfig will parse the contents of `sys.argv` at
|
InvokeAIAppConfig.parse_args() will parse the contents of `sys.argv`
|
||||||
initialization time. You may pass a list of strings in the optional
|
at initialization time. You may pass a list of strings in the optional
|
||||||
`argv` argument to use instead of the system argv:
|
`argv` argument to use instead of the system argv:
|
||||||
|
|
||||||
conf = InvokeAIAppConfig(arg=['--xformers_enabled'])
|
conf.parse_args(argv=['--xformers_enabled'])
|
||||||
|
|
||||||
It is also possible to set a value at initialization time. This value
|
It is also possible to set a value at initialization time. However, if
|
||||||
has highest priority.
|
you call parse_args() it may be overwritten.
|
||||||
|
|
||||||
conf = InvokeAIAppConfig(xformers_enabled=True)
|
conf = InvokeAIAppConfig(xformers_enabled=True)
|
||||||
|
conf.parse_args(argv=['--no-xformers'])
|
||||||
|
conf.xformers_enabled
|
||||||
|
# False
|
||||||
|
|
||||||
|
|
||||||
|
To avoid this, use `get_config()` to retrieve the application-wide
|
||||||
|
configuration object. This will retain any properties set at object
|
||||||
|
creation time:
|
||||||
|
|
||||||
|
conf = InvokeAIAppConfig.get_config(xformers_enabled=True)
|
||||||
|
conf.parse_args(argv=['--no-xformers'])
|
||||||
|
conf.xformers_enabled
|
||||||
|
# True
|
||||||
|
|
||||||
Any setting can be overwritten by setting an environment variable of
|
Any setting can be overwritten by setting an environment variable of
|
||||||
form: "INVOKEAI_<setting>", as in:
|
form: "INVOKEAI_<setting>", as in:
|
||||||
@ -76,18 +90,23 @@ Order of precedence (from highest):
|
|||||||
4) config file options
|
4) config file options
|
||||||
5) pydantic defaults
|
5) pydantic defaults
|
||||||
|
|
||||||
Typical usage:
|
Typical usage at the top level file:
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.invocations.generate import TextToImageInvocation
|
|
||||||
|
|
||||||
# get global configuration and print its nsfw_checker value
|
# get global configuration and print its nsfw_checker value
|
||||||
conf = InvokeAIAppConfig()
|
conf = InvokeAIAppConfig.get_config()
|
||||||
|
conf.parse_args()
|
||||||
|
print(conf.nsfw_checker)
|
||||||
|
|
||||||
|
Typical usage in a backend module:
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
# get global configuration and print its nsfw_checker value
|
||||||
|
conf = InvokeAIAppConfig.get_config()
|
||||||
print(conf.nsfw_checker)
|
print(conf.nsfw_checker)
|
||||||
|
|
||||||
# get the text2image invocation and print its step value
|
|
||||||
text2image = TextToImageInvocation()
|
|
||||||
print(text2image.steps)
|
|
||||||
|
|
||||||
Computed properties:
|
Computed properties:
|
||||||
|
|
||||||
@ -103,10 +122,11 @@ a Path object:
|
|||||||
lora_path - path to the LoRA directory
|
lora_path - path to the LoRA directory
|
||||||
|
|
||||||
In most cases, you will want to create a single InvokeAIAppConfig
|
In most cases, you will want to create a single InvokeAIAppConfig
|
||||||
object for the entire application. The get_invokeai_config() function
|
object for the entire application. The InvokeAIAppConfig.get_config() function
|
||||||
does this:
|
does this:
|
||||||
|
|
||||||
config = get_invokeai_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
config.parse_args() # read values from the command line/config file
|
||||||
print(config.root)
|
print(config.root)
|
||||||
|
|
||||||
# Subclassing
|
# Subclassing
|
||||||
@ -140,24 +160,22 @@ two configs are kept in separate sections of the config file:
|
|||||||
legacy_conf_dir: configs/stable-diffusion
|
legacy_conf_dir: configs/stable-diffusion
|
||||||
outdir: outputs
|
outdir: outputs
|
||||||
...
|
...
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
from __future__ import annotations
|
||||||
import argparse
|
import argparse
|
||||||
import pydoc
|
import pydoc
|
||||||
import typing
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from omegaconf import OmegaConf, DictConfig
|
from omegaconf import OmegaConf, DictConfig
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydantic import BaseSettings, Field, parse_obj_as
|
from pydantic import BaseSettings, Field, parse_obj_as
|
||||||
from typing import Any, ClassVar, Dict, List, Literal, Type, Union, get_origin, get_type_hints, get_args
|
from typing import ClassVar, Dict, List, Literal, Type, Union, get_origin, get_type_hints, get_args
|
||||||
|
|
||||||
INIT_FILE = Path('invokeai.yaml')
|
INIT_FILE = Path('invokeai.yaml')
|
||||||
LEGACY_INIT_FILE = Path('invokeai.init')
|
LEGACY_INIT_FILE = Path('invokeai.init')
|
||||||
|
|
||||||
# This global stores a singleton InvokeAIAppConfig configuration object
|
|
||||||
global_config = None
|
|
||||||
|
|
||||||
class InvokeAISettings(BaseSettings):
|
class InvokeAISettings(BaseSettings):
|
||||||
'''
|
'''
|
||||||
Runtime configuration settings in which default values are
|
Runtime configuration settings in which default values are
|
||||||
@ -168,7 +186,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
|
|
||||||
def parse_args(self, argv: list=sys.argv[1:]):
|
def parse_args(self, argv: list=sys.argv[1:]):
|
||||||
parser = self.get_parser()
|
parser = self.get_parser()
|
||||||
opt, _ = parser.parse_known_args(argv)
|
opt = parser.parse_args(argv)
|
||||||
for name in self.__fields__:
|
for name in self.__fields__:
|
||||||
if name not in self._excluded():
|
if name not in self._excluded():
|
||||||
setattr(self, name, getattr(opt,name))
|
setattr(self, name, getattr(opt,name))
|
||||||
@ -330,6 +348,9 @@ the command-line client (recommended for experts only), or
|
|||||||
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
||||||
setting environment variables INVOKEAI_<setting>.
|
setting environment variables INVOKEAI_<setting>.
|
||||||
'''
|
'''
|
||||||
|
singleton_config: ClassVar[InvokeAIAppConfig] = None
|
||||||
|
singleton_init: ClassVar[Dict] = None
|
||||||
|
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["InvokeAI"] = "InvokeAI"
|
type: Literal["InvokeAI"] = "InvokeAI"
|
||||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||||
@ -367,35 +388,51 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
|
|
||||||
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||||
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
|
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
|
||||||
|
|
||||||
|
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
|
||||||
|
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
|
||||||
|
log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging")
|
||||||
|
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="debug", description="Emit logging messages at this level or higher", category="Logging")
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
def __init__(self, conf: DictConfig = None, argv: List[str]=None, **kwargs):
|
def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
|
||||||
'''
|
'''
|
||||||
Initialize InvokeAIAppconfig.
|
Update settings with contents of init file, environment, and
|
||||||
|
command-line settings.
|
||||||
:param conf: alternate Omegaconf dictionary object
|
:param conf: alternate Omegaconf dictionary object
|
||||||
:param argv: aternate sys.argv list
|
:param argv: aternate sys.argv list
|
||||||
:param **kwargs: attributes to initialize with
|
:param clobber: ovewrite any initialization parameters passed during initialization
|
||||||
'''
|
'''
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
# Set the runtime root directory. We parse command-line switches here
|
# Set the runtime root directory. We parse command-line switches here
|
||||||
# in order to pick up the --root_dir option.
|
# in order to pick up the --root_dir option.
|
||||||
self.parse_args(argv)
|
super().parse_args(argv)
|
||||||
if conf is None:
|
if conf is None:
|
||||||
try:
|
try:
|
||||||
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
InvokeAISettings.initconf = conf
|
InvokeAISettings.initconf = conf
|
||||||
|
|
||||||
# parse args again in order to pick up settings in configuration file
|
# parse args again in order to pick up settings in configuration file
|
||||||
self.parse_args(argv)
|
super().parse_args(argv)
|
||||||
|
|
||||||
# restore initialization values
|
if self.singleton_init and not clobber:
|
||||||
hints = get_type_hints(self)
|
hints = get_type_hints(self.__class__)
|
||||||
for k in kwargs:
|
for k in self.singleton_init:
|
||||||
setattr(self,k,parse_obj_as(hints[k],kwargs[k]))
|
setattr(self,k,parse_obj_as(hints[k],self.singleton_init[k]))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls,**kwargs)->InvokeAIAppConfig:
|
||||||
|
'''
|
||||||
|
This returns a singleton InvokeAIAppConfig configuration object.
|
||||||
|
'''
|
||||||
|
if cls.singleton_config is None \
|
||||||
|
or type(cls.singleton_config)!=cls \
|
||||||
|
or (kwargs and cls.singleton_init != kwargs):
|
||||||
|
cls.singleton_config = cls(**kwargs)
|
||||||
|
cls.singleton_init = kwargs
|
||||||
|
return cls.singleton_config
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def root_path(self)->Path:
|
def root_path(self)->Path:
|
||||||
'''
|
'''
|
||||||
@ -520,11 +557,8 @@ class PagingArgumentParser(argparse.ArgumentParser):
|
|||||||
text = self.format_help()
|
text = self.format_help()
|
||||||
pydoc.pager(text)
|
pydoc.pager(text)
|
||||||
|
|
||||||
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAIAppConfig:
|
def get_invokeai_config(**kwargs)->InvokeAIAppConfig:
|
||||||
'''
|
'''
|
||||||
This returns a singleton InvokeAIAppConfig configuration object.
|
Legacy function which returns InvokeAIAppConfig.get_config()
|
||||||
'''
|
'''
|
||||||
global global_config
|
return InvokeAIAppConfig.get_config(**kwargs)
|
||||||
if global_config is None or type(global_config)!=cls:
|
|
||||||
global_config = cls(**kwargs)
|
|
||||||
return global_config
|
|
||||||
|
@ -26,7 +26,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._table_name = table_name
|
self._table_name = table_name
|
||||||
self._id_field = id_field # TODO: validate that T has this field
|
self._id_field = id_field # TODO: validate that T has this field
|
||||||
self._lock = Lock()
|
self._lock = Lock()
|
||||||
|
|
||||||
self._conn = sqlite3.connect(
|
self._conn = sqlite3.connect(
|
||||||
self._filename, check_same_thread=False
|
self._filename, check_same_thread=False
|
||||||
) # TODO: figure out a better threading solution
|
) # TODO: figure out a better threading solution
|
||||||
|
@ -35,15 +35,19 @@ from transformers import (
|
|||||||
CLIPTextModel,
|
CLIPTextModel,
|
||||||
CLIPTokenizer,
|
CLIPTokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
|
|
||||||
|
from invokeai.app.services.config import (
|
||||||
|
get_invokeai_config,
|
||||||
|
InvokeAIAppConfig,
|
||||||
|
)
|
||||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||||
from invokeai.frontend.install.widgets import (
|
from invokeai.frontend.install.widgets import (
|
||||||
CenteredButtonPress,
|
CenteredButtonPress,
|
||||||
IntTitleSlider,
|
IntTitleSlider,
|
||||||
set_min_terminal_size,
|
set_min_terminal_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
from invokeai.backend.config.legacy_arg_parsing import legacy_parser
|
from invokeai.backend.config.legacy_arg_parsing import legacy_parser
|
||||||
from invokeai.backend.config.model_install_backend import (
|
from invokeai.backend.config.model_install_backend import (
|
||||||
default_dataset,
|
default_dataset,
|
||||||
@ -51,10 +55,8 @@ from invokeai.backend.config.model_install_backend import (
|
|||||||
hf_download_with_resume,
|
hf_download_with_resume,
|
||||||
recommended_datasets,
|
recommended_datasets,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config import (
|
|
||||||
get_invokeai_config,
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
InvokeAIAppConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
@ -62,7 +64,8 @@ transformers.logging.set_verbosity_error()
|
|||||||
|
|
||||||
|
|
||||||
# --------------------------globals-----------------------
|
# --------------------------globals-----------------------
|
||||||
config = get_invokeai_config()
|
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
Model_dir = "models"
|
Model_dir = "models"
|
||||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||||
@ -634,7 +637,7 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
|
|||||||
|
|
||||||
|
|
||||||
def default_startup_options(init_file: Path) -> Namespace:
|
def default_startup_options(init_file: Path) -> Namespace:
|
||||||
opts = InvokeAIAppConfig(argv=[])
|
opts = InvokeAIAppConfig.get_config()
|
||||||
outdir = Path(opts.outdir)
|
outdir = Path(opts.outdir)
|
||||||
if not outdir.is_absolute():
|
if not outdir.is_absolute():
|
||||||
opts.outdir = str(config.root / opts.outdir)
|
opts.outdir = str(config.root / opts.outdir)
|
||||||
@ -699,7 +702,7 @@ def write_opts(opts: Namespace, init_file: Path):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# this will load current settings
|
# this will load current settings
|
||||||
config = InvokeAIAppConfig()
|
config = InvokeAIAppConfig.get_config()
|
||||||
for key,value in opts.__dict__.items():
|
for key,value in opts.__dict__.items():
|
||||||
if hasattr(config,key):
|
if hasattr(config,key):
|
||||||
setattr(config,key,value)
|
setattr(config,key,value)
|
||||||
@ -731,7 +734,7 @@ def write_default_options(program_opts: Namespace, initfile: Path):
|
|||||||
# yaml format.
|
# yaml format.
|
||||||
def migrate_init_file(legacy_format:Path):
|
def migrate_init_file(legacy_format:Path):
|
||||||
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
|
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
|
||||||
new = InvokeAIAppConfig(conf={})
|
new = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
||||||
for attr in fields:
|
for attr in fields:
|
||||||
@ -820,8 +823,9 @@ def main():
|
|||||||
if old_init_file.exists() and not new_init_file.exists():
|
if old_init_file.exists() and not new_init_file.exists():
|
||||||
print('** Migrating invokeai.init to invokeai.yaml')
|
print('** Migrating invokeai.init to invokeai.yaml')
|
||||||
migrate_init_file(old_init_file)
|
migrate_init_file(old_init_file)
|
||||||
config = get_invokeai_config() # reread defaults
|
|
||||||
|
|
||||||
|
# Load new init file into config
|
||||||
|
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
|
||||||
|
|
||||||
if not config.model_conf_path.exists():
|
if not config.model_conf_path.exists():
|
||||||
initialize_rootdir(config.root, opt.yes_to_all)
|
initialize_rootdir(config.root, opt.yes_to_all)
|
||||||
|
@ -19,7 +19,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
|
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from ..model_management import ModelManager
|
from ..model_management import ModelManager
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
@ -27,7 +27,8 @@ from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
|||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
# --------------------------globals-----------------------
|
# --------------------------globals-----------------------
|
||||||
config = get_invokeai_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
Model_dir = "models"
|
Model_dir = "models"
|
||||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||||
|
|
||||||
|
@ -6,7 +6,8 @@ be suppressed or deferred
|
|||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
class PatchMatch:
|
class PatchMatch:
|
||||||
"""
|
"""
|
||||||
@ -21,7 +22,6 @@ class PatchMatch:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_patch_match(self):
|
def _load_patch_match(self):
|
||||||
config = get_invokeai_config()
|
|
||||||
if self.tried_load:
|
if self.tried_load:
|
||||||
return
|
return
|
||||||
if config.try_patchmatch:
|
if config.try_patchmatch:
|
||||||
|
@ -33,10 +33,11 @@ from PIL import Image, ImageOps
|
|||||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||||
CLIPSEG_SIZE = 352
|
CLIPSEG_SIZE = 352
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
class SegmentedGrayscale(object):
|
class SegmentedGrayscale(object):
|
||||||
def __init__(self, image: Image, heatmap: torch.Tensor):
|
def __init__(self, image: Image, heatmap: torch.Tensor):
|
||||||
@ -83,7 +84,6 @@ class Txt2Mask(object):
|
|||||||
|
|
||||||
def __init__(self, device="cpu", refined=False):
|
def __init__(self, device="cpu", refined=False):
|
||||||
logger.info("Initializing clipseg model for text to mask inference")
|
logger.info("Initializing clipseg model for text to mask inference")
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
# BUG: we are not doing anything with the device option at this time
|
# BUG: we are not doing anything with the device option at this time
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -26,7 +26,7 @@ import torch
|
|||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
from .model_manager import ModelManager, SDLegacyType
|
from .model_manager import ModelManager, SDLegacyType
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
@ -857,7 +857,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
|||||||
|
|
||||||
def convert_ldm_clip_checkpoint(checkpoint):
|
def convert_ldm_clip_checkpoint(checkpoint):
|
||||||
text_model = CLIPTextModel.from_pretrained(
|
text_model = CLIPTextModel.from_pretrained(
|
||||||
"openai/clip-vit-large-patch14", cache_dir=get_invokeai_config().cache_dir
|
"openai/clip-vit-large-patch14", cache_dir=InvokeAIAppConfig.get_config().cache_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
@ -912,7 +912,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|||||||
|
|
||||||
|
|
||||||
def convert_paint_by_example_checkpoint(checkpoint):
|
def convert_paint_by_example_checkpoint(checkpoint):
|
||||||
cache_dir = get_invokeai_config().cache_dir
|
cache_dir = InvokeAIAppConfig.get_config().cache_dir
|
||||||
config = CLIPVisionConfig.from_pretrained(
|
config = CLIPVisionConfig.from_pretrained(
|
||||||
"openai/clip-vit-large-patch14", cache_dir=cache_dir
|
"openai/clip-vit-large-patch14", cache_dir=cache_dir
|
||||||
)
|
)
|
||||||
@ -984,7 +984,7 @@ def convert_paint_by_example_checkpoint(checkpoint):
|
|||||||
|
|
||||||
|
|
||||||
def convert_open_clip_checkpoint(checkpoint):
|
def convert_open_clip_checkpoint(checkpoint):
|
||||||
cache_dir = get_invokeai_config().cache_dir
|
cache_dir = InvokeAIAppConfig.get_config().cache_dir
|
||||||
text_model = CLIPTextModel.from_pretrained(
|
text_model = CLIPTextModel.from_pretrained(
|
||||||
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
|
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
|
||||||
)
|
)
|
||||||
@ -1120,7 +1120,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
:param vae: A diffusers VAE to load into the pipeline.
|
:param vae: A diffusers VAE to load into the pipeline.
|
||||||
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
|
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
|
||||||
"""
|
"""
|
||||||
config = get_invokeai_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
verbosity = dlogging.get_verbosity()
|
verbosity = dlogging.get_verbosity()
|
||||||
|
@ -149,7 +149,7 @@ from omegaconf import OmegaConf
|
|||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util import download_with_resume
|
from invokeai.backend.util import download_with_resume
|
||||||
|
|
||||||
from ..util import CUDA_DEVICE
|
from ..util import CUDA_DEVICE
|
||||||
@ -224,7 +224,7 @@ class ModelManager(object):
|
|||||||
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
|
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
|
||||||
|
|
||||||
# check config version number and update on disk/RAM if necessary
|
# check config version number and update on disk/RAM if necessary
|
||||||
self.globals = get_invokeai_config()
|
self.globals = InvokeAIAppConfig.get_config()
|
||||||
self._update_config_file_version()
|
self._update_config_file_version()
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.cache = ModelCache(
|
self.cache = ModelCache(
|
||||||
@ -1149,13 +1149,17 @@ class ModelManager(object):
|
|||||||
"""\
|
"""\
|
||||||
# This file describes the alternative machine learning models
|
# This file describes the alternative machine learning models
|
||||||
# available to InvokeAI script.
|
# available to InvokeAI script.
|
||||||
"""
|
#
|
||||||
|
# To add a new model, follow the examples below. Each
|
||||||
|
# model requires a model config file, a weights file,
|
||||||
|
# and the width and height of the images it
|
||||||
|
# was trained on.
|
||||||
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _delete_model_from_cache(cls,repo_id):
|
def _delete_model_from_cache(cls,repo_id):
|
||||||
cache_info = scan_cache_dir(get_invokeai_config().cache_dir)
|
cache_info = scan_cache_dir(InvokeAIAppConfig.get_config().cache_dir)
|
||||||
|
|
||||||
# I'm sure there is a way to do this with comprehensions
|
# I'm sure there is a way to do this with comprehensions
|
||||||
# but the code quickly became incomprehensible!
|
# but the code quickly became incomprehensible!
|
||||||
@ -1172,7 +1176,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _abs_path(path: str | Path) -> Path:
|
def _abs_path(path: str | Path) -> Path:
|
||||||
globals = get_invokeai_config()
|
globals = InvokeAIAppConfig.get_config()
|
||||||
if path is None or Path(path).is_absolute():
|
if path is None or Path(path).is_absolute():
|
||||||
return path
|
return path
|
||||||
return Path(globals.root_dir, path).resolve()
|
return Path(globals.root_dir, path).resolve()
|
||||||
|
@ -22,10 +22,12 @@ from compel.prompt_parser import (
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||||
from ..util import torch_dtype
|
from ..util import torch_dtype
|
||||||
|
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
def get_uc_and_c_and_ec(prompt_string,
|
def get_uc_and_c_and_ec(prompt_string,
|
||||||
model: InvokeAIDiffuserComponent,
|
model: InvokeAIDiffuserComponent,
|
||||||
log_tokens=False, skip_normalize_legacy_blend=False):
|
log_tokens=False, skip_normalize_legacy_blend=False):
|
||||||
@ -38,9 +40,7 @@ def get_uc_and_c_and_ec(prompt_string,
|
|||||||
textual_inversion_manager=model.textual_inversion_manager,
|
textual_inversion_manager=model.textual_inversion_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=torch_dtype,
|
||||||
truncate_long_prompts=False,
|
truncate_long_prompts=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
# get rid of any newline characters
|
# get rid of any newline characters
|
||||||
prompt_string = prompt_string.replace("\n", " ")
|
prompt_string = prompt_string.replace("\n", " ")
|
||||||
@ -283,6 +283,8 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
|||||||
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
|
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
|
||||||
for match in re.finditer(prompt_parser, text)
|
for match in re.finditer(prompt_parser, text)
|
||||||
]
|
]
|
||||||
|
if len(parsed_prompts) == 0:
|
||||||
|
return []
|
||||||
if skip_normalize:
|
if skip_normalize:
|
||||||
return parsed_prompts
|
return parsed_prompts
|
||||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||||
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
pretrained_model_url = (
|
pretrained_model_url = (
|
||||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||||
@ -18,7 +18,7 @@ class CodeFormerRestoration:
|
|||||||
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
|
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self.globals = get_invokeai_config()
|
self.globals = InvokeAIAppConfig.get_config()
|
||||||
codeformer_dir = self.globals.root_dir / codeformer_dir
|
codeformer_dir = self.globals.root_dir / codeformer_dir
|
||||||
self.model_path = codeformer_dir / codeformer_model_path
|
self.model_path = codeformer_dir / codeformer_model_path
|
||||||
self.codeformer_model_exists = self.model_path.exists()
|
self.codeformer_model_exists = self.model_path.exists()
|
||||||
|
@ -7,11 +7,11 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
class GFPGAN:
|
class GFPGAN:
|
||||||
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
||||||
self.globals = get_invokeai_config()
|
self.globals = InvokeAIAppConfig.get_config()
|
||||||
if not os.path.isabs(gfpgan_model_path):
|
if not os.path.isabs(gfpgan_model_path):
|
||||||
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
|
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
|
||||||
self.model_path = gfpgan_model_path
|
self.model_path = gfpgan_model_path
|
||||||
|
@ -6,8 +6,8 @@ from PIL import Image
|
|||||||
from PIL.Image import Image as ImageType
|
from PIL.Image import Image as ImageType
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
config = get_invokeai_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
class ESRGAN:
|
class ESRGAN:
|
||||||
def __init__(self, bg_tile_size=400) -> None:
|
def __init__(self, bg_tile_size=400) -> None:
|
||||||
|
@ -15,9 +15,11 @@ from transformers import AutoFeatureExtractor
|
|||||||
|
|
||||||
import invokeai.assets.web as web_assets
|
import invokeai.assets.web as web_assets
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from .util import CPU_DEVICE
|
from .util import CPU_DEVICE
|
||||||
|
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
class SafetyChecker(object):
|
class SafetyChecker(object):
|
||||||
CAUTION_IMG = "caution.png"
|
CAUTION_IMG = "caution.png"
|
||||||
|
|
||||||
@ -26,7 +28,6 @@ class SafetyChecker(object):
|
|||||||
caution = Image.open(path)
|
caution = Image.open(path)
|
||||||
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
||||||
self.device = device
|
self.device = device
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
|
@ -17,15 +17,16 @@ from huggingface_hub import (
|
|||||||
hf_hub_url,
|
hf_hub_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
logger = InvokeAILogger.getLogger()
|
||||||
|
|
||||||
class HuggingFaceConceptsLibrary(object):
|
class HuggingFaceConceptsLibrary(object):
|
||||||
def __init__(self, root=None):
|
def __init__(self, root=None):
|
||||||
"""
|
"""
|
||||||
Initialize the Concepts object. May optionally pass a root directory.
|
Initialize the Concepts object. May optionally pass a root directory.
|
||||||
"""
|
"""
|
||||||
self.config = get_invokeai_config()
|
self.config = InvokeAIAppConfig.get_config()
|
||||||
self.root = root or self.config.root
|
self.root = root or self.config.root
|
||||||
self.hf_api = HfApi()
|
self.hf_api = HfApi()
|
||||||
self.local_concepts = dict()
|
self.local_concepts = dict()
|
||||||
|
@ -40,7 +40,7 @@ from torchvision.transforms.functional import resize as tv_resize
|
|||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from ..util import CPU_DEVICE, normalize_device
|
from ..util import CPU_DEVICE, normalize_device
|
||||||
from .diffusion import (
|
from .diffusion import (
|
||||||
AttentionMapSaver,
|
AttentionMapSaver,
|
||||||
@ -364,7 +364,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
"""
|
"""
|
||||||
if xformers is available, use it, otherwise use sliced attention.
|
if xformers is available, use it, otherwise use sliced attention.
|
||||||
"""
|
"""
|
||||||
config = get_invokeai_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
if (
|
if (
|
||||||
torch.cuda.is_available()
|
torch.cuda.is_available()
|
||||||
and is_xformers_available()
|
and is_xformers_available()
|
||||||
|
@ -10,7 +10,7 @@ from diffusers.models.attention_processor import AttentionProcessor
|
|||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
Arguments,
|
Arguments,
|
||||||
@ -72,7 +72,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
:param model: the unet model to pass through to cross attention control
|
:param model: the unet model to pass through to cross attention control
|
||||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||||
"""
|
"""
|
||||||
config = get_invokeai_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
self.conditioning = None
|
self.conditioning = None
|
||||||
self.model = model
|
self.model = model
|
||||||
self.is_running_diffusers = is_running_diffusers
|
self.is_running_diffusers = is_running_diffusers
|
||||||
@ -112,23 +112,25 @@ class InvokeAIDiffuserComponent:
|
|||||||
# TODO resuscitate attention map saving
|
# TODO resuscitate attention map saving
|
||||||
# self.remove_attention_map_saving()
|
# self.remove_attention_map_saving()
|
||||||
|
|
||||||
def override_cross_attention(
|
# apparently unused code
|
||||||
self, conditioning: ExtraConditioningInfo, step_count: int
|
# TODO: delete
|
||||||
) -> Dict[str, AttentionProcessor]:
|
# def override_cross_attention(
|
||||||
"""
|
# self, conditioning: ExtraConditioningInfo, step_count: int
|
||||||
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
# ) -> Dict[str, AttentionProcessor]:
|
||||||
the previous attention processor is returned so that the caller can restore it later.
|
# """
|
||||||
"""
|
# setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||||
self.conditioning = conditioning
|
# the previous attention processor is returned so that the caller can restore it later.
|
||||||
self.cross_attention_control_context = Context(
|
# """
|
||||||
arguments=self.conditioning.cross_attention_control_args,
|
# self.conditioning = conditioning
|
||||||
step_count=step_count,
|
# self.cross_attention_control_context = Context(
|
||||||
)
|
# arguments=self.conditioning.cross_attention_control_args,
|
||||||
return override_cross_attention(
|
# step_count=step_count,
|
||||||
self.model,
|
# )
|
||||||
self.cross_attention_control_context,
|
# return override_cross_attention(
|
||||||
is_running_diffusers=self.is_running_diffusers,
|
# self.model,
|
||||||
)
|
# self.cross_attention_control_context,
|
||||||
|
# is_running_diffusers=self.is_running_diffusers,
|
||||||
|
# )
|
||||||
|
|
||||||
def restore_default_cross_attention(
|
def restore_default_cross_attention(
|
||||||
self, restore_attention_processor: Optional["AttentionProcessor"] = None
|
self, restore_attention_processor: Optional["AttentionProcessor"] = None
|
||||||
|
@ -88,7 +88,7 @@ def save_progress(
|
|||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
config = InvokeAIAppConfig(argv=[])
|
config = InvokeAIAppConfig.get_config()
|
||||||
parser = PagingArgumentParser(
|
parser = PagingArgumentParser(
|
||||||
description="Textual inversion training"
|
description="Textual inversion training"
|
||||||
)
|
)
|
||||||
|
@ -17,3 +17,5 @@ from .util import (
|
|||||||
instantiate_from_config,
|
instantiate_from_config,
|
||||||
url_attachment_name,
|
url_attachment_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,15 +4,15 @@ from contextlib import nullcontext
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
CPU_DEVICE = torch.device("cpu")
|
CPU_DEVICE = torch.device("cpu")
|
||||||
CUDA_DEVICE = torch.device("cuda")
|
CUDA_DEVICE = torch.device("cuda")
|
||||||
MPS_DEVICE = torch.device("mps")
|
MPS_DEVICE = torch.device("mps")
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
def choose_torch_device() -> torch.device:
|
def choose_torch_device() -> torch.device:
|
||||||
"""Convenience routine for guessing which GPU device to run model on"""
|
"""Convenience routine for guessing which GPU device to run model on"""
|
||||||
config = get_invokeai_config()
|
|
||||||
if config.always_use_cpu:
|
if config.always_use_cpu:
|
||||||
return CPU_DEVICE
|
return CPU_DEVICE
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -32,7 +32,6 @@ def choose_precision(device: torch.device) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def torch_dtype(device: torch.device) -> torch.dtype:
|
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||||
config = get_invokeai_config()
|
|
||||||
if config.full_precision:
|
if config.full_precision:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
if choose_precision(device) == "float16":
|
if choose_precision(device) == "float16":
|
||||||
|
@ -31,7 +31,20 @@ IAILogger.debug('this is a debugging message')
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import logging.handlers
|
||||||
|
import socket
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig, get_invokeai_config
|
||||||
|
|
||||||
|
try:
|
||||||
|
import syslog
|
||||||
|
SYSLOG_AVAILABLE = True
|
||||||
|
except:
|
||||||
|
SYSLOG_AVAILABLE = False
|
||||||
|
|
||||||
# module level functions
|
# module level functions
|
||||||
def debug(msg, *args, **kwargs):
|
def debug(msg, *args, **kwargs):
|
||||||
@ -62,11 +75,77 @@ def getLogger(name: str = None) -> logging.Logger:
|
|||||||
return InvokeAILogger.getLogger(name)
|
return InvokeAILogger.getLogger(name)
|
||||||
|
|
||||||
|
|
||||||
class InvokeAILogFormatter(logging.Formatter):
|
_FACILITY_MAP = dict(
|
||||||
|
LOG_KERN = syslog.LOG_KERN,
|
||||||
|
LOG_USER = syslog.LOG_USER,
|
||||||
|
LOG_MAIL = syslog.LOG_MAIL,
|
||||||
|
LOG_DAEMON = syslog.LOG_DAEMON,
|
||||||
|
LOG_AUTH = syslog.LOG_AUTH,
|
||||||
|
LOG_LPR = syslog.LOG_LPR,
|
||||||
|
LOG_NEWS = syslog.LOG_NEWS,
|
||||||
|
LOG_UUCP = syslog.LOG_UUCP,
|
||||||
|
LOG_CRON = syslog.LOG_CRON,
|
||||||
|
LOG_SYSLOG = syslog.LOG_SYSLOG,
|
||||||
|
LOG_LOCAL0 = syslog.LOG_LOCAL0,
|
||||||
|
LOG_LOCAL1 = syslog.LOG_LOCAL1,
|
||||||
|
LOG_LOCAL2 = syslog.LOG_LOCAL2,
|
||||||
|
LOG_LOCAL3 = syslog.LOG_LOCAL3,
|
||||||
|
LOG_LOCAL4 = syslog.LOG_LOCAL4,
|
||||||
|
LOG_LOCAL5 = syslog.LOG_LOCAL5,
|
||||||
|
LOG_LOCAL6 = syslog.LOG_LOCAL6,
|
||||||
|
LOG_LOCAL7 = syslog.LOG_LOCAL7,
|
||||||
|
) if SYSLOG_AVAILABLE else dict()
|
||||||
|
|
||||||
|
_SOCK_MAP = dict(
|
||||||
|
SOCK_STREAM = socket.SOCK_STREAM,
|
||||||
|
SOCK_DGRAM = socket.SOCK_DGRAM,
|
||||||
|
)
|
||||||
|
|
||||||
|
class InvokeAIFormatter(logging.Formatter):
|
||||||
|
'''
|
||||||
|
Base class for logging formatter
|
||||||
|
|
||||||
|
'''
|
||||||
|
def format(self, record):
|
||||||
|
formatter = logging.Formatter(self.log_fmt(record.levelno))
|
||||||
|
return formatter.format(record)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def log_fmt(self, levelno: int)->str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class InvokeAISyslogFormatter(InvokeAIFormatter):
|
||||||
|
'''
|
||||||
|
Formatting for syslog
|
||||||
|
'''
|
||||||
|
def log_fmt(self, levelno: int)->str:
|
||||||
|
return '%(name)s [%(process)d] <%(levelname)s> %(message)s'
|
||||||
|
|
||||||
|
class InvokeAILegacyLogFormatter(InvokeAIFormatter):
|
||||||
|
'''
|
||||||
|
Formatting for the InvokeAI Logger (legacy version)
|
||||||
|
'''
|
||||||
|
FORMATS = {
|
||||||
|
logging.DEBUG: " | %(message)s",
|
||||||
|
logging.INFO: ">> %(message)s",
|
||||||
|
logging.WARNING: "** %(message)s",
|
||||||
|
logging.ERROR: "*** %(message)s",
|
||||||
|
logging.CRITICAL: "### %(message)s",
|
||||||
|
}
|
||||||
|
def log_fmt(self,levelno:int)->str:
|
||||||
|
return self.FORMATS.get(levelno)
|
||||||
|
|
||||||
|
class InvokeAIPlainLogFormatter(InvokeAIFormatter):
|
||||||
|
'''
|
||||||
|
Custom Formatting for the InvokeAI Logger (plain version)
|
||||||
|
'''
|
||||||
|
def log_fmt(self, levelno: int)->str:
|
||||||
|
return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
|
||||||
|
|
||||||
|
class InvokeAIColorLogFormatter(InvokeAIFormatter):
|
||||||
'''
|
'''
|
||||||
Custom Formatting for the InvokeAI Logger
|
Custom Formatting for the InvokeAI Logger
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# Color Codes
|
# Color Codes
|
||||||
grey = "\x1b[38;20m"
|
grey = "\x1b[38;20m"
|
||||||
yellow = "\x1b[33;20m"
|
yellow = "\x1b[33;20m"
|
||||||
@ -88,23 +167,109 @@ class InvokeAILogFormatter(logging.Formatter):
|
|||||||
logging.CRITICAL: bold_red + log_format + reset
|
logging.CRITICAL: bold_red + log_format + reset
|
||||||
}
|
}
|
||||||
|
|
||||||
def format(self, record):
|
def log_fmt(self, levelno: int)->str:
|
||||||
log_fmt = self.FORMATS.get(record.levelno)
|
return self.FORMATS.get(levelno)
|
||||||
formatter = logging.Formatter(log_fmt, datefmt="%d-%m-%Y %H:%M:%S")
|
|
||||||
return formatter.format(record)
|
|
||||||
|
|
||||||
|
LOG_FORMATTERS = {
|
||||||
|
'plain': InvokeAIPlainLogFormatter,
|
||||||
|
'color': InvokeAIColorLogFormatter,
|
||||||
|
'syslog': InvokeAISyslogFormatter,
|
||||||
|
'legacy': InvokeAILegacyLogFormatter,
|
||||||
|
}
|
||||||
|
|
||||||
class InvokeAILogger(object):
|
class InvokeAILogger(object):
|
||||||
loggers = dict()
|
loggers = dict()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def getLogger(cls, name: str = 'InvokeAI') -> logging.Logger:
|
def getLogger(cls, name: str = 'InvokeAI') -> logging.Logger:
|
||||||
|
config = get_invokeai_config()
|
||||||
|
|
||||||
if name not in cls.loggers:
|
if name not in cls.loggers:
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(config.log_level.upper()) # yes, strings work here
|
||||||
ch = logging.StreamHandler()
|
for ch in cls.getLoggers(config):
|
||||||
fmt = InvokeAILogFormatter()
|
logger.addHandler(ch)
|
||||||
ch.setFormatter(fmt)
|
|
||||||
logger.addHandler(ch)
|
|
||||||
cls.loggers[name] = logger
|
cls.loggers[name] = logger
|
||||||
return cls.loggers[name]
|
return cls.loggers[name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def getLoggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
|
||||||
|
handler_strs = config.log_handlers
|
||||||
|
print(f'handler_strs={handler_strs}')
|
||||||
|
handlers = list()
|
||||||
|
for handler in handler_strs:
|
||||||
|
handler_name,*args = handler.split('=',2)
|
||||||
|
args = args[0] if len(args) > 0 else None
|
||||||
|
|
||||||
|
# console is the only handler that gets a custom formatter
|
||||||
|
if handler_name=='console':
|
||||||
|
formatter = LOG_FORMATTERS[config.log_format]
|
||||||
|
ch = logging.StreamHandler()
|
||||||
|
ch.setFormatter(formatter())
|
||||||
|
handlers.append(ch)
|
||||||
|
|
||||||
|
elif handler_name=='syslog':
|
||||||
|
ch = cls._parse_syslog_args(args)
|
||||||
|
ch.setFormatter(InvokeAISyslogFormatter())
|
||||||
|
handlers.append(ch)
|
||||||
|
|
||||||
|
elif handler_name=='file':
|
||||||
|
handlers.append(cls._parse_file_args(args))
|
||||||
|
|
||||||
|
elif handler_name=='http':
|
||||||
|
handlers.append(cls._parse_http_args(args))
|
||||||
|
return handlers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_syslog_args(
|
||||||
|
args: str=None
|
||||||
|
)-> logging.Handler:
|
||||||
|
if not SYSLOG_AVAILABLE:
|
||||||
|
raise ValueError("syslog is not available on this system")
|
||||||
|
if not args:
|
||||||
|
args='/dev/log' if Path('/dev/log').exists() else 'address:localhost:514'
|
||||||
|
syslog_args = dict()
|
||||||
|
try:
|
||||||
|
for a in args.split(','):
|
||||||
|
arg_name,*arg_value = a.split(':',2)
|
||||||
|
if arg_name=='address':
|
||||||
|
host,*port = arg_value
|
||||||
|
port = 514 if len(port)==0 else int(port[0])
|
||||||
|
syslog_args['address'] = (host,port)
|
||||||
|
elif arg_name=='facility':
|
||||||
|
syslog_args['facility'] = _FACILITY_MAP[arg_value[0]]
|
||||||
|
elif arg_name=='socktype':
|
||||||
|
syslog_args['socktype'] = _SOCK_MAP[arg_value[0]]
|
||||||
|
else:
|
||||||
|
syslog_args['address'] = arg_name
|
||||||
|
except:
|
||||||
|
raise ValueError(f"{args} is not a value argument list for syslog logging")
|
||||||
|
return logging.handlers.SysLogHandler(**syslog_args)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_file_args(args: str=None)-> logging.Handler:
|
||||||
|
if not args:
|
||||||
|
raise ValueError("please provide filename for file logging using format 'file=/path/to/logfile.txt'")
|
||||||
|
return logging.FileHandler(args)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_http_args(args: str=None)-> logging.Handler:
|
||||||
|
if not args:
|
||||||
|
raise ValueError("please provide destination for http logging using format 'http=url'")
|
||||||
|
arg_list = args.split(',')
|
||||||
|
url = urllib.parse.urlparse(arg_list.pop(0))
|
||||||
|
if url.scheme != 'http':
|
||||||
|
raise ValueError(f"the http logging module can only log to HTTP URLs, but {url.scheme} was specified")
|
||||||
|
host = url.hostname
|
||||||
|
path = url.path
|
||||||
|
port = url.port or 80
|
||||||
|
|
||||||
|
syslog_args = dict()
|
||||||
|
for a in arg_list:
|
||||||
|
arg_name, *arg_value = a.split(':',2)
|
||||||
|
if arg_name=='method':
|
||||||
|
arg_value = arg_value[0] if len(arg_value)>0 else 'GET'
|
||||||
|
syslog_args[arg_name] = arg_value
|
||||||
|
else: # TODO: Provide support for SSL context and credentials
|
||||||
|
pass
|
||||||
|
return logging.handlers.HTTPHandler(f'{host}:{port}',path,**syslog_args)
|
||||||
|
@ -40,13 +40,13 @@ from .widgets import (
|
|||||||
TextBox,
|
TextBox,
|
||||||
set_min_terminal_size,
|
set_min_terminal_size,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
# minimum size for the UI
|
# minimum size for the UI
|
||||||
MIN_COLS = 120
|
MIN_COLS = 120
|
||||||
MIN_LINES = 45
|
MIN_LINES = 45
|
||||||
|
|
||||||
config = get_invokeai_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
class addModelsForm(npyscreen.FormMultiPage):
|
class addModelsForm(npyscreen.FormMultiPage):
|
||||||
# for responsive resizing - disabled
|
# for responsive resizing - disabled
|
||||||
|
@ -20,12 +20,12 @@ from npyscreen import widget
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.services.config import get_invokeai_config
|
from invokeai.services.config import InvokeAIAppConfig
|
||||||
from ...backend.model_management import ModelManager
|
from ...backend.model_management import ModelManager
|
||||||
from ...frontend.install.widgets import FloatTitleSlider
|
from ...frontend.install.widgets import FloatTitleSlider
|
||||||
|
|
||||||
DEST_MERGED_MODEL_DIR = "merged_models"
|
DEST_MERGED_MODEL_DIR = "merged_models"
|
||||||
config = get_invokeai_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
def merge_diffusion_models(
|
def merge_diffusion_models(
|
||||||
model_ids_or_paths: List[Union[str, Path]],
|
model_ids_or_paths: List[Union[str, Path]],
|
||||||
|
@ -22,7 +22,7 @@ from omegaconf import OmegaConf
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from ...backend.training import (
|
from ...backend.training import (
|
||||||
do_textual_inversion_training,
|
do_textual_inversion_training,
|
||||||
parse_args
|
parse_args
|
||||||
@ -423,7 +423,7 @@ def do_front_end(args: Namespace):
|
|||||||
save_args(args)
|
save_args(args)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
do_textual_inversion_training(get_invokeai_config(),**args)
|
do_textual_inversion_training(InvokeAIAppConfig.get_config(),**args)
|
||||||
copy_to_embeddings_folder(args)
|
copy_to_embeddings_folder(args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("An exception occurred during training. The exception was:")
|
logger.error("An exception occurred during training. The exception was:")
|
||||||
@ -436,7 +436,7 @@ def main():
|
|||||||
global config
|
global config
|
||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
config = get_invokeai_config(argv=[])
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
# change root if needed
|
# change root if needed
|
||||||
if args.root_dir:
|
if args.root_dir:
|
||||||
|
@ -26,10 +26,10 @@ We need to start the nodes web server, which serves the OpenAPI schema to the ge
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# from the repo root
|
# from the repo root
|
||||||
python scripts/invoke-new.py --web
|
python scripts/invokeai-web.py
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Generate the API client.
|
2. Generate the API client.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# from invokeai/frontend/web/
|
# from invokeai/frontend/web/
|
||||||
|
@ -12,7 +12,14 @@ Code in `invokeai/frontend/web/` if you want to have a look.
|
|||||||
|
|
||||||
## Stack
|
## Stack
|
||||||
|
|
||||||
State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a custom redux middleware to help).
|
State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). We lean heavily on RTK:
|
||||||
|
- `createAsyncThunk` for HTTP requests
|
||||||
|
- `createEntityAdapter` for fetching images and models
|
||||||
|
- `createListenerMiddleware` for workflows
|
||||||
|
|
||||||
|
The API client and associated types are generated from the OpenAPI schema. See API_CLIENT.md.
|
||||||
|
|
||||||
|
Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a simple socket.io redux middleware to help).
|
||||||
|
|
||||||
[Chakra-UI](https://github.com/chakra-ui/chakra-ui) for components and styling.
|
[Chakra-UI](https://github.com/chakra-ui/chakra-ui) for components and styling.
|
||||||
|
|
||||||
@ -37,9 +44,15 @@ From `invokeai/frontend/web/` run `yarn install` to get everything set up.
|
|||||||
Start everything in dev mode:
|
Start everything in dev mode:
|
||||||
|
|
||||||
1. Start the dev server: `yarn dev`
|
1. Start the dev server: `yarn dev`
|
||||||
2. Start the InvokeAI Nodes backend: `python scripts/invokeai-new.py --web # run from the repo root`
|
2. Start the InvokeAI Nodes backend: `python scripts/invokeai-web.py # run from the repo root`
|
||||||
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
|
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
|
||||||
|
|
||||||
|
#### VSCode Remote Dev
|
||||||
|
|
||||||
|
We've noticed an intermittent issue with the VSCode Remote Dev port forwarding. If you use this feature of VSCode, you may intermittently click the Invoke button and then get nothing until the request times out. Suggest disabling the IDE's port forwarding feature and doing it manually via SSH:
|
||||||
|
|
||||||
|
`ssh -L 9090:localhost:9090 -L 5173:localhost:5173 user@host`
|
||||||
|
|
||||||
### Production builds
|
### Production builds
|
||||||
|
|
||||||
For a number of technical and logistical reasons, we need to commit UI build artefacts to the repo.
|
For a number of technical and logistical reasons, we need to commit UI build artefacts to the repo.
|
||||||
|
@ -60,6 +60,8 @@
|
|||||||
"@chakra-ui/styled-system": "^2.9.0",
|
"@chakra-ui/styled-system": "^2.9.0",
|
||||||
"@chakra-ui/theme-tools": "^2.0.16",
|
"@chakra-ui/theme-tools": "^2.0.16",
|
||||||
"@dagrejs/graphlib": "^2.1.12",
|
"@dagrejs/graphlib": "^2.1.12",
|
||||||
|
"@dnd-kit/core": "^6.0.8",
|
||||||
|
"@dnd-kit/modifiers": "^6.0.1",
|
||||||
"@emotion/react": "^11.10.6",
|
"@emotion/react": "^11.10.6",
|
||||||
"@emotion/styled": "^11.10.6",
|
"@emotion/styled": "^11.10.6",
|
||||||
"@floating-ui/react-dom": "^2.0.0",
|
"@floating-ui/react-dom": "^2.0.0",
|
||||||
@ -87,7 +89,7 @@
|
|||||||
"react-dropzone": "^14.2.3",
|
"react-dropzone": "^14.2.3",
|
||||||
"react-hotkeys-hook": "4.4.0",
|
"react-hotkeys-hook": "4.4.0",
|
||||||
"react-i18next": "^12.2.2",
|
"react-i18next": "^12.2.2",
|
||||||
"react-icons": "^4.7.1",
|
"react-icons": "^4.9.0",
|
||||||
"react-konva": "^18.2.7",
|
"react-konva": "^18.2.7",
|
||||||
"react-redux": "^8.0.5",
|
"react-redux": "^8.0.5",
|
||||||
"react-resizable-panels": "^0.0.42",
|
"react-resizable-panels": "^0.0.42",
|
||||||
|
@ -21,6 +21,7 @@ import { ReactNode, memo, useCallback, useEffect, useState } from 'react';
|
|||||||
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
||||||
import GlobalHotkeys from './GlobalHotkeys';
|
import GlobalHotkeys from './GlobalHotkeys';
|
||||||
import Toaster from './Toaster';
|
import Toaster from './Toaster';
|
||||||
|
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
|
||||||
|
|
||||||
const DEFAULT_CONFIG = {};
|
const DEFAULT_CONFIG = {};
|
||||||
|
|
||||||
@ -76,18 +77,21 @@ const App = ({
|
|||||||
{isLightboxEnabled && <Lightbox />}
|
{isLightboxEnabled && <Lightbox />}
|
||||||
<ImageUploader>
|
<ImageUploader>
|
||||||
<Grid
|
<Grid
|
||||||
gap={4}
|
sx={{
|
||||||
p={4}
|
gap: 4,
|
||||||
gridAutoRows="min-content auto"
|
p: 4,
|
||||||
w={APP_WIDTH}
|
gridAutoRows: 'min-content auto',
|
||||||
h={APP_HEIGHT}
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
}}
|
||||||
>
|
>
|
||||||
{headerComponent || <SiteHeader />}
|
{headerComponent || <SiteHeader />}
|
||||||
<Flex
|
<Flex
|
||||||
gap={4}
|
sx={{
|
||||||
w={{ base: '100vw', xl: 'full' }}
|
gap: 4,
|
||||||
h="full"
|
w: 'full',
|
||||||
flexDir={{ base: 'column', xl: 'row' }}
|
h: 'full',
|
||||||
|
}}
|
||||||
>
|
>
|
||||||
<InvokeTabs />
|
<InvokeTabs />
|
||||||
</Flex>
|
</Flex>
|
||||||
@ -130,6 +134,7 @@ const App = ({
|
|||||||
<FloatingGalleryButton />
|
<FloatingGalleryButton />
|
||||||
</Portal>
|
</Portal>
|
||||||
</Grid>
|
</Grid>
|
||||||
|
<DeleteImageModal />
|
||||||
<Toaster />
|
<Toaster />
|
||||||
<GlobalHotkeys />
|
<GlobalHotkeys />
|
||||||
</>
|
</>
|
||||||
|
@ -0,0 +1,71 @@
|
|||||||
|
import {
|
||||||
|
DndContext,
|
||||||
|
DragEndEvent,
|
||||||
|
DragOverlay,
|
||||||
|
DragStartEvent,
|
||||||
|
KeyboardSensor,
|
||||||
|
MouseSensor,
|
||||||
|
TouchSensor,
|
||||||
|
pointerWithin,
|
||||||
|
useSensor,
|
||||||
|
useSensors,
|
||||||
|
} from '@dnd-kit/core';
|
||||||
|
import { PropsWithChildren, memo, useCallback, useState } from 'react';
|
||||||
|
import OverlayDragImage from './OverlayDragImage';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
import { isImageDTO } from 'services/types/guards';
|
||||||
|
import { snapCenterToCursor } from '@dnd-kit/modifiers';
|
||||||
|
|
||||||
|
type ImageDndContextProps = PropsWithChildren;
|
||||||
|
|
||||||
|
const ImageDndContext = (props: ImageDndContextProps) => {
|
||||||
|
const [draggedImage, setDraggedImage] = useState<ImageDTO | null>(null);
|
||||||
|
|
||||||
|
const handleDragStart = useCallback((event: DragStartEvent) => {
|
||||||
|
const dragData = event.active.data.current;
|
||||||
|
if (dragData && 'image' in dragData && isImageDTO(dragData.image)) {
|
||||||
|
setDraggedImage(dragData.image);
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handleDragEnd = useCallback(
|
||||||
|
(event: DragEndEvent) => {
|
||||||
|
const handleDrop = event.over?.data.current?.handleDrop;
|
||||||
|
if (handleDrop && typeof handleDrop === 'function' && draggedImage) {
|
||||||
|
handleDrop(draggedImage);
|
||||||
|
}
|
||||||
|
setDraggedImage(null);
|
||||||
|
},
|
||||||
|
[draggedImage]
|
||||||
|
);
|
||||||
|
|
||||||
|
const mouseSensor = useSensor(MouseSensor, {
|
||||||
|
activationConstraint: { distance: 15 },
|
||||||
|
});
|
||||||
|
|
||||||
|
const touchSensor = useSensor(TouchSensor, {
|
||||||
|
activationConstraint: { distance: 15 },
|
||||||
|
});
|
||||||
|
// TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos
|
||||||
|
// Alternatively, fix `rectIntersection` collection detection to work with the drag overlay
|
||||||
|
// (currently the drag element collision rect is not correctly calculated)
|
||||||
|
// const keyboardSensor = useSensor(KeyboardSensor);
|
||||||
|
|
||||||
|
const sensors = useSensors(mouseSensor, touchSensor);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<DndContext
|
||||||
|
onDragStart={handleDragStart}
|
||||||
|
onDragEnd={handleDragEnd}
|
||||||
|
sensors={sensors}
|
||||||
|
collisionDetection={pointerWithin}
|
||||||
|
>
|
||||||
|
{props.children}
|
||||||
|
<DragOverlay dropAnimation={null} modifiers={[snapCenterToCursor]}>
|
||||||
|
{draggedImage && <OverlayDragImage image={draggedImage} />}
|
||||||
|
</DragOverlay>
|
||||||
|
</DndContext>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ImageDndContext);
|
@ -0,0 +1,36 @@
|
|||||||
|
import { Box, Image } from '@chakra-ui/react';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
|
||||||
|
type OverlayDragImageProps = {
|
||||||
|
image: ImageDTO;
|
||||||
|
};
|
||||||
|
|
||||||
|
const OverlayDragImage = (props: OverlayDragImageProps) => {
|
||||||
|
return (
|
||||||
|
<Box
|
||||||
|
style={{
|
||||||
|
width: '100%',
|
||||||
|
height: '100%',
|
||||||
|
display: 'flex',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
userSelect: 'none',
|
||||||
|
cursor: 'grabbing',
|
||||||
|
opacity: 0.5,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Image
|
||||||
|
sx={{
|
||||||
|
maxW: 36,
|
||||||
|
maxH: 36,
|
||||||
|
borderRadius: 'base',
|
||||||
|
shadow: 'dark-lg',
|
||||||
|
}}
|
||||||
|
src={props.image.thumbnail_url}
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(OverlayDragImage);
|
@ -16,6 +16,11 @@ import { PartialAppConfig } from 'app/types/invokeai';
|
|||||||
import '../../i18n';
|
import '../../i18n';
|
||||||
import { socketMiddleware } from 'services/events/middleware';
|
import { socketMiddleware } from 'services/events/middleware';
|
||||||
import { Middleware } from '@reduxjs/toolkit';
|
import { Middleware } from '@reduxjs/toolkit';
|
||||||
|
import ImageDndContext from './ImageDnd/ImageDndContext';
|
||||||
|
import {
|
||||||
|
DeleteImageContext,
|
||||||
|
DeleteImageContextProvider,
|
||||||
|
} from 'app/contexts/DeleteImageContext';
|
||||||
|
|
||||||
const App = lazy(() => import('./App'));
|
const App = lazy(() => import('./App'));
|
||||||
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
||||||
@ -69,11 +74,15 @@ const InvokeAIUI = ({
|
|||||||
<Provider store={store}>
|
<Provider store={store}>
|
||||||
<React.Suspense fallback={<Loading />}>
|
<React.Suspense fallback={<Loading />}>
|
||||||
<ThemeLocaleProvider>
|
<ThemeLocaleProvider>
|
||||||
<App
|
<ImageDndContext>
|
||||||
config={config}
|
<DeleteImageContextProvider>
|
||||||
headerComponent={headerComponent}
|
<App
|
||||||
setIsReady={setIsReady}
|
config={config}
|
||||||
/>
|
headerComponent={headerComponent}
|
||||||
|
setIsReady={setIsReady}
|
||||||
|
/>
|
||||||
|
</DeleteImageContextProvider>
|
||||||
|
</ImageDndContext>
|
||||||
</ThemeLocaleProvider>
|
</ThemeLocaleProvider>
|
||||||
</React.Suspense>
|
</React.Suspense>
|
||||||
</Provider>
|
</Provider>
|
||||||
|
203
invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx
Normal file
203
invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
import { useDisclosure } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { requestedImageDeletion } from 'features/gallery/store/actions';
|
||||||
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
|
import {
|
||||||
|
PropsWithChildren,
|
||||||
|
createContext,
|
||||||
|
useCallback,
|
||||||
|
useEffect,
|
||||||
|
useState,
|
||||||
|
} from 'react';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
|
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { nodesSelecter } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
|
import { some } from 'lodash-es';
|
||||||
|
|
||||||
|
export type ImageUsage = {
|
||||||
|
isInitialImage: boolean;
|
||||||
|
isCanvasImage: boolean;
|
||||||
|
isNodesImage: boolean;
|
||||||
|
isControlNetImage: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const selectImageUsage = createSelector(
|
||||||
|
[
|
||||||
|
generationSelector,
|
||||||
|
canvasSelector,
|
||||||
|
nodesSelecter,
|
||||||
|
controlNetSelector,
|
||||||
|
(state: RootState, image_name?: string) => image_name,
|
||||||
|
],
|
||||||
|
(generation, canvas, nodes, controlNet, image_name) => {
|
||||||
|
const isInitialImage = generation.initialImage?.image_name === image_name;
|
||||||
|
|
||||||
|
const isCanvasImage = canvas.layerState.objects.some(
|
||||||
|
(obj) => obj.kind === 'image' && obj.image.image_name === image_name
|
||||||
|
);
|
||||||
|
|
||||||
|
const isNodesImage = nodes.nodes.some((node) => {
|
||||||
|
return some(
|
||||||
|
node.data.inputs,
|
||||||
|
(input) =>
|
||||||
|
input.type === 'image' && input.value?.image_name === image_name
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
const isControlNetImage = some(
|
||||||
|
controlNet.controlNets,
|
||||||
|
(c) =>
|
||||||
|
c.controlImage?.image_name === image_name ||
|
||||||
|
c.processedControlImage?.image_name === image_name
|
||||||
|
);
|
||||||
|
|
||||||
|
const imageUsage: ImageUsage = {
|
||||||
|
isInitialImage,
|
||||||
|
isCanvasImage,
|
||||||
|
isNodesImage,
|
||||||
|
isControlNetImage,
|
||||||
|
};
|
||||||
|
|
||||||
|
return imageUsage;
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
type DeleteImageContextValue = {
|
||||||
|
/**
|
||||||
|
* Whether the delete image dialog is open.
|
||||||
|
*/
|
||||||
|
isOpen: boolean;
|
||||||
|
/**
|
||||||
|
* Closes the delete image dialog.
|
||||||
|
*/
|
||||||
|
onClose: () => void;
|
||||||
|
/**
|
||||||
|
* Opens the delete image dialog and handles all deletion-related checks.
|
||||||
|
*/
|
||||||
|
onDelete: (image?: ImageDTO) => void;
|
||||||
|
/**
|
||||||
|
* The image pending deletion
|
||||||
|
*/
|
||||||
|
image?: ImageDTO;
|
||||||
|
/**
|
||||||
|
* The features in which this image is used
|
||||||
|
*/
|
||||||
|
imageUsage?: ImageUsage;
|
||||||
|
/**
|
||||||
|
* Immediately deletes an image.
|
||||||
|
*
|
||||||
|
* You probably don't want to use this - use `onDelete` instead.
|
||||||
|
*/
|
||||||
|
onImmediatelyDelete: () => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const DeleteImageContext = createContext<DeleteImageContextValue>({
|
||||||
|
isOpen: false,
|
||||||
|
onClose: () => undefined,
|
||||||
|
onImmediatelyDelete: () => undefined,
|
||||||
|
onDelete: () => undefined,
|
||||||
|
});
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[systemSelector],
|
||||||
|
(system) => {
|
||||||
|
const { isProcessing, isConnected, shouldConfirmOnDelete } = system;
|
||||||
|
|
||||||
|
return {
|
||||||
|
canDeleteImage: isConnected && !isProcessing,
|
||||||
|
shouldConfirmOnDelete,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
type Props = PropsWithChildren;
|
||||||
|
|
||||||
|
export const DeleteImageContextProvider = (props: Props) => {
|
||||||
|
const { canDeleteImage, shouldConfirmOnDelete } = useAppSelector(selector);
|
||||||
|
const [imageToDelete, setImageToDelete] = useState<ImageDTO>();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
|
|
||||||
|
// Check where the image to be deleted is used (eg init image, controlnet, etc.)
|
||||||
|
const imageUsage = useAppSelector((state) =>
|
||||||
|
selectImageUsage(state, imageToDelete?.image_name)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Clean up after deleting or dismissing the modal
|
||||||
|
const closeAndClearImageToDelete = useCallback(() => {
|
||||||
|
setImageToDelete(undefined);
|
||||||
|
onClose();
|
||||||
|
}, [onClose]);
|
||||||
|
|
||||||
|
// Dispatch the actual deletion action, to be handled by listener middleware
|
||||||
|
const handleActualDeletion = useCallback(
|
||||||
|
(image: ImageDTO) => {
|
||||||
|
dispatch(requestedImageDeletion({ image, imageUsage }));
|
||||||
|
closeAndClearImageToDelete();
|
||||||
|
},
|
||||||
|
[closeAndClearImageToDelete, dispatch, imageUsage]
|
||||||
|
);
|
||||||
|
|
||||||
|
// This is intended to be called by the delete button in the dialog
|
||||||
|
const onImmediatelyDelete = useCallback(() => {
|
||||||
|
if (canDeleteImage && imageToDelete) {
|
||||||
|
handleActualDeletion(imageToDelete);
|
||||||
|
}
|
||||||
|
closeAndClearImageToDelete();
|
||||||
|
}, [
|
||||||
|
canDeleteImage,
|
||||||
|
imageToDelete,
|
||||||
|
closeAndClearImageToDelete,
|
||||||
|
handleActualDeletion,
|
||||||
|
]);
|
||||||
|
|
||||||
|
const handleGatedDeletion = useCallback(
|
||||||
|
(image: ImageDTO) => {
|
||||||
|
if (shouldConfirmOnDelete || some(imageUsage)) {
|
||||||
|
// If we should confirm on delete, or if the image is in use, open the dialog
|
||||||
|
onOpen();
|
||||||
|
} else {
|
||||||
|
handleActualDeletion(image);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[imageUsage, shouldConfirmOnDelete, onOpen, handleActualDeletion]
|
||||||
|
);
|
||||||
|
|
||||||
|
// Consumers of the context call this to delete an image
|
||||||
|
const onDelete = useCallback((image?: ImageDTO) => {
|
||||||
|
if (!image) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Set the image to delete, then let the effect call the actual deletion
|
||||||
|
setImageToDelete(image);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
// We need to use an effect here to trigger the image usage selector, else we get a stale value
|
||||||
|
if (imageToDelete) {
|
||||||
|
handleGatedDeletion(imageToDelete);
|
||||||
|
}
|
||||||
|
}, [handleGatedDeletion, imageToDelete]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<DeleteImageContext.Provider
|
||||||
|
value={{
|
||||||
|
isOpen,
|
||||||
|
image: imageToDelete,
|
||||||
|
onClose: closeAndClearImageToDelete,
|
||||||
|
onDelete,
|
||||||
|
onImmediatelyDelete,
|
||||||
|
imageUsage,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{props.children}
|
||||||
|
</DeleteImageContext.Provider>
|
||||||
|
);
|
||||||
|
};
|
@ -1,4 +1,5 @@
|
|||||||
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
|
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
|
||||||
|
import { controlNetDenylist } from 'features/controlNet/store/controlNetDenylist';
|
||||||
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
|
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
|
||||||
import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
|
import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
|
||||||
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
||||||
@ -23,6 +24,7 @@ const serializationDenylist: {
|
|||||||
system: systemPersistDenylist,
|
system: systemPersistDenylist,
|
||||||
// config: configPersistDenyList,
|
// config: configPersistDenyList,
|
||||||
ui: uiPersistDenylist,
|
ui: uiPersistDenylist,
|
||||||
|
controlNet: controlNetDenylist,
|
||||||
// hotkeys: hotkeysPersistDenylist,
|
// hotkeys: hotkeysPersistDenylist,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
|
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
|
||||||
|
import { initialControlNetState } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
|
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
|
||||||
import { initialImagesState } from 'features/gallery/store/imagesSlice';
|
import { initialImagesState } from 'features/gallery/store/imagesSlice';
|
||||||
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
|
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
|
||||||
@ -28,6 +29,7 @@ const initialStates: {
|
|||||||
ui: initialUIState,
|
ui: initialUIState,
|
||||||
hotkeys: initialHotkeysState,
|
hotkeys: initialHotkeysState,
|
||||||
images: initialImagesState,
|
images: initialImagesState,
|
||||||
|
controlNet: initialControlNetState,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const unserialize: UnserializeFunction = (data, key) => {
|
export const unserialize: UnserializeFunction = (data, key) => {
|
||||||
|
@ -70,6 +70,9 @@ import {
|
|||||||
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
|
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
|
||||||
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
|
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
|
||||||
import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged';
|
import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged';
|
||||||
|
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
|
||||||
|
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
|
||||||
|
import { addUpdateImageUrlsOnConnectListener } from './listeners/updateImageUrlsOnConnect';
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
|
|
||||||
@ -173,3 +176,10 @@ addReceivedPageOfImagesRejectedListener();
|
|||||||
|
|
||||||
// Gallery
|
// Gallery
|
||||||
addImageCategoriesChangedListener();
|
addImageCategoriesChangedListener();
|
||||||
|
|
||||||
|
// ControlNet
|
||||||
|
addControlNetImageProcessedListener();
|
||||||
|
addControlNetAutoProcessListener();
|
||||||
|
|
||||||
|
// Update image URLs on connect
|
||||||
|
addUpdateImageUrlsOnConnectListener();
|
||||||
|
@ -28,6 +28,13 @@ export const addCanvasCopiedToClipboardListener = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
copyBlobToClipboard(blob);
|
copyBlobToClipboard(blob);
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: 'Canvas Copied to Clipboard',
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -27,7 +27,8 @@ export const addCanvasDownloadedAsImageListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
downloadBlob(blob, 'mergedCanvas.png');
|
downloadBlob(blob, 'canvas.png');
|
||||||
|
dispatch(addToast({ title: 'Canvas Downloaded', status: 'success' }));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,22 +1,20 @@
|
|||||||
import { canvasMerged } from 'features/canvas/store/actions';
|
import { canvasMerged } from 'features/canvas/store/actions';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { imageUploaded } from 'services/thunks/image';
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
|
import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
|
||||||
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
|
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
|
||||||
|
import { getFullBaseLayerBlob } from 'features/canvas/util/getFullBaseLayerBlob';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' });
|
const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' });
|
||||||
|
export const MERGED_CANVAS_FILENAME = 'mergedCanvas.png';
|
||||||
|
|
||||||
export const addCanvasMergedListener = () => {
|
export const addCanvasMergedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: canvasMerged,
|
actionCreator: canvasMerged,
|
||||||
effect: async (action, { dispatch, getState, take }) => {
|
effect: async (action, { dispatch, getState, take }) => {
|
||||||
const state = getState();
|
const blob = await getFullBaseLayerBlob();
|
||||||
|
|
||||||
const blob = await getBaseLayerBlob(state, true);
|
|
||||||
|
|
||||||
if (!blob) {
|
if (!blob) {
|
||||||
moduleLog.error('Problem getting base layer blob');
|
moduleLog.error('Problem getting base layer blob');
|
||||||
@ -48,12 +46,12 @@ export const addCanvasMergedListener = () => {
|
|||||||
relativeTo: canvasBaseLayer.getParent(),
|
relativeTo: canvasBaseLayer.getParent(),
|
||||||
});
|
});
|
||||||
|
|
||||||
const filename = `mergedCanvas_${uuidv4()}.png`;
|
const imageUploadedRequest = dispatch(
|
||||||
|
|
||||||
dispatch(
|
|
||||||
imageUploaded({
|
imageUploaded({
|
||||||
formData: {
|
formData: {
|
||||||
file: new File([blob], filename, { type: 'image/png' }),
|
file: new File([blob], MERGED_CANVAS_FILENAME, {
|
||||||
|
type: 'image/png',
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
imageCategory: 'general',
|
imageCategory: 'general',
|
||||||
isIntermediate: true,
|
isIntermediate: true,
|
||||||
@ -61,9 +59,11 @@ export const addCanvasMergedListener = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const [{ payload }] = await take(
|
const [{ payload }] = await take(
|
||||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
(
|
||||||
imageUploaded.fulfilled.match(action) &&
|
uploadedImageAction
|
||||||
action.meta.arg.formData.file.name === filename
|
): uploadedImageAction is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||||
|
imageUploaded.fulfilled.match(uploadedImageAction) &&
|
||||||
|
uploadedImageAction.meta.requestId === imageUploadedRequest.requestId
|
||||||
);
|
);
|
||||||
|
|
||||||
const mergedCanvasImage = payload;
|
const mergedCanvasImage = payload;
|
||||||
|
@ -4,9 +4,10 @@ import { log } from 'app/logging/useLogger';
|
|||||||
import { imageUploaded } from 'services/thunks/image';
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||||
|
|
||||||
|
export const SAVED_CANVAS_FILENAME = 'savedCanvas.png';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
|
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
|
||||||
|
|
||||||
export const addCanvasSavedToGalleryListener = () => {
|
export const addCanvasSavedToGalleryListener = () => {
|
||||||
@ -15,7 +16,7 @@ export const addCanvasSavedToGalleryListener = () => {
|
|||||||
effect: async (action, { dispatch, getState, take }) => {
|
effect: async (action, { dispatch, getState, take }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
const blob = await getBaseLayerBlob(state, true);
|
const blob = await getBaseLayerBlob(state);
|
||||||
|
|
||||||
if (!blob) {
|
if (!blob) {
|
||||||
moduleLog.error('Problem getting base layer blob');
|
moduleLog.error('Problem getting base layer blob');
|
||||||
@ -29,12 +30,12 @@ export const addCanvasSavedToGalleryListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const filename = `mergedCanvas_${uuidv4()}.png`;
|
const imageUploadedRequest = dispatch(
|
||||||
|
|
||||||
dispatch(
|
|
||||||
imageUploaded({
|
imageUploaded({
|
||||||
formData: {
|
formData: {
|
||||||
file: new File([blob], filename, { type: 'image/png' }),
|
file: new File([blob], SAVED_CANVAS_FILENAME, {
|
||||||
|
type: 'image/png',
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
imageCategory: 'general',
|
imageCategory: 'general',
|
||||||
isIntermediate: false,
|
isIntermediate: false,
|
||||||
@ -42,9 +43,11 @@ export const addCanvasSavedToGalleryListener = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const [{ payload: uploadedImageDTO }] = await take(
|
const [{ payload: uploadedImageDTO }] = await take(
|
||||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
(
|
||||||
imageUploaded.fulfilled.match(action) &&
|
uploadedImageAction
|
||||||
action.meta.arg.formData.file.name === filename
|
): uploadedImageAction is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||||
|
imageUploaded.fulfilled.match(uploadedImageAction) &&
|
||||||
|
uploadedImageAction.meta.requestId === imageUploadedRequest.requestId
|
||||||
);
|
);
|
||||||
|
|
||||||
dispatch(imageUpserted(uploadedImageDTO));
|
dispatch(imageUpserted(uploadedImageDTO));
|
||||||
|
@ -0,0 +1,59 @@
|
|||||||
|
import { AnyAction } from '@reduxjs/toolkit';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
|
||||||
|
import {
|
||||||
|
controlNetImageChanged,
|
||||||
|
controlNetProcessorParamsChanged,
|
||||||
|
controlNetProcessorTypeChanged,
|
||||||
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'controlNet' });
|
||||||
|
|
||||||
|
const predicate = (action: AnyAction, state: RootState) => {
|
||||||
|
const isActionMatched =
|
||||||
|
controlNetProcessorParamsChanged.match(action) ||
|
||||||
|
controlNetImageChanged.match(action) ||
|
||||||
|
controlNetProcessorTypeChanged.match(action);
|
||||||
|
|
||||||
|
if (!isActionMatched) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { controlImage, processorType } =
|
||||||
|
state.controlNet.controlNets[action.payload.controlNetId];
|
||||||
|
|
||||||
|
const isProcessorSelected = processorType !== 'none';
|
||||||
|
|
||||||
|
const isBusy = state.system.isProcessing;
|
||||||
|
|
||||||
|
const hasControlImage = Boolean(controlImage);
|
||||||
|
|
||||||
|
return isProcessorSelected && !isBusy && hasControlImage;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Listener that automatically processes a ControlNet image when its processor parameters are changed.
|
||||||
|
*
|
||||||
|
* The network request is debounced by 1 second.
|
||||||
|
*/
|
||||||
|
export const addControlNetAutoProcessListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
predicate,
|
||||||
|
effect: async (
|
||||||
|
action,
|
||||||
|
{ dispatch, getState, cancelActiveListeners, delay }
|
||||||
|
) => {
|
||||||
|
const { controlNetId } = action.payload;
|
||||||
|
|
||||||
|
// Cancel any in-progress instances of this listener
|
||||||
|
cancelActiveListeners();
|
||||||
|
|
||||||
|
// Delay before starting actual work
|
||||||
|
await delay(300);
|
||||||
|
|
||||||
|
dispatch(controlNetImageProcessed({ controlNetId }));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,93 @@
|
|||||||
|
import { startAppListening } from '..';
|
||||||
|
import { imageMetadataReceived } from 'services/thunks/image';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
|
||||||
|
import { Graph } from 'services/api';
|
||||||
|
import { sessionCreated } from 'services/thunks/session';
|
||||||
|
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||||
|
import { socketInvocationComplete } from 'services/events/actions';
|
||||||
|
import { isImageOutput } from 'services/types/guards';
|
||||||
|
import { controlNetProcessedImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { pick } from 'lodash-es';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'controlNet' });
|
||||||
|
|
||||||
|
export const addControlNetImageProcessedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: controlNetImageProcessed,
|
||||||
|
effect: async (
|
||||||
|
action,
|
||||||
|
{ dispatch, getState, take, unsubscribe, subscribe }
|
||||||
|
) => {
|
||||||
|
const { controlNetId } = action.payload;
|
||||||
|
const controlNet = getState().controlNet.controlNets[controlNetId];
|
||||||
|
|
||||||
|
if (!controlNet.controlImage) {
|
||||||
|
moduleLog.error('Unable to process ControlNet image');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ControlNet one-off procressing graph is just the processor node, no edges.
|
||||||
|
// Also we need to grab the image.
|
||||||
|
const graph: Graph = {
|
||||||
|
nodes: {
|
||||||
|
[controlNet.processorNode.id]: {
|
||||||
|
...controlNet.processorNode,
|
||||||
|
is_intermediate: true,
|
||||||
|
image: pick(controlNet.controlImage, [
|
||||||
|
'image_name',
|
||||||
|
'image_origin',
|
||||||
|
]),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create a session to run the graph & wait til it's ready to invoke
|
||||||
|
const sessionCreatedAction = dispatch(sessionCreated({ graph }));
|
||||||
|
const [sessionCreatedFulfilledAction] = await take(
|
||||||
|
(action): action is ReturnType<typeof sessionCreated.fulfilled> =>
|
||||||
|
sessionCreated.fulfilled.match(action) &&
|
||||||
|
action.meta.requestId === sessionCreatedAction.requestId
|
||||||
|
);
|
||||||
|
|
||||||
|
const sessionId = sessionCreatedFulfilledAction.payload.id;
|
||||||
|
|
||||||
|
// Invoke the session & wait til it's complete
|
||||||
|
dispatch(sessionReadyToInvoke());
|
||||||
|
const [invocationCompleteAction] = await take(
|
||||||
|
(action): action is ReturnType<typeof socketInvocationComplete> =>
|
||||||
|
socketInvocationComplete.match(action) &&
|
||||||
|
action.payload.data.graph_execution_state_id === sessionId
|
||||||
|
);
|
||||||
|
|
||||||
|
// We still have to check the output type
|
||||||
|
if (isImageOutput(invocationCompleteAction.payload.data.result)) {
|
||||||
|
const { image_name } =
|
||||||
|
invocationCompleteAction.payload.data.result.image;
|
||||||
|
|
||||||
|
// Wait for the ImageDTO to be received
|
||||||
|
const [imageMetadataReceivedAction] = await take(
|
||||||
|
(
|
||||||
|
action
|
||||||
|
): action is ReturnType<typeof imageMetadataReceived.fulfilled> =>
|
||||||
|
imageMetadataReceived.fulfilled.match(action) &&
|
||||||
|
action.payload.image_name === image_name
|
||||||
|
);
|
||||||
|
const processedControlImage = imageMetadataReceivedAction.payload;
|
||||||
|
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { arg: action.payload, processedControlImage } },
|
||||||
|
'ControlNet image processed'
|
||||||
|
);
|
||||||
|
|
||||||
|
// Update the processed image in the store
|
||||||
|
dispatch(
|
||||||
|
controlNetProcessedImageChanged({
|
||||||
|
controlNetId,
|
||||||
|
processedControlImage,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -6,10 +6,13 @@ import { clamp } from 'lodash-es';
|
|||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
import {
|
import {
|
||||||
imageRemoved,
|
imageRemoved,
|
||||||
imagesAdapter,
|
|
||||||
selectImagesEntities,
|
selectImagesEntities,
|
||||||
selectImagesIds,
|
selectImagesIds,
|
||||||
} from 'features/gallery/store/imagesSlice';
|
} from 'features/gallery/store/imagesSlice';
|
||||||
|
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||||
|
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||||
|
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
||||||
|
|
||||||
@ -20,11 +23,7 @@ export const addRequestedImageDeletionListener = () => {
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: requestedImageDeletion,
|
actionCreator: requestedImageDeletion,
|
||||||
effect: (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const image = action.payload;
|
const { image, imageUsage } = action.payload;
|
||||||
if (!image) {
|
|
||||||
moduleLog.warn('No image provided');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const { image_name, image_origin } = image;
|
const { image_name, image_origin } = image;
|
||||||
|
|
||||||
@ -58,8 +57,28 @@ export const addRequestedImageDeletionListener = () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
||||||
|
|
||||||
|
if (imageUsage.isCanvasImage) {
|
||||||
|
dispatch(resetCanvas());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (imageUsage.isControlNetImage) {
|
||||||
|
dispatch(controlNetReset());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (imageUsage.isInitialImage) {
|
||||||
|
dispatch(clearInitialImage());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (imageUsage.isNodesImage) {
|
||||||
|
dispatch(nodeEditorReset());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preemptively remove from gallery
|
||||||
dispatch(imageRemoved(image_name));
|
dispatch(imageRemoved(image_name));
|
||||||
|
|
||||||
|
// Delete from server
|
||||||
dispatch(
|
dispatch(
|
||||||
imageDeleted({ imageName: image_name, imageOrigin: image_origin })
|
imageDeleted({ imageName: image_name, imageOrigin: image_origin })
|
||||||
);
|
);
|
||||||
@ -74,9 +93,7 @@ export const addImageDeletedPendingListener = () => {
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: imageDeleted.pending,
|
actionCreator: imageDeleted.pending,
|
||||||
effect: (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const { imageName, imageOrigin } = action.meta.arg;
|
//
|
||||||
// Preemptively remove the image from the gallery
|
|
||||||
imagesAdapter.removeOne(getState().images, imageName);
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { imageMetadataReceived } from 'services/thunks/image';
|
import { imageMetadataReceived, imageUpdated } from 'services/thunks/image';
|
||||||
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'image' });
|
const moduleLog = log.child({ namespace: 'image' });
|
||||||
@ -10,10 +10,29 @@ export const addImageMetadataReceivedFulfilledListener = () => {
|
|||||||
actionCreator: imageMetadataReceived.fulfilled,
|
actionCreator: imageMetadataReceived.fulfilled,
|
||||||
effect: (action, { getState, dispatch }) => {
|
effect: (action, { getState, dispatch }) => {
|
||||||
const image = action.payload;
|
const image = action.payload;
|
||||||
if (image.is_intermediate) {
|
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
if (
|
||||||
|
image.session_id === state.canvas.layerState.stagingArea.sessionId &&
|
||||||
|
state.canvas.shouldAutoSave
|
||||||
|
) {
|
||||||
|
dispatch(
|
||||||
|
imageUpdated({
|
||||||
|
imageName: image.image_name,
|
||||||
|
imageOrigin: image.image_origin,
|
||||||
|
requestBody: { is_intermediate: false },
|
||||||
|
})
|
||||||
|
);
|
||||||
|
} else if (image.is_intermediate) {
|
||||||
// No further actions needed for intermediate images
|
// No further actions needed for intermediate images
|
||||||
|
moduleLog.trace(
|
||||||
|
{ data: { image } },
|
||||||
|
'Image metadata received (intermediate), skipping'
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
moduleLog.debug({ data: { image } }, 'Image metadata received');
|
moduleLog.debug({ data: { image } }, 'Image metadata received');
|
||||||
dispatch(imageUpserted(image));
|
dispatch(imageUpserted(image));
|
||||||
},
|
},
|
||||||
|
@ -3,6 +3,8 @@ import { imageUploaded } from 'services/thunks/image';
|
|||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||||
|
import { SAVED_CANVAS_FILENAME } from './canvasSavedToGallery';
|
||||||
|
import { MERGED_CANVAS_FILENAME } from './canvasMerged';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'image' });
|
const moduleLog = log.child({ namespace: 'image' });
|
||||||
|
|
||||||
@ -19,9 +21,22 @@ export const addImageUploadedFulfilledListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const state = getState();
|
const originalFileName = action.meta.arg.formData.file.name;
|
||||||
|
|
||||||
dispatch(imageUpserted(image));
|
dispatch(imageUpserted(image));
|
||||||
|
|
||||||
|
if (originalFileName === SAVED_CANVAS_FILENAME) {
|
||||||
|
dispatch(
|
||||||
|
addToast({ title: 'Canvas Saved to Gallery', status: 'success' })
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (originalFileName === MERGED_CANVAS_FILENAME) {
|
||||||
|
dispatch(addToast({ title: 'Canvas Merged', status: 'success' }));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
|
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { imageUrlsReceived } from 'services/thunks/image';
|
import { imageUrlsReceived } from 'services/thunks/image';
|
||||||
import { imagesAdapter } from 'features/gallery/store/imagesSlice';
|
import { imageUpdatedOne } from 'features/gallery/store/imagesSlice';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'image' });
|
const moduleLog = log.child({ namespace: 'image' });
|
||||||
|
|
||||||
@ -14,13 +14,12 @@ export const addImageUrlsReceivedFulfilledListener = () => {
|
|||||||
|
|
||||||
const { image_name, image_url, thumbnail_url } = image;
|
const { image_name, image_url, thumbnail_url } = image;
|
||||||
|
|
||||||
imagesAdapter.updateOne(getState().images, {
|
dispatch(
|
||||||
id: image_name,
|
imageUpdatedOne({
|
||||||
changes: {
|
id: image_name,
|
||||||
image_url,
|
changes: { image_url, thumbnail_url },
|
||||||
thumbnail_url,
|
})
|
||||||
},
|
);
|
||||||
});
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -2,12 +2,10 @@ import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
|||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import {
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
initialImageSelected,
|
|
||||||
isImageDTO,
|
|
||||||
} from 'features/parameters/store/actions';
|
|
||||||
import { makeToast } from 'app/components/Toaster';
|
import { makeToast } from 'app/components/Toaster';
|
||||||
import { selectImagesById } from 'features/gallery/store/imagesSlice';
|
import { selectImagesById } from 'features/gallery/store/imagesSlice';
|
||||||
|
import { isImageDTO } from 'services/types/guards';
|
||||||
|
|
||||||
export const addInitialImageSelectedListener = () => {
|
export const addInitialImageSelectedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
|
@ -0,0 +1,93 @@
|
|||||||
|
import { socketConnected } from 'services/events/actions';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
|
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
|
import { nodesSelecter } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
import { forEach, uniqBy } from 'lodash-es';
|
||||||
|
import { imageUrlsReceived } from 'services/thunks/image';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { selectImagesEntities } from 'features/gallery/store/imagesSlice';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'images' });
|
||||||
|
|
||||||
|
const selectAllUsedImages = createSelector(
|
||||||
|
[
|
||||||
|
generationSelector,
|
||||||
|
canvasSelector,
|
||||||
|
nodesSelecter,
|
||||||
|
controlNetSelector,
|
||||||
|
selectImagesEntities,
|
||||||
|
],
|
||||||
|
(generation, canvas, nodes, controlNet, imageEntities) => {
|
||||||
|
const allUsedImages: ImageDTO[] = [];
|
||||||
|
|
||||||
|
if (generation.initialImage) {
|
||||||
|
allUsedImages.push(generation.initialImage);
|
||||||
|
}
|
||||||
|
|
||||||
|
canvas.layerState.objects.forEach((obj) => {
|
||||||
|
if (obj.kind === 'image') {
|
||||||
|
allUsedImages.push(obj.image);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
nodes.nodes.forEach((node) => {
|
||||||
|
forEach(node.data.inputs, (input) => {
|
||||||
|
if (input.type === 'image' && input.value) {
|
||||||
|
allUsedImages.push(input.value);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
forEach(controlNet.controlNets, (c) => {
|
||||||
|
if (c.controlImage) {
|
||||||
|
allUsedImages.push(c.controlImage);
|
||||||
|
}
|
||||||
|
if (c.processedControlImage) {
|
||||||
|
allUsedImages.push(c.processedControlImage);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
forEach(imageEntities, (image) => {
|
||||||
|
if (image) {
|
||||||
|
allUsedImages.push(image);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const uniqueImages = uniqBy(allUsedImages, 'image_name');
|
||||||
|
|
||||||
|
return uniqueImages;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
export const addUpdateImageUrlsOnConnectListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketConnected,
|
||||||
|
effect: async (action, { dispatch, getState, take }) => {
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
if (!state.config.shouldUpdateImagesOnConnect) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const allUsedImages = selectAllUsedImages(state);
|
||||||
|
|
||||||
|
moduleLog.trace(
|
||||||
|
{ data: allUsedImages },
|
||||||
|
`Fetching new image URLs for ${allUsedImages.length} images`
|
||||||
|
);
|
||||||
|
|
||||||
|
allUsedImages.forEach(({ image_name, image_origin }) => {
|
||||||
|
dispatch(
|
||||||
|
imageUrlsReceived({
|
||||||
|
imageName: image_name,
|
||||||
|
imageOrigin: image_origin,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -13,6 +13,7 @@ import galleryReducer from 'features/gallery/store/gallerySlice';
|
|||||||
import imagesReducer from 'features/gallery/store/imagesSlice';
|
import imagesReducer from 'features/gallery/store/imagesSlice';
|
||||||
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
||||||
import generationReducer from 'features/parameters/store/generationSlice';
|
import generationReducer from 'features/parameters/store/generationSlice';
|
||||||
|
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
||||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||||
import systemReducer from 'features/system/store/systemSlice';
|
import systemReducer from 'features/system/store/systemSlice';
|
||||||
// import sessionReducer from 'features/system/store/sessionSlice';
|
// import sessionReducer from 'features/system/store/sessionSlice';
|
||||||
@ -45,6 +46,7 @@ const allReducers = {
|
|||||||
ui: uiReducer,
|
ui: uiReducer,
|
||||||
hotkeys: hotkeysReducer,
|
hotkeys: hotkeysReducer,
|
||||||
images: imagesReducer,
|
images: imagesReducer,
|
||||||
|
controlNet: controlNetReducer,
|
||||||
// session: sessionReducer,
|
// session: sessionReducer,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -62,6 +64,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
|||||||
'postprocessing',
|
'postprocessing',
|
||||||
'system',
|
'system',
|
||||||
'ui',
|
'ui',
|
||||||
|
'controlNet',
|
||||||
// 'hotkeys',
|
// 'hotkeys',
|
||||||
// 'config',
|
// 'config',
|
||||||
];
|
];
|
||||||
|
@ -95,6 +95,7 @@ export type AppFeature =
|
|||||||
* A disable-able Stable Diffusion feature
|
* A disable-able Stable Diffusion feature
|
||||||
*/
|
*/
|
||||||
export type SDFeature =
|
export type SDFeature =
|
||||||
|
| 'controlNet'
|
||||||
| 'noise'
|
| 'noise'
|
||||||
| 'variation'
|
| 'variation'
|
||||||
| 'symmetry'
|
| 'symmetry'
|
||||||
@ -107,12 +108,9 @@ export type SDFeature =
|
|||||||
*/
|
*/
|
||||||
export type AppConfig = {
|
export type AppConfig = {
|
||||||
/**
|
/**
|
||||||
* Whether or not URLs should be transformed to use a different host
|
* Whether or not we should update image urls when image loading errors
|
||||||
*/
|
|
||||||
shouldTransformUrls: boolean;
|
|
||||||
/**
|
|
||||||
* Whether or not we need to re-fetch images
|
|
||||||
*/
|
*/
|
||||||
|
shouldUpdateImagesOnConnect: boolean;
|
||||||
disabledTabs: InvokeTabName[];
|
disabledTabs: InvokeTabName[];
|
||||||
disabledFeatures: AppFeature[];
|
disabledFeatures: AppFeature[];
|
||||||
disabledSDFeatures: SDFeature[];
|
disabledSDFeatures: SDFeature[];
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
import { Checkbox, CheckboxProps } from '@chakra-ui/react';
|
|
||||||
import { memo, ReactNode } from 'react';
|
|
||||||
|
|
||||||
type IAICheckboxProps = CheckboxProps & {
|
|
||||||
label: string | ReactNode;
|
|
||||||
};
|
|
||||||
|
|
||||||
const IAICheckbox = (props: IAICheckboxProps) => {
|
|
||||||
const { label, ...rest } = props;
|
|
||||||
return (
|
|
||||||
<Checkbox colorScheme="accent" {...rest}>
|
|
||||||
{label}
|
|
||||||
</Checkbox>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(IAICheckbox);
|
|
@ -49,7 +49,7 @@ const IAICollapse = (props: IAIToggleCollapseProps) => {
|
|||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
<Collapse in={isOpen} animateOpacity>
|
<Collapse in={isOpen} animateOpacity style={{ overflow: 'unset' }}>
|
||||||
<Box sx={{ p: 4, borderBottomRadius: 'base', bg: 'base.800' }}>
|
<Box sx={{ p: 4, borderBottomRadius: 'base', bg: 'base.800' }}>
|
||||||
{children}
|
{children}
|
||||||
</Box>
|
</Box>
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { CheckIcon } from '@chakra-ui/icons';
|
import { CheckIcon, ChevronUpIcon } from '@chakra-ui/icons';
|
||||||
import {
|
import {
|
||||||
Box,
|
Box,
|
||||||
Flex,
|
Flex,
|
||||||
@ -10,7 +10,6 @@ import {
|
|||||||
GridItem,
|
GridItem,
|
||||||
List,
|
List,
|
||||||
ListItem,
|
ListItem,
|
||||||
Select,
|
|
||||||
Text,
|
Text,
|
||||||
Tooltip,
|
Tooltip,
|
||||||
TooltipProps,
|
TooltipProps,
|
||||||
@ -19,7 +18,8 @@ import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom';
|
|||||||
import { useSelect } from 'downshift';
|
import { useSelect } from 'downshift';
|
||||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||||
|
|
||||||
import { memo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
|
import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles';
|
||||||
|
|
||||||
export type ItemTooltips = { [key: string]: string };
|
export type ItemTooltips = { [key: string]: string };
|
||||||
|
|
||||||
@ -34,6 +34,7 @@ type IAICustomSelectProps = {
|
|||||||
buttonProps?: FlexProps;
|
buttonProps?: FlexProps;
|
||||||
tooltip?: string;
|
tooltip?: string;
|
||||||
tooltipProps?: Omit<TooltipProps, 'children'>;
|
tooltipProps?: Omit<TooltipProps, 'children'>;
|
||||||
|
ellipsisPosition?: 'start' | 'end';
|
||||||
};
|
};
|
||||||
|
|
||||||
const IAICustomSelect = (props: IAICustomSelectProps) => {
|
const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||||
@ -48,6 +49,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
tooltip,
|
tooltip,
|
||||||
buttonProps,
|
buttonProps,
|
||||||
tooltipProps,
|
tooltipProps,
|
||||||
|
ellipsisPosition = 'end',
|
||||||
} = props;
|
} = props;
|
||||||
|
|
||||||
const {
|
const {
|
||||||
@ -69,6 +71,14 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
middleware: [offset(4), shift({ crossAxis: true, padding: 8 })],
|
middleware: [offset(4), shift({ crossAxis: true, padding: 8 })],
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const labelTextDirection = useMemo(() => {
|
||||||
|
if (ellipsisPosition === 'start') {
|
||||||
|
return document.dir === 'rtl' ? 'ltr' : 'rtl';
|
||||||
|
}
|
||||||
|
|
||||||
|
return document.dir;
|
||||||
|
}, [ellipsisPosition]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<FormControl sx={{ w: 'full' }} {...formControlProps}>
|
<FormControl sx={{ w: 'full' }} {...formControlProps}>
|
||||||
{label && (
|
{label && (
|
||||||
@ -82,20 +92,44 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
</FormLabel>
|
</FormLabel>
|
||||||
)}
|
)}
|
||||||
<Tooltip label={tooltip} {...tooltipProps}>
|
<Tooltip label={tooltip} {...tooltipProps}>
|
||||||
<Select
|
<Flex
|
||||||
{...getToggleButtonProps({ ref: refs.setReference })}
|
{...getToggleButtonProps({ ref: refs.setReference })}
|
||||||
{...buttonProps}
|
{...buttonProps}
|
||||||
as={Flex}
|
|
||||||
sx={{
|
sx={{
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
userSelect: 'none',
|
userSelect: 'none',
|
||||||
cursor: 'pointer',
|
cursor: 'pointer',
|
||||||
|
overflow: 'hidden',
|
||||||
|
width: 'full',
|
||||||
|
py: 1,
|
||||||
|
px: 2,
|
||||||
|
gap: 2,
|
||||||
|
justifyContent: 'space-between',
|
||||||
|
...getInputOutlineStyles(),
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Text sx={{ fontSize: 'sm', fontWeight: 500, color: 'base.100' }}>
|
<Text
|
||||||
|
sx={{
|
||||||
|
fontSize: 'sm',
|
||||||
|
fontWeight: 500,
|
||||||
|
color: 'base.100',
|
||||||
|
whiteSpace: 'nowrap',
|
||||||
|
overflow: 'hidden',
|
||||||
|
textOverflow: 'ellipsis',
|
||||||
|
direction: labelTextDirection,
|
||||||
|
}}
|
||||||
|
>
|
||||||
{selectedItem}
|
{selectedItem}
|
||||||
</Text>
|
</Text>
|
||||||
</Select>
|
<ChevronUpIcon
|
||||||
|
sx={{
|
||||||
|
color: 'base.300',
|
||||||
|
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: 'normal',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
<Box {...getMenuProps()}>
|
<Box {...getMenuProps()}>
|
||||||
{isOpen && (
|
{isOpen && (
|
||||||
@ -104,11 +138,10 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
ref={refs.setFloating}
|
ref={refs.setFloating}
|
||||||
sx={{
|
sx={{
|
||||||
...floatingStyles,
|
...floatingStyles,
|
||||||
width: 'max-content',
|
|
||||||
top: 0,
|
top: 0,
|
||||||
left: 0,
|
insetInlineStart: 0,
|
||||||
flexDirection: 'column',
|
flexDirection: 'column',
|
||||||
zIndex: 1,
|
zIndex: 2,
|
||||||
bg: 'base.800',
|
bg: 'base.800',
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
border: '1px',
|
border: '1px',
|
||||||
@ -118,61 +151,72 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
px: 0,
|
px: 0,
|
||||||
h: 'fit-content',
|
h: 'fit-content',
|
||||||
maxH: 64,
|
maxH: 64,
|
||||||
|
minW: 48,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<OverlayScrollbarsComponent>
|
<OverlayScrollbarsComponent>
|
||||||
{items.map((item, index) => (
|
{items.map((item, index) => {
|
||||||
<Tooltip
|
const isSelected = selectedItem === item;
|
||||||
isDisabled={!itemTooltips}
|
const isHighlighted = highlightedIndex === index;
|
||||||
key={`${item}${index}`}
|
const fontWeight = isSelected ? 700 : 500;
|
||||||
label={itemTooltips?.[item]}
|
const bg = isHighlighted
|
||||||
hasArrow
|
? 'base.700'
|
||||||
placement="right"
|
: isSelected
|
||||||
>
|
? 'base.750'
|
||||||
<ListItem
|
: undefined;
|
||||||
sx={{
|
return (
|
||||||
bg: highlightedIndex === index ? 'base.700' : undefined,
|
<Tooltip
|
||||||
py: 1,
|
isDisabled={!itemTooltips}
|
||||||
paddingInlineStart: 3,
|
|
||||||
paddingInlineEnd: 6,
|
|
||||||
cursor: 'pointer',
|
|
||||||
transitionProperty: 'common',
|
|
||||||
transitionDuration: '0.15s',
|
|
||||||
}}
|
|
||||||
key={`${item}${index}`}
|
key={`${item}${index}`}
|
||||||
{...getItemProps({ item, index })}
|
label={itemTooltips?.[item]}
|
||||||
|
hasArrow
|
||||||
|
placement="right"
|
||||||
>
|
>
|
||||||
{withCheckIcon ? (
|
<ListItem
|
||||||
<Grid gridTemplateColumns="1.25rem auto">
|
sx={{
|
||||||
<GridItem>
|
bg,
|
||||||
{selectedItem === item && <CheckIcon boxSize={2} />}
|
py: 1,
|
||||||
</GridItem>
|
paddingInlineStart: 3,
|
||||||
<GridItem>
|
paddingInlineEnd: 6,
|
||||||
<Text
|
cursor: 'pointer',
|
||||||
sx={{
|
transitionProperty: 'common',
|
||||||
fontSize: 'sm',
|
transitionDuration: '0.15s',
|
||||||
color: 'base.100',
|
}}
|
||||||
fontWeight: 500,
|
key={`${item}${index}`}
|
||||||
}}
|
{...getItemProps({ item, index })}
|
||||||
>
|
>
|
||||||
{item}
|
{withCheckIcon ? (
|
||||||
</Text>
|
<Grid gridTemplateColumns="1.25rem auto">
|
||||||
</GridItem>
|
<GridItem>
|
||||||
</Grid>
|
{isSelected && <CheckIcon boxSize={2} />}
|
||||||
) : (
|
</GridItem>
|
||||||
<Text
|
<GridItem>
|
||||||
sx={{
|
<Text
|
||||||
fontSize: 'sm',
|
sx={{
|
||||||
color: 'base.100',
|
fontSize: 'sm',
|
||||||
fontWeight: 500,
|
color: 'base.100',
|
||||||
}}
|
fontWeight,
|
||||||
>
|
}}
|
||||||
{item}
|
>
|
||||||
</Text>
|
{item}
|
||||||
)}
|
</Text>
|
||||||
</ListItem>
|
</GridItem>
|
||||||
</Tooltip>
|
</Grid>
|
||||||
))}
|
) : (
|
||||||
|
<Text
|
||||||
|
sx={{
|
||||||
|
fontSize: 'sm',
|
||||||
|
color: 'base.50',
|
||||||
|
fontWeight,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{item}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</ListItem>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
})}
|
||||||
</OverlayScrollbarsComponent>
|
</OverlayScrollbarsComponent>
|
||||||
</List>
|
</List>
|
||||||
)}
|
)}
|
||||||
|
167
invokeai/frontend/web/src/common/components/IAIDndImage.tsx
Normal file
167
invokeai/frontend/web/src/common/components/IAIDndImage.tsx
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
import { Box, Flex, Icon, IconButtonProps, Image } from '@chakra-ui/react';
|
||||||
|
import { useDraggable, useDroppable } from '@dnd-kit/core';
|
||||||
|
import { useCombinedRefs } from '@dnd-kit/utilities';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import { IAIImageFallback } from 'common/components/IAIImageFallback';
|
||||||
|
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
|
||||||
|
import { AnimatePresence } from 'framer-motion';
|
||||||
|
import { ReactElement, SyntheticEvent } from 'react';
|
||||||
|
import { memo, useRef } from 'react';
|
||||||
|
import { FaImage, FaTimes } from 'react-icons/fa';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
import IAIDropOverlay from './IAIDropOverlay';
|
||||||
|
|
||||||
|
type IAIDndImageProps = {
|
||||||
|
image: ImageDTO | null | undefined;
|
||||||
|
onDrop: (droppedImage: ImageDTO) => void;
|
||||||
|
onReset?: () => void;
|
||||||
|
onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
|
||||||
|
onLoad?: (event: SyntheticEvent<HTMLImageElement>) => void;
|
||||||
|
resetIconSize?: IconButtonProps['size'];
|
||||||
|
withResetIcon?: boolean;
|
||||||
|
withMetadataOverlay?: boolean;
|
||||||
|
isDragDisabled?: boolean;
|
||||||
|
isDropDisabled?: boolean;
|
||||||
|
fallback?: ReactElement;
|
||||||
|
payloadImage?: ImageDTO | null | undefined;
|
||||||
|
minSize?: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
const IAIDndImage = (props: IAIDndImageProps) => {
|
||||||
|
const {
|
||||||
|
image,
|
||||||
|
onDrop,
|
||||||
|
onReset,
|
||||||
|
onError,
|
||||||
|
resetIconSize = 'md',
|
||||||
|
withResetIcon = false,
|
||||||
|
withMetadataOverlay = false,
|
||||||
|
isDropDisabled = false,
|
||||||
|
isDragDisabled = false,
|
||||||
|
fallback = <IAIImageFallback />,
|
||||||
|
payloadImage,
|
||||||
|
minSize = 24,
|
||||||
|
} = props;
|
||||||
|
const dndId = useRef(uuidv4());
|
||||||
|
const {
|
||||||
|
isOver,
|
||||||
|
setNodeRef: setDroppableRef,
|
||||||
|
active,
|
||||||
|
} = useDroppable({
|
||||||
|
id: dndId.current,
|
||||||
|
disabled: isDropDisabled,
|
||||||
|
data: {
|
||||||
|
handleDrop: onDrop,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const {
|
||||||
|
attributes,
|
||||||
|
listeners,
|
||||||
|
setNodeRef: setDraggableRef,
|
||||||
|
} = useDraggable({
|
||||||
|
id: dndId.current,
|
||||||
|
data: {
|
||||||
|
image: payloadImage ? payloadImage : image,
|
||||||
|
},
|
||||||
|
disabled: isDragDisabled,
|
||||||
|
});
|
||||||
|
|
||||||
|
const setNodeRef = useCombinedRefs(setDroppableRef, setDraggableRef);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
width: 'full',
|
||||||
|
height: 'full',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
position: 'relative',
|
||||||
|
minW: minSize,
|
||||||
|
minH: minSize,
|
||||||
|
userSelect: 'none',
|
||||||
|
cursor: 'grab',
|
||||||
|
}}
|
||||||
|
{...attributes}
|
||||||
|
{...listeners}
|
||||||
|
ref={setNodeRef}
|
||||||
|
>
|
||||||
|
{image && (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
position: 'relative',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Image
|
||||||
|
src={image.image_url}
|
||||||
|
fallbackStrategy="beforeLoadOrError"
|
||||||
|
fallback={fallback}
|
||||||
|
onError={onError}
|
||||||
|
objectFit="contain"
|
||||||
|
draggable={false}
|
||||||
|
sx={{
|
||||||
|
maxW: 'full',
|
||||||
|
maxH: 'full',
|
||||||
|
borderRadius: 'base',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
{withMetadataOverlay && <ImageMetadataOverlay image={image} />}
|
||||||
|
{onReset && withResetIcon && (
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 0,
|
||||||
|
right: 0,
|
||||||
|
p: 2,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<IAIIconButton
|
||||||
|
size={resetIconSize}
|
||||||
|
tooltip="Reset Image"
|
||||||
|
aria-label="Reset Image"
|
||||||
|
icon={<FaTimes />}
|
||||||
|
onClick={onReset}
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
|
)}
|
||||||
|
<AnimatePresence>
|
||||||
|
{active && <IAIDropOverlay isOver={isOver} />}
|
||||||
|
</AnimatePresence>
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
{!image && (
|
||||||
|
<>
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
minH: minSize,
|
||||||
|
bg: 'base.850',
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
borderRadius: 'base',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Icon
|
||||||
|
as={FaImage}
|
||||||
|
sx={{
|
||||||
|
boxSize: 24,
|
||||||
|
color: 'base.500',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
<AnimatePresence>
|
||||||
|
{active && <IAIDropOverlay isOver={isOver} />}
|
||||||
|
</AnimatePresence>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(IAIDndImage);
|
@ -0,0 +1,91 @@
|
|||||||
|
import { Flex, Text } from '@chakra-ui/react';
|
||||||
|
import { motion } from 'framer-motion';
|
||||||
|
import { memo, useRef } from 'react';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
isOver: boolean;
|
||||||
|
label?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const IAIDropOverlay = (props: Props) => {
|
||||||
|
const { isOver, label = 'Drop' } = props;
|
||||||
|
const motionId = useRef(uuidv4());
|
||||||
|
return (
|
||||||
|
<motion.div
|
||||||
|
key={motionId.current}
|
||||||
|
initial={{
|
||||||
|
opacity: 0,
|
||||||
|
}}
|
||||||
|
animate={{
|
||||||
|
opacity: 1,
|
||||||
|
transition: { duration: 0.1 },
|
||||||
|
}}
|
||||||
|
exit={{
|
||||||
|
opacity: 0,
|
||||||
|
transition: { duration: 0.1 },
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 0,
|
||||||
|
left: 0,
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 0,
|
||||||
|
left: 0,
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
bg: 'base.900',
|
||||||
|
opacity: 0.7,
|
||||||
|
borderRadius: 'base',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: '0.1s',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 0,
|
||||||
|
left: 0,
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
opacity: 1,
|
||||||
|
borderWidth: 2,
|
||||||
|
borderColor: isOver ? 'base.200' : 'base.500',
|
||||||
|
borderRadius: 'base',
|
||||||
|
borderStyle: 'dashed',
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: '0.1s',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Text
|
||||||
|
sx={{
|
||||||
|
fontSize: '2xl',
|
||||||
|
fontWeight: 600,
|
||||||
|
transform: isOver ? 'scale(1.1)' : 'scale(1)',
|
||||||
|
color: isOver ? 'base.100' : 'base.500',
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: '0.1s',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{label}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
</motion.div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(IAIDropOverlay);
|
@ -0,0 +1,25 @@
|
|||||||
|
import {
|
||||||
|
Checkbox,
|
||||||
|
CheckboxProps,
|
||||||
|
FormControl,
|
||||||
|
FormControlProps,
|
||||||
|
FormLabel,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { memo, ReactNode } from 'react';
|
||||||
|
|
||||||
|
type IAIFullCheckboxProps = CheckboxProps & {
|
||||||
|
label: string | ReactNode;
|
||||||
|
formControlProps?: FormControlProps;
|
||||||
|
};
|
||||||
|
|
||||||
|
const IAIFullCheckbox = (props: IAIFullCheckboxProps) => {
|
||||||
|
const { label, formControlProps, ...rest } = props;
|
||||||
|
return (
|
||||||
|
<FormControl {...formControlProps}>
|
||||||
|
<FormLabel>{label}</FormLabel>
|
||||||
|
<Checkbox colorScheme="accent" {...rest} />
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(IAIFullCheckbox);
|
@ -0,0 +1,27 @@
|
|||||||
|
import { Flex, FlexProps, Spinner, SpinnerProps } from '@chakra-ui/react';
|
||||||
|
|
||||||
|
type Props = FlexProps & {
|
||||||
|
spinnerProps?: SpinnerProps;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const IAIImageFallback = (props: Props) => {
|
||||||
|
const { spinnerProps, ...rest } = props;
|
||||||
|
const { sx, ...restFlexProps } = rest;
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
bg: 'base.900',
|
||||||
|
opacity: 0.7,
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
borderRadius: 'base',
|
||||||
|
...sx,
|
||||||
|
}}
|
||||||
|
{...restFlexProps}
|
||||||
|
>
|
||||||
|
<Spinner size="xl" {...spinnerProps} />
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
@ -0,0 +1,19 @@
|
|||||||
|
import { Checkbox, CheckboxProps, Text } from '@chakra-ui/react';
|
||||||
|
import { memo, ReactElement } from 'react';
|
||||||
|
|
||||||
|
type IAISimpleCheckboxProps = CheckboxProps & {
|
||||||
|
label: string | ReactElement;
|
||||||
|
};
|
||||||
|
|
||||||
|
const IAISimpleCheckbox = (props: IAISimpleCheckboxProps) => {
|
||||||
|
const { label, ...rest } = props;
|
||||||
|
return (
|
||||||
|
<Checkbox colorScheme="accent" {...rest}>
|
||||||
|
<Text color="base.200" fontSize="md">
|
||||||
|
{label}
|
||||||
|
</Text>
|
||||||
|
</Checkbox>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(IAISimpleCheckbox);
|
@ -40,7 +40,7 @@ import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
|
|||||||
import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
|
import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
|
||||||
|
|
||||||
export type IAIFullSliderProps = {
|
export type IAIFullSliderProps = {
|
||||||
label: string;
|
label?: string;
|
||||||
value: number;
|
value: number;
|
||||||
min?: number;
|
min?: number;
|
||||||
max?: number;
|
max?: number;
|
||||||
@ -178,9 +178,11 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
isDisabled={isDisabled}
|
isDisabled={isDisabled}
|
||||||
{...sliderFormControlProps}
|
{...sliderFormControlProps}
|
||||||
>
|
>
|
||||||
<FormLabel {...sliderFormLabelProps} mb={-1}>
|
{label && (
|
||||||
{label}
|
<FormLabel {...sliderFormLabelProps} mb={-1}>
|
||||||
</FormLabel>
|
{label}
|
||||||
|
</FormLabel>
|
||||||
|
)}
|
||||||
|
|
||||||
<HStack w="100%" gap={2} alignItems="center">
|
<HStack w="100%" gap={2} alignItems="center">
|
||||||
<Slider
|
<Slider
|
||||||
@ -203,6 +205,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
sx={{
|
sx={{
|
||||||
insetInlineStart: '0 !important',
|
insetInlineStart: '0 !important',
|
||||||
insetInlineEnd: 'unset !important',
|
insetInlineEnd: 'unset !important',
|
||||||
|
mt: 1.5,
|
||||||
}}
|
}}
|
||||||
{...sliderMarkProps}
|
{...sliderMarkProps}
|
||||||
>
|
>
|
||||||
@ -213,6 +216,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
sx={{
|
sx={{
|
||||||
insetInlineStart: 'unset !important',
|
insetInlineStart: 'unset !important',
|
||||||
insetInlineEnd: '0 !important',
|
insetInlineEnd: '0 !important',
|
||||||
|
mt: 1.5,
|
||||||
}}
|
}}
|
||||||
{...sliderMarkProps}
|
{...sliderMarkProps}
|
||||||
>
|
>
|
||||||
|
@ -5,6 +5,7 @@ import {
|
|||||||
FormLabelProps,
|
FormLabelProps,
|
||||||
Switch,
|
Switch,
|
||||||
SwitchProps,
|
SwitchProps,
|
||||||
|
Tooltip,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
|
||||||
@ -13,6 +14,7 @@ interface Props extends SwitchProps {
|
|||||||
width?: string | number;
|
width?: string | number;
|
||||||
formControlProps?: FormControlProps;
|
formControlProps?: FormControlProps;
|
||||||
formLabelProps?: FormLabelProps;
|
formLabelProps?: FormLabelProps;
|
||||||
|
tooltip?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -25,22 +27,27 @@ const IAISwitch = (props: Props) => {
|
|||||||
width = 'auto',
|
width = 'auto',
|
||||||
formControlProps,
|
formControlProps,
|
||||||
formLabelProps,
|
formLabelProps,
|
||||||
|
tooltip,
|
||||||
...rest
|
...rest
|
||||||
} = props;
|
} = props;
|
||||||
return (
|
return (
|
||||||
<FormControl
|
<Tooltip label={tooltip} hasArrow placement="top" isDisabled={!tooltip}>
|
||||||
isDisabled={isDisabled}
|
<FormControl
|
||||||
width={width}
|
isDisabled={isDisabled}
|
||||||
display="flex"
|
width={width}
|
||||||
gap={4}
|
display="flex"
|
||||||
alignItems="center"
|
gap={4}
|
||||||
{...formControlProps}
|
alignItems="center"
|
||||||
>
|
{...formControlProps}
|
||||||
<FormLabel my={1} flexGrow={1} {...formLabelProps}>
|
>
|
||||||
{label}
|
{label && (
|
||||||
</FormLabel>
|
<FormLabel my={1} flexGrow={1} {...formLabelProps}>
|
||||||
<Switch {...rest} />
|
{label}
|
||||||
</FormControl>
|
</FormLabel>
|
||||||
|
)}
|
||||||
|
<Switch {...rest} />
|
||||||
|
</FormControl>
|
||||||
|
</Tooltip>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { Badge, Flex } from '@chakra-ui/react';
|
import { Badge, Flex } from '@chakra-ui/react';
|
||||||
import { isNumber, isString } from 'lodash-es';
|
import { isString } from 'lodash-es';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
|
|
||||||
@ -8,14 +8,6 @@ type ImageMetadataOverlayProps = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => {
|
const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => {
|
||||||
const dimensions = useMemo(() => {
|
|
||||||
if (!isNumber(image.metadata?.width) || isNumber(!image.metadata?.height)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
return `${image.metadata?.width} × ${image.metadata?.height}`;
|
|
||||||
}, [image.metadata]);
|
|
||||||
|
|
||||||
const model = useMemo(() => {
|
const model = useMemo(() => {
|
||||||
if (!isString(image.metadata?.model)) {
|
if (!isString(image.metadata?.model)) {
|
||||||
return;
|
return;
|
||||||
@ -31,17 +23,15 @@ const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => {
|
|||||||
flexDirection: 'column',
|
flexDirection: 'column',
|
||||||
position: 'absolute',
|
position: 'absolute',
|
||||||
top: 0,
|
top: 0,
|
||||||
right: 0,
|
insetInlineStart: 0,
|
||||||
p: 2,
|
p: 2,
|
||||||
alignItems: 'flex-end',
|
alignItems: 'flex-start',
|
||||||
gap: 2,
|
gap: 2,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{dimensions && (
|
<Badge variant="solid" colorScheme="base">
|
||||||
<Badge variant="solid" colorScheme="base">
|
{image.width} × {image.height}
|
||||||
{dimensions}
|
</Badge>
|
||||||
</Badge>
|
|
||||||
)}
|
|
||||||
{model && (
|
{model && (
|
||||||
<Badge variant="solid" colorScheme="base">
|
<Badge variant="solid" colorScheme="base">
|
||||||
{model}
|
{model}
|
||||||
|
@ -1,42 +0,0 @@
|
|||||||
import { ButtonGroup, Flex, Spacer, Text } from '@chakra-ui/react';
|
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
|
||||||
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { FaUndo, FaUpload } from 'react-icons/fa';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import { useCallback } from 'react';
|
|
||||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
|
||||||
import useImageUploader from 'common/hooks/useImageUploader';
|
|
||||||
|
|
||||||
const InitialImageButtons = () => {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const { openUploader } = useImageUploader();
|
|
||||||
|
|
||||||
const handleResetInitialImage = useCallback(() => {
|
|
||||||
dispatch(clearInitialImage());
|
|
||||||
}, [dispatch]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex w="full" alignItems="center">
|
|
||||||
<Text size="sm" fontWeight={500} color="base.300">
|
|
||||||
{t('parameters.initialImage')}
|
|
||||||
</Text>
|
|
||||||
<Spacer />
|
|
||||||
<ButtonGroup>
|
|
||||||
<IAIIconButton
|
|
||||||
icon={<FaUndo />}
|
|
||||||
aria-label={t('accessibility.reset')}
|
|
||||||
onClick={handleResetInitialImage}
|
|
||||||
/>
|
|
||||||
<IAIIconButton
|
|
||||||
icon={<FaUpload />}
|
|
||||||
onClick={openUploader}
|
|
||||||
aria-label={t('common.upload')}
|
|
||||||
/>
|
|
||||||
</ButtonGroup>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default InitialImageButtons;
|
|
@ -1,12 +1,12 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { isEqual } from 'lodash-es';
|
|
||||||
|
|
||||||
export const readinessSelector = createSelector(
|
const readinessSelector = createSelector(
|
||||||
[generationSelector, systemSelector, activeTabNameSelector],
|
[generationSelector, systemSelector, activeTabNameSelector],
|
||||||
(generation, system, activeTabName) => {
|
(generation, system, activeTabName) => {
|
||||||
const {
|
const {
|
||||||
@ -60,3 +60,8 @@ export const readinessSelector = createSelector(
|
|||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
|
export const useIsReadyToInvoke = () => {
|
||||||
|
const { isReady } = useAppSelector(readinessSelector);
|
||||||
|
return isReady;
|
||||||
|
};
|
@ -1,34 +0,0 @@
|
|||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { useCallback } from 'react';
|
|
||||||
import { OpenAPI } from 'services/api';
|
|
||||||
|
|
||||||
export const getUrlAlt = (url: string, shouldTransformUrls: boolean) => {
|
|
||||||
if (OpenAPI.BASE && shouldTransformUrls) {
|
|
||||||
return [OpenAPI.BASE, url].join('/');
|
|
||||||
}
|
|
||||||
|
|
||||||
return url;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const useGetUrl = () => {
|
|
||||||
const shouldTransformUrls = useAppSelector(
|
|
||||||
(state: RootState) => state.config.shouldTransformUrls
|
|
||||||
);
|
|
||||||
|
|
||||||
const getUrl = useCallback(
|
|
||||||
(url?: string) => {
|
|
||||||
if (OpenAPI.BASE && shouldTransformUrls) {
|
|
||||||
return [OpenAPI.BASE, url].join('/');
|
|
||||||
}
|
|
||||||
|
|
||||||
return url;
|
|
||||||
},
|
|
||||||
[shouldTransformUrls]
|
|
||||||
);
|
|
||||||
|
|
||||||
return {
|
|
||||||
shouldTransformUrls,
|
|
||||||
getUrl,
|
|
||||||
};
|
|
||||||
};
|
|
@ -1,6 +1,5 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useGetUrl } from 'common/util/getUrl';
|
|
||||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
import { rgbaColorToString } from 'features/canvas/util/colorToString';
|
import { rgbaColorToString } from 'features/canvas/util/colorToString';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
@ -33,7 +32,6 @@ const selector = createSelector(
|
|||||||
|
|
||||||
const IAICanvasObjectRenderer = () => {
|
const IAICanvasObjectRenderer = () => {
|
||||||
const { objects } = useAppSelector(selector);
|
const { objects } = useAppSelector(selector);
|
||||||
const { getUrl } = useGetUrl();
|
|
||||||
|
|
||||||
if (!objects) return null;
|
if (!objects) return null;
|
||||||
|
|
||||||
@ -46,7 +44,7 @@ const IAICanvasObjectRenderer = () => {
|
|||||||
key={i}
|
key={i}
|
||||||
x={obj.x}
|
x={obj.x}
|
||||||
y={obj.y}
|
y={obj.y}
|
||||||
url={getUrl(obj.image.image_url)}
|
url={obj.image.image_url}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
} else if (isCanvasBaseLine(obj)) {
|
} else if (isCanvasBaseLine(obj)) {
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useGetUrl } from 'common/util/getUrl';
|
|
||||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
import { GroupConfig } from 'konva/lib/Group';
|
import { GroupConfig } from 'konva/lib/Group';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
@ -56,13 +55,12 @@ const IAICanvasStagingArea = (props: Props) => {
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
} = useAppSelector(selector);
|
} = useAppSelector(selector);
|
||||||
const { getUrl } = useGetUrl();
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Group {...rest}>
|
<Group {...rest}>
|
||||||
{shouldShowStagingImage && currentStagingAreaImage && (
|
{shouldShowStagingImage && currentStagingAreaImage && (
|
||||||
<IAICanvasImage
|
<IAICanvasImage
|
||||||
url={getUrl(currentStagingAreaImage.image.image_url) ?? ''}
|
url={currentStagingAreaImage.image.image_url}
|
||||||
x={x}
|
x={x}
|
||||||
y={y}
|
y={y}
|
||||||
/>
|
/>
|
||||||
|
@ -2,7 +2,7 @@ import { ButtonGroup, Flex } from '@chakra-ui/react';
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAICheckbox from 'common/components/IAICheckbox';
|
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||||
import IAIColorPicker from 'common/components/IAIColorPicker';
|
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import IAIPopover from 'common/components/IAIPopover';
|
import IAIPopover from 'common/components/IAIPopover';
|
||||||
@ -117,12 +117,12 @@ const IAICanvasMaskOptions = () => {
|
|||||||
}
|
}
|
||||||
>
|
>
|
||||||
<Flex direction="column" gap={2}>
|
<Flex direction="column" gap={2}>
|
||||||
<IAICheckbox
|
<IAISimpleCheckbox
|
||||||
label={`${t('unifiedCanvas.enableMask')} (H)`}
|
label={`${t('unifiedCanvas.enableMask')} (H)`}
|
||||||
isChecked={isMaskEnabled}
|
isChecked={isMaskEnabled}
|
||||||
onChange={handleToggleEnableMask}
|
onChange={handleToggleEnableMask}
|
||||||
/>
|
/>
|
||||||
<IAICheckbox
|
<IAISimpleCheckbox
|
||||||
label={t('unifiedCanvas.preserveMaskedArea')}
|
label={t('unifiedCanvas.preserveMaskedArea')}
|
||||||
isChecked={shouldPreserveMaskedArea}
|
isChecked={shouldPreserveMaskedArea}
|
||||||
onChange={(e) =>
|
onChange={(e) =>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { Flex } from '@chakra-ui/react';
|
import { Flex } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAICheckbox from 'common/components/IAICheckbox';
|
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import IAIPopover from 'common/components/IAIPopover';
|
import IAIPopover from 'common/components/IAIPopover';
|
||||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
@ -102,50 +102,50 @@ const IAICanvasSettingsButtonPopover = () => {
|
|||||||
}
|
}
|
||||||
>
|
>
|
||||||
<Flex direction="column" gap={2}>
|
<Flex direction="column" gap={2}>
|
||||||
<IAICheckbox
|
<IAISimpleCheckbox
|
||||||
label={t('unifiedCanvas.showIntermediates')}
|
label={t('unifiedCanvas.showIntermediates')}
|
||||||
isChecked={shouldShowIntermediates}
|
isChecked={shouldShowIntermediates}
|
||||||
onChange={(e) =>
|
onChange={(e) =>
|
||||||
dispatch(setShouldShowIntermediates(e.target.checked))
|
dispatch(setShouldShowIntermediates(e.target.checked))
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
<IAICheckbox
|
<IAISimpleCheckbox
|
||||||
label={t('unifiedCanvas.showGrid')}
|
label={t('unifiedCanvas.showGrid')}
|
||||||
isChecked={shouldShowGrid}
|
isChecked={shouldShowGrid}
|
||||||
onChange={(e) => dispatch(setShouldShowGrid(e.target.checked))}
|
onChange={(e) => dispatch(setShouldShowGrid(e.target.checked))}
|
||||||
/>
|
/>
|
||||||
<IAICheckbox
|
<IAISimpleCheckbox
|
||||||
label={t('unifiedCanvas.snapToGrid')}
|
label={t('unifiedCanvas.snapToGrid')}
|
||||||
isChecked={shouldSnapToGrid}
|
isChecked={shouldSnapToGrid}
|
||||||
onChange={handleChangeShouldSnapToGrid}
|
onChange={handleChangeShouldSnapToGrid}
|
||||||
/>
|
/>
|
||||||
<IAICheckbox
|
<IAISimpleCheckbox
|
||||||
label={t('unifiedCanvas.darkenOutsideSelection')}
|
label={t('unifiedCanvas.darkenOutsideSelection')}
|
||||||
isChecked={shouldDarkenOutsideBoundingBox}
|
isChecked={shouldDarkenOutsideBoundingBox}
|
||||||
onChange={(e) =>
|
onChange={(e) =>
|
||||||
dispatch(setShouldDarkenOutsideBoundingBox(e.target.checked))
|
dispatch(setShouldDarkenOutsideBoundingBox(e.target.checked))
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
<IAICheckbox
|
<IAISimpleCheckbox
|
||||||
label={t('unifiedCanvas.autoSaveToGallery')}
|
label={t('unifiedCanvas.autoSaveToGallery')}
|
||||||
isChecked={shouldAutoSave}
|
isChecked={shouldAutoSave}
|
||||||
onChange={(e) => dispatch(setShouldAutoSave(e.target.checked))}
|
onChange={(e) => dispatch(setShouldAutoSave(e.target.checked))}
|
||||||
/>
|
/>
|
||||||
<IAICheckbox
|
<IAISimpleCheckbox
|
||||||
label={t('unifiedCanvas.saveBoxRegionOnly')}
|
label={t('unifiedCanvas.saveBoxRegionOnly')}
|
||||||
isChecked={shouldCropToBoundingBoxOnSave}
|
isChecked={shouldCropToBoundingBoxOnSave}
|
||||||
onChange={(e) =>
|
onChange={(e) =>
|
||||||
dispatch(setShouldCropToBoundingBoxOnSave(e.target.checked))
|
dispatch(setShouldCropToBoundingBoxOnSave(e.target.checked))
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
<IAICheckbox
|
<IAISimpleCheckbox
|
||||||
label={t('unifiedCanvas.limitStrokesToBox')}
|
label={t('unifiedCanvas.limitStrokesToBox')}
|
||||||
isChecked={shouldRestrictStrokesToBox}
|
isChecked={shouldRestrictStrokesToBox}
|
||||||
onChange={(e) =>
|
onChange={(e) =>
|
||||||
dispatch(setShouldRestrictStrokesToBox(e.target.checked))
|
dispatch(setShouldRestrictStrokesToBox(e.target.checked))
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
<IAICheckbox
|
<IAISimpleCheckbox
|
||||||
label={t('unifiedCanvas.showCanvasDebugInfo')}
|
label={t('unifiedCanvas.showCanvasDebugInfo')}
|
||||||
isChecked={shouldShowCanvasDebugInfo}
|
isChecked={shouldShowCanvasDebugInfo}
|
||||||
onChange={(e) =>
|
onChange={(e) =>
|
||||||
@ -153,7 +153,7 @@ const IAICanvasSettingsButtonPopover = () => {
|
|||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<IAICheckbox
|
<IAISimpleCheckbox
|
||||||
label={t('unifiedCanvas.antialiasing')}
|
label={t('unifiedCanvas.antialiasing')}
|
||||||
isChecked={shouldAntialias}
|
isChecked={shouldAntialias}
|
||||||
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
|
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { ButtonGroup, Flex } from '@chakra-ui/react';
|
import { Box, ButtonGroup, Flex } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
@ -210,16 +210,19 @@ const IAICanvasToolbar = () => {
|
|||||||
sx={{
|
sx={{
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
gap: 2,
|
gap: 2,
|
||||||
|
flexWrap: 'wrap',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<IAISelect
|
<Box w={24}>
|
||||||
tooltip={`${t('unifiedCanvas.layer')} (Q)`}
|
<IAISelect
|
||||||
tooltipProps={{ hasArrow: true, placement: 'top' }}
|
tooltip={`${t('unifiedCanvas.layer')} (Q)`}
|
||||||
value={layer}
|
tooltipProps={{ hasArrow: true, placement: 'top' }}
|
||||||
validValues={LAYER_NAMES_DICT}
|
value={layer}
|
||||||
onChange={handleChangeLayer}
|
validValues={LAYER_NAMES_DICT}
|
||||||
isDisabled={isStaging}
|
onChange={handleChangeLayer}
|
||||||
/>
|
isDisabled={isStaging}
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
|
|
||||||
<IAICanvasMaskOptions />
|
<IAICanvasMaskOptions />
|
||||||
<IAICanvasToolChooserOptions />
|
<IAICanvasToolChooserOptions />
|
||||||
|
@ -30,6 +30,8 @@ import {
|
|||||||
} from './canvasTypes';
|
} from './canvasTypes';
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
import { sessionCanceled } from 'services/thunks/session';
|
import { sessionCanceled } from 'services/thunks/session';
|
||||||
|
import { setShouldUseCanvasBetaLayout } from 'features/ui/store/uiSlice';
|
||||||
|
import { imageUrlsReceived } from 'services/thunks/image';
|
||||||
|
|
||||||
export const initialLayerState: CanvasLayerState = {
|
export const initialLayerState: CanvasLayerState = {
|
||||||
objects: [],
|
objects: [],
|
||||||
@ -851,6 +853,30 @@ export const canvasSlice = createSlice({
|
|||||||
state.layerState.stagingArea = initialLayerState.stagingArea;
|
state.layerState.stagingArea = initialLayerState.stagingArea;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
builder.addCase(setShouldUseCanvasBetaLayout, (state, action) => {
|
||||||
|
state.doesCanvasNeedScaling = true;
|
||||||
|
});
|
||||||
|
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||||
|
const { image_name, image_origin, image_url, thumbnail_url } =
|
||||||
|
action.payload;
|
||||||
|
|
||||||
|
state.layerState.objects.forEach((object) => {
|
||||||
|
if (object.kind === 'image') {
|
||||||
|
if (object.image.image_name === image_name) {
|
||||||
|
object.image.image_url = image_url;
|
||||||
|
object.image.thumbnail_url = thumbnail_url;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
state.layerState.stagingArea.images.forEach((stagedImage) => {
|
||||||
|
if (stagedImage.image.image_name === image_name) {
|
||||||
|
stagedImage.image.image_url = image_url;
|
||||||
|
stagedImage.image.thumbnail_url = thumbnail_url;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -2,10 +2,10 @@ import { getCanvasBaseLayer } from './konvaInstanceProvider';
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { konvaNodeToBlob } from './konvaNodeToBlob';
|
import { konvaNodeToBlob } from './konvaNodeToBlob';
|
||||||
|
|
||||||
export const getBaseLayerBlob = async (
|
/**
|
||||||
state: RootState,
|
* Get the canvas base layer blob, with or without bounding box according to `shouldCropToBoundingBoxOnSave`
|
||||||
withoutBoundingBox?: boolean
|
*/
|
||||||
) => {
|
export const getBaseLayerBlob = async (state: RootState) => {
|
||||||
const canvasBaseLayer = getCanvasBaseLayer();
|
const canvasBaseLayer = getCanvasBaseLayer();
|
||||||
|
|
||||||
if (!canvasBaseLayer) {
|
if (!canvasBaseLayer) {
|
||||||
@ -24,15 +24,14 @@ export const getBaseLayerBlob = async (
|
|||||||
|
|
||||||
const absPos = clonedBaseLayer.getAbsolutePosition();
|
const absPos = clonedBaseLayer.getAbsolutePosition();
|
||||||
|
|
||||||
const boundingBox =
|
const boundingBox = shouldCropToBoundingBoxOnSave
|
||||||
shouldCropToBoundingBoxOnSave && !withoutBoundingBox
|
? {
|
||||||
? {
|
x: boundingBoxCoordinates.x + absPos.x,
|
||||||
x: boundingBoxCoordinates.x + absPos.x,
|
y: boundingBoxCoordinates.y + absPos.y,
|
||||||
y: boundingBoxCoordinates.y + absPos.y,
|
width: boundingBoxDimensions.width,
|
||||||
width: boundingBoxDimensions.width,
|
height: boundingBoxDimensions.height,
|
||||||
height: boundingBoxDimensions.height,
|
}
|
||||||
}
|
: clonedBaseLayer.getClientRect();
|
||||||
: clonedBaseLayer.getClientRect();
|
|
||||||
|
|
||||||
return konvaNodeToBlob(clonedBaseLayer, boundingBox);
|
return konvaNodeToBlob(clonedBaseLayer, boundingBox);
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,19 @@
|
|||||||
|
import { getCanvasBaseLayer } from './konvaInstanceProvider';
|
||||||
|
import { konvaNodeToBlob } from './konvaNodeToBlob';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the canvas base layer blob, without bounding box
|
||||||
|
*/
|
||||||
|
export const getFullBaseLayerBlob = async () => {
|
||||||
|
const canvasBaseLayer = getCanvasBaseLayer();
|
||||||
|
|
||||||
|
if (!canvasBaseLayer) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const clonedBaseLayer = canvasBaseLayer.clone();
|
||||||
|
|
||||||
|
clonedBaseLayer.scale({ x: 1, y: 1 });
|
||||||
|
|
||||||
|
return konvaNodeToBlob(clonedBaseLayer, clonedBaseLayer.getClientRect());
|
||||||
|
};
|
@ -0,0 +1,258 @@
|
|||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
|
controlNetAdded,
|
||||||
|
controlNetRemoved,
|
||||||
|
controlNetToggled,
|
||||||
|
} from '../store/controlNetSlice';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import ParamControlNetModel from './parameters/ParamControlNetModel';
|
||||||
|
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
|
||||||
|
import {
|
||||||
|
Checkbox,
|
||||||
|
Flex,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
HStack,
|
||||||
|
TabList,
|
||||||
|
TabPanels,
|
||||||
|
Tabs,
|
||||||
|
Tab,
|
||||||
|
TabPanel,
|
||||||
|
Box,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { FaCopy, FaPlus, FaTrash, FaWrench } from 'react-icons/fa';
|
||||||
|
|
||||||
|
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
|
||||||
|
import ControlNetImagePreview from './ControlNetImagePreview';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
import { useToggle } from 'react-use';
|
||||||
|
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
|
||||||
|
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
|
||||||
|
import ControlNetPreprocessButton from './ControlNetPreprocessButton';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { ChevronDownIcon, ChevronUpIcon } from '@chakra-ui/icons';
|
||||||
|
|
||||||
|
type ControlNetProps = {
|
||||||
|
controlNet: ControlNetConfig;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ControlNet = (props: ControlNetProps) => {
|
||||||
|
const {
|
||||||
|
controlNetId,
|
||||||
|
isEnabled,
|
||||||
|
model,
|
||||||
|
weight,
|
||||||
|
beginStepPct,
|
||||||
|
endStepPct,
|
||||||
|
controlImage,
|
||||||
|
processedControlImage,
|
||||||
|
processorNode,
|
||||||
|
processorType,
|
||||||
|
} = props.controlNet;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const [shouldShowAdvanced, onToggleAdvanced] = useToggle(false);
|
||||||
|
|
||||||
|
const handleDelete = useCallback(() => {
|
||||||
|
dispatch(controlNetRemoved({ controlNetId }));
|
||||||
|
}, [controlNetId, dispatch]);
|
||||||
|
|
||||||
|
const handleDuplicate = useCallback(() => {
|
||||||
|
dispatch(
|
||||||
|
controlNetAdded({ controlNetId: uuidv4(), controlNet: props.controlNet })
|
||||||
|
);
|
||||||
|
}, [dispatch, props.controlNet]);
|
||||||
|
|
||||||
|
const handleToggleIsEnabled = useCallback(() => {
|
||||||
|
dispatch(controlNetToggled({ controlNetId }));
|
||||||
|
}, [controlNetId, dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
flexDir: 'column',
|
||||||
|
gap: 2,
|
||||||
|
p: 3,
|
||||||
|
bg: 'base.850',
|
||||||
|
borderRadius: 'base',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Flex sx={{ gap: 2 }}>
|
||||||
|
<IAISwitch
|
||||||
|
tooltip="Toggle"
|
||||||
|
aria-label="Toggle"
|
||||||
|
isChecked={isEnabled}
|
||||||
|
onChange={handleToggleIsEnabled}
|
||||||
|
/>
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
w: 'full',
|
||||||
|
minW: 0,
|
||||||
|
opacity: isEnabled ? 1 : 0.5,
|
||||||
|
pointerEvents: isEnabled ? 'auto' : 'none',
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: '0.1s',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<ParamControlNetModel controlNetId={controlNetId} model={model} />
|
||||||
|
</Box>
|
||||||
|
<IAIIconButton
|
||||||
|
size="sm"
|
||||||
|
tooltip="Duplicate"
|
||||||
|
aria-label="Duplicate"
|
||||||
|
onClick={handleDuplicate}
|
||||||
|
icon={<FaCopy />}
|
||||||
|
/>
|
||||||
|
<IAIIconButton
|
||||||
|
size="sm"
|
||||||
|
tooltip="Delete"
|
||||||
|
aria-label="Delete"
|
||||||
|
colorScheme="error"
|
||||||
|
onClick={handleDelete}
|
||||||
|
icon={<FaTrash />}
|
||||||
|
/>
|
||||||
|
<IAIIconButton
|
||||||
|
size="sm"
|
||||||
|
aria-label="Expand"
|
||||||
|
onClick={onToggleAdvanced}
|
||||||
|
variant="link"
|
||||||
|
icon={
|
||||||
|
<ChevronUpIcon
|
||||||
|
sx={{
|
||||||
|
boxSize: 4,
|
||||||
|
color: 'base.300',
|
||||||
|
transform: shouldShowAdvanced
|
||||||
|
? 'rotate(0deg)'
|
||||||
|
: 'rotate(180deg)',
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: 'normal',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
{isEnabled && (
|
||||||
|
<>
|
||||||
|
<Flex sx={{ gap: 4 }}>
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
flexDir: 'column',
|
||||||
|
gap: 2,
|
||||||
|
w: 'full',
|
||||||
|
h: 24,
|
||||||
|
paddingInlineStart: 1,
|
||||||
|
paddingInlineEnd: shouldShowAdvanced ? 1 : 0,
|
||||||
|
pb: 2,
|
||||||
|
justifyContent: 'space-between',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<ParamControlNetWeight
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
weight={weight}
|
||||||
|
mini
|
||||||
|
/>
|
||||||
|
<ParamControlNetBeginEnd
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
beginStepPct={beginStepPct}
|
||||||
|
endStepPct={endStepPct}
|
||||||
|
mini
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
{!shouldShowAdvanced && (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
h: 24,
|
||||||
|
w: 24,
|
||||||
|
aspectRatio: '1/1',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<ControlNetImagePreview controlNet={props.controlNet} />
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
{shouldShowAdvanced && (
|
||||||
|
<>
|
||||||
|
<Box pt={2}>
|
||||||
|
<ControlNetImagePreview controlNet={props.controlNet} />
|
||||||
|
</Box>
|
||||||
|
<ParamControlNetProcessorSelect
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
processorNode={processorNode}
|
||||||
|
/>
|
||||||
|
<ControlNetProcessorComponent
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
processorNode={processorNode}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex sx={{ flexDir: 'column', gap: 3 }}>
|
||||||
|
<ControlNetImagePreview controlNet={props.controlNet} />
|
||||||
|
<ParamControlNetModel controlNetId={controlNetId} model={model} />
|
||||||
|
<Tabs
|
||||||
|
isFitted
|
||||||
|
orientation="horizontal"
|
||||||
|
variant="enclosed"
|
||||||
|
size="sm"
|
||||||
|
colorScheme="accent"
|
||||||
|
>
|
||||||
|
<TabList>
|
||||||
|
<Tab
|
||||||
|
sx={{ 'button&': { _selected: { borderBottomColor: 'base.800' } } }}
|
||||||
|
>
|
||||||
|
Model Config
|
||||||
|
</Tab>
|
||||||
|
<Tab
|
||||||
|
sx={{ 'button&': { _selected: { borderBottomColor: 'base.800' } } }}
|
||||||
|
>
|
||||||
|
Preprocess
|
||||||
|
</Tab>
|
||||||
|
</TabList>
|
||||||
|
<TabPanels sx={{ pt: 2 }}>
|
||||||
|
<TabPanel sx={{ p: 0 }}>
|
||||||
|
<ParamControlNetWeight
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
weight={weight}
|
||||||
|
/>
|
||||||
|
<ParamControlNetBeginEnd
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
beginStepPct={beginStepPct}
|
||||||
|
endStepPct={endStepPct}
|
||||||
|
/>
|
||||||
|
</TabPanel>
|
||||||
|
<TabPanel sx={{ p: 0 }}>
|
||||||
|
<ParamControlNetProcessorSelect
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
processorNode={processorNode}
|
||||||
|
/>
|
||||||
|
<ControlNetProcessorComponent
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
processorNode={processorNode}
|
||||||
|
/>
|
||||||
|
<ControlNetPreprocessButton controlNet={props.controlNet} />
|
||||||
|
{/* <IAIButton
|
||||||
|
size="sm"
|
||||||
|
leftIcon={<FaUndo />}
|
||||||
|
onClick={handleReset}
|
||||||
|
isDisabled={Boolean(!processedControlImage)}
|
||||||
|
>
|
||||||
|
Reset Processing
|
||||||
|
</IAIButton> */}
|
||||||
|
</TabPanel>
|
||||||
|
</TabPanels>
|
||||||
|
</Tabs>
|
||||||
|
<IAIButton onClick={handleDelete}>Remove ControlNet</IAIButton>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ControlNet);
|
@ -0,0 +1,144 @@
|
|||||||
|
import { memo, useCallback, useRef, useState } from 'react';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
|
controlNetImageChanged,
|
||||||
|
controlNetSelector,
|
||||||
|
} from '../store/controlNetSlice';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { Box } from '@chakra-ui/react';
|
||||||
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { AnimatePresence, motion } from 'framer-motion';
|
||||||
|
import { IAIImageFallback } from 'common/components/IAIImageFallback';
|
||||||
|
import { useHoverDirty } from 'react-use';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
controlNetSelector,
|
||||||
|
(controlNet) => {
|
||||||
|
const { isProcessingControlImage } = controlNet;
|
||||||
|
return { isProcessingControlImage };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
controlNet: ControlNetConfig;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ControlNetImagePreview = (props: Props) => {
|
||||||
|
const { controlNetId, controlImage, processedControlImage, processorType } =
|
||||||
|
props.controlNet;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { isProcessingControlImage } = useAppSelector(selector);
|
||||||
|
const containerRef = useRef<HTMLDivElement>(null);
|
||||||
|
|
||||||
|
const isMouseOverImage = useHoverDirty(containerRef);
|
||||||
|
|
||||||
|
const handleDrop = useCallback(
|
||||||
|
(droppedImage: ImageDTO) => {
|
||||||
|
if (controlImage?.image_name === droppedImage.image_name) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(
|
||||||
|
controlNetImageChanged({ controlNetId, controlImage: droppedImage })
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[controlImage, controlNetId, dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const shouldShowProcessedImageBackdrop =
|
||||||
|
Number(controlImage?.width) > Number(processedControlImage?.width) ||
|
||||||
|
Number(controlImage?.height) > Number(processedControlImage?.height);
|
||||||
|
|
||||||
|
const shouldShowProcessedImage =
|
||||||
|
controlImage &&
|
||||||
|
processedControlImage &&
|
||||||
|
!isMouseOverImage &&
|
||||||
|
!isProcessingControlImage &&
|
||||||
|
processorType !== 'none';
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box
|
||||||
|
ref={containerRef}
|
||||||
|
sx={{ position: 'relative', w: 'full', h: 'full', aspectRatio: '1/1' }}
|
||||||
|
>
|
||||||
|
<IAIDndImage
|
||||||
|
image={controlImage}
|
||||||
|
onDrop={handleDrop}
|
||||||
|
isDropDisabled={Boolean(
|
||||||
|
processedControlImage && processorType !== 'none'
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
<AnimatePresence>
|
||||||
|
{shouldShowProcessedImage && (
|
||||||
|
<motion.div
|
||||||
|
initial={{
|
||||||
|
opacity: 0,
|
||||||
|
}}
|
||||||
|
animate={{
|
||||||
|
opacity: 1,
|
||||||
|
transition: { duration: 0.1 },
|
||||||
|
}}
|
||||||
|
exit={{
|
||||||
|
opacity: 0,
|
||||||
|
transition: { duration: 0.1 },
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
top: 0,
|
||||||
|
insetInlineStart: 0,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{shouldShowProcessedImageBackdrop && (
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
bg: 'base.900',
|
||||||
|
opacity: 0.7,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 0,
|
||||||
|
insetInlineStart: 0,
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<IAIDndImage
|
||||||
|
image={processedControlImage}
|
||||||
|
onDrop={handleDrop}
|
||||||
|
payloadImage={controlImage}
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
|
</Box>
|
||||||
|
</motion.div>
|
||||||
|
)}
|
||||||
|
</AnimatePresence>
|
||||||
|
{isProcessingControlImage && (
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 0,
|
||||||
|
insetInlineStart: 0,
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<IAIImageFallback />
|
||||||
|
</Box>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ControlNetImagePreview);
|
@ -0,0 +1,36 @@
|
|||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { ControlNetConfig } from '../store/controlNetSlice';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { controlNetImageProcessed } from '../store/actions';
|
||||||
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
controlNet: ControlNetConfig;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ControlNetPreprocessButton = (props: Props) => {
|
||||||
|
const { controlNetId, controlImage } = props.controlNet;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const isReady = useIsReadyToInvoke();
|
||||||
|
|
||||||
|
const handleProcess = useCallback(() => {
|
||||||
|
dispatch(
|
||||||
|
controlNetImageProcessed({
|
||||||
|
controlNetId,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}, [controlNetId, dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIButton
|
||||||
|
size="sm"
|
||||||
|
onClick={handleProcess}
|
||||||
|
isDisabled={Boolean(!controlImage) || !isReady}
|
||||||
|
>
|
||||||
|
Preprocess
|
||||||
|
</IAIButton>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ControlNetPreprocessButton);
|
@ -0,0 +1,131 @@
|
|||||||
|
import { memo } from 'react';
|
||||||
|
import { RequiredControlNetProcessorNode } from '../store/types';
|
||||||
|
import CannyProcessor from './processors/CannyProcessor';
|
||||||
|
import HedProcessor from './processors/HedProcessor';
|
||||||
|
import LineartProcessor from './processors/LineartProcessor';
|
||||||
|
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
|
||||||
|
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
|
||||||
|
import MediapipeFaceProcessor from './processors/MediapipeFaceProcessor';
|
||||||
|
import MidasDepthProcessor from './processors/MidasDepthProcessor';
|
||||||
|
import MlsdImageProcessor from './processors/MlsdImageProcessor';
|
||||||
|
import NormalBaeProcessor from './processors/NormalBaeProcessor';
|
||||||
|
import OpenposeProcessor from './processors/OpenposeProcessor';
|
||||||
|
import PidiProcessor from './processors/PidiProcessor';
|
||||||
|
import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
|
||||||
|
|
||||||
|
export type ControlNetProcessorProps = {
|
||||||
|
controlNetId: string;
|
||||||
|
processorNode: RequiredControlNetProcessorNode;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||||
|
const { controlNetId, processorNode } = props;
|
||||||
|
if (processorNode.type === 'canny_image_processor') {
|
||||||
|
return (
|
||||||
|
<CannyProcessor
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
processorNode={processorNode}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (processorNode.type === 'hed_image_processor') {
|
||||||
|
return (
|
||||||
|
<HedProcessor controlNetId={controlNetId} processorNode={processorNode} />
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (processorNode.type === 'lineart_image_processor') {
|
||||||
|
return (
|
||||||
|
<LineartProcessor
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
processorNode={processorNode}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (processorNode.type === 'content_shuffle_image_processor') {
|
||||||
|
return (
|
||||||
|
<ContentShuffleProcessor
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
processorNode={processorNode}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (processorNode.type === 'lineart_anime_image_processor') {
|
||||||
|
return (
|
||||||
|
<LineartAnimeProcessor
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
processorNode={processorNode}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (processorNode.type === 'mediapipe_face_processor') {
|
||||||
|
return (
|
||||||
|
<MediapipeFaceProcessor
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
processorNode={processorNode}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (processorNode.type === 'midas_depth_image_processor') {
|
||||||
|
return (
|
||||||
|
<MidasDepthProcessor
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
processorNode={processorNode}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (processorNode.type === 'mlsd_image_processor') {
|
||||||
|
return (
|
||||||
|
<MlsdImageProcessor
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
processorNode={processorNode}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (processorNode.type === 'normalbae_image_processor') {
|
||||||
|
return (
|
||||||
|
<NormalBaeProcessor
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
processorNode={processorNode}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (processorNode.type === 'openpose_image_processor') {
|
||||||
|
return (
|
||||||
|
<OpenposeProcessor
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
processorNode={processorNode}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (processorNode.type === 'pidi_image_processor') {
|
||||||
|
return (
|
||||||
|
<PidiProcessor
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
processorNode={processorNode}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (processorNode.type === 'zoe_depth_image_processor') {
|
||||||
|
return (
|
||||||
|
<ZoeDepthProcessor
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
processorNode={processorNode}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ControlNetProcessorComponent);
|
@ -0,0 +1,20 @@
|
|||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { controlNetProcessorParamsChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { ControlNetProcessorNode } from 'features/controlNet/store/types';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
|
||||||
|
export const useProcessorNodeChanged = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const handleProcessorNodeChanged = useCallback(
|
||||||
|
(controlNetId: string, changes: Partial<ControlNetProcessorNode>) => {
|
||||||
|
dispatch(
|
||||||
|
controlNetProcessorParamsChanged({
|
||||||
|
controlNetId,
|
||||||
|
changes,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
return handleProcessorNodeChanged;
|
||||||
|
};
|
@ -0,0 +1,130 @@
|
|||||||
|
import {
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
HStack,
|
||||||
|
RangeSlider,
|
||||||
|
RangeSliderFilledTrack,
|
||||||
|
RangeSliderMark,
|
||||||
|
RangeSliderThumb,
|
||||||
|
RangeSliderTrack,
|
||||||
|
Tooltip,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import {
|
||||||
|
controlNetBeginStepPctChanged,
|
||||||
|
controlNetEndStepPctChanged,
|
||||||
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { BiReset } from 'react-icons/bi';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
controlNetId: string;
|
||||||
|
beginStepPct: number;
|
||||||
|
endStepPct: number;
|
||||||
|
mini?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
|
||||||
|
|
||||||
|
const ParamControlNetBeginEnd = (props: Props) => {
|
||||||
|
const { controlNetId, beginStepPct, endStepPct, mini = false } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleStepPctChanged = useCallback(
|
||||||
|
(v: number[]) => {
|
||||||
|
dispatch(
|
||||||
|
controlNetBeginStepPctChanged({ controlNetId, beginStepPct: v[0] })
|
||||||
|
);
|
||||||
|
dispatch(controlNetEndStepPctChanged({ controlNetId, endStepPct: v[1] }));
|
||||||
|
},
|
||||||
|
[controlNetId, dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleStepPctReset = useCallback(() => {
|
||||||
|
dispatch(controlNetBeginStepPctChanged({ controlNetId, beginStepPct: 0 }));
|
||||||
|
dispatch(controlNetEndStepPctChanged({ controlNetId, endStepPct: 1 }));
|
||||||
|
}, [controlNetId, dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl>
|
||||||
|
<FormLabel>Begin / End Step Percentage</FormLabel>
|
||||||
|
<HStack w="100%" gap={2} alignItems="center">
|
||||||
|
<RangeSlider
|
||||||
|
aria-label={['Begin Step %', 'End Step %']}
|
||||||
|
value={[beginStepPct, endStepPct]}
|
||||||
|
onChange={handleStepPctChanged}
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
step={0.01}
|
||||||
|
minStepsBetweenThumbs={5}
|
||||||
|
>
|
||||||
|
<RangeSliderTrack>
|
||||||
|
<RangeSliderFilledTrack />
|
||||||
|
</RangeSliderTrack>
|
||||||
|
<Tooltip label={formatPct(beginStepPct)} placement="top" hasArrow>
|
||||||
|
<RangeSliderThumb index={0} />
|
||||||
|
</Tooltip>
|
||||||
|
<Tooltip label={formatPct(endStepPct)} placement="top" hasArrow>
|
||||||
|
<RangeSliderThumb index={1} />
|
||||||
|
</Tooltip>
|
||||||
|
{!mini && (
|
||||||
|
<>
|
||||||
|
<RangeSliderMark
|
||||||
|
value={0}
|
||||||
|
sx={{
|
||||||
|
fontSize: 'xs',
|
||||||
|
fontWeight: '500',
|
||||||
|
color: 'base.200',
|
||||||
|
insetInlineStart: '0 !important',
|
||||||
|
insetInlineEnd: 'unset !important',
|
||||||
|
mt: 1.5,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
0%
|
||||||
|
</RangeSliderMark>
|
||||||
|
<RangeSliderMark
|
||||||
|
value={0.5}
|
||||||
|
sx={{
|
||||||
|
fontSize: 'xs',
|
||||||
|
fontWeight: '500',
|
||||||
|
color: 'base.200',
|
||||||
|
mt: 1.5,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
50%
|
||||||
|
</RangeSliderMark>
|
||||||
|
<RangeSliderMark
|
||||||
|
value={1}
|
||||||
|
sx={{
|
||||||
|
fontSize: 'xs',
|
||||||
|
fontWeight: '500',
|
||||||
|
color: 'base.200',
|
||||||
|
insetInlineStart: 'unset !important',
|
||||||
|
insetInlineEnd: '0 !important',
|
||||||
|
mt: 1.5,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
100%
|
||||||
|
</RangeSliderMark>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</RangeSlider>
|
||||||
|
|
||||||
|
{!mini && (
|
||||||
|
<IAIIconButton
|
||||||
|
size="sm"
|
||||||
|
aria-label={t('accessibility.reset')}
|
||||||
|
tooltip={t('accessibility.reset')}
|
||||||
|
icon={<BiReset />}
|
||||||
|
onClick={handleStepPctReset}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</HStack>
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamControlNetBeginEnd);
|
@ -0,0 +1,28 @@
|
|||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { controlNetToggled } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
|
type ParamControlNetIsEnabledProps = {
|
||||||
|
controlNetId: string;
|
||||||
|
isEnabled: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
|
||||||
|
const { controlNetId, isEnabled } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleIsEnabledChanged = useCallback(() => {
|
||||||
|
dispatch(controlNetToggled({ controlNetId }));
|
||||||
|
}, [dispatch, controlNetId]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAISwitch
|
||||||
|
label="Enabled"
|
||||||
|
isChecked={isEnabled}
|
||||||
|
onChange={handleIsEnabledChanged}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamControlNetIsEnabled);
|
@ -0,0 +1,36 @@
|
|||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIFullCheckbox from 'common/components/IAIFullCheckbox';
|
||||||
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import {
|
||||||
|
controlNetToggled,
|
||||||
|
isControlNetImagePreprocessedToggled,
|
||||||
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
|
type ParamControlNetIsEnabledProps = {
|
||||||
|
controlNetId: string;
|
||||||
|
isControlImageProcessed: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
|
||||||
|
const { controlNetId, isControlImageProcessed } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleIsControlImageProcessedToggled = useCallback(() => {
|
||||||
|
dispatch(
|
||||||
|
isControlNetImagePreprocessedToggled({
|
||||||
|
controlNetId,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}, [controlNetId, dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAISwitch
|
||||||
|
label="Preprocess"
|
||||||
|
isChecked={isControlImageProcessed}
|
||||||
|
onChange={handleIsControlImageProcessedToggled}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamControlNetIsEnabled);
|
@ -0,0 +1,41 @@
|
|||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAICustomSelect from 'common/components/IAICustomSelect';
|
||||||
|
import {
|
||||||
|
CONTROLNET_MODELS,
|
||||||
|
ControlNetModel,
|
||||||
|
} from 'features/controlNet/store/constants';
|
||||||
|
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
|
type ParamIsControlNetModelProps = {
|
||||||
|
controlNetId: string;
|
||||||
|
model: ControlNetModel;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ParamIsControlNetModel = (props: ParamIsControlNetModelProps) => {
|
||||||
|
const { controlNetId, model } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleModelChanged = useCallback(
|
||||||
|
(val: string | null | undefined) => {
|
||||||
|
// TODO: do not cast
|
||||||
|
const model = val as ControlNetModel;
|
||||||
|
dispatch(controlNetModelChanged({ controlNetId, model }));
|
||||||
|
},
|
||||||
|
[controlNetId, dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAICustomSelect
|
||||||
|
tooltip={model}
|
||||||
|
tooltipProps={{ placement: 'top', hasArrow: true }}
|
||||||
|
items={CONTROLNET_MODELS}
|
||||||
|
selectedItem={model}
|
||||||
|
setSelectedItem={handleModelChanged}
|
||||||
|
ellipsisPosition="start"
|
||||||
|
withCheckIcon
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamIsControlNetModel);
|
@ -0,0 +1,47 @@
|
|||||||
|
import IAICustomSelect from 'common/components/IAICustomSelect';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import {
|
||||||
|
ControlNetProcessorNode,
|
||||||
|
ControlNetProcessorType,
|
||||||
|
} from '../../store/types';
|
||||||
|
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { CONTROLNET_PROCESSORS } from '../../store/constants';
|
||||||
|
|
||||||
|
type ParamControlNetProcessorSelectProps = {
|
||||||
|
controlNetId: string;
|
||||||
|
processorNode: ControlNetProcessorNode;
|
||||||
|
};
|
||||||
|
|
||||||
|
const CONTROLNET_PROCESSOR_TYPES = Object.keys(
|
||||||
|
CONTROLNET_PROCESSORS
|
||||||
|
) as ControlNetProcessorType[];
|
||||||
|
|
||||||
|
const ParamControlNetProcessorSelect = (
|
||||||
|
props: ParamControlNetProcessorSelectProps
|
||||||
|
) => {
|
||||||
|
const { controlNetId, processorNode } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const handleProcessorTypeChanged = useCallback(
|
||||||
|
(v: string | null | undefined) => {
|
||||||
|
dispatch(
|
||||||
|
controlNetProcessorTypeChanged({
|
||||||
|
controlNetId,
|
||||||
|
processorType: v as ControlNetProcessorType,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[controlNetId, dispatch]
|
||||||
|
);
|
||||||
|
return (
|
||||||
|
<IAICustomSelect
|
||||||
|
label="Processor"
|
||||||
|
items={CONTROLNET_PROCESSOR_TYPES}
|
||||||
|
selectedItem={processorNode.type ?? 'canny_image_processor'}
|
||||||
|
setSelectedItem={handleProcessorTypeChanged}
|
||||||
|
withCheckIcon
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamControlNetProcessorSelect);
|
@ -0,0 +1,57 @@
|
|||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
|
type ParamControlNetWeightProps = {
|
||||||
|
controlNetId: string;
|
||||||
|
weight: number;
|
||||||
|
mini?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
|
||||||
|
const { controlNetId, weight, mini = false } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleWeightChanged = useCallback(
|
||||||
|
(weight: number) => {
|
||||||
|
dispatch(controlNetWeightChanged({ controlNetId, weight }));
|
||||||
|
},
|
||||||
|
[controlNetId, dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleWeightReset = () => {
|
||||||
|
dispatch(controlNetWeightChanged({ controlNetId, weight: 1 }));
|
||||||
|
};
|
||||||
|
|
||||||
|
if (mini) {
|
||||||
|
return (
|
||||||
|
<IAISlider
|
||||||
|
label={'Weight'}
|
||||||
|
sliderFormLabelProps={{ pb: 1 }}
|
||||||
|
value={weight}
|
||||||
|
onChange={handleWeightChanged}
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
step={0.01}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAISlider
|
||||||
|
label="Weight"
|
||||||
|
value={weight}
|
||||||
|
onChange={handleWeightChanged}
|
||||||
|
withInput
|
||||||
|
withReset
|
||||||
|
handleReset={handleWeightReset}
|
||||||
|
withSliderMarks
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
step={0.01}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamControlNetWeight);
|
@ -0,0 +1,72 @@
|
|||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
|
import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
|
||||||
|
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default;
|
||||||
|
|
||||||
|
type CannyProcessorProps = {
|
||||||
|
controlNetId: string;
|
||||||
|
processorNode: RequiredCannyImageProcessorInvocation;
|
||||||
|
};
|
||||||
|
|
||||||
|
const CannyProcessor = (props: CannyProcessorProps) => {
|
||||||
|
const { controlNetId, processorNode } = props;
|
||||||
|
const { low_threshold, high_threshold } = processorNode;
|
||||||
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
|
||||||
|
const handleLowThresholdChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { low_threshold: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleLowThresholdReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
low_threshold: DEFAULTS.low_threshold,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleHighThresholdChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { high_threshold: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleHighThresholdReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
high_threshold: DEFAULTS.high_threshold,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ProcessorWrapper>
|
||||||
|
<IAISlider
|
||||||
|
label="Low Threshold"
|
||||||
|
value={low_threshold}
|
||||||
|
onChange={handleLowThresholdChanged}
|
||||||
|
handleReset={handleLowThresholdReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={255}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISlider
|
||||||
|
label="High Threshold"
|
||||||
|
value={high_threshold}
|
||||||
|
onChange={handleHighThresholdChanged}
|
||||||
|
handleReset={handleHighThresholdReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={255}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
</ProcessorWrapper>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(CannyProcessor);
|
@ -0,0 +1,141 @@
|
|||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
|
import { RequiredContentShuffleImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
|
||||||
|
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default;
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
controlNetId: string;
|
||||||
|
processorNode: RequiredContentShuffleImageProcessorInvocation;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ContentShuffleProcessor = (props: Props) => {
|
||||||
|
const { controlNetId, processorNode } = props;
|
||||||
|
const { image_resolution, detect_resolution, w, h, f } = processorNode;
|
||||||
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
|
||||||
|
const handleDetectResolutionChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { detect_resolution: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleDetectResolutionReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
detect_resolution: DEFAULTS.detect_resolution,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleImageResolutionChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { image_resolution: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleImageResolutionReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
image_resolution: DEFAULTS.image_resolution,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleWChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { w: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleWReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
w: DEFAULTS.w,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleHChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { h: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleHReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
h: DEFAULTS.h,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleFChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { f: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleFReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
f: DEFAULTS.f,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ProcessorWrapper>
|
||||||
|
<IAISlider
|
||||||
|
label="Detect Resolution"
|
||||||
|
value={detect_resolution}
|
||||||
|
onChange={handleDetectResolutionChanged}
|
||||||
|
handleReset={handleDetectResolutionReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISlider
|
||||||
|
label="Image Resolution"
|
||||||
|
value={image_resolution}
|
||||||
|
onChange={handleImageResolutionChanged}
|
||||||
|
handleReset={handleImageResolutionReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISlider
|
||||||
|
label="W"
|
||||||
|
value={w}
|
||||||
|
onChange={handleWChanged}
|
||||||
|
handleReset={handleWReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISlider
|
||||||
|
label="H"
|
||||||
|
value={h}
|
||||||
|
onChange={handleHChanged}
|
||||||
|
handleReset={handleHReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISlider
|
||||||
|
label="F"
|
||||||
|
value={f}
|
||||||
|
onChange={handleFChanged}
|
||||||
|
handleReset={handleFReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
</ProcessorWrapper>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ContentShuffleProcessor);
|
@ -0,0 +1,88 @@
|
|||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
|
import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
|
||||||
|
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default;
|
||||||
|
|
||||||
|
type HedProcessorProps = {
|
||||||
|
controlNetId: string;
|
||||||
|
processorNode: RequiredHedImageProcessorInvocation;
|
||||||
|
};
|
||||||
|
|
||||||
|
const HedPreprocessor = (props: HedProcessorProps) => {
|
||||||
|
const {
|
||||||
|
controlNetId,
|
||||||
|
processorNode: { detect_resolution, image_resolution, scribble },
|
||||||
|
} = props;
|
||||||
|
|
||||||
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
|
||||||
|
const handleDetectResolutionChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { detect_resolution: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleImageResolutionChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { image_resolution: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleScribbleChanged = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
processorChanged(controlNetId, { scribble: e.target.checked });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleDetectResolutionReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
detect_resolution: DEFAULTS.detect_resolution,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleImageResolutionReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
image_resolution: DEFAULTS.image_resolution,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ProcessorWrapper>
|
||||||
|
<IAISlider
|
||||||
|
label="Detect Resolution"
|
||||||
|
value={detect_resolution}
|
||||||
|
onChange={handleDetectResolutionChanged}
|
||||||
|
handleReset={handleDetectResolutionReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISlider
|
||||||
|
label="Image Resolution"
|
||||||
|
value={image_resolution}
|
||||||
|
onChange={handleImageResolutionChanged}
|
||||||
|
handleReset={handleImageResolutionReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISwitch
|
||||||
|
label="Scribble"
|
||||||
|
isChecked={scribble}
|
||||||
|
onChange={handleScribbleChanged}
|
||||||
|
/>
|
||||||
|
</ProcessorWrapper>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(HedPreprocessor);
|
@ -0,0 +1,72 @@
|
|||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
|
import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
|
||||||
|
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default;
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
controlNetId: string;
|
||||||
|
processorNode: RequiredLineartAnimeImageProcessorInvocation;
|
||||||
|
};
|
||||||
|
|
||||||
|
const LineartAnimeProcessor = (props: Props) => {
|
||||||
|
const { controlNetId, processorNode } = props;
|
||||||
|
const { image_resolution, detect_resolution } = processorNode;
|
||||||
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
|
||||||
|
const handleDetectResolutionChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { detect_resolution: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleImageResolutionChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { image_resolution: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleDetectResolutionReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
detect_resolution: DEFAULTS.detect_resolution,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleImageResolutionReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
image_resolution: DEFAULTS.image_resolution,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ProcessorWrapper>
|
||||||
|
<IAISlider
|
||||||
|
label="Detect Resolution"
|
||||||
|
value={detect_resolution}
|
||||||
|
onChange={handleDetectResolutionChanged}
|
||||||
|
handleReset={handleDetectResolutionReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISlider
|
||||||
|
label="Image Resolution"
|
||||||
|
value={image_resolution}
|
||||||
|
onChange={handleImageResolutionChanged}
|
||||||
|
handleReset={handleImageResolutionReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
</ProcessorWrapper>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(LineartAnimeProcessor);
|
@ -0,0 +1,85 @@
|
|||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
|
import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
|
||||||
|
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default;
|
||||||
|
|
||||||
|
type LineartProcessorProps = {
|
||||||
|
controlNetId: string;
|
||||||
|
processorNode: RequiredLineartImageProcessorInvocation;
|
||||||
|
};
|
||||||
|
|
||||||
|
const LineartProcessor = (props: LineartProcessorProps) => {
|
||||||
|
const { controlNetId, processorNode } = props;
|
||||||
|
const { image_resolution, detect_resolution, coarse } = processorNode;
|
||||||
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
|
||||||
|
const handleDetectResolutionChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { detect_resolution: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleImageResolutionChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { image_resolution: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleDetectResolutionReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
detect_resolution: DEFAULTS.detect_resolution,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleImageResolutionReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
image_resolution: DEFAULTS.image_resolution,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleCoarseChanged = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
processorChanged(controlNetId, { coarse: e.target.checked });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ProcessorWrapper>
|
||||||
|
<IAISlider
|
||||||
|
label="Detect Resolution"
|
||||||
|
value={detect_resolution}
|
||||||
|
onChange={handleDetectResolutionChanged}
|
||||||
|
handleReset={handleDetectResolutionReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISlider
|
||||||
|
label="Image Resolution"
|
||||||
|
value={image_resolution}
|
||||||
|
onChange={handleImageResolutionChanged}
|
||||||
|
handleReset={handleImageResolutionReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISwitch
|
||||||
|
label="Coarse"
|
||||||
|
isChecked={coarse}
|
||||||
|
onChange={handleCoarseChanged}
|
||||||
|
/>
|
||||||
|
</ProcessorWrapper>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(LineartProcessor);
|
@ -0,0 +1,69 @@
|
|||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
|
import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
|
||||||
|
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default;
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
controlNetId: string;
|
||||||
|
processorNode: RequiredMediapipeFaceProcessorInvocation;
|
||||||
|
};
|
||||||
|
|
||||||
|
const MediapipeFaceProcessor = (props: Props) => {
|
||||||
|
const { controlNetId, processorNode } = props;
|
||||||
|
const { max_faces, min_confidence } = processorNode;
|
||||||
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
|
||||||
|
const handleMaxFacesChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { max_faces: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleMinConfidenceChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { min_confidence: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleMaxFacesReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, { max_faces: DEFAULTS.max_faces });
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleMinConfidenceReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, { min_confidence: DEFAULTS.min_confidence });
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ProcessorWrapper>
|
||||||
|
<IAISlider
|
||||||
|
label="Max Faces"
|
||||||
|
value={max_faces}
|
||||||
|
onChange={handleMaxFacesChanged}
|
||||||
|
handleReset={handleMaxFacesReset}
|
||||||
|
withReset
|
||||||
|
min={1}
|
||||||
|
max={20}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISlider
|
||||||
|
label="Min Confidence"
|
||||||
|
value={min_confidence}
|
||||||
|
onChange={handleMinConfidenceChanged}
|
||||||
|
handleReset={handleMinConfidenceReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
step={0.01}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
</ProcessorWrapper>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(MediapipeFaceProcessor);
|
@ -0,0 +1,70 @@
|
|||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
|
import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
|
||||||
|
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default;
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
controlNetId: string;
|
||||||
|
processorNode: RequiredMidasDepthImageProcessorInvocation;
|
||||||
|
};
|
||||||
|
|
||||||
|
const MidasDepthProcessor = (props: Props) => {
|
||||||
|
const { controlNetId, processorNode } = props;
|
||||||
|
const { a_mult, bg_th } = processorNode;
|
||||||
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
|
||||||
|
const handleAMultChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { a_mult: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleBgThChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { bg_th: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleAMultReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, { a_mult: DEFAULTS.a_mult });
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleBgThReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, { bg_th: DEFAULTS.bg_th });
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ProcessorWrapper>
|
||||||
|
<IAISlider
|
||||||
|
label="a_mult"
|
||||||
|
value={a_mult}
|
||||||
|
onChange={handleAMultChanged}
|
||||||
|
handleReset={handleAMultReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={20}
|
||||||
|
step={0.01}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISlider
|
||||||
|
label="bg_th"
|
||||||
|
value={bg_th}
|
||||||
|
onChange={handleBgThChanged}
|
||||||
|
handleReset={handleBgThReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={20}
|
||||||
|
step={0.01}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
</ProcessorWrapper>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(MidasDepthProcessor);
|
@ -0,0 +1,116 @@
|
|||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
|
import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
|
||||||
|
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default;
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
controlNetId: string;
|
||||||
|
processorNode: RequiredMlsdImageProcessorInvocation;
|
||||||
|
};
|
||||||
|
|
||||||
|
const MlsdImageProcessor = (props: Props) => {
|
||||||
|
const { controlNetId, processorNode } = props;
|
||||||
|
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
|
||||||
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
|
||||||
|
const handleDetectResolutionChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { detect_resolution: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleImageResolutionChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { image_resolution: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleThrDChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { thr_d: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleThrVChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { thr_v: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleDetectResolutionReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
detect_resolution: DEFAULTS.detect_resolution,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleImageResolutionReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
image_resolution: DEFAULTS.image_resolution,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleThrDReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, { thr_d: DEFAULTS.thr_d });
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleThrVReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, { thr_v: DEFAULTS.thr_v });
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ProcessorWrapper>
|
||||||
|
<IAISlider
|
||||||
|
label="Detect Resolution"
|
||||||
|
value={detect_resolution}
|
||||||
|
onChange={handleDetectResolutionChanged}
|
||||||
|
handleReset={handleDetectResolutionReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISlider
|
||||||
|
label="Image Resolution"
|
||||||
|
value={image_resolution}
|
||||||
|
onChange={handleImageResolutionChanged}
|
||||||
|
handleReset={handleImageResolutionReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISlider
|
||||||
|
label="W"
|
||||||
|
value={thr_d}
|
||||||
|
onChange={handleThrDChanged}
|
||||||
|
handleReset={handleThrDReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
step={0.01}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISlider
|
||||||
|
label="H"
|
||||||
|
value={thr_v}
|
||||||
|
onChange={handleThrVChanged}
|
||||||
|
handleReset={handleThrVReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
step={0.01}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
</ProcessorWrapper>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(MlsdImageProcessor);
|
@ -0,0 +1,72 @@
|
|||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
|
import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
|
||||||
|
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default;
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
controlNetId: string;
|
||||||
|
processorNode: RequiredNormalbaeImageProcessorInvocation;
|
||||||
|
};
|
||||||
|
|
||||||
|
const NormalBaeProcessor = (props: Props) => {
|
||||||
|
const { controlNetId, processorNode } = props;
|
||||||
|
const { image_resolution, detect_resolution } = processorNode;
|
||||||
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
|
||||||
|
const handleDetectResolutionChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { detect_resolution: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleImageResolutionChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { image_resolution: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleDetectResolutionReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
detect_resolution: DEFAULTS.detect_resolution,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleImageResolutionReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
image_resolution: DEFAULTS.image_resolution,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ProcessorWrapper>
|
||||||
|
<IAISlider
|
||||||
|
label="Detect Resolution"
|
||||||
|
value={detect_resolution}
|
||||||
|
onChange={handleDetectResolutionChanged}
|
||||||
|
handleReset={handleDetectResolutionReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISlider
|
||||||
|
label="Image Resolution"
|
||||||
|
value={image_resolution}
|
||||||
|
onChange={handleImageResolutionChanged}
|
||||||
|
handleReset={handleImageResolutionReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
</ProcessorWrapper>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(NormalBaeProcessor);
|
@ -0,0 +1,85 @@
|
|||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
|
import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
|
||||||
|
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default;
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
controlNetId: string;
|
||||||
|
processorNode: RequiredOpenposeImageProcessorInvocation;
|
||||||
|
};
|
||||||
|
|
||||||
|
const OpenposeProcessor = (props: Props) => {
|
||||||
|
const { controlNetId, processorNode } = props;
|
||||||
|
const { image_resolution, detect_resolution, hand_and_face } = processorNode;
|
||||||
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
|
||||||
|
const handleDetectResolutionChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { detect_resolution: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleImageResolutionChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { image_resolution: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleDetectResolutionReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
detect_resolution: DEFAULTS.detect_resolution,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleImageResolutionReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
image_resolution: DEFAULTS.image_resolution,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleHandAndFaceChanged = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
processorChanged(controlNetId, { hand_and_face: e.target.checked });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ProcessorWrapper>
|
||||||
|
<IAISlider
|
||||||
|
label="Detect Resolution"
|
||||||
|
value={detect_resolution}
|
||||||
|
onChange={handleDetectResolutionChanged}
|
||||||
|
handleReset={handleDetectResolutionReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISlider
|
||||||
|
label="Image Resolution"
|
||||||
|
value={image_resolution}
|
||||||
|
onChange={handleImageResolutionChanged}
|
||||||
|
handleReset={handleImageResolutionReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISwitch
|
||||||
|
label="Hand and Face"
|
||||||
|
isChecked={hand_and_face}
|
||||||
|
onChange={handleHandAndFaceChanged}
|
||||||
|
/>
|
||||||
|
</ProcessorWrapper>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(OpenposeProcessor);
|
@ -0,0 +1,93 @@
|
|||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
|
import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
|
||||||
|
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default;
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
controlNetId: string;
|
||||||
|
processorNode: RequiredPidiImageProcessorInvocation;
|
||||||
|
};
|
||||||
|
|
||||||
|
const PidiProcessor = (props: Props) => {
|
||||||
|
const { controlNetId, processorNode } = props;
|
||||||
|
const { image_resolution, detect_resolution, scribble, safe } = processorNode;
|
||||||
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
|
||||||
|
const handleDetectResolutionChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { detect_resolution: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleImageResolutionChanged = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { image_resolution: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleDetectResolutionReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
detect_resolution: DEFAULTS.detect_resolution,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleImageResolutionReset = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, {
|
||||||
|
image_resolution: DEFAULTS.image_resolution,
|
||||||
|
});
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
|
const handleScribbleChanged = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
processorChanged(controlNetId, { scribble: e.target.checked });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleSafeChanged = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
processorChanged(controlNetId, { safe: e.target.checked });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ProcessorWrapper>
|
||||||
|
<IAISlider
|
||||||
|
label="Detect Resolution"
|
||||||
|
value={detect_resolution}
|
||||||
|
onChange={handleDetectResolutionChanged}
|
||||||
|
handleReset={handleDetectResolutionReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISlider
|
||||||
|
label="Image Resolution"
|
||||||
|
value={image_resolution}
|
||||||
|
onChange={handleImageResolutionChanged}
|
||||||
|
handleReset={handleImageResolutionReset}
|
||||||
|
withReset
|
||||||
|
min={0}
|
||||||
|
max={4096}
|
||||||
|
withInput
|
||||||
|
/>
|
||||||
|
<IAISwitch
|
||||||
|
label="Scribble"
|
||||||
|
isChecked={scribble}
|
||||||
|
onChange={handleScribbleChanged}
|
||||||
|
/>
|
||||||
|
<IAISwitch label="Safe" isChecked={safe} onChange={handleSafeChanged} />
|
||||||
|
</ProcessorWrapper>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(PidiProcessor);
|
@ -0,0 +1,14 @@
|
|||||||
|
import { RequiredZoeDepthImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
controlNetId: string;
|
||||||
|
processorNode: RequiredZoeDepthImageProcessorInvocation;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ZoeDepthProcessor = (props: Props) => {
|
||||||
|
// Has no parameters?
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ZoeDepthProcessor);
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user