Merge branch 'main' into lstein/new-model-manager

This commit is contained in:
StAlKeR7779 2023-06-13 23:37:52 +03:00 committed by GitHub
commit c9ae26a176
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 784 additions and 181 deletions

View File

@ -1,11 +1,12 @@
# InvokeAI nodes for ControlNet image preprocessors # InvokeAI nodes for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023 # initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import float
import numpy as np import numpy as np
from typing import Literal, Optional, Union, List from typing import Literal, Optional, Union, List
from PIL import Image, ImageFilter, ImageOps from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, validator
from ..models.image import ImageField, ImageCategory, ResourceOrigin from ..models.image import ImageField, ImageCategory, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import (
@ -14,6 +15,7 @@ from .baseinvocation import (
InvocationContext, InvocationContext,
InvocationConfig, InvocationConfig,
) )
from controlnet_aux import ( from controlnet_aux import (
CannyDetector, CannyDetector,
HEDdetector, HEDdetector,
@ -96,15 +98,32 @@ CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image") image: ImageField = Field(default=None, description="The control image")
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use") control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
control_weight: Optional[float] = Field(default=1, description="The weight given to the ControlNet") # control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1, begin_step_percent: float = Field(default=0, ge=0, le=1,
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
@validator("control_weight")
def abs_le_one(cls, v):
"""validate that all abs(values) are <=1"""
if isinstance(v, list):
for i in v:
if abs(i) > 1:
raise ValueError('all abs(control_weight) must be <= 1')
else:
if abs(v) > 1:
raise ValueError('abs(control_weight) must be <= 1')
return v
class Config: class Config:
schema_extra = { schema_extra = {
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"] "required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
"ui": {
"type_hints": {
"control_weight": "float",
# "control_weight": "number",
}
}
} }
@ -112,7 +131,7 @@ class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info""" """node output for ControlNet info"""
# fmt: off # fmt: off
type: Literal["control_output"] = "control_output" type: Literal["control_output"] = "control_output"
control: ControlField = Field(default=None, description="The output control image") control: ControlField = Field(default=None, description="The control info")
# fmt: on # fmt: on
@ -123,15 +142,28 @@ class ControlNetInvocation(BaseInvocation):
# Inputs # Inputs
image: ImageField = Field(default=None, description="The control image") image: ImageField = Field(default=None, description="The control image")
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny", control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
description="The ControlNet model to use") description="control model used")
control_weight: float = Field(default=1.0, ge=0, le=1, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode # TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
begin_step_percent: float = Field(default=0, ge=0, le=1, begin_step_percent: float = Field(default=0, ge=0, le=1,
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
# fmt: on # fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number",
"control_weight": "float",
}
},
}
def invoke(self, context: InvocationContext) -> ControlOutput: def invoke(self, context: InvocationContext) -> ControlOutput:
@ -161,7 +193,6 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
return image return image
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get_pil_image( raw_image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name self.image.image_origin, self.image.image_name
) )

View File

@ -3,8 +3,6 @@
from contextlib import ExitStack from contextlib import ExitStack
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
import einops
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
import torch import torch
from diffusers import ControlNetModel from diffusers import ControlNetModel
@ -173,23 +171,36 @@ class TextToLatentsInvocation(BaseInvocation):
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation") negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use") noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = Field(default=None, description="UNet submodel")
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on # fmt: on
@validator("cfg_scale")
def ge_one(cls, v):
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError('cfg_scale must be greater than 1')
else:
if v < 1:
raise ValueError('cfg_scale must be greater than 1')
return v
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {
"tags": ["latents", "image"], "tags": ["latents"],
"type_hints": { "type_hints": {
"model": "model", "model": "model",
"control": "control", "control": "control",
# "cfg_scale": "float",
"cfg_scale": "number"
} }
}, },
} }
@ -210,10 +221,10 @@ class TextToLatentsInvocation(BaseInvocation):
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
uc, unconditioned_embeddings=uc,
c, text_embeddings=c,
self.cfg_scale, guidance_scale=self.cfg_scale,
extra_conditioning_info, extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings( postprocessing_settings=PostprocessingSettings(
threshold=0.0,#threshold, threshold=0.0,#threshold,
warmup=0.2,#warmup, warmup=0.2,#warmup,
@ -351,10 +362,10 @@ class TextToLatentsInvocation(BaseInvocation):
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
print("type of control input: ", type(self.control)) control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
control_data = self.prep_control_data(model=pipeline, context=context, control_input=self.control, latents_shape=noise.shape,
latents_shape=noise.shape, # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=(self.cfg_scale >= 1.0)) do_classifier_free_guidance=True,)
with ModelPatcher.apply_lora_unet(pipeline.unet, loras): with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
@ -364,7 +375,7 @@ class TextToLatentsInvocation(BaseInvocation):
num_inference_steps=self.steps, num_inference_steps=self.steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData] control_data=control_data, # list[ControlNetData]
callback=step_callback callback=step_callback,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
@ -391,6 +402,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
"type_hints": { "type_hints": {
"model": "model", "model": "model",
"control": "control", "control": "control",
"cfg_scale": "number",
} }
}, },
} }
@ -421,6 +433,12 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
pipeline = self.create_pipeline(unet, scheduler) pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler)
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
)
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like( initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
@ -442,6 +460,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
noise=noise, noise=noise,
num_inference_steps=self.steps, num_inference_steps=self.steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback callback=step_callback
) )

View File

@ -0,0 +1,237 @@
import io
from typing import Literal, Optional, Any
# from PIL.Image import Image
import PIL.Image
from matplotlib.ticker import MaxNLocator
from matplotlib.figure import Figure
from pydantic import BaseModel, Field
import numpy as np
import matplotlib.pyplot as plt
from easing_functions import (
LinearInOut,
QuadEaseInOut, QuadEaseIn, QuadEaseOut,
CubicEaseInOut, CubicEaseIn, CubicEaseOut,
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
SineEaseInOut, SineEaseIn, SineEaseOut,
CircularEaseIn, CircularEaseInOut, CircularEaseOut,
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
BackEaseIn, BackEaseInOut, BackEaseOut,
BounceEaseIn, BounceEaseInOut, BounceEaseOut)
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from ...backend.util.logging import InvokeAILogger
from .collections import FloatCollectionOutput
class FloatLinearRangeInvocation(BaseInvocation):
"""Creates a range"""
type: Literal["float_range"] = "float_range"
# Inputs
start: float = Field(default=5, description="The first value of the range")
stop: float = Field(default=10, description="The last value of the range")
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
param_list = list(np.linspace(self.start, self.stop, self.steps))
return FloatCollectionOutput(
collection=param_list
)
EASING_FUNCTIONS_MAP = {
"Linear": LinearInOut,
"QuadIn": QuadEaseIn,
"QuadOut": QuadEaseOut,
"QuadInOut": QuadEaseInOut,
"CubicIn": CubicEaseIn,
"CubicOut": CubicEaseOut,
"CubicInOut": CubicEaseInOut,
"QuarticIn": QuarticEaseIn,
"QuarticOut": QuarticEaseOut,
"QuarticInOut": QuarticEaseInOut,
"QuinticIn": QuinticEaseIn,
"QuinticOut": QuinticEaseOut,
"QuinticInOut": QuinticEaseInOut,
"SineIn": SineEaseIn,
"SineOut": SineEaseOut,
"SineInOut": SineEaseInOut,
"CircularIn": CircularEaseIn,
"CircularOut": CircularEaseOut,
"CircularInOut": CircularEaseInOut,
"ExponentialIn": ExponentialEaseIn,
"ExponentialOut": ExponentialEaseOut,
"ExponentialInOut": ExponentialEaseInOut,
"ElasticIn": ElasticEaseIn,
"ElasticOut": ElasticEaseOut,
"ElasticInOut": ElasticEaseInOut,
"BackIn": BackEaseIn,
"BackOut": BackEaseOut,
"BackInOut": BackEaseInOut,
"BounceIn": BounceEaseIn,
"BounceOut": BounceEaseOut,
"BounceInOut": BounceEaseInOut,
}
EASING_FUNCTION_KEYS: Any = Literal[
tuple(list(EASING_FUNCTIONS_MAP.keys()))
]
# actually I think for now could just use CollectionOutput (which is list[Any]
class StepParamEasingInvocation(BaseInvocation):
"""Experimental per-step parameter easing for denoising steps"""
type: Literal["step_param_easing"] = "step_param_easing"
# Inputs
# fmt: off
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
num_steps: int = Field(default=20, description="number of denoising steps")
start_value: float = Field(default=0.0, description="easing starting value")
end_value: float = Field(default=1.0, description="easing ending value")
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
# if None, then start_value is used prior to easing start
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
# if None, then end value is used prior to easing end
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
mirror: bool = Field(default=False, description="include mirror of easing function")
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
show_easing_plot: bool = Field(default=False, description="show easing plot")
# fmt: on
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
log_diagnostics = False
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
start_step = int(np.round(self.num_steps * self.start_step_percent))
# convert from end_step_percent to nearest step >= (steps * end_step_percent)
# end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))
# end_step = int(np.ceil(self.num_steps * self.end_step_percent))
num_easing_steps = end_step - start_step + 1
# num_presteps = max(start_step - 1, 0)
num_presteps = start_step
num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
prelist = list(num_presteps * [self.pre_start_value])
postlist = list(num_poststeps * [self.post_end_value])
if log_diagnostics:
logger = InvokeAILogger.getLogger(name="StepParamEasing")
logger.debug("start_step: " + str(start_step))
logger.debug("end_step: " + str(end_step))
logger.debug("num_easing_steps: " + str(num_easing_steps))
logger.debug("num_presteps: " + str(num_presteps))
logger.debug("num_poststeps: " + str(num_poststeps))
logger.debug("prelist size: " + str(len(prelist)))
logger.debug("postlist size: " + str(len(postlist)))
logger.debug("prelist: " + str(prelist))
logger.debug("postlist: " + str(postlist))
easing_class = EASING_FUNCTIONS_MAP[self.easing]
if log_diagnostics:
logger.debug("easing class: " + str(easing_class))
easing_list = list()
if self.mirror: # "expected" mirroring
# if number of steps is even, squeeze duration down to (number_of_steps)/2
# and create reverse copy of list to append
# if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
# and create reverse copy of list[1:end-1]
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration))
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
easing_function = easing_class(start=self.start_value,
end=self.end_value,
duration=base_easing_duration - 1)
base_easing_vals = list()
for step_index in range(base_easing_duration):
easing_val = easing_function.ease(step_index)
base_easing_vals.append(easing_val)
if log_diagnostics:
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
if even_num_steps:
mirror_easing_vals = list(reversed(base_easing_vals))
else:
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
if log_diagnostics:
logger.debug("base easing vals: " + str(base_easing_vals))
logger.debug("mirror easing vals: " + str(mirror_easing_vals))
easing_list = base_easing_vals + mirror_easing_vals
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
# elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me))
# # half_ease_duration = round(num_easing_steps - 1 / 2)
# half_ease_duration = round((num_easing_steps - 1) / 2)
# easing_function = easing_class(start=self.start_value,
# end=self.end_value,
# duration=half_ease_duration,
# )
#
# mirror_function = easing_class(start=self.end_value,
# end=self.start_value,
# duration=half_ease_duration,
# )
# for step_index in range(num_easing_steps):
# if step_index <= half_ease_duration:
# step_val = easing_function.ease(step_index)
# else:
# step_val = mirror_function.ease(step_index - half_ease_duration)
# easing_list.append(step_val)
# if log_diagnostics: logger.debug(step_index, step_val)
#
else: # no mirroring (default)
easing_function = easing_class(start=self.start_value,
end=self.end_value,
duration=num_easing_steps - 1)
for step_index in range(num_easing_steps):
step_val = easing_function.ease(step_index)
easing_list.append(step_val)
if log_diagnostics:
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
if log_diagnostics:
logger.debug("prelist size: " + str(len(prelist)))
logger.debug("easing_list size: " + str(len(easing_list)))
logger.debug("postlist size: " + str(len(postlist)))
param_list = prelist + easing_list + postlist
if self.show_easing_plot:
plt.figure()
plt.xlabel("Step")
plt.ylabel("Param Value")
plt.title("Per-Step Values Based On Easing: " + self.easing)
plt.bar(range(len(param_list)), param_list)
# plt.plot(param_list)
ax = plt.gca()
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
im = PIL.Image.open(buf)
im.show()
buf.close()
# output array of size steps, each entry list[i] is param value for step i
return FloatCollectionOutput(
collection=param_list
)

View File

@ -1,4 +1,4 @@
from typing import Optional from typing import Optional, Union, List
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
@ -47,7 +47,9 @@ class ImageMetadata(BaseModel):
default=None, description="The seed used for noise generation." default=None, description="The seed used for noise generation."
) )
"""The seed used for noise generation""" """The seed used for noise generation"""
cfg_scale: Optional[StrictFloat] = Field( # cfg_scale: Optional[StrictFloat] = Field(
# cfg_scale: Union[float, list[float]] = Field(
cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field(
default=None, description="The classifier-free guidance scale." default=None, description="The classifier-free guidance scale."
) )
"""The classifier-free guidance scale""" """The classifier-free guidance scale"""

View File

@ -65,7 +65,6 @@ from typing import Optional, Union, List, get_args
def is_union_subtype(t1, t2): def is_union_subtype(t1, t2):
t1_args = get_args(t1) t1_args = get_args(t1)
t2_args = get_args(t2) t2_args = get_args(t2)
if not t1_args: if not t1_args:
# t1 is a single type # t1 is a single type
return t1 in t2_args return t1 in t2_args
@ -86,7 +85,6 @@ def is_list_or_contains_list(t):
for arg in t_args: for arg in t_args:
if get_origin(arg) is list: if get_origin(arg) is list:
return True return True
return False return False
@ -393,7 +391,7 @@ class Graph(BaseModel):
from_node = self.get_node(edge.source.node_id) from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id) to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError: except NodeNotFoundError:
raise InvalidEdgeError("One or both nodes don't exist") raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}")
# Validate that an edge to this node+field doesn't already exist # Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field) input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
@ -404,41 +402,41 @@ class Graph(BaseModel):
g = self.nx_graph_flat() g = self.nx_graph_flat()
g.add_edge(edge.source.node_id, edge.destination.node_id) g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g): if not nx.is_directed_acyclic_graph(g):
raise InvalidEdgeError(f'Edge creates a cycle in the graph') raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}')
# Validate that the field types are compatible # Validate that the field types are compatible
if not are_connections_compatible( if not are_connections_compatible(
from_node, edge.source.field, to_node, edge.destination.field from_node, edge.source.field, to_node, edge.destination.field
): ):
raise InvalidEdgeError(f'Fields are incompatible') raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if iterator output type matches iterator input type (if this edge results in both being set) # Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection": if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if not self._is_iterator_connection_valid( if not self._is_iterator_connection_valid(
edge.destination.node_id, new_input=edge.source edge.destination.node_id, new_input=edge.source
): ):
raise InvalidEdgeError(f'Iterator input type does not match iterator output type') raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if iterator input type matches output type (if this edge results in both being set) # Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge.source.field == "item": if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if not self._is_iterator_connection_valid( if not self._is_iterator_connection_valid(
edge.source.node_id, new_output=edge.destination edge.source.node_id, new_output=edge.destination
): ):
raise InvalidEdgeError(f'Iterator output type does not match iterator input type') raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if collector input type matches output type (if this edge results in both being set) # Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item": if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if not self._is_collector_connection_valid( if not self._is_collector_connection_valid(
edge.destination.node_id, new_input=edge.source edge.destination.node_id, new_input=edge.source
): ):
raise InvalidEdgeError(f'Collector output type does not match collector input type') raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if collector output type matches input type (if this edge results in both being set) # Validate if collector output type matches input type (if this edge results in both being set)
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection": if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
if not self._is_collector_connection_valid( if not self._is_collector_connection_valid(
edge.source.node_id, new_output=edge.destination edge.source.node_id, new_output=edge.destination
): ):
raise InvalidEdgeError(f'Collector input type does not match collector output type') raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
def has_node(self, node_path: str) -> bool: def has_node(self, node_path: str) -> bool:

View File

@ -6,7 +6,7 @@ import torch
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel from diffusers.pipelines.controlnet import MultiControlNetModel
from ..stable_diffusion import ( from ..stable_diffusion import (
ConditioningData, ConditioningData,

View File

@ -23,7 +23,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline, StableDiffusionPipeline,
) )
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel from diffusers.pipelines.controlnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
@ -218,7 +218,7 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
class ControlNetData: class ControlNetData:
model: ControlNetModel = Field(default=None) model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor= Field(default=None) image_tensor: torch.Tensor= Field(default=None)
weight: float = Field(default=1.0) weight: Union[float, List[float]]= Field(default=1.0)
begin_step_percent: float = Field(default=0.0) begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0) end_step_percent: float = Field(default=1.0)
@ -226,7 +226,7 @@ class ControlNetData:
class ConditioningData: class ConditioningData:
unconditioned_embeddings: torch.Tensor unconditioned_embeddings: torch.Tensor
text_embeddings: torch.Tensor text_embeddings: torch.Tensor
guidance_scale: float guidance_scale: Union[float, List[float]]
""" """
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
@ -662,7 +662,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
down_block_res_samples, mid_block_res_sample = None, None down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None: if control_data is not None:
if conditioning_data.guidance_scale > 1.0: # FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting
# if conditioning_data.guidance_scale > 1.0:
if conditioning_data.guidance_scale is not None:
# expand the latents input to control model if doing classifier free guidance # expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if # (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?) # classifier_free_guidance is <= 1.0 ?)
@ -679,13 +681,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# only apply controlnet if current step is within the controlnet's begin/end step range # only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step: if step_index >= first_control_step and step_index <= last_control_step:
# print("running controlnet", i, "for step", step_index) # print("running controlnet", i, "for step", step_index)
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index]
else:
# if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight
down_samples, mid_sample = control_datum.model( down_samples, mid_sample = control_datum.model(
sample=latent_control_input, sample=latent_control_input,
timestep=timestep, timestep=timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]), conditioning_data.text_embeddings]),
controlnet_cond=control_datum.image_tensor, controlnet_cond=control_datum.image_tensor,
conditioning_scale=control_datum.weight, conditioning_scale=controlnet_weight,
# cross_attention_kwargs, # cross_attention_kwargs,
guess_mode=False, guess_mode=False,
return_dict=False, return_dict=False,

View File

@ -1,7 +1,7 @@
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from math import ceil from math import ceil
from typing import Any, Callable, Dict, Optional, Union from typing import Any, Callable, Dict, Optional, Union, List
import numpy as np import numpy as np
import torch import torch
@ -180,7 +180,8 @@ class InvokeAIDiffuserComponent:
sigma: torch.Tensor, sigma: torch.Tensor,
unconditioning: Union[torch.Tensor, dict], unconditioning: Union[torch.Tensor, dict],
conditioning: Union[torch.Tensor, dict], conditioning: Union[torch.Tensor, dict],
unconditional_guidance_scale: float, # unconditional_guidance_scale: float,
unconditional_guidance_scale: Union[float, List[float]],
step_index: Optional[int] = None, step_index: Optional[int] = None,
total_step_count: Optional[int] = None, total_step_count: Optional[int] = None,
**kwargs, **kwargs,
@ -195,6 +196,11 @@ class InvokeAIDiffuserComponent:
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning. :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
""" """
if isinstance(unconditional_guidance_scale, list):
guidance_scale = unconditional_guidance_scale[step_index]
else:
guidance_scale = unconditional_guidance_scale
cross_attention_control_types_to_do = [] cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None: if self.cross_attention_control_context is not None:
@ -243,7 +249,8 @@ class InvokeAIDiffuserComponent:
) )
combined_next_x = self._combine( combined_next_x = self._combine(
unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale # unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
unconditioned_next_x, conditioned_next_x, guidance_scale
) )
return combined_next_x return combined_next_x
@ -497,7 +504,7 @@ class InvokeAIDiffuserComponent:
logger.debug( logger.debug(
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}" f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
) )
logger.debug( logger.debug(
f"{outside / latents.numel() * 100:.2f}% values outside threshold" f"{outside / latents.numel() * 100:.2f}% values outside threshold"
) )

View File

@ -1,3 +1,7 @@
import {
CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
} from 'features/controlNet/store/constants';
import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName } from 'features/ui/store/tabMap';
import { O } from 'ts-toolbelt'; import { O } from 'ts-toolbelt';
@ -117,6 +121,8 @@ export type AppConfig = {
canRestoreDeletedImagesFromBin: boolean; canRestoreDeletedImagesFromBin: boolean;
sd: { sd: {
defaultModel?: string; defaultModel?: string;
disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[];
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
iterations: { iterations: {
initial: number; initial: number;
min: number; min: number;

View File

@ -18,7 +18,7 @@ import { useSelect } from 'downshift';
import { isString } from 'lodash-es'; import { isString } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useMemo } from 'react'; import { memo, useLayoutEffect, useMemo } from 'react';
import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles'; import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles';
export type ItemTooltips = { [key: string]: string }; export type ItemTooltips = { [key: string]: string };
@ -39,6 +39,7 @@ type IAICustomSelectProps = {
tooltip?: string; tooltip?: string;
tooltipProps?: Omit<TooltipProps, 'children'>; tooltipProps?: Omit<TooltipProps, 'children'>;
ellipsisPosition?: 'start' | 'end'; ellipsisPosition?: 'start' | 'end';
isDisabled?: boolean;
}; };
const IAICustomSelect = (props: IAICustomSelectProps) => { const IAICustomSelect = (props: IAICustomSelectProps) => {
@ -52,6 +53,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
data, data,
value, value,
onChange, onChange,
isDisabled = false,
} = props; } = props;
const values = useMemo(() => { const values = useMemo(() => {
@ -86,11 +88,17 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
}, },
}); });
const { refs, floatingStyles } = useFloating<HTMLButtonElement>({ const { refs, floatingStyles, update } = useFloating<HTMLButtonElement>({
whileElementsMounted: autoUpdate, // whileElementsMounted: autoUpdate,
middleware: [offset(4), shift({ crossAxis: true, padding: 8 })], middleware: [offset(4), shift({ crossAxis: true, padding: 8 })],
}); });
useLayoutEffect(() => {
if (isOpen && refs.reference.current && refs.floating.current) {
return autoUpdate(refs.reference.current, refs.floating.current, update);
}
}, [isOpen, update, refs.floating, refs.reference]);
const labelTextDirection = useMemo(() => { const labelTextDirection = useMemo(() => {
if (ellipsisPosition === 'start') { if (ellipsisPosition === 'start') {
return document.dir === 'rtl' ? 'ltr' : 'rtl'; return document.dir === 'rtl' ? 'ltr' : 'rtl';
@ -124,6 +132,8 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
px: 2, px: 2,
gap: 2, gap: 2,
justifyContent: 'space-between', justifyContent: 'space-between',
pointerEvents: isDisabled ? 'none' : undefined,
opacity: isDisabled ? 0.5 : undefined,
...getInputOutlineStyles(), ...getInputOutlineStyles(),
}} }}
> >

View File

@ -1,4 +1,5 @@
import { import {
ChakraProps,
FormControl, FormControl,
FormControlProps, FormControlProps,
FormLabel, FormLabel,
@ -39,6 +40,11 @@ import { BiReset } from 'react-icons/bi';
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton'; import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
import { roundDownToMultiple } from 'common/util/roundDownToMultiple'; import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
};
export type IAIFullSliderProps = { export type IAIFullSliderProps = {
label?: string; label?: string;
value: number; value: number;
@ -57,6 +63,7 @@ export type IAIFullSliderProps = {
hideTooltip?: boolean; hideTooltip?: boolean;
isCompact?: boolean; isCompact?: boolean;
isDisabled?: boolean; isDisabled?: boolean;
sliderMarks?: number[];
sliderFormControlProps?: FormControlProps; sliderFormControlProps?: FormControlProps;
sliderFormLabelProps?: FormLabelProps; sliderFormLabelProps?: FormLabelProps;
sliderMarkProps?: Omit<SliderMarkProps, 'value'>; sliderMarkProps?: Omit<SliderMarkProps, 'value'>;
@ -88,6 +95,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
hideTooltip = false, hideTooltip = false,
isCompact = false, isCompact = false,
isDisabled = false, isDisabled = false,
sliderMarks,
handleReset, handleReset,
sliderFormControlProps, sliderFormControlProps,
sliderFormLabelProps, sliderFormLabelProps,
@ -198,14 +206,14 @@ const IAISlider = (props: IAIFullSliderProps) => {
isDisabled={isDisabled} isDisabled={isDisabled}
{...rest} {...rest}
> >
{withSliderMarks && ( {withSliderMarks && !sliderMarks && (
<> <>
<SliderMark <SliderMark
value={min} value={min}
sx={{ sx={{
insetInlineStart: '0 !important', insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important', insetInlineEnd: 'unset !important',
mt: 1.5, ...SLIDER_MARK_STYLES,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >
@ -216,7 +224,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{ sx={{
insetInlineStart: 'unset !important', insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important', insetInlineEnd: '0 !important',
mt: 1.5, ...SLIDER_MARK_STYLES,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >
@ -224,6 +232,56 @@ const IAISlider = (props: IAIFullSliderProps) => {
</SliderMark> </SliderMark>
</> </>
)} )}
{withSliderMarks && sliderMarks && (
<>
{sliderMarks.map((m, i) => {
if (i === 0) {
return (
<SliderMark
key={m}
value={m}
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
{m}
</SliderMark>
);
} else if (i === sliderMarks.length - 1) {
return (
<SliderMark
key={m}
value={m}
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
{m}
</SliderMark>
);
} else {
return (
<SliderMark
key={m}
value={m}
sx={{
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
{m}
</SliderMark>
);
}
})}
</>
)}
<SliderTrack {...sliderTrackProps}> <SliderTrack {...sliderTrackProps}>
<SliderFilledTrack /> <SliderFilledTrack />

View File

@ -143,7 +143,7 @@ const ControlNet = (props: ControlNetProps) => {
flexDir: 'column', flexDir: 'column',
gap: 2, gap: 2,
w: 'full', w: 'full',
h: 24, h: isExpanded ? 28 : 24,
paddingInlineStart: 1, paddingInlineStart: 1,
paddingInlineEnd: isExpanded ? 1 : 0, paddingInlineEnd: isExpanded ? 1 : 0,
pb: 2, pb: 2,
@ -153,13 +153,13 @@ const ControlNet = (props: ControlNetProps) => {
<ParamControlNetWeight <ParamControlNetWeight
controlNetId={controlNetId} controlNetId={controlNetId}
weight={weight} weight={weight}
mini mini={!isExpanded}
/> />
<ParamControlNetBeginEnd <ParamControlNetBeginEnd
controlNetId={controlNetId} controlNetId={controlNetId}
beginStepPct={beginStepPct} beginStepPct={beginStepPct}
endStepPct={endStepPct} endStepPct={endStepPct}
mini mini={!isExpanded}
/> />
</Flex> </Flex>
{!isExpanded && ( {!isExpanded && (

View File

@ -1,5 +1,6 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice'; import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
@ -11,7 +12,7 @@ type Props = {
const ParamControlNetShouldAutoConfig = (props: Props) => { const ParamControlNetShouldAutoConfig = (props: Props) => {
const { controlNetId, shouldAutoConfig } = props; const { controlNetId, shouldAutoConfig } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const handleShouldAutoConfigChanged = useCallback(() => { const handleShouldAutoConfigChanged = useCallback(() => {
dispatch(controlNetAutoConfigToggled({ controlNetId })); dispatch(controlNetAutoConfigToggled({ controlNetId }));
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
@ -22,6 +23,7 @@ const ParamControlNetShouldAutoConfig = (props: Props) => {
aria-label="Auto configure processor" aria-label="Auto configure processor"
isChecked={shouldAutoConfig} isChecked={shouldAutoConfig}
onChange={handleShouldAutoConfigChanged} onChange={handleShouldAutoConfigChanged}
isDisabled={!isReady}
/> />
); );
}; };

View File

@ -1,4 +1,5 @@
import { import {
ChakraProps,
FormControl, FormControl,
FormLabel, FormLabel,
HStack, HStack,
@ -10,14 +11,19 @@ import {
Tooltip, Tooltip,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { import {
controlNetBeginStepPctChanged, controlNetBeginStepPctChanged,
controlNetEndStepPctChanged, controlNetEndStepPctChanged,
} from 'features/controlNet/store/controlNetSlice'; } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { BiReset } from 'react-icons/bi';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
fontWeight: '500',
color: 'base.400',
};
type Props = { type Props = {
controlNetId: string; controlNetId: string;
@ -29,7 +35,7 @@ type Props = {
const formatPct = (v: number) => `${Math.round(v * 100)}%`; const formatPct = (v: number) => `${Math.round(v * 100)}%`;
const ParamControlNetBeginEnd = (props: Props) => { const ParamControlNetBeginEnd = (props: Props) => {
const { controlNetId, beginStepPct, endStepPct, mini = false } = props; const { controlNetId, beginStepPct, mini = false, endStepPct } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -75,12 +81,9 @@ const ParamControlNetBeginEnd = (props: Props) => {
<RangeSliderMark <RangeSliderMark
value={0} value={0}
sx={{ sx={{
fontSize: 'xs',
fontWeight: '500',
color: 'base.200',
insetInlineStart: '0 !important', insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important', insetInlineEnd: 'unset !important',
mt: 1.5, ...SLIDER_MARK_STYLES,
}} }}
> >
0% 0%
@ -88,10 +91,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
<RangeSliderMark <RangeSliderMark
value={0.5} value={0.5}
sx={{ sx={{
fontSize: 'xs', ...SLIDER_MARK_STYLES,
fontWeight: '500',
color: 'base.200',
mt: 1.5,
}} }}
> >
50% 50%
@ -99,12 +99,9 @@ const ParamControlNetBeginEnd = (props: Props) => {
<RangeSliderMark <RangeSliderMark
value={1} value={1}
sx={{ sx={{
fontSize: 'xs',
fontWeight: '500',
color: 'base.200',
insetInlineStart: 'unset !important', insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important', insetInlineEnd: '0 !important',
mt: 1.5, ...SLIDER_MARK_STYLES,
}} }}
> >
100% 100%
@ -112,16 +109,6 @@ const ParamControlNetBeginEnd = (props: Props) => {
</> </>
)} )}
</RangeSlider> </RangeSlider>
{!mini && (
<IAIIconButton
size="sm"
aria-label={t('accessibility.reset')}
tooltip={t('accessibility.reset')}
icon={<BiReset />}
onClick={handleStepPctReset}
/>
)}
</HStack> </HStack>
</FormControl> </FormControl>
); );

View File

@ -1,50 +1,85 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAICustomSelect, { import IAICustomSelect, {
IAICustomSelectOption, IAICustomSelectOption,
} from 'common/components/IAICustomSelect'; } from 'common/components/IAICustomSelect';
import IAISelect from 'common/components/IAISelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { import {
CONTROLNET_MODELS, CONTROLNET_MODELS,
ControlNetModelName, ControlNetModelName,
} from 'features/controlNet/store/constants'; } from 'features/controlNet/store/constants';
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import { configSelector } from 'features/system/store/configSelectors';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
type ParamControlNetModelProps = { type ParamControlNetModelProps = {
controlNetId: string; controlNetId: string;
model: ControlNetModelName; model: ControlNetModelName;
}; };
const DATA: IAICustomSelectOption[] = map(CONTROLNET_MODELS, (m) => ({ const selector = createSelector(configSelector, (config) => {
value: m.type, return map(CONTROLNET_MODELS, (m) => ({
label: m.label, key: m.label,
tooltip: m.type, value: m.type,
})); })).filter((d) => !config.sd.disabledControlNetModels.includes(d.value));
});
// const DATA: IAICustomSelectOption[] = map(CONTROLNET_MODELS, (m) => ({
// value: m.type,
// label: m.label,
// tooltip: m.type,
// }));
const ParamControlNetModel = (props: ParamControlNetModelProps) => { const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId, model } = props; const { controlNetId, model } = props;
const controlNetModels = useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const handleModelChanged = useCallback( const handleModelChanged = useCallback(
(val: string | null | undefined) => { (e: ChangeEvent<HTMLSelectElement>) => {
// TODO: do not cast // TODO: do not cast
const model = val as ControlNetModelName; const model = e.target.value as ControlNetModelName;
dispatch(controlNetModelChanged({ controlNetId, model })); dispatch(controlNetModelChanged({ controlNetId, model }));
}, },
[controlNetId, dispatch] [controlNetId, dispatch]
); );
// const handleModelChanged = useCallback(
// (val: string | null | undefined) => {
// // TODO: do not cast
// const model = val as ControlNetModelName;
// dispatch(controlNetModelChanged({ controlNetId, model }));
// },
// [controlNetId, dispatch]
// );
return ( return (
<IAICustomSelect <IAISelect
tooltip={model} tooltip={model}
tooltipProps={{ placement: 'top', hasArrow: true }} tooltipProps={{ placement: 'top', hasArrow: true }}
data={DATA} validValues={controlNetModels}
value={model} value={model}
onChange={handleModelChanged} onChange={handleModelChanged}
ellipsisPosition="start" isDisabled={!isReady}
withCheckIcon // ellipsisPosition="start"
// withCheckIcon
/> />
); );
// return (
// <IAICustomSelect
// tooltip={model}
// tooltipProps={{ placement: 'top', hasArrow: true }}
// data={DATA}
// value={model}
// onChange={handleModelChanged}
// isDisabled={!isReady}
// ellipsisPosition="start"
// withCheckIcon
// />
// );
}; };
export default memo(ParamControlNetModel); export default memo(ParamControlNetModel);

View File

@ -1,62 +1,115 @@
import IAICustomSelect, { import IAICustomSelect, {
IAICustomSelectOption, IAICustomSelectOption,
} from 'common/components/IAICustomSelect'; } from 'common/components/IAICustomSelect';
import { memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { import {
ControlNetProcessorNode, ControlNetProcessorNode,
ControlNetProcessorType, ControlNetProcessorType,
} from '../../store/types'; } from '../../store/types';
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice'; import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { CONTROLNET_PROCESSORS } from '../../store/constants'; import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import IAISelect from 'common/components/IAISelect';
import { createSelector } from '@reduxjs/toolkit';
import { configSelector } from 'features/system/store/configSelectors';
type ParamControlNetProcessorSelectProps = { type ParamControlNetProcessorSelectProps = {
controlNetId: string; controlNetId: string;
processorNode: ControlNetProcessorNode; processorNode: ControlNetProcessorNode;
}; };
const CONTROLNET_PROCESSOR_TYPES: IAICustomSelectOption[] = map( const CONTROLNET_PROCESSOR_TYPES = map(CONTROLNET_PROCESSORS, (p) => ({
CONTROLNET_PROCESSORS, value: p.type,
(p) => ({ key: p.label,
value: p.type, })).sort((a, b) =>
label: p.label,
tooltip: p.description,
})
).sort((a, b) =>
// sort 'none' to the top // sort 'none' to the top
a.value === 'none' a.value === 'none' ? -1 : b.value === 'none' ? 1 : a.key.localeCompare(b.key)
? -1
: b.value === 'none'
? 1
: a.label.localeCompare(b.label)
); );
const selector = createSelector(configSelector, (config) => {
return map(CONTROLNET_PROCESSORS, (p) => ({
value: p.type,
key: p.label,
}))
.sort((a, b) =>
// sort 'none' to the top
a.value === 'none'
? -1
: b.value === 'none'
? 1
: a.key.localeCompare(b.key)
)
.filter((d) => !config.sd.disabledControlNetProcessors.includes(d.value));
});
// const CONTROLNET_PROCESSOR_TYPES: IAICustomSelectOption[] = map(
// CONTROLNET_PROCESSORS,
// (p) => ({
// value: p.type,
// label: p.label,
// tooltip: p.description,
// })
// ).sort((a, b) =>
// // sort 'none' to the top
// a.value === 'none'
// ? -1
// : b.value === 'none'
// ? 1
// : a.label.localeCompare(b.label)
// );
const ParamControlNetProcessorSelect = ( const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps props: ParamControlNetProcessorSelectProps
) => { ) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const controlNetProcessors = useAppSelector(selector);
const handleProcessorTypeChanged = useCallback( const handleProcessorTypeChanged = useCallback(
(v: string | null | undefined) => { (e: ChangeEvent<HTMLSelectElement>) => {
dispatch( dispatch(
controlNetProcessorTypeChanged({ controlNetProcessorTypeChanged({
controlNetId, controlNetId,
processorType: v as ControlNetProcessorType, processorType: e.target.value as ControlNetProcessorType,
}) })
); );
}, },
[controlNetId, dispatch] [controlNetId, dispatch]
); );
// const handleProcessorTypeChanged = useCallback(
// (v: string | null | undefined) => {
// dispatch(
// controlNetProcessorTypeChanged({
// controlNetId,
// processorType: v as ControlNetProcessorType,
// })
// );
// },
// [controlNetId, dispatch]
// );
return ( return (
<IAICustomSelect <IAISelect
label="Processor" label="Processor"
value={processorNode.type ?? 'canny_image_processor'} value={processorNode.type ?? 'canny_image_processor'}
data={CONTROLNET_PROCESSOR_TYPES} validValues={controlNetProcessors}
onChange={handleProcessorTypeChanged} onChange={handleProcessorTypeChanged}
withCheckIcon isDisabled={!isReady}
/> />
); );
// return (
// <IAICustomSelect
// label="Processor"
// value={processorNode.type ?? 'canny_image_processor'}
// data={CONTROLNET_PROCESSOR_TYPES}
// onChange={handleProcessorTypeChanged}
// withCheckIcon
// isDisabled={!isReady}
// />
// );
}; };
export default memo(ParamControlNetProcessorSelect); export default memo(ParamControlNetProcessorSelect);

View File

@ -20,36 +20,17 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
[controlNetId, dispatch] [controlNetId, dispatch]
); );
const handleWeightReset = () => {
dispatch(controlNetWeightChanged({ controlNetId, weight: 1 }));
};
if (mini) {
return (
<IAISlider
label={'Weight'}
sliderFormLabelProps={{ pb: 1 }}
value={weight}
onChange={handleWeightChanged}
min={0}
max={1}
step={0.01}
/>
);
}
return ( return (
<IAISlider <IAISlider
label="Weight" label={'Weight'}
sliderFormLabelProps={{ pb: 2 }}
value={weight} value={weight}
onChange={handleWeightChanged} onChange={handleWeightChanged}
withInput min={-1}
withReset
handleReset={handleWeightReset}
withSliderMarks
min={0}
max={1} max={1}
step={0.01} step={0.01}
withSliderMarks={!mini}
sliderMarks={[-1, 0, 1]}
/> />
); );
}; };

View File

@ -4,6 +4,7 @@ import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default;
@ -15,6 +16,7 @@ type CannyProcessorProps = {
const CannyProcessor = (props: CannyProcessorProps) => { const CannyProcessor = (props: CannyProcessorProps) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { low_threshold, high_threshold } = processorNode; const { low_threshold, high_threshold } = processorNode;
const isReady = useIsReadyToInvoke();
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const handleLowThresholdChanged = useCallback( const handleLowThresholdChanged = useCallback(
@ -46,6 +48,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
return ( return (
<ProcessorWrapper> <ProcessorWrapper>
<IAISlider <IAISlider
isDisabled={!isReady}
label="Low Threshold" label="Low Threshold"
value={low_threshold} value={low_threshold}
onChange={handleLowThresholdChanged} onChange={handleLowThresholdChanged}
@ -54,8 +57,10 @@ const CannyProcessor = (props: CannyProcessorProps) => {
min={0} min={0}
max={255} max={255}
withInput withInput
withSliderMarks
/> />
<IAISlider <IAISlider
isDisabled={!isReady}
label="High Threshold" label="High Threshold"
value={high_threshold} value={high_threshold}
onChange={handleHighThresholdChanged} onChange={handleHighThresholdChanged}
@ -64,6 +69,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
min={0} min={0}
max={255} max={255}
withInput withInput
withSliderMarks
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -4,6 +4,7 @@ import { RequiredContentShuffleImageProcessorInvocation } from 'features/control
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default;
@ -16,6 +17,7 @@ const ContentShuffleProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, w, h, f } = processorNode; const { image_resolution, detect_resolution, w, h, f } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -93,6 +95,8 @@ const ContentShuffleProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -103,6 +107,8 @@ const ContentShuffleProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="W" label="W"
@ -113,6 +119,8 @@ const ContentShuffleProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="H" label="H"
@ -123,6 +131,8 @@ const ContentShuffleProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="F" label="F"
@ -133,6 +143,8 @@ const ContentShuffleProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -5,6 +5,7 @@ import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/t
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default;
@ -18,7 +19,7 @@ const HedPreprocessor = (props: HedProcessorProps) => {
controlNetId, controlNetId,
processorNode: { detect_resolution, image_resolution, scribble }, processorNode: { detect_resolution, image_resolution, scribble },
} = props; } = props;
const isReady = useIsReadyToInvoke();
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
@ -65,6 +66,8 @@ const HedPreprocessor = (props: HedProcessorProps) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -75,11 +78,14 @@ const HedPreprocessor = (props: HedProcessorProps) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISwitch <IAISwitch
label="Scribble" label="Scribble"
isChecked={scribble} isChecked={scribble}
onChange={handleScribbleChanged} onChange={handleScribbleChanged}
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -4,6 +4,7 @@ import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNe
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default;
@ -16,6 +17,7 @@ const LineartAnimeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution } = processorNode; const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -54,6 +56,8 @@ const LineartAnimeProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -64,6 +68,8 @@ const LineartAnimeProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -5,6 +5,7 @@ import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/sto
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default;
@ -17,6 +18,7 @@ const LineartProcessor = (props: LineartProcessorProps) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, coarse } = processorNode; const { image_resolution, detect_resolution, coarse } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -62,6 +64,8 @@ const LineartProcessor = (props: LineartProcessorProps) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -72,11 +76,14 @@ const LineartProcessor = (props: LineartProcessorProps) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISwitch <IAISwitch
label="Coarse" label="Coarse"
isChecked={coarse} isChecked={coarse}
onChange={handleCoarseChanged} onChange={handleCoarseChanged}
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -4,6 +4,7 @@ import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/st
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default;
@ -16,6 +17,7 @@ const MediapipeFaceProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { max_faces, min_confidence } = processorNode; const { max_faces, min_confidence } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleMaxFacesChanged = useCallback( const handleMaxFacesChanged = useCallback(
(v: number) => { (v: number) => {
@ -50,6 +52,8 @@ const MediapipeFaceProcessor = (props: Props) => {
min={1} min={1}
max={20} max={20}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Min Confidence" label="Min Confidence"
@ -61,6 +65,8 @@ const MediapipeFaceProcessor = (props: Props) => {
max={1} max={1}
step={0.01} step={0.01}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -4,6 +4,7 @@ import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default;
@ -16,6 +17,7 @@ const MidasDepthProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { a_mult, bg_th } = processorNode; const { a_mult, bg_th } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleAMultChanged = useCallback( const handleAMultChanged = useCallback(
(v: number) => { (v: number) => {
@ -51,6 +53,8 @@ const MidasDepthProcessor = (props: Props) => {
max={20} max={20}
step={0.01} step={0.01}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="bg_th" label="bg_th"
@ -62,6 +66,8 @@ const MidasDepthProcessor = (props: Props) => {
max={20} max={20}
step={0.01} step={0.01}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -4,6 +4,7 @@ import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default;
@ -16,6 +17,7 @@ const MlsdImageProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode; const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -76,6 +78,8 @@ const MlsdImageProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -86,6 +90,8 @@ const MlsdImageProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="W" label="W"
@ -97,6 +103,8 @@ const MlsdImageProcessor = (props: Props) => {
max={1} max={1}
step={0.01} step={0.01}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="H" label="H"
@ -108,6 +116,8 @@ const MlsdImageProcessor = (props: Props) => {
max={1} max={1}
step={0.01} step={0.01}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -4,6 +4,7 @@ import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/s
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default;
@ -16,6 +17,7 @@ const NormalBaeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution } = processorNode; const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -54,6 +56,8 @@ const NormalBaeProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -64,6 +68,8 @@ const NormalBaeProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -5,6 +5,7 @@ import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/st
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default;
@ -17,6 +18,7 @@ const OpenposeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, hand_and_face } = processorNode; const { image_resolution, detect_resolution, hand_and_face } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -62,6 +64,8 @@ const OpenposeProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -72,11 +76,14 @@ const OpenposeProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISwitch <IAISwitch
label="Hand and Face" label="Hand and Face"
isChecked={hand_and_face} isChecked={hand_and_face}
onChange={handleHandAndFaceChanged} onChange={handleHandAndFaceChanged}
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -5,6 +5,7 @@ import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default;
@ -17,6 +18,7 @@ const PidiProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, scribble, safe } = processorNode; const { image_resolution, detect_resolution, scribble, safe } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -69,6 +71,8 @@ const PidiProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -79,13 +83,20 @@ const PidiProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISwitch <IAISwitch
label="Scribble" label="Scribble"
isChecked={scribble} isChecked={scribble}
onChange={handleScribbleChanged} onChange={handleScribbleChanged}
/> />
<IAISwitch label="Safe" isChecked={safe} onChange={handleSafeChanged} /> <IAISwitch
label="Safe"
isChecked={safe}
onChange={handleSafeChanged}
isDisabled={!isReady}
/>
</ProcessorWrapper> </ProcessorWrapper>
); );
}; };

View File

@ -23,7 +23,7 @@ type ControlNetProcessorsDict = Record<
* *
* TODO: Generate from the OpenAPI schema * TODO: Generate from the OpenAPI schema
*/ */
export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = { export const CONTROLNET_PROCESSORS = {
none: { none: {
type: 'none', type: 'none',
label: 'none', label: 'none',
@ -129,7 +129,7 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
}, },
normalbae_image_processor: { normalbae_image_processor: {
type: 'normalbae_image_processor', type: 'normalbae_image_processor',
label: 'NormalBae', label: 'Normal BAE',
description: '', description: '',
default: { default: {
id: 'normalbae_image_processor', id: 'normalbae_image_processor',
@ -181,7 +181,7 @@ type ControlNetModel = {
defaultProcessor?: ControlNetProcessorType; defaultProcessor?: ControlNetProcessorType;
}; };
export const CONTROLNET_MODELS: Record<string, ControlNetModel> = { export const CONTROLNET_MODELS = {
'lllyasviel/control_v11p_sd15_canny': { 'lllyasviel/control_v11p_sd15_canny': {
type: 'lllyasviel/control_v11p_sd15_canny', type: 'lllyasviel/control_v11p_sd15_canny',
label: 'Canny', label: 'Canny',
@ -208,7 +208,7 @@ export const CONTROLNET_MODELS: Record<string, ControlNetModel> = {
}, },
'lllyasviel/control_v11p_sd15_seg': { 'lllyasviel/control_v11p_sd15_seg': {
type: 'lllyasviel/control_v11p_sd15_seg', type: 'lllyasviel/control_v11p_sd15_seg',
label: 'Segment Anything', label: 'Segmentation',
}, },
'lllyasviel/control_v11p_sd15_lineart': { 'lllyasviel/control_v11p_sd15_lineart': {
type: 'lllyasviel/control_v11p_sd15_lineart', type: 'lllyasviel/control_v11p_sd15_lineart',

View File

@ -21,6 +21,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
ColorField: 'color', ColorField: 'color',
ControlField: 'control', ControlField: 'control',
control: 'control', control: 'control',
cfg_scale: 'float',
control_weight: 'float',
}; };
const COLOR_TOKEN_VALUE = 500; const COLOR_TOKEN_VALUE = 500;

View File

@ -1,5 +1,5 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { forEach, size } from 'lodash-es'; import { filter, forEach, size } from 'lodash-es';
import { CollectInvocation, ControlNetInvocation } from 'services/api'; import { CollectInvocation, ControlNetInvocation } from 'services/api';
import { NonNullableGraph } from '../types/types'; import { NonNullableGraph } from '../types/types';
@ -12,8 +12,16 @@ export const addControlNetToLinearGraph = (
): void => { ): void => {
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet; const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
const validControlNets = filter(
controlNets,
(c) =>
c.isEnabled &&
(Boolean(c.processedControlImage) ||
(c.processorType === 'none' && Boolean(c.controlImage)))
);
// Add ControlNet // Add ControlNet
if (isControlNetEnabled) { if (isControlNetEnabled && validControlNets.length > 0) {
if (size(controlNets) > 1) { if (size(controlNets) > 1) {
const controlNetIterateNode: CollectInvocation = { const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT, id: CONTROL_NET_COLLECT,

View File

@ -3,10 +3,11 @@ import { Scheduler } from 'app/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICustomSelect from 'common/components/IAICustomSelect'; import IAICustomSelect from 'common/components/IAICustomSelect';
import IAISelect from 'common/components/IAISelect';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setScheduler } from 'features/parameters/store/generationSlice'; import { setScheduler } from 'features/parameters/store/generationSlice';
import { uiSelector } from 'features/ui/store/uiSelectors'; import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const selector = createSelector( const selector = createSelector(
@ -35,24 +36,39 @@ const ParamScheduler = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const handleChange = useCallback( const handleChange = useCallback(
(v: string | null | undefined) => { (e: ChangeEvent<HTMLSelectElement>) => {
if (!v) { dispatch(setScheduler(e.target.value as Scheduler));
return;
}
dispatch(setScheduler(v as Scheduler));
}, },
[dispatch] [dispatch]
); );
// const handleChange = useCallback(
// (v: string | null | undefined) => {
// if (!v) {
// return;
// }
// dispatch(setScheduler(v as Scheduler));
// },
// [dispatch]
// );
return ( return (
<IAICustomSelect <IAISelect
label={t('parameters.scheduler')} label={t('parameters.scheduler')}
value={scheduler} value={scheduler}
data={allSchedulers} validValues={allSchedulers}
onChange={handleChange} onChange={handleChange}
withCheckIcon
/> />
); );
// return (
// <IAICustomSelect
// label={t('parameters.scheduler')}
// value={scheduler}
// data={allSchedulers}
// onChange={handleChange}
// withCheckIcon
// />
// );
}; };
export default memo(ParamScheduler); export default memo(ParamScheduler);

View File

@ -1,5 +1,5 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -11,6 +11,7 @@ import { generationSelector } from 'features/parameters/store/generationSelector
import IAICustomSelect, { import IAICustomSelect, {
IAICustomSelectOption, IAICustomSelectOption,
} from 'common/components/IAICustomSelect'; } from 'common/components/IAICustomSelect';
import IAISelect from 'common/components/IAISelect';
const selector = createSelector( const selector = createSelector(
[(state: RootState) => state, generationSelector], [(state: RootState) => state, generationSelector],
@ -18,12 +19,18 @@ const selector = createSelector(
const selectedModel = selectModelsById(state, generation.model); const selectedModel = selectModelsById(state, generation.model);
const modelData = selectModelsAll(state) const modelData = selectModelsAll(state)
.map<IAICustomSelectOption>((m) => ({ .map((m) => ({
value: m.name, value: m.name,
label: m.name, key: m.name,
tooltip: m.description,
})) }))
.sort((a, b) => a.label.localeCompare(b.label)); .sort((a, b) => a.key.localeCompare(b.key));
// const modelData = selectModelsAll(state)
// .map<IAICustomSelectOption>((m) => ({
// value: m.name,
// label: m.name,
// tooltip: m.description,
// }))
// .sort((a, b) => a.label.localeCompare(b.label));
return { return {
selectedModel, selectedModel,
modelData, modelData,
@ -41,26 +48,43 @@ const ModelSelect = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { selectedModel, modelData } = useAppSelector(selector); const { selectedModel, modelData } = useAppSelector(selector);
const handleChangeModel = useCallback( const handleChangeModel = useCallback(
(v: string | null | undefined) => { (e: ChangeEvent<HTMLSelectElement>) => {
if (!v) { dispatch(modelSelected(e.target.value));
return;
}
dispatch(modelSelected(v));
}, },
[dispatch] [dispatch]
); );
// const handleChangeModel = useCallback(
// (v: string | null | undefined) => {
// if (!v) {
// return;
// }
// dispatch(modelSelected(v));
// },
// [dispatch]
// );
return ( return (
<IAICustomSelect <IAISelect
label={t('modelManager.model')} label={t('modelManager.model')}
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
data={modelData} validValues={modelData}
value={selectedModel?.name ?? ''} value={selectedModel?.name ?? ''}
onChange={handleChangeModel} onChange={handleChangeModel}
withCheckIcon={true}
tooltipProps={{ placement: 'top', hasArrow: true }} tooltipProps={{ placement: 'top', hasArrow: true }}
/> />
); );
// return (
// <IAICustomSelect
// label={t('modelManager.model')}
// tooltip={selectedModel?.description}
// data={modelData}
// value={selectedModel?.name ?? ''}
// onChange={handleChangeModel}
// withCheckIcon={true}
// tooltipProps={{ placement: 'top', hasArrow: true }}
// />
// );
}; };
export default memo(ModelSelect); export default memo(ModelSelect);

View File

@ -10,6 +10,8 @@ export const initialConfigState: AppConfig = {
disabledSDFeatures: [], disabledSDFeatures: [],
canRestoreDeletedImagesFromBin: true, canRestoreDeletedImagesFromBin: true,
sd: { sd: {
disabledControlNetModels: [],
disabledControlNetProcessors: [],
iterations: { iterations: {
initial: 1, initial: 1,
min: 1, min: 1,

View File

@ -47,3 +47,6 @@ export const languageSelector = createSelector(
(system) => system.language, (system) => system.language,
defaultSelectorOptions defaultSelectorOptions
); );
export const isProcessingSelector = (state: RootState) =>
state.system.isProcessing;

View File

@ -7,30 +7,26 @@ import type { ImageField } from './ImageField';
/** /**
* Applies HED edge detection to image * Applies HED edge detection to image
*/ */
export type HedImageProcessorInvocation = { export type HedImageprocessorInvocation = {
/** /**
* The id of this node. Must be unique among all nodes. * The id of this node. Must be unique among all nodes.
*/ */
id: string; id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'hed_image_processor'; type?: 'hed_image_processor';
/** /**
* The image to process * image to process
*/ */
image?: ImageField; image?: ImageField;
/** /**
* The pixel resolution for detection * pixel resolution for edge detection
*/ */
detect_resolution?: number; detect_resolution?: number;
/** /**
* The pixel resolution for the output image * pixel resolution for output image
*/ */
image_resolution?: number; image_resolution?: number;
/** /**
* Whether to use scribble mode * whether to use scribble mode
*/ */
scribble?: boolean; scribble?: boolean;
}; };

View File

@ -0,0 +1,33 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
/**
* Applies HED edge detection to image
*/
export type HedImageprocessorInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
type?: 'hed_image_processor';
/**
* image to process
*/
image?: ImageField;
/**
* pixel resolution for edge detection
*/
detect_resolution?: number;
/**
* pixel resolution for output image
*/
image_resolution?: number;
/**
* whether to use scribble mode
*/
scribble?: boolean;
};

View File

@ -30,7 +30,7 @@ const invokeAIMark = defineStyle((_props) => {
return { return {
fontSize: 'xs', fontSize: 'xs',
fontWeight: '500', fontWeight: '500',
color: 'base.200', color: 'base.400',
mt: 2, mt: 2,
insetInlineStart: 'unset', insetInlineStart: 'unset',
}; };

View File

@ -42,8 +42,9 @@ dependencies = [
"controlnet-aux>=0.0.4", "controlnet-aux>=0.0.4",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"datasets", "datasets",
"diffusers[torch]~=0.16.1", "diffusers[torch]~=0.17.0",
"dnspython==2.2.1", "dnspython==2.2.1",
"easing-functions",
"einops", "einops",
"eventlet", "eventlet",
"facexlib", "facexlib",
@ -56,6 +57,7 @@ dependencies = [
"flaskwebgui==1.0.3", "flaskwebgui==1.0.3",
"gfpgan==1.3.8", "gfpgan==1.3.8",
"huggingface-hub>=0.11.1", "huggingface-hub>=0.11.1",
"matplotlib", # needed for plotting of Penner easing functions
"mediapipe", # needed for "mediapipeface" controlnet model "mediapipe", # needed for "mediapipeface" controlnet model
"npyscreen", "npyscreen",
"numpy<1.24", "numpy<1.24",