mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/new-model-manager
This commit is contained in:
commit
c9ae26a176
@ -1,11 +1,12 @@
|
||||
# InvokeAI nodes for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import float
|
||||
|
||||
import numpy as np
|
||||
from typing import Literal, Optional, Union, List
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from ..models.image import ImageField, ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
@ -14,6 +15,7 @@ from .baseinvocation import (
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
from controlnet_aux import (
|
||||
CannyDetector,
|
||||
HEDdetector,
|
||||
@ -96,15 +98,32 @@ CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
|
||||
control_weight: Optional[float] = Field(default=1, description="The weight given to the ControlNet")
|
||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
|
||||
@validator("control_weight")
|
||||
def abs_le_one(cls, v):
|
||||
"""validate that all abs(values) are <=1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if abs(i) > 1:
|
||||
raise ValueError('all abs(control_weight) must be <= 1')
|
||||
else:
|
||||
if abs(v) > 1:
|
||||
raise ValueError('abs(control_weight) must be <= 1')
|
||||
return v
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"]
|
||||
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
"control_weight": "float",
|
||||
# "control_weight": "number",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -112,7 +131,7 @@ class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
# fmt: off
|
||||
type: Literal["control_output"] = "control_output"
|
||||
control: ControlField = Field(default=None, description="The output control image")
|
||||
control: ControlField = Field(default=None, description="The control info")
|
||||
# fmt: on
|
||||
|
||||
|
||||
@ -123,15 +142,28 @@ class ControlNetInvocation(BaseInvocation):
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
||||
description="The ControlNet model to use")
|
||||
control_weight: float = Field(default=1.0, ge=0, le=1, description="The weight given to the ControlNet")
|
||||
description="control model used")
|
||||
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
||||
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
# fmt: on
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number",
|
||||
"control_weight": "float",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
|
||||
@ -161,7 +193,6 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
return image
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
|
||||
raw_image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
|
@ -3,8 +3,6 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import torch
|
||||
from diffusers import ControlNetModel
|
||||
@ -173,23 +171,36 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
# fmt: on
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError('cfg_scale must be greater than 1')
|
||||
return v
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
# "cfg_scale": "float",
|
||||
"cfg_scale": "number"
|
||||
}
|
||||
},
|
||||
}
|
||||
@ -210,10 +221,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
self.cfg_scale,
|
||||
extra_conditioning_info,
|
||||
unconditioned_embeddings=uc,
|
||||
text_embeddings=c,
|
||||
guidance_scale=self.cfg_scale,
|
||||
extra=extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0,#threshold,
|
||||
warmup=0.2,#warmup,
|
||||
@ -351,10 +362,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
|
||||
print("type of control input: ", type(self.control))
|
||||
control_data = self.prep_control_data(model=pipeline, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,)
|
||||
|
||||
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
|
||||
# TODO: Verify the noise is the right size
|
||||
@ -364,7 +375,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
@ -391,6 +402,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
"control": "control",
|
||||
"cfg_scale": "number",
|
||||
}
|
||||
},
|
||||
}
|
||||
@ -421,6 +433,12 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler)
|
||||
|
||||
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
@ -442,6 +460,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=control_data, # list[ControlNetData]
|
||||
callback=step_callback
|
||||
)
|
||||
|
||||
|
237
invokeai/app/invocations/param_easing.py
Normal file
237
invokeai/app/invocations/param_easing.py
Normal file
@ -0,0 +1,237 @@
|
||||
import io
|
||||
from typing import Literal, Optional, Any
|
||||
|
||||
# from PIL.Image import Image
|
||||
import PIL.Image
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from easing_functions import (
|
||||
LinearInOut,
|
||||
QuadEaseInOut, QuadEaseIn, QuadEaseOut,
|
||||
CubicEaseInOut, CubicEaseIn, CubicEaseOut,
|
||||
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
|
||||
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
|
||||
SineEaseInOut, SineEaseIn, SineEaseOut,
|
||||
CircularEaseIn, CircularEaseInOut, CircularEaseOut,
|
||||
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
|
||||
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
|
||||
BackEaseIn, BackEaseInOut, BackEaseOut,
|
||||
BounceEaseIn, BounceEaseInOut, BounceEaseOut)
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
from ...backend.util.logging import InvokeAILogger
|
||||
from .collections import FloatCollectionOutput
|
||||
|
||||
|
||||
class FloatLinearRangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
|
||||
type: Literal["float_range"] = "float_range"
|
||||
|
||||
# Inputs
|
||||
start: float = Field(default=5, description="The first value of the range")
|
||||
stop: float = Field(default=10, description="The last value of the range")
|
||||
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||
return FloatCollectionOutput(
|
||||
collection=param_list
|
||||
)
|
||||
|
||||
|
||||
EASING_FUNCTIONS_MAP = {
|
||||
"Linear": LinearInOut,
|
||||
"QuadIn": QuadEaseIn,
|
||||
"QuadOut": QuadEaseOut,
|
||||
"QuadInOut": QuadEaseInOut,
|
||||
"CubicIn": CubicEaseIn,
|
||||
"CubicOut": CubicEaseOut,
|
||||
"CubicInOut": CubicEaseInOut,
|
||||
"QuarticIn": QuarticEaseIn,
|
||||
"QuarticOut": QuarticEaseOut,
|
||||
"QuarticInOut": QuarticEaseInOut,
|
||||
"QuinticIn": QuinticEaseIn,
|
||||
"QuinticOut": QuinticEaseOut,
|
||||
"QuinticInOut": QuinticEaseInOut,
|
||||
"SineIn": SineEaseIn,
|
||||
"SineOut": SineEaseOut,
|
||||
"SineInOut": SineEaseInOut,
|
||||
"CircularIn": CircularEaseIn,
|
||||
"CircularOut": CircularEaseOut,
|
||||
"CircularInOut": CircularEaseInOut,
|
||||
"ExponentialIn": ExponentialEaseIn,
|
||||
"ExponentialOut": ExponentialEaseOut,
|
||||
"ExponentialInOut": ExponentialEaseInOut,
|
||||
"ElasticIn": ElasticEaseIn,
|
||||
"ElasticOut": ElasticEaseOut,
|
||||
"ElasticInOut": ElasticEaseInOut,
|
||||
"BackIn": BackEaseIn,
|
||||
"BackOut": BackEaseOut,
|
||||
"BackInOut": BackEaseInOut,
|
||||
"BounceIn": BounceEaseIn,
|
||||
"BounceOut": BounceEaseOut,
|
||||
"BounceInOut": BounceEaseInOut,
|
||||
}
|
||||
|
||||
EASING_FUNCTION_KEYS: Any = Literal[
|
||||
tuple(list(EASING_FUNCTIONS_MAP.keys()))
|
||||
]
|
||||
|
||||
|
||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
|
||||
type: Literal["step_param_easing"] = "step_param_easing"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
|
||||
num_steps: int = Field(default=20, description="number of denoising steps")
|
||||
start_value: float = Field(default=0.0, description="easing starting value")
|
||||
end_value: float = Field(default=1.0, description="easing ending value")
|
||||
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
|
||||
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
|
||||
# if None, then start_value is used prior to easing start
|
||||
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
|
||||
# if None, then end value is used prior to easing end
|
||||
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
|
||||
mirror: bool = Field(default=False, description="include mirror of easing function")
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
|
||||
show_easing_plot: bool = Field(default=False, description="show easing plot")
|
||||
# fmt: on
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
log_diagnostics = False
|
||||
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
||||
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
|
||||
start_step = int(np.round(self.num_steps * self.start_step_percent))
|
||||
# convert from end_step_percent to nearest step >= (steps * end_step_percent)
|
||||
# end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
|
||||
end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))
|
||||
|
||||
# end_step = int(np.ceil(self.num_steps * self.end_step_percent))
|
||||
num_easing_steps = end_step - start_step + 1
|
||||
|
||||
# num_presteps = max(start_step - 1, 0)
|
||||
num_presteps = start_step
|
||||
num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
|
||||
prelist = list(num_presteps * [self.pre_start_value])
|
||||
postlist = list(num_poststeps * [self.post_end_value])
|
||||
|
||||
if log_diagnostics:
|
||||
logger = InvokeAILogger.getLogger(name="StepParamEasing")
|
||||
logger.debug("start_step: " + str(start_step))
|
||||
logger.debug("end_step: " + str(end_step))
|
||||
logger.debug("num_easing_steps: " + str(num_easing_steps))
|
||||
logger.debug("num_presteps: " + str(num_presteps))
|
||||
logger.debug("num_poststeps: " + str(num_poststeps))
|
||||
logger.debug("prelist size: " + str(len(prelist)))
|
||||
logger.debug("postlist size: " + str(len(postlist)))
|
||||
logger.debug("prelist: " + str(prelist))
|
||||
logger.debug("postlist: " + str(postlist))
|
||||
|
||||
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||
if log_diagnostics:
|
||||
logger.debug("easing class: " + str(easing_class))
|
||||
easing_list = list()
|
||||
if self.mirror: # "expected" mirroring
|
||||
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||
# and create reverse copy of list to append
|
||||
# if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
|
||||
# and create reverse copy of list[1:end-1]
|
||||
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
||||
|
||||
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
|
||||
if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
|
||||
easing_function = easing_class(start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=base_easing_duration - 1)
|
||||
base_easing_vals = list()
|
||||
for step_index in range(base_easing_duration):
|
||||
easing_val = easing_function.ease(step_index)
|
||||
base_easing_vals.append(easing_val)
|
||||
if log_diagnostics:
|
||||
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
||||
if even_num_steps:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals))
|
||||
else:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
||||
if log_diagnostics:
|
||||
logger.debug("base easing vals: " + str(base_easing_vals))
|
||||
logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
||||
easing_list = base_easing_vals + mirror_easing_vals
|
||||
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
# elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me))
|
||||
# # half_ease_duration = round(num_easing_steps - 1 / 2)
|
||||
# half_ease_duration = round((num_easing_steps - 1) / 2)
|
||||
# easing_function = easing_class(start=self.start_value,
|
||||
# end=self.end_value,
|
||||
# duration=half_ease_duration,
|
||||
# )
|
||||
#
|
||||
# mirror_function = easing_class(start=self.end_value,
|
||||
# end=self.start_value,
|
||||
# duration=half_ease_duration,
|
||||
# )
|
||||
# for step_index in range(num_easing_steps):
|
||||
# if step_index <= half_ease_duration:
|
||||
# step_val = easing_function.ease(step_index)
|
||||
# else:
|
||||
# step_val = mirror_function.ease(step_index - half_ease_duration)
|
||||
# easing_list.append(step_val)
|
||||
# if log_diagnostics: logger.debug(step_index, step_val)
|
||||
#
|
||||
|
||||
else: # no mirroring (default)
|
||||
easing_function = easing_class(start=self.start_value,
|
||||
end=self.end_value,
|
||||
duration=num_easing_steps - 1)
|
||||
for step_index in range(num_easing_steps):
|
||||
step_val = easing_function.ease(step_index)
|
||||
easing_list.append(step_val)
|
||||
if log_diagnostics:
|
||||
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
||||
|
||||
if log_diagnostics:
|
||||
logger.debug("prelist size: " + str(len(prelist)))
|
||||
logger.debug("easing_list size: " + str(len(easing_list)))
|
||||
logger.debug("postlist size: " + str(len(postlist)))
|
||||
|
||||
param_list = prelist + easing_list + postlist
|
||||
|
||||
if self.show_easing_plot:
|
||||
plt.figure()
|
||||
plt.xlabel("Step")
|
||||
plt.ylabel("Param Value")
|
||||
plt.title("Per-Step Values Based On Easing: " + self.easing)
|
||||
plt.bar(range(len(param_list)), param_list)
|
||||
# plt.plot(param_list)
|
||||
ax = plt.gca()
|
||||
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format='png')
|
||||
buf.seek(0)
|
||||
im = PIL.Image.open(buf)
|
||||
im.show()
|
||||
buf.close()
|
||||
|
||||
# output array of size steps, each entry list[i] is param value for step i
|
||||
return FloatCollectionOutput(
|
||||
collection=param_list
|
||||
)
|
@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union, List
|
||||
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
|
||||
|
||||
|
||||
@ -47,7 +47,9 @@ class ImageMetadata(BaseModel):
|
||||
default=None, description="The seed used for noise generation."
|
||||
)
|
||||
"""The seed used for noise generation"""
|
||||
cfg_scale: Optional[StrictFloat] = Field(
|
||||
# cfg_scale: Optional[StrictFloat] = Field(
|
||||
# cfg_scale: Union[float, list[float]] = Field(
|
||||
cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field(
|
||||
default=None, description="The classifier-free guidance scale."
|
||||
)
|
||||
"""The classifier-free guidance scale"""
|
||||
|
@ -65,7 +65,6 @@ from typing import Optional, Union, List, get_args
|
||||
def is_union_subtype(t1, t2):
|
||||
t1_args = get_args(t1)
|
||||
t2_args = get_args(t2)
|
||||
|
||||
if not t1_args:
|
||||
# t1 is a single type
|
||||
return t1 in t2_args
|
||||
@ -86,7 +85,6 @@ def is_list_or_contains_list(t):
|
||||
for arg in t_args:
|
||||
if get_origin(arg) is list:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@ -393,7 +391,7 @@ class Graph(BaseModel):
|
||||
from_node = self.get_node(edge.source.node_id)
|
||||
to_node = self.get_node(edge.destination.node_id)
|
||||
except NodeNotFoundError:
|
||||
raise InvalidEdgeError("One or both nodes don't exist")
|
||||
raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}")
|
||||
|
||||
# Validate that an edge to this node+field doesn't already exist
|
||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||
@ -404,41 +402,41 @@ class Graph(BaseModel):
|
||||
g = self.nx_graph_flat()
|
||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
|
||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}')
|
||||
|
||||
# Validate that the field types are compatible
|
||||
if not are_connections_compatible(
|
||||
from_node, edge.source.field, to_node, edge.destination.field
|
||||
):
|
||||
raise InvalidEdgeError(f'Fields are incompatible')
|
||||
raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
|
||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
|
||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
raise InvalidEdgeError(f'Collector output type does not match collector input type')
|
||||
raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
raise InvalidEdgeError(f'Collector input type does not match collector output type')
|
||||
raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||
|
||||
|
||||
def has_node(self, node_path: str) -> bool:
|
||||
|
@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
from diffusers.pipelines.controlnet import MultiControlNetModel
|
||||
|
||||
from ..stable_diffusion import (
|
||||
ConditioningData,
|
||||
|
@ -23,7 +23,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
from diffusers.pipelines.controlnet import MultiControlNetModel
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
@ -218,7 +218,7 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
||||
class ControlNetData:
|
||||
model: ControlNetModel = Field(default=None)
|
||||
image_tensor: torch.Tensor= Field(default=None)
|
||||
weight: float = Field(default=1.0)
|
||||
weight: Union[float, List[float]]= Field(default=1.0)
|
||||
begin_step_percent: float = Field(default=0.0)
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
|
||||
@ -226,7 +226,7 @@ class ControlNetData:
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: torch.Tensor
|
||||
text_embeddings: torch.Tensor
|
||||
guidance_scale: float
|
||||
guidance_scale: Union[float, List[float]]
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
@ -662,7 +662,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
if control_data is not None:
|
||||
if conditioning_data.guidance_scale > 1.0:
|
||||
# FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting
|
||||
# if conditioning_data.guidance_scale > 1.0:
|
||||
if conditioning_data.guidance_scale is not None:
|
||||
# expand the latents input to control model if doing classifier free guidance
|
||||
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||
# classifier_free_guidance is <= 1.0 ?)
|
||||
@ -679,13 +681,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||
if step_index >= first_control_step and step_index <= last_control_step:
|
||||
# print("running controlnet", i, "for step", step_index)
|
||||
if isinstance(control_datum.weight, list):
|
||||
# if controlnet has multiple weights, use the weight for the current step
|
||||
controlnet_weight = control_datum.weight[step_index]
|
||||
else:
|
||||
# if controlnet has a single weight, use it for all steps
|
||||
controlnet_weight = control_datum.weight
|
||||
down_samples, mid_sample = control_datum.model(
|
||||
sample=latent_control_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
||||
conditioning_data.text_embeddings]),
|
||||
controlnet_cond=control_datum.image_tensor,
|
||||
conditioning_scale=control_datum.weight,
|
||||
conditioning_scale=controlnet_weight,
|
||||
# cross_attention_kwargs,
|
||||
guess_mode=False,
|
||||
return_dict=False,
|
||||
|
@ -1,7 +1,7 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -180,7 +180,8 @@ class InvokeAIDiffuserComponent:
|
||||
sigma: torch.Tensor,
|
||||
unconditioning: Union[torch.Tensor, dict],
|
||||
conditioning: Union[torch.Tensor, dict],
|
||||
unconditional_guidance_scale: float,
|
||||
# unconditional_guidance_scale: float,
|
||||
unconditional_guidance_scale: Union[float, List[float]],
|
||||
step_index: Optional[int] = None,
|
||||
total_step_count: Optional[int] = None,
|
||||
**kwargs,
|
||||
@ -195,6 +196,11 @@ class InvokeAIDiffuserComponent:
|
||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||
"""
|
||||
|
||||
if isinstance(unconditional_guidance_scale, list):
|
||||
guidance_scale = unconditional_guidance_scale[step_index]
|
||||
else:
|
||||
guidance_scale = unconditional_guidance_scale
|
||||
|
||||
cross_attention_control_types_to_do = []
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
@ -243,7 +249,8 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
|
||||
combined_next_x = self._combine(
|
||||
unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
|
||||
# unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
|
||||
unconditioned_next_x, conditioned_next_x, guidance_scale
|
||||
)
|
||||
|
||||
return combined_next_x
|
||||
@ -497,7 +504,7 @@ class InvokeAIDiffuserComponent:
|
||||
logger.debug(
|
||||
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
|
||||
)
|
||||
logger.debug(
|
||||
logger.debug(
|
||||
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
|
||||
)
|
||||
|
||||
|
@ -1,3 +1,7 @@
|
||||
import {
|
||||
CONTROLNET_MODELS,
|
||||
CONTROLNET_PROCESSORS,
|
||||
} from 'features/controlNet/store/constants';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { O } from 'ts-toolbelt';
|
||||
|
||||
@ -117,6 +121,8 @@ export type AppConfig = {
|
||||
canRestoreDeletedImagesFromBin: boolean;
|
||||
sd: {
|
||||
defaultModel?: string;
|
||||
disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[];
|
||||
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
|
||||
iterations: {
|
||||
initial: number;
|
||||
min: number;
|
||||
|
@ -18,7 +18,7 @@ import { useSelect } from 'downshift';
|
||||
import { isString } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
|
||||
import { memo, useMemo } from 'react';
|
||||
import { memo, useLayoutEffect, useMemo } from 'react';
|
||||
import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles';
|
||||
|
||||
export type ItemTooltips = { [key: string]: string };
|
||||
@ -39,6 +39,7 @@ type IAICustomSelectProps = {
|
||||
tooltip?: string;
|
||||
tooltipProps?: Omit<TooltipProps, 'children'>;
|
||||
ellipsisPosition?: 'start' | 'end';
|
||||
isDisabled?: boolean;
|
||||
};
|
||||
|
||||
const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||
@ -52,6 +53,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||
data,
|
||||
value,
|
||||
onChange,
|
||||
isDisabled = false,
|
||||
} = props;
|
||||
|
||||
const values = useMemo(() => {
|
||||
@ -86,11 +88,17 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||
},
|
||||
});
|
||||
|
||||
const { refs, floatingStyles } = useFloating<HTMLButtonElement>({
|
||||
whileElementsMounted: autoUpdate,
|
||||
const { refs, floatingStyles, update } = useFloating<HTMLButtonElement>({
|
||||
// whileElementsMounted: autoUpdate,
|
||||
middleware: [offset(4), shift({ crossAxis: true, padding: 8 })],
|
||||
});
|
||||
|
||||
useLayoutEffect(() => {
|
||||
if (isOpen && refs.reference.current && refs.floating.current) {
|
||||
return autoUpdate(refs.reference.current, refs.floating.current, update);
|
||||
}
|
||||
}, [isOpen, update, refs.floating, refs.reference]);
|
||||
|
||||
const labelTextDirection = useMemo(() => {
|
||||
if (ellipsisPosition === 'start') {
|
||||
return document.dir === 'rtl' ? 'ltr' : 'rtl';
|
||||
@ -124,6 +132,8 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||
px: 2,
|
||||
gap: 2,
|
||||
justifyContent: 'space-between',
|
||||
pointerEvents: isDisabled ? 'none' : undefined,
|
||||
opacity: isDisabled ? 0.5 : undefined,
|
||||
...getInputOutlineStyles(),
|
||||
}}
|
||||
>
|
||||
|
@ -1,4 +1,5 @@
|
||||
import {
|
||||
ChakraProps,
|
||||
FormControl,
|
||||
FormControlProps,
|
||||
FormLabel,
|
||||
@ -39,6 +40,11 @@ import { BiReset } from 'react-icons/bi';
|
||||
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
|
||||
import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
|
||||
|
||||
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
|
||||
mt: 1.5,
|
||||
fontSize: '2xs',
|
||||
};
|
||||
|
||||
export type IAIFullSliderProps = {
|
||||
label?: string;
|
||||
value: number;
|
||||
@ -57,6 +63,7 @@ export type IAIFullSliderProps = {
|
||||
hideTooltip?: boolean;
|
||||
isCompact?: boolean;
|
||||
isDisabled?: boolean;
|
||||
sliderMarks?: number[];
|
||||
sliderFormControlProps?: FormControlProps;
|
||||
sliderFormLabelProps?: FormLabelProps;
|
||||
sliderMarkProps?: Omit<SliderMarkProps, 'value'>;
|
||||
@ -88,6 +95,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
hideTooltip = false,
|
||||
isCompact = false,
|
||||
isDisabled = false,
|
||||
sliderMarks,
|
||||
handleReset,
|
||||
sliderFormControlProps,
|
||||
sliderFormLabelProps,
|
||||
@ -198,14 +206,14 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
isDisabled={isDisabled}
|
||||
{...rest}
|
||||
>
|
||||
{withSliderMarks && (
|
||||
{withSliderMarks && !sliderMarks && (
|
||||
<>
|
||||
<SliderMark
|
||||
value={min}
|
||||
sx={{
|
||||
insetInlineStart: '0 !important',
|
||||
insetInlineEnd: 'unset !important',
|
||||
mt: 1.5,
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
{...sliderMarkProps}
|
||||
>
|
||||
@ -216,7 +224,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
sx={{
|
||||
insetInlineStart: 'unset !important',
|
||||
insetInlineEnd: '0 !important',
|
||||
mt: 1.5,
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
{...sliderMarkProps}
|
||||
>
|
||||
@ -224,6 +232,56 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
</SliderMark>
|
||||
</>
|
||||
)}
|
||||
{withSliderMarks && sliderMarks && (
|
||||
<>
|
||||
{sliderMarks.map((m, i) => {
|
||||
if (i === 0) {
|
||||
return (
|
||||
<SliderMark
|
||||
key={m}
|
||||
value={m}
|
||||
sx={{
|
||||
insetInlineStart: '0 !important',
|
||||
insetInlineEnd: 'unset !important',
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
{...sliderMarkProps}
|
||||
>
|
||||
{m}
|
||||
</SliderMark>
|
||||
);
|
||||
} else if (i === sliderMarks.length - 1) {
|
||||
return (
|
||||
<SliderMark
|
||||
key={m}
|
||||
value={m}
|
||||
sx={{
|
||||
insetInlineStart: 'unset !important',
|
||||
insetInlineEnd: '0 !important',
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
{...sliderMarkProps}
|
||||
>
|
||||
{m}
|
||||
</SliderMark>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<SliderMark
|
||||
key={m}
|
||||
value={m}
|
||||
sx={{
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
{...sliderMarkProps}
|
||||
>
|
||||
{m}
|
||||
</SliderMark>
|
||||
);
|
||||
}
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
|
||||
<SliderTrack {...sliderTrackProps}>
|
||||
<SliderFilledTrack />
|
||||
|
@ -143,7 +143,7 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
flexDir: 'column',
|
||||
gap: 2,
|
||||
w: 'full',
|
||||
h: 24,
|
||||
h: isExpanded ? 28 : 24,
|
||||
paddingInlineStart: 1,
|
||||
paddingInlineEnd: isExpanded ? 1 : 0,
|
||||
pb: 2,
|
||||
@ -153,13 +153,13 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
<ParamControlNetWeight
|
||||
controlNetId={controlNetId}
|
||||
weight={weight}
|
||||
mini
|
||||
mini={!isExpanded}
|
||||
/>
|
||||
<ParamControlNetBeginEnd
|
||||
controlNetId={controlNetId}
|
||||
beginStepPct={beginStepPct}
|
||||
endStepPct={endStepPct}
|
||||
mini
|
||||
mini={!isExpanded}
|
||||
/>
|
||||
</Flex>
|
||||
{!isExpanded && (
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
@ -11,7 +12,7 @@ type Props = {
|
||||
const ParamControlNetShouldAutoConfig = (props: Props) => {
|
||||
const { controlNetId, shouldAutoConfig } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const handleShouldAutoConfigChanged = useCallback(() => {
|
||||
dispatch(controlNetAutoConfigToggled({ controlNetId }));
|
||||
}, [controlNetId, dispatch]);
|
||||
@ -22,6 +23,7 @@ const ParamControlNetShouldAutoConfig = (props: Props) => {
|
||||
aria-label="Auto configure processor"
|
||||
isChecked={shouldAutoConfig}
|
||||
onChange={handleShouldAutoConfigChanged}
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,4 +1,5 @@
|
||||
import {
|
||||
ChakraProps,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
HStack,
|
||||
@ -10,14 +11,19 @@ import {
|
||||
Tooltip,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import {
|
||||
controlNetBeginStepPctChanged,
|
||||
controlNetEndStepPctChanged,
|
||||
} from 'features/controlNet/store/controlNetSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { BiReset } from 'react-icons/bi';
|
||||
|
||||
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
|
||||
mt: 1.5,
|
||||
fontSize: '2xs',
|
||||
fontWeight: '500',
|
||||
color: 'base.400',
|
||||
};
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
@ -29,7 +35,7 @@ type Props = {
|
||||
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
|
||||
|
||||
const ParamControlNetBeginEnd = (props: Props) => {
|
||||
const { controlNetId, beginStepPct, endStepPct, mini = false } = props;
|
||||
const { controlNetId, beginStepPct, mini = false, endStepPct } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
@ -75,12 +81,9 @@ const ParamControlNetBeginEnd = (props: Props) => {
|
||||
<RangeSliderMark
|
||||
value={0}
|
||||
sx={{
|
||||
fontSize: 'xs',
|
||||
fontWeight: '500',
|
||||
color: 'base.200',
|
||||
insetInlineStart: '0 !important',
|
||||
insetInlineEnd: 'unset !important',
|
||||
mt: 1.5,
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
>
|
||||
0%
|
||||
@ -88,10 +91,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
|
||||
<RangeSliderMark
|
||||
value={0.5}
|
||||
sx={{
|
||||
fontSize: 'xs',
|
||||
fontWeight: '500',
|
||||
color: 'base.200',
|
||||
mt: 1.5,
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
>
|
||||
50%
|
||||
@ -99,12 +99,9 @@ const ParamControlNetBeginEnd = (props: Props) => {
|
||||
<RangeSliderMark
|
||||
value={1}
|
||||
sx={{
|
||||
fontSize: 'xs',
|
||||
fontWeight: '500',
|
||||
color: 'base.200',
|
||||
insetInlineStart: 'unset !important',
|
||||
insetInlineEnd: '0 !important',
|
||||
mt: 1.5,
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
>
|
||||
100%
|
||||
@ -112,16 +109,6 @@ const ParamControlNetBeginEnd = (props: Props) => {
|
||||
</>
|
||||
)}
|
||||
</RangeSlider>
|
||||
|
||||
{!mini && (
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
aria-label={t('accessibility.reset')}
|
||||
tooltip={t('accessibility.reset')}
|
||||
icon={<BiReset />}
|
||||
onClick={handleStepPctReset}
|
||||
/>
|
||||
)}
|
||||
</HStack>
|
||||
</FormControl>
|
||||
);
|
||||
|
@ -1,50 +1,85 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAICustomSelect, {
|
||||
IAICustomSelectOption,
|
||||
} from 'common/components/IAICustomSelect';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import {
|
||||
CONTROLNET_MODELS,
|
||||
ControlNetModelName,
|
||||
} from 'features/controlNet/store/constants';
|
||||
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { configSelector } from 'features/system/store/configSelectors';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
|
||||
type ParamControlNetModelProps = {
|
||||
controlNetId: string;
|
||||
model: ControlNetModelName;
|
||||
};
|
||||
|
||||
const DATA: IAICustomSelectOption[] = map(CONTROLNET_MODELS, (m) => ({
|
||||
value: m.type,
|
||||
label: m.label,
|
||||
tooltip: m.type,
|
||||
}));
|
||||
const selector = createSelector(configSelector, (config) => {
|
||||
return map(CONTROLNET_MODELS, (m) => ({
|
||||
key: m.label,
|
||||
value: m.type,
|
||||
})).filter((d) => !config.sd.disabledControlNetModels.includes(d.value));
|
||||
});
|
||||
|
||||
// const DATA: IAICustomSelectOption[] = map(CONTROLNET_MODELS, (m) => ({
|
||||
// value: m.type,
|
||||
// label: m.label,
|
||||
// tooltip: m.type,
|
||||
// }));
|
||||
|
||||
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||
const { controlNetId, model } = props;
|
||||
const controlNetModels = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
|
||||
const handleModelChanged = useCallback(
|
||||
(val: string | null | undefined) => {
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
// TODO: do not cast
|
||||
const model = val as ControlNetModelName;
|
||||
const model = e.target.value as ControlNetModelName;
|
||||
dispatch(controlNetModelChanged({ controlNetId, model }));
|
||||
},
|
||||
[controlNetId, dispatch]
|
||||
);
|
||||
|
||||
// const handleModelChanged = useCallback(
|
||||
// (val: string | null | undefined) => {
|
||||
// // TODO: do not cast
|
||||
// const model = val as ControlNetModelName;
|
||||
// dispatch(controlNetModelChanged({ controlNetId, model }));
|
||||
// },
|
||||
// [controlNetId, dispatch]
|
||||
// );
|
||||
|
||||
return (
|
||||
<IAICustomSelect
|
||||
<IAISelect
|
||||
tooltip={model}
|
||||
tooltipProps={{ placement: 'top', hasArrow: true }}
|
||||
data={DATA}
|
||||
validValues={controlNetModels}
|
||||
value={model}
|
||||
onChange={handleModelChanged}
|
||||
ellipsisPosition="start"
|
||||
withCheckIcon
|
||||
isDisabled={!isReady}
|
||||
// ellipsisPosition="start"
|
||||
// withCheckIcon
|
||||
/>
|
||||
);
|
||||
// return (
|
||||
// <IAICustomSelect
|
||||
// tooltip={model}
|
||||
// tooltipProps={{ placement: 'top', hasArrow: true }}
|
||||
// data={DATA}
|
||||
// value={model}
|
||||
// onChange={handleModelChanged}
|
||||
// isDisabled={!isReady}
|
||||
// ellipsisPosition="start"
|
||||
// withCheckIcon
|
||||
// />
|
||||
// );
|
||||
};
|
||||
|
||||
export default memo(ParamControlNetModel);
|
||||
|
@ -1,62 +1,115 @@
|
||||
import IAICustomSelect, {
|
||||
IAICustomSelectOption,
|
||||
} from 'common/components/IAICustomSelect';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import {
|
||||
ControlNetProcessorNode,
|
||||
ControlNetProcessorType,
|
||||
} from '../../store/types';
|
||||
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { CONTROLNET_PROCESSORS } from '../../store/constants';
|
||||
import { map } from 'lodash-es';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { configSelector } from 'features/system/store/configSelectors';
|
||||
|
||||
type ParamControlNetProcessorSelectProps = {
|
||||
controlNetId: string;
|
||||
processorNode: ControlNetProcessorNode;
|
||||
};
|
||||
|
||||
const CONTROLNET_PROCESSOR_TYPES: IAICustomSelectOption[] = map(
|
||||
CONTROLNET_PROCESSORS,
|
||||
(p) => ({
|
||||
value: p.type,
|
||||
label: p.label,
|
||||
tooltip: p.description,
|
||||
})
|
||||
).sort((a, b) =>
|
||||
const CONTROLNET_PROCESSOR_TYPES = map(CONTROLNET_PROCESSORS, (p) => ({
|
||||
value: p.type,
|
||||
key: p.label,
|
||||
})).sort((a, b) =>
|
||||
// sort 'none' to the top
|
||||
a.value === 'none'
|
||||
? -1
|
||||
: b.value === 'none'
|
||||
? 1
|
||||
: a.label.localeCompare(b.label)
|
||||
a.value === 'none' ? -1 : b.value === 'none' ? 1 : a.key.localeCompare(b.key)
|
||||
);
|
||||
|
||||
const selector = createSelector(configSelector, (config) => {
|
||||
return map(CONTROLNET_PROCESSORS, (p) => ({
|
||||
value: p.type,
|
||||
key: p.label,
|
||||
}))
|
||||
.sort((a, b) =>
|
||||
// sort 'none' to the top
|
||||
a.value === 'none'
|
||||
? -1
|
||||
: b.value === 'none'
|
||||
? 1
|
||||
: a.key.localeCompare(b.key)
|
||||
)
|
||||
.filter((d) => !config.sd.disabledControlNetProcessors.includes(d.value));
|
||||
});
|
||||
|
||||
// const CONTROLNET_PROCESSOR_TYPES: IAICustomSelectOption[] = map(
|
||||
// CONTROLNET_PROCESSORS,
|
||||
// (p) => ({
|
||||
// value: p.type,
|
||||
// label: p.label,
|
||||
// tooltip: p.description,
|
||||
// })
|
||||
// ).sort((a, b) =>
|
||||
// // sort 'none' to the top
|
||||
// a.value === 'none'
|
||||
// ? -1
|
||||
// : b.value === 'none'
|
||||
// ? 1
|
||||
// : a.label.localeCompare(b.label)
|
||||
// );
|
||||
|
||||
const ParamControlNetProcessorSelect = (
|
||||
props: ParamControlNetProcessorSelectProps
|
||||
) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const controlNetProcessors = useAppSelector(selector);
|
||||
|
||||
const handleProcessorTypeChanged = useCallback(
|
||||
(v: string | null | undefined) => {
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
dispatch(
|
||||
controlNetProcessorTypeChanged({
|
||||
controlNetId,
|
||||
processorType: v as ControlNetProcessorType,
|
||||
processorType: e.target.value as ControlNetProcessorType,
|
||||
})
|
||||
);
|
||||
},
|
||||
[controlNetId, dispatch]
|
||||
);
|
||||
// const handleProcessorTypeChanged = useCallback(
|
||||
// (v: string | null | undefined) => {
|
||||
// dispatch(
|
||||
// controlNetProcessorTypeChanged({
|
||||
// controlNetId,
|
||||
// processorType: v as ControlNetProcessorType,
|
||||
// })
|
||||
// );
|
||||
// },
|
||||
// [controlNetId, dispatch]
|
||||
// );
|
||||
|
||||
return (
|
||||
<IAICustomSelect
|
||||
<IAISelect
|
||||
label="Processor"
|
||||
value={processorNode.type ?? 'canny_image_processor'}
|
||||
data={CONTROLNET_PROCESSOR_TYPES}
|
||||
validValues={controlNetProcessors}
|
||||
onChange={handleProcessorTypeChanged}
|
||||
withCheckIcon
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
);
|
||||
// return (
|
||||
// <IAICustomSelect
|
||||
// label="Processor"
|
||||
// value={processorNode.type ?? 'canny_image_processor'}
|
||||
// data={CONTROLNET_PROCESSOR_TYPES}
|
||||
// onChange={handleProcessorTypeChanged}
|
||||
// withCheckIcon
|
||||
// isDisabled={!isReady}
|
||||
// />
|
||||
// );
|
||||
};
|
||||
|
||||
export default memo(ParamControlNetProcessorSelect);
|
||||
|
@ -20,36 +20,17 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
|
||||
[controlNetId, dispatch]
|
||||
);
|
||||
|
||||
const handleWeightReset = () => {
|
||||
dispatch(controlNetWeightChanged({ controlNetId, weight: 1 }));
|
||||
};
|
||||
|
||||
if (mini) {
|
||||
return (
|
||||
<IAISlider
|
||||
label={'Weight'}
|
||||
sliderFormLabelProps={{ pb: 1 }}
|
||||
value={weight}
|
||||
onChange={handleWeightChanged}
|
||||
min={0}
|
||||
max={1}
|
||||
step={0.01}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label="Weight"
|
||||
label={'Weight'}
|
||||
sliderFormLabelProps={{ pb: 2 }}
|
||||
value={weight}
|
||||
onChange={handleWeightChanged}
|
||||
withInput
|
||||
withReset
|
||||
handleReset={handleWeightReset}
|
||||
withSliderMarks
|
||||
min={0}
|
||||
min={-1}
|
||||
max={1}
|
||||
step={0.01}
|
||||
withSliderMarks={!mini}
|
||||
sliderMarks={[-1, 0, 1]}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -4,6 +4,7 @@ import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default;
|
||||
|
||||
@ -15,6 +16,7 @@ type CannyProcessorProps = {
|
||||
const CannyProcessor = (props: CannyProcessorProps) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { low_threshold, high_threshold } = processorNode;
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
|
||||
const handleLowThresholdChanged = useCallback(
|
||||
@ -46,6 +48,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
|
||||
return (
|
||||
<ProcessorWrapper>
|
||||
<IAISlider
|
||||
isDisabled={!isReady}
|
||||
label="Low Threshold"
|
||||
value={low_threshold}
|
||||
onChange={handleLowThresholdChanged}
|
||||
@ -54,8 +57,10 @@ const CannyProcessor = (props: CannyProcessorProps) => {
|
||||
min={0}
|
||||
max={255}
|
||||
withInput
|
||||
withSliderMarks
|
||||
/>
|
||||
<IAISlider
|
||||
isDisabled={!isReady}
|
||||
label="High Threshold"
|
||||
value={high_threshold}
|
||||
onChange={handleHighThresholdChanged}
|
||||
@ -64,6 +69,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
|
||||
min={0}
|
||||
max={255}
|
||||
withInput
|
||||
withSliderMarks
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -4,6 +4,7 @@ import { RequiredContentShuffleImageProcessorInvocation } from 'features/control
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default;
|
||||
|
||||
@ -16,6 +17,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { image_resolution, detect_resolution, w, h, f } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -93,6 +95,8 @@ const ContentShuffleProcessor = (props: Props) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -103,6 +107,8 @@ const ContentShuffleProcessor = (props: Props) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISlider
|
||||
label="W"
|
||||
@ -113,6 +119,8 @@ const ContentShuffleProcessor = (props: Props) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISlider
|
||||
label="H"
|
||||
@ -123,6 +131,8 @@ const ContentShuffleProcessor = (props: Props) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISlider
|
||||
label="F"
|
||||
@ -133,6 +143,8 @@ const ContentShuffleProcessor = (props: Props) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -5,6 +5,7 @@ import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/t
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default;
|
||||
|
||||
@ -18,7 +19,7 @@ const HedPreprocessor = (props: HedProcessorProps) => {
|
||||
controlNetId,
|
||||
processorNode: { detect_resolution, image_resolution, scribble },
|
||||
} = props;
|
||||
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
@ -65,6 +66,8 @@ const HedPreprocessor = (props: HedProcessorProps) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -75,11 +78,14 @@ const HedPreprocessor = (props: HedProcessorProps) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISwitch
|
||||
label="Scribble"
|
||||
isChecked={scribble}
|
||||
onChange={handleScribbleChanged}
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -4,6 +4,7 @@ import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNe
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default;
|
||||
|
||||
@ -16,6 +17,7 @@ const LineartAnimeProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { image_resolution, detect_resolution } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -54,6 +56,8 @@ const LineartAnimeProcessor = (props: Props) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -64,6 +68,8 @@ const LineartAnimeProcessor = (props: Props) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -5,6 +5,7 @@ import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/sto
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default;
|
||||
|
||||
@ -17,6 +18,7 @@ const LineartProcessor = (props: LineartProcessorProps) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { image_resolution, detect_resolution, coarse } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -62,6 +64,8 @@ const LineartProcessor = (props: LineartProcessorProps) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -72,11 +76,14 @@ const LineartProcessor = (props: LineartProcessorProps) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISwitch
|
||||
label="Coarse"
|
||||
isChecked={coarse}
|
||||
onChange={handleCoarseChanged}
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -4,6 +4,7 @@ import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/st
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default;
|
||||
|
||||
@ -16,6 +17,7 @@ const MediapipeFaceProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { max_faces, min_confidence } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
|
||||
const handleMaxFacesChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -50,6 +52,8 @@ const MediapipeFaceProcessor = (props: Props) => {
|
||||
min={1}
|
||||
max={20}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Min Confidence"
|
||||
@ -61,6 +65,8 @@ const MediapipeFaceProcessor = (props: Props) => {
|
||||
max={1}
|
||||
step={0.01}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -4,6 +4,7 @@ import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default;
|
||||
|
||||
@ -16,6 +17,7 @@ const MidasDepthProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { a_mult, bg_th } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
|
||||
const handleAMultChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -51,6 +53,8 @@ const MidasDepthProcessor = (props: Props) => {
|
||||
max={20}
|
||||
step={0.01}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISlider
|
||||
label="bg_th"
|
||||
@ -62,6 +66,8 @@ const MidasDepthProcessor = (props: Props) => {
|
||||
max={20}
|
||||
step={0.01}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -4,6 +4,7 @@ import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default;
|
||||
|
||||
@ -16,6 +17,7 @@ const MlsdImageProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -76,6 +78,8 @@ const MlsdImageProcessor = (props: Props) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -86,6 +90,8 @@ const MlsdImageProcessor = (props: Props) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISlider
|
||||
label="W"
|
||||
@ -97,6 +103,8 @@ const MlsdImageProcessor = (props: Props) => {
|
||||
max={1}
|
||||
step={0.01}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISlider
|
||||
label="H"
|
||||
@ -108,6 +116,8 @@ const MlsdImageProcessor = (props: Props) => {
|
||||
max={1}
|
||||
step={0.01}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -4,6 +4,7 @@ import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/s
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default;
|
||||
|
||||
@ -16,6 +17,7 @@ const NormalBaeProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { image_resolution, detect_resolution } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -54,6 +56,8 @@ const NormalBaeProcessor = (props: Props) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -64,6 +68,8 @@ const NormalBaeProcessor = (props: Props) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -5,6 +5,7 @@ import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/st
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default;
|
||||
|
||||
@ -17,6 +18,7 @@ const OpenposeProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { image_resolution, detect_resolution, hand_and_face } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -62,6 +64,8 @@ const OpenposeProcessor = (props: Props) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -72,11 +76,14 @@ const OpenposeProcessor = (props: Props) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISwitch
|
||||
label="Hand and Face"
|
||||
isChecked={hand_and_face}
|
||||
onChange={handleHandAndFaceChanged}
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -5,6 +5,7 @@ import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default;
|
||||
|
||||
@ -17,6 +18,7 @@ const PidiProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { image_resolution, detect_resolution, scribble, safe } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -69,6 +71,8 @@ const PidiProcessor = (props: Props) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -79,13 +83,20 @@ const PidiProcessor = (props: Props) => {
|
||||
min={0}
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
<IAISwitch
|
||||
label="Scribble"
|
||||
isChecked={scribble}
|
||||
onChange={handleScribbleChanged}
|
||||
/>
|
||||
<IAISwitch label="Safe" isChecked={safe} onChange={handleSafeChanged} />
|
||||
<IAISwitch
|
||||
label="Safe"
|
||||
isChecked={safe}
|
||||
onChange={handleSafeChanged}
|
||||
isDisabled={!isReady}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
};
|
||||
|
@ -23,7 +23,7 @@ type ControlNetProcessorsDict = Record<
|
||||
*
|
||||
* TODO: Generate from the OpenAPI schema
|
||||
*/
|
||||
export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
|
||||
export const CONTROLNET_PROCESSORS = {
|
||||
none: {
|
||||
type: 'none',
|
||||
label: 'none',
|
||||
@ -129,7 +129,7 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
|
||||
},
|
||||
normalbae_image_processor: {
|
||||
type: 'normalbae_image_processor',
|
||||
label: 'NormalBae',
|
||||
label: 'Normal BAE',
|
||||
description: '',
|
||||
default: {
|
||||
id: 'normalbae_image_processor',
|
||||
@ -181,7 +181,7 @@ type ControlNetModel = {
|
||||
defaultProcessor?: ControlNetProcessorType;
|
||||
};
|
||||
|
||||
export const CONTROLNET_MODELS: Record<string, ControlNetModel> = {
|
||||
export const CONTROLNET_MODELS = {
|
||||
'lllyasviel/control_v11p_sd15_canny': {
|
||||
type: 'lllyasviel/control_v11p_sd15_canny',
|
||||
label: 'Canny',
|
||||
@ -208,7 +208,7 @@ export const CONTROLNET_MODELS: Record<string, ControlNetModel> = {
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_seg': {
|
||||
type: 'lllyasviel/control_v11p_sd15_seg',
|
||||
label: 'Segment Anything',
|
||||
label: 'Segmentation',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_lineart': {
|
||||
type: 'lllyasviel/control_v11p_sd15_lineart',
|
||||
|
@ -21,6 +21,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
||||
ColorField: 'color',
|
||||
ControlField: 'control',
|
||||
control: 'control',
|
||||
cfg_scale: 'float',
|
||||
control_weight: 'float',
|
||||
};
|
||||
|
||||
const COLOR_TOKEN_VALUE = 500;
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { forEach, size } from 'lodash-es';
|
||||
import { filter, forEach, size } from 'lodash-es';
|
||||
import { CollectInvocation, ControlNetInvocation } from 'services/api';
|
||||
import { NonNullableGraph } from '../types/types';
|
||||
|
||||
@ -12,8 +12,16 @@ export const addControlNetToLinearGraph = (
|
||||
): void => {
|
||||
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
|
||||
|
||||
const validControlNets = filter(
|
||||
controlNets,
|
||||
(c) =>
|
||||
c.isEnabled &&
|
||||
(Boolean(c.processedControlImage) ||
|
||||
(c.processorType === 'none' && Boolean(c.controlImage)))
|
||||
);
|
||||
|
||||
// Add ControlNet
|
||||
if (isControlNetEnabled) {
|
||||
if (isControlNetEnabled && validControlNets.length > 0) {
|
||||
if (size(controlNets) > 1) {
|
||||
const controlNetIterateNode: CollectInvocation = {
|
||||
id: CONTROL_NET_COLLECT,
|
||||
|
@ -3,10 +3,11 @@ import { Scheduler } from 'app/constants';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAICustomSelect from 'common/components/IAICustomSelect';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { setScheduler } from 'features/parameters/store/generationSlice';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createSelector(
|
||||
@ -35,24 +36,39 @@ const ParamScheduler = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: string | null | undefined) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
dispatch(setScheduler(v as Scheduler));
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
dispatch(setScheduler(e.target.value as Scheduler));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
// const handleChange = useCallback(
|
||||
// (v: string | null | undefined) => {
|
||||
// if (!v) {
|
||||
// return;
|
||||
// }
|
||||
// dispatch(setScheduler(v as Scheduler));
|
||||
// },
|
||||
// [dispatch]
|
||||
// );
|
||||
|
||||
return (
|
||||
<IAICustomSelect
|
||||
<IAISelect
|
||||
label={t('parameters.scheduler')}
|
||||
value={scheduler}
|
||||
data={allSchedulers}
|
||||
validValues={allSchedulers}
|
||||
onChange={handleChange}
|
||||
withCheckIcon
|
||||
/>
|
||||
);
|
||||
|
||||
// return (
|
||||
// <IAICustomSelect
|
||||
// label={t('parameters.scheduler')}
|
||||
// value={scheduler}
|
||||
// data={allSchedulers}
|
||||
// onChange={handleChange}
|
||||
// withCheckIcon
|
||||
// />
|
||||
// );
|
||||
};
|
||||
|
||||
export default memo(ParamScheduler);
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@ -11,6 +11,7 @@ import { generationSelector } from 'features/parameters/store/generationSelector
|
||||
import IAICustomSelect, {
|
||||
IAICustomSelectOption,
|
||||
} from 'common/components/IAICustomSelect';
|
||||
import IAISelect from 'common/components/IAISelect';
|
||||
|
||||
const selector = createSelector(
|
||||
[(state: RootState) => state, generationSelector],
|
||||
@ -18,12 +19,18 @@ const selector = createSelector(
|
||||
const selectedModel = selectModelsById(state, generation.model);
|
||||
|
||||
const modelData = selectModelsAll(state)
|
||||
.map<IAICustomSelectOption>((m) => ({
|
||||
.map((m) => ({
|
||||
value: m.name,
|
||||
label: m.name,
|
||||
tooltip: m.description,
|
||||
key: m.name,
|
||||
}))
|
||||
.sort((a, b) => a.label.localeCompare(b.label));
|
||||
.sort((a, b) => a.key.localeCompare(b.key));
|
||||
// const modelData = selectModelsAll(state)
|
||||
// .map<IAICustomSelectOption>((m) => ({
|
||||
// value: m.name,
|
||||
// label: m.name,
|
||||
// tooltip: m.description,
|
||||
// }))
|
||||
// .sort((a, b) => a.label.localeCompare(b.label));
|
||||
return {
|
||||
selectedModel,
|
||||
modelData,
|
||||
@ -41,26 +48,43 @@ const ModelSelect = () => {
|
||||
const { t } = useTranslation();
|
||||
const { selectedModel, modelData } = useAppSelector(selector);
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null | undefined) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
dispatch(modelSelected(v));
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
dispatch(modelSelected(e.target.value));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
// const handleChangeModel = useCallback(
|
||||
// (v: string | null | undefined) => {
|
||||
// if (!v) {
|
||||
// return;
|
||||
// }
|
||||
// dispatch(modelSelected(v));
|
||||
// },
|
||||
// [dispatch]
|
||||
// );
|
||||
|
||||
return (
|
||||
<IAICustomSelect
|
||||
<IAISelect
|
||||
label={t('modelManager.model')}
|
||||
tooltip={selectedModel?.description}
|
||||
data={modelData}
|
||||
validValues={modelData}
|
||||
value={selectedModel?.name ?? ''}
|
||||
onChange={handleChangeModel}
|
||||
withCheckIcon={true}
|
||||
tooltipProps={{ placement: 'top', hasArrow: true }}
|
||||
/>
|
||||
);
|
||||
|
||||
// return (
|
||||
// <IAICustomSelect
|
||||
// label={t('modelManager.model')}
|
||||
// tooltip={selectedModel?.description}
|
||||
// data={modelData}
|
||||
// value={selectedModel?.name ?? ''}
|
||||
// onChange={handleChangeModel}
|
||||
// withCheckIcon={true}
|
||||
// tooltipProps={{ placement: 'top', hasArrow: true }}
|
||||
// />
|
||||
// );
|
||||
};
|
||||
|
||||
export default memo(ModelSelect);
|
||||
|
@ -10,6 +10,8 @@ export const initialConfigState: AppConfig = {
|
||||
disabledSDFeatures: [],
|
||||
canRestoreDeletedImagesFromBin: true,
|
||||
sd: {
|
||||
disabledControlNetModels: [],
|
||||
disabledControlNetProcessors: [],
|
||||
iterations: {
|
||||
initial: 1,
|
||||
min: 1,
|
||||
|
@ -47,3 +47,6 @@ export const languageSelector = createSelector(
|
||||
(system) => system.language,
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
export const isProcessingSelector = (state: RootState) =>
|
||||
state.system.isProcessing;
|
||||
|
@ -7,30 +7,26 @@ import type { ImageField } from './ImageField';
|
||||
/**
|
||||
* Applies HED edge detection to image
|
||||
*/
|
||||
export type HedImageProcessorInvocation = {
|
||||
export type HedImageprocessorInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Whether or not this node is an intermediate node.
|
||||
*/
|
||||
is_intermediate?: boolean;
|
||||
type?: 'hed_image_processor';
|
||||
/**
|
||||
* The image to process
|
||||
* image to process
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* The pixel resolution for detection
|
||||
* pixel resolution for edge detection
|
||||
*/
|
||||
detect_resolution?: number;
|
||||
/**
|
||||
* The pixel resolution for the output image
|
||||
* pixel resolution for output image
|
||||
*/
|
||||
image_resolution?: number;
|
||||
/**
|
||||
* Whether to use scribble mode
|
||||
* whether to use scribble mode
|
||||
*/
|
||||
scribble?: boolean;
|
||||
};
|
||||
|
@ -0,0 +1,33 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { ImageField } from './ImageField';
|
||||
|
||||
/**
|
||||
* Applies HED edge detection to image
|
||||
*/
|
||||
export type HedImageprocessorInvocation = {
|
||||
/**
|
||||
* The id of this node. Must be unique among all nodes.
|
||||
*/
|
||||
id: string;
|
||||
type?: 'hed_image_processor';
|
||||
/**
|
||||
* image to process
|
||||
*/
|
||||
image?: ImageField;
|
||||
/**
|
||||
* pixel resolution for edge detection
|
||||
*/
|
||||
detect_resolution?: number;
|
||||
/**
|
||||
* pixel resolution for output image
|
||||
*/
|
||||
image_resolution?: number;
|
||||
/**
|
||||
* whether to use scribble mode
|
||||
*/
|
||||
scribble?: boolean;
|
||||
};
|
||||
|
@ -30,7 +30,7 @@ const invokeAIMark = defineStyle((_props) => {
|
||||
return {
|
||||
fontSize: 'xs',
|
||||
fontWeight: '500',
|
||||
color: 'base.200',
|
||||
color: 'base.400',
|
||||
mt: 2,
|
||||
insetInlineStart: 'unset',
|
||||
};
|
||||
|
@ -42,8 +42,9 @@ dependencies = [
|
||||
"controlnet-aux>=0.0.4",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"datasets",
|
||||
"diffusers[torch]~=0.16.1",
|
||||
"diffusers[torch]~=0.17.0",
|
||||
"dnspython==2.2.1",
|
||||
"easing-functions",
|
||||
"einops",
|
||||
"eventlet",
|
||||
"facexlib",
|
||||
@ -56,6 +57,7 @@ dependencies = [
|
||||
"flaskwebgui==1.0.3",
|
||||
"gfpgan==1.3.8",
|
||||
"huggingface-hub>=0.11.1",
|
||||
"matplotlib", # needed for plotting of Penner easing functions
|
||||
"mediapipe", # needed for "mediapipeface" controlnet model
|
||||
"npyscreen",
|
||||
"numpy<1.24",
|
||||
|
Loading…
Reference in New Issue
Block a user