Merge branch 'main' into lstein/new-model-manager

This commit is contained in:
StAlKeR7779 2023-06-13 23:37:52 +03:00 committed by GitHub
commit c9ae26a176
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 784 additions and 181 deletions

View File

@ -1,11 +1,12 @@
# InvokeAI nodes for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import float
import numpy as np
from typing import Literal, Optional, Union, List
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator
from ..models.image import ImageField, ImageCategory, ResourceOrigin
from .baseinvocation import (
@ -14,6 +15,7 @@ from .baseinvocation import (
InvocationContext,
InvocationConfig,
)
from controlnet_aux import (
CannyDetector,
HEDdetector,
@ -96,15 +98,32 @@ CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image")
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
control_weight: Optional[float] = Field(default=1, description="The weight given to the ControlNet")
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1,
description="When the ControlNet is first applied (% of total steps)")
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
@validator("control_weight")
def abs_le_one(cls, v):
"""validate that all abs(values) are <=1"""
if isinstance(v, list):
for i in v:
if abs(i) > 1:
raise ValueError('all abs(control_weight) must be <= 1')
else:
if abs(v) > 1:
raise ValueError('abs(control_weight) must be <= 1')
return v
class Config:
schema_extra = {
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"]
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
"ui": {
"type_hints": {
"control_weight": "float",
# "control_weight": "number",
}
}
}
@ -112,7 +131,7 @@ class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info"""
# fmt: off
type: Literal["control_output"] = "control_output"
control: ControlField = Field(default=None, description="The output control image")
control: ControlField = Field(default=None, description="The control info")
# fmt: on
@ -123,15 +142,28 @@ class ControlNetInvocation(BaseInvocation):
# Inputs
image: ImageField = Field(default=None, description="The control image")
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
description="The ControlNet model to use")
control_weight: float = Field(default=1.0, ge=0, le=1, description="The weight given to the ControlNet")
description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
begin_step_percent: float = Field(default=0, ge=0, le=1,
description="When the ControlNet is first applied (% of total steps)")
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
description="When the ControlNet is last applied (% of total steps)")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number",
"control_weight": "float",
}
},
}
def invoke(self, context: InvocationContext) -> ControlOutput:
@ -161,7 +193,6 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
return image
def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)

View File

@ -3,8 +3,6 @@
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import einops
from pydantic import BaseModel, Field, validator
import torch
from diffusers import ControlNetModel
@ -173,23 +171,36 @@ class TextToLatentsInvocation(BaseInvocation):
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
unet: UNetField = Field(default=None, description="UNet submodel")
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator("cfg_scale")
def ge_one(cls, v):
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError('cfg_scale must be greater than 1')
else:
if v < 1:
raise ValueError('cfg_scale must be greater than 1')
return v
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
"tags": ["latents"],
"type_hints": {
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number"
}
},
}
@ -210,10 +221,10 @@ class TextToLatentsInvocation(BaseInvocation):
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
conditioning_data = ConditioningData(
uc,
c,
self.cfg_scale,
extra_conditioning_info,
unconditioned_embeddings=uc,
text_embeddings=c,
guidance_scale=self.cfg_scale,
extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=0.0,#threshold,
warmup=0.2,#warmup,
@ -351,10 +362,10 @@ class TextToLatentsInvocation(BaseInvocation):
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
print("type of control input: ", type(self.control))
control_data = self.prep_control_data(model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape,
do_classifier_free_guidance=(self.cfg_scale >= 1.0))
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,)
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
# TODO: Verify the noise is the right size
@ -364,7 +375,7 @@ class TextToLatentsInvocation(BaseInvocation):
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback
callback=step_callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
@ -391,6 +402,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
"type_hints": {
"model": "model",
"control": "control",
"cfg_scale": "number",
}
},
}
@ -421,6 +433,12 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
@ -442,6 +460,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback
)

View File

@ -0,0 +1,237 @@
import io
from typing import Literal, Optional, Any
# from PIL.Image import Image
import PIL.Image
from matplotlib.ticker import MaxNLocator
from matplotlib.figure import Figure
from pydantic import BaseModel, Field
import numpy as np
import matplotlib.pyplot as plt
from easing_functions import (
LinearInOut,
QuadEaseInOut, QuadEaseIn, QuadEaseOut,
CubicEaseInOut, CubicEaseIn, CubicEaseOut,
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
SineEaseInOut, SineEaseIn, SineEaseOut,
CircularEaseIn, CircularEaseInOut, CircularEaseOut,
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
BackEaseIn, BackEaseInOut, BackEaseOut,
BounceEaseIn, BounceEaseInOut, BounceEaseOut)
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from ...backend.util.logging import InvokeAILogger
from .collections import FloatCollectionOutput
class FloatLinearRangeInvocation(BaseInvocation):
"""Creates a range"""
type: Literal["float_range"] = "float_range"
# Inputs
start: float = Field(default=5, description="The first value of the range")
stop: float = Field(default=10, description="The last value of the range")
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
param_list = list(np.linspace(self.start, self.stop, self.steps))
return FloatCollectionOutput(
collection=param_list
)
EASING_FUNCTIONS_MAP = {
"Linear": LinearInOut,
"QuadIn": QuadEaseIn,
"QuadOut": QuadEaseOut,
"QuadInOut": QuadEaseInOut,
"CubicIn": CubicEaseIn,
"CubicOut": CubicEaseOut,
"CubicInOut": CubicEaseInOut,
"QuarticIn": QuarticEaseIn,
"QuarticOut": QuarticEaseOut,
"QuarticInOut": QuarticEaseInOut,
"QuinticIn": QuinticEaseIn,
"QuinticOut": QuinticEaseOut,
"QuinticInOut": QuinticEaseInOut,
"SineIn": SineEaseIn,
"SineOut": SineEaseOut,
"SineInOut": SineEaseInOut,
"CircularIn": CircularEaseIn,
"CircularOut": CircularEaseOut,
"CircularInOut": CircularEaseInOut,
"ExponentialIn": ExponentialEaseIn,
"ExponentialOut": ExponentialEaseOut,
"ExponentialInOut": ExponentialEaseInOut,
"ElasticIn": ElasticEaseIn,
"ElasticOut": ElasticEaseOut,
"ElasticInOut": ElasticEaseInOut,
"BackIn": BackEaseIn,
"BackOut": BackEaseOut,
"BackInOut": BackEaseInOut,
"BounceIn": BounceEaseIn,
"BounceOut": BounceEaseOut,
"BounceInOut": BounceEaseInOut,
}
EASING_FUNCTION_KEYS: Any = Literal[
tuple(list(EASING_FUNCTIONS_MAP.keys()))
]
# actually I think for now could just use CollectionOutput (which is list[Any]
class StepParamEasingInvocation(BaseInvocation):
"""Experimental per-step parameter easing for denoising steps"""
type: Literal["step_param_easing"] = "step_param_easing"
# Inputs
# fmt: off
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
num_steps: int = Field(default=20, description="number of denoising steps")
start_value: float = Field(default=0.0, description="easing starting value")
end_value: float = Field(default=1.0, description="easing ending value")
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
# if None, then start_value is used prior to easing start
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
# if None, then end value is used prior to easing end
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
mirror: bool = Field(default=False, description="include mirror of easing function")
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
show_easing_plot: bool = Field(default=False, description="show easing plot")
# fmt: on
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
log_diagnostics = False
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
start_step = int(np.round(self.num_steps * self.start_step_percent))
# convert from end_step_percent to nearest step >= (steps * end_step_percent)
# end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))
# end_step = int(np.ceil(self.num_steps * self.end_step_percent))
num_easing_steps = end_step - start_step + 1
# num_presteps = max(start_step - 1, 0)
num_presteps = start_step
num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
prelist = list(num_presteps * [self.pre_start_value])
postlist = list(num_poststeps * [self.post_end_value])
if log_diagnostics:
logger = InvokeAILogger.getLogger(name="StepParamEasing")
logger.debug("start_step: " + str(start_step))
logger.debug("end_step: " + str(end_step))
logger.debug("num_easing_steps: " + str(num_easing_steps))
logger.debug("num_presteps: " + str(num_presteps))
logger.debug("num_poststeps: " + str(num_poststeps))
logger.debug("prelist size: " + str(len(prelist)))
logger.debug("postlist size: " + str(len(postlist)))
logger.debug("prelist: " + str(prelist))
logger.debug("postlist: " + str(postlist))
easing_class = EASING_FUNCTIONS_MAP[self.easing]
if log_diagnostics:
logger.debug("easing class: " + str(easing_class))
easing_list = list()
if self.mirror: # "expected" mirroring
# if number of steps is even, squeeze duration down to (number_of_steps)/2
# and create reverse copy of list to append
# if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
# and create reverse copy of list[1:end-1]
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration))
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
easing_function = easing_class(start=self.start_value,
end=self.end_value,
duration=base_easing_duration - 1)
base_easing_vals = list()
for step_index in range(base_easing_duration):
easing_val = easing_function.ease(step_index)
base_easing_vals.append(easing_val)
if log_diagnostics:
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
if even_num_steps:
mirror_easing_vals = list(reversed(base_easing_vals))
else:
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
if log_diagnostics:
logger.debug("base easing vals: " + str(base_easing_vals))
logger.debug("mirror easing vals: " + str(mirror_easing_vals))
easing_list = base_easing_vals + mirror_easing_vals
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
# elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me))
# # half_ease_duration = round(num_easing_steps - 1 / 2)
# half_ease_duration = round((num_easing_steps - 1) / 2)
# easing_function = easing_class(start=self.start_value,
# end=self.end_value,
# duration=half_ease_duration,
# )
#
# mirror_function = easing_class(start=self.end_value,
# end=self.start_value,
# duration=half_ease_duration,
# )
# for step_index in range(num_easing_steps):
# if step_index <= half_ease_duration:
# step_val = easing_function.ease(step_index)
# else:
# step_val = mirror_function.ease(step_index - half_ease_duration)
# easing_list.append(step_val)
# if log_diagnostics: logger.debug(step_index, step_val)
#
else: # no mirroring (default)
easing_function = easing_class(start=self.start_value,
end=self.end_value,
duration=num_easing_steps - 1)
for step_index in range(num_easing_steps):
step_val = easing_function.ease(step_index)
easing_list.append(step_val)
if log_diagnostics:
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
if log_diagnostics:
logger.debug("prelist size: " + str(len(prelist)))
logger.debug("easing_list size: " + str(len(easing_list)))
logger.debug("postlist size: " + str(len(postlist)))
param_list = prelist + easing_list + postlist
if self.show_easing_plot:
plt.figure()
plt.xlabel("Step")
plt.ylabel("Param Value")
plt.title("Per-Step Values Based On Easing: " + self.easing)
plt.bar(range(len(param_list)), param_list)
# plt.plot(param_list)
ax = plt.gca()
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
im = PIL.Image.open(buf)
im.show()
buf.close()
# output array of size steps, each entry list[i] is param value for step i
return FloatCollectionOutput(
collection=param_list
)

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union, List
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
@ -47,7 +47,9 @@ class ImageMetadata(BaseModel):
default=None, description="The seed used for noise generation."
)
"""The seed used for noise generation"""
cfg_scale: Optional[StrictFloat] = Field(
# cfg_scale: Optional[StrictFloat] = Field(
# cfg_scale: Union[float, list[float]] = Field(
cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field(
default=None, description="The classifier-free guidance scale."
)
"""The classifier-free guidance scale"""

View File

@ -65,7 +65,6 @@ from typing import Optional, Union, List, get_args
def is_union_subtype(t1, t2):
t1_args = get_args(t1)
t2_args = get_args(t2)
if not t1_args:
# t1 is a single type
return t1 in t2_args
@ -86,7 +85,6 @@ def is_list_or_contains_list(t):
for arg in t_args:
if get_origin(arg) is list:
return True
return False
@ -393,7 +391,7 @@ class Graph(BaseModel):
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError:
raise InvalidEdgeError("One or both nodes don't exist")
raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}")
# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
@ -404,41 +402,41 @@ class Graph(BaseModel):
g = self.nx_graph_flat()
g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g):
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}')
# Validate that the field types are compatible
if not are_connections_compatible(
from_node, edge.source.field, to_node, edge.destination.field
):
raise InvalidEdgeError(f'Fields are incompatible')
raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if not self._is_iterator_connection_valid(
edge.destination.node_id, new_input=edge.source
):
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if not self._is_iterator_connection_valid(
edge.source.node_id, new_output=edge.destination
):
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if not self._is_collector_connection_valid(
edge.destination.node_id, new_input=edge.source
):
raise InvalidEdgeError(f'Collector output type does not match collector input type')
raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if collector output type matches input type (if this edge results in both being set)
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
if not self._is_collector_connection_valid(
edge.source.node_id, new_output=edge.destination
):
raise InvalidEdgeError(f'Collector input type does not match collector output type')
raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
def has_node(self, node_path: str) -> bool:

View File

@ -6,7 +6,7 @@ import torch
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.pipelines.controlnet import MultiControlNetModel
from ..stable_diffusion import (
ConditioningData,

View File

@ -23,7 +23,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.pipelines.controlnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline,
@ -218,7 +218,7 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
class ControlNetData:
model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor= Field(default=None)
weight: float = Field(default=1.0)
weight: Union[float, List[float]]= Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
@ -226,7 +226,7 @@ class ControlNetData:
class ConditioningData:
unconditioned_embeddings: torch.Tensor
text_embeddings: torch.Tensor
guidance_scale: float
guidance_scale: Union[float, List[float]]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
@ -662,7 +662,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None:
if conditioning_data.guidance_scale > 1.0:
# FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting
# if conditioning_data.guidance_scale > 1.0:
if conditioning_data.guidance_scale is not None:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
@ -679,13 +681,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step:
# print("running controlnet", i, "for step", step_index)
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index]
else:
# if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight
down_samples, mid_sample = control_datum.model(
sample=latent_control_input,
timestep=timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]),
controlnet_cond=control_datum.image_tensor,
conditioning_scale=control_datum.weight,
conditioning_scale=controlnet_weight,
# cross_attention_kwargs,
guess_mode=False,
return_dict=False,

View File

@ -1,7 +1,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from math import ceil
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union, List
import numpy as np
import torch
@ -180,7 +180,8 @@ class InvokeAIDiffuserComponent:
sigma: torch.Tensor,
unconditioning: Union[torch.Tensor, dict],
conditioning: Union[torch.Tensor, dict],
unconditional_guidance_scale: float,
# unconditional_guidance_scale: float,
unconditional_guidance_scale: Union[float, List[float]],
step_index: Optional[int] = None,
total_step_count: Optional[int] = None,
**kwargs,
@ -195,6 +196,11 @@ class InvokeAIDiffuserComponent:
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
"""
if isinstance(unconditional_guidance_scale, list):
guidance_scale = unconditional_guidance_scale[step_index]
else:
guidance_scale = unconditional_guidance_scale
cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None:
@ -243,7 +249,8 @@ class InvokeAIDiffuserComponent:
)
combined_next_x = self._combine(
unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
# unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
unconditioned_next_x, conditioned_next_x, guidance_scale
)
return combined_next_x
@ -497,7 +504,7 @@ class InvokeAIDiffuserComponent:
logger.debug(
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
)
logger.debug(
logger.debug(
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
)

View File

@ -1,3 +1,7 @@
import {
CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
} from 'features/controlNet/store/constants';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { O } from 'ts-toolbelt';
@ -117,6 +121,8 @@ export type AppConfig = {
canRestoreDeletedImagesFromBin: boolean;
sd: {
defaultModel?: string;
disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[];
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
iterations: {
initial: number;
min: number;

View File

@ -18,7 +18,7 @@ import { useSelect } from 'downshift';
import { isString } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useMemo } from 'react';
import { memo, useLayoutEffect, useMemo } from 'react';
import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles';
export type ItemTooltips = { [key: string]: string };
@ -39,6 +39,7 @@ type IAICustomSelectProps = {
tooltip?: string;
tooltipProps?: Omit<TooltipProps, 'children'>;
ellipsisPosition?: 'start' | 'end';
isDisabled?: boolean;
};
const IAICustomSelect = (props: IAICustomSelectProps) => {
@ -52,6 +53,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
data,
value,
onChange,
isDisabled = false,
} = props;
const values = useMemo(() => {
@ -86,11 +88,17 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
},
});
const { refs, floatingStyles } = useFloating<HTMLButtonElement>({
whileElementsMounted: autoUpdate,
const { refs, floatingStyles, update } = useFloating<HTMLButtonElement>({
// whileElementsMounted: autoUpdate,
middleware: [offset(4), shift({ crossAxis: true, padding: 8 })],
});
useLayoutEffect(() => {
if (isOpen && refs.reference.current && refs.floating.current) {
return autoUpdate(refs.reference.current, refs.floating.current, update);
}
}, [isOpen, update, refs.floating, refs.reference]);
const labelTextDirection = useMemo(() => {
if (ellipsisPosition === 'start') {
return document.dir === 'rtl' ? 'ltr' : 'rtl';
@ -124,6 +132,8 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
px: 2,
gap: 2,
justifyContent: 'space-between',
pointerEvents: isDisabled ? 'none' : undefined,
opacity: isDisabled ? 0.5 : undefined,
...getInputOutlineStyles(),
}}
>

View File

@ -1,4 +1,5 @@
import {
ChakraProps,
FormControl,
FormControlProps,
FormLabel,
@ -39,6 +40,11 @@ import { BiReset } from 'react-icons/bi';
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
};
export type IAIFullSliderProps = {
label?: string;
value: number;
@ -57,6 +63,7 @@ export type IAIFullSliderProps = {
hideTooltip?: boolean;
isCompact?: boolean;
isDisabled?: boolean;
sliderMarks?: number[];
sliderFormControlProps?: FormControlProps;
sliderFormLabelProps?: FormLabelProps;
sliderMarkProps?: Omit<SliderMarkProps, 'value'>;
@ -88,6 +95,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
hideTooltip = false,
isCompact = false,
isDisabled = false,
sliderMarks,
handleReset,
sliderFormControlProps,
sliderFormLabelProps,
@ -198,14 +206,14 @@ const IAISlider = (props: IAIFullSliderProps) => {
isDisabled={isDisabled}
{...rest}
>
{withSliderMarks && (
{withSliderMarks && !sliderMarks && (
<>
<SliderMark
value={min}
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
mt: 1.5,
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
@ -216,7 +224,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
mt: 1.5,
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
@ -224,6 +232,56 @@ const IAISlider = (props: IAIFullSliderProps) => {
</SliderMark>
</>
)}
{withSliderMarks && sliderMarks && (
<>
{sliderMarks.map((m, i) => {
if (i === 0) {
return (
<SliderMark
key={m}
value={m}
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
{m}
</SliderMark>
);
} else if (i === sliderMarks.length - 1) {
return (
<SliderMark
key={m}
value={m}
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
{m}
</SliderMark>
);
} else {
return (
<SliderMark
key={m}
value={m}
sx={{
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
{m}
</SliderMark>
);
}
})}
</>
)}
<SliderTrack {...sliderTrackProps}>
<SliderFilledTrack />

View File

@ -143,7 +143,7 @@ const ControlNet = (props: ControlNetProps) => {
flexDir: 'column',
gap: 2,
w: 'full',
h: 24,
h: isExpanded ? 28 : 24,
paddingInlineStart: 1,
paddingInlineEnd: isExpanded ? 1 : 0,
pb: 2,
@ -153,13 +153,13 @@ const ControlNet = (props: ControlNetProps) => {
<ParamControlNetWeight
controlNetId={controlNetId}
weight={weight}
mini
mini={!isExpanded}
/>
<ParamControlNetBeginEnd
controlNetId={controlNetId}
beginStepPct={beginStepPct}
endStepPct={endStepPct}
mini
mini={!isExpanded}
/>
</Flex>
{!isExpanded && (

View File

@ -1,5 +1,6 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
@ -11,7 +12,7 @@ type Props = {
const ParamControlNetShouldAutoConfig = (props: Props) => {
const { controlNetId, shouldAutoConfig } = props;
const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const handleShouldAutoConfigChanged = useCallback(() => {
dispatch(controlNetAutoConfigToggled({ controlNetId }));
}, [controlNetId, dispatch]);
@ -22,6 +23,7 @@ const ParamControlNetShouldAutoConfig = (props: Props) => {
aria-label="Auto configure processor"
isChecked={shouldAutoConfig}
onChange={handleShouldAutoConfigChanged}
isDisabled={!isReady}
/>
);
};

View File

@ -1,4 +1,5 @@
import {
ChakraProps,
FormControl,
FormLabel,
HStack,
@ -10,14 +11,19 @@ import {
Tooltip,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import {
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { BiReset } from 'react-icons/bi';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
fontWeight: '500',
color: 'base.400',
};
type Props = {
controlNetId: string;
@ -29,7 +35,7 @@ type Props = {
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
const ParamControlNetBeginEnd = (props: Props) => {
const { controlNetId, beginStepPct, endStepPct, mini = false } = props;
const { controlNetId, beginStepPct, mini = false, endStepPct } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -75,12 +81,9 @@ const ParamControlNetBeginEnd = (props: Props) => {
<RangeSliderMark
value={0}
sx={{
fontSize: 'xs',
fontWeight: '500',
color: 'base.200',
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
mt: 1.5,
...SLIDER_MARK_STYLES,
}}
>
0%
@ -88,10 +91,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
<RangeSliderMark
value={0.5}
sx={{
fontSize: 'xs',
fontWeight: '500',
color: 'base.200',
mt: 1.5,
...SLIDER_MARK_STYLES,
}}
>
50%
@ -99,12 +99,9 @@ const ParamControlNetBeginEnd = (props: Props) => {
<RangeSliderMark
value={1}
sx={{
fontSize: 'xs',
fontWeight: '500',
color: 'base.200',
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
mt: 1.5,
...SLIDER_MARK_STYLES,
}}
>
100%
@ -112,16 +109,6 @@ const ParamControlNetBeginEnd = (props: Props) => {
</>
)}
</RangeSlider>
{!mini && (
<IAIIconButton
size="sm"
aria-label={t('accessibility.reset')}
tooltip={t('accessibility.reset')}
icon={<BiReset />}
onClick={handleStepPctReset}
/>
)}
</HStack>
</FormControl>
);

View File

@ -1,50 +1,85 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAICustomSelect, {
IAICustomSelectOption,
} from 'common/components/IAICustomSelect';
import IAISelect from 'common/components/IAISelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import {
CONTROLNET_MODELS,
ControlNetModelName,
} from 'features/controlNet/store/constants';
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import { configSelector } from 'features/system/store/configSelectors';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { ChangeEvent, memo, useCallback } from 'react';
type ParamControlNetModelProps = {
controlNetId: string;
model: ControlNetModelName;
};
const DATA: IAICustomSelectOption[] = map(CONTROLNET_MODELS, (m) => ({
value: m.type,
label: m.label,
tooltip: m.type,
}));
const selector = createSelector(configSelector, (config) => {
return map(CONTROLNET_MODELS, (m) => ({
key: m.label,
value: m.type,
})).filter((d) => !config.sd.disabledControlNetModels.includes(d.value));
});
// const DATA: IAICustomSelectOption[] = map(CONTROLNET_MODELS, (m) => ({
// value: m.type,
// label: m.label,
// tooltip: m.type,
// }));
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId, model } = props;
const controlNetModels = useAppSelector(selector);
const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const handleModelChanged = useCallback(
(val: string | null | undefined) => {
(e: ChangeEvent<HTMLSelectElement>) => {
// TODO: do not cast
const model = val as ControlNetModelName;
const model = e.target.value as ControlNetModelName;
dispatch(controlNetModelChanged({ controlNetId, model }));
},
[controlNetId, dispatch]
);
// const handleModelChanged = useCallback(
// (val: string | null | undefined) => {
// // TODO: do not cast
// const model = val as ControlNetModelName;
// dispatch(controlNetModelChanged({ controlNetId, model }));
// },
// [controlNetId, dispatch]
// );
return (
<IAICustomSelect
<IAISelect
tooltip={model}
tooltipProps={{ placement: 'top', hasArrow: true }}
data={DATA}
validValues={controlNetModels}
value={model}
onChange={handleModelChanged}
ellipsisPosition="start"
withCheckIcon
isDisabled={!isReady}
// ellipsisPosition="start"
// withCheckIcon
/>
);
// return (
// <IAICustomSelect
// tooltip={model}
// tooltipProps={{ placement: 'top', hasArrow: true }}
// data={DATA}
// value={model}
// onChange={handleModelChanged}
// isDisabled={!isReady}
// ellipsisPosition="start"
// withCheckIcon
// />
// );
};
export default memo(ParamControlNetModel);

View File

@ -1,62 +1,115 @@
import IAICustomSelect, {
IAICustomSelectOption,
} from 'common/components/IAICustomSelect';
import { memo, useCallback } from 'react';
import { ChangeEvent, memo, useCallback } from 'react';
import {
ControlNetProcessorNode,
ControlNetProcessorType,
} from '../../store/types';
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { map } from 'lodash-es';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import IAISelect from 'common/components/IAISelect';
import { createSelector } from '@reduxjs/toolkit';
import { configSelector } from 'features/system/store/configSelectors';
type ParamControlNetProcessorSelectProps = {
controlNetId: string;
processorNode: ControlNetProcessorNode;
};
const CONTROLNET_PROCESSOR_TYPES: IAICustomSelectOption[] = map(
CONTROLNET_PROCESSORS,
(p) => ({
value: p.type,
label: p.label,
tooltip: p.description,
})
).sort((a, b) =>
const CONTROLNET_PROCESSOR_TYPES = map(CONTROLNET_PROCESSORS, (p) => ({
value: p.type,
key: p.label,
})).sort((a, b) =>
// sort 'none' to the top
a.value === 'none'
? -1
: b.value === 'none'
? 1
: a.label.localeCompare(b.label)
a.value === 'none' ? -1 : b.value === 'none' ? 1 : a.key.localeCompare(b.key)
);
const selector = createSelector(configSelector, (config) => {
return map(CONTROLNET_PROCESSORS, (p) => ({
value: p.type,
key: p.label,
}))
.sort((a, b) =>
// sort 'none' to the top
a.value === 'none'
? -1
: b.value === 'none'
? 1
: a.key.localeCompare(b.key)
)
.filter((d) => !config.sd.disabledControlNetProcessors.includes(d.value));
});
// const CONTROLNET_PROCESSOR_TYPES: IAICustomSelectOption[] = map(
// CONTROLNET_PROCESSORS,
// (p) => ({
// value: p.type,
// label: p.label,
// tooltip: p.description,
// })
// ).sort((a, b) =>
// // sort 'none' to the top
// a.value === 'none'
// ? -1
// : b.value === 'none'
// ? 1
// : a.label.localeCompare(b.label)
// );
const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps
) => {
const { controlNetId, processorNode } = props;
const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const controlNetProcessors = useAppSelector(selector);
const handleProcessorTypeChanged = useCallback(
(v: string | null | undefined) => {
(e: ChangeEvent<HTMLSelectElement>) => {
dispatch(
controlNetProcessorTypeChanged({
controlNetId,
processorType: v as ControlNetProcessorType,
processorType: e.target.value as ControlNetProcessorType,
})
);
},
[controlNetId, dispatch]
);
// const handleProcessorTypeChanged = useCallback(
// (v: string | null | undefined) => {
// dispatch(
// controlNetProcessorTypeChanged({
// controlNetId,
// processorType: v as ControlNetProcessorType,
// })
// );
// },
// [controlNetId, dispatch]
// );
return (
<IAICustomSelect
<IAISelect
label="Processor"
value={processorNode.type ?? 'canny_image_processor'}
data={CONTROLNET_PROCESSOR_TYPES}
validValues={controlNetProcessors}
onChange={handleProcessorTypeChanged}
withCheckIcon
isDisabled={!isReady}
/>
);
// return (
// <IAICustomSelect
// label="Processor"
// value={processorNode.type ?? 'canny_image_processor'}
// data={CONTROLNET_PROCESSOR_TYPES}
// onChange={handleProcessorTypeChanged}
// withCheckIcon
// isDisabled={!isReady}
// />
// );
};
export default memo(ParamControlNetProcessorSelect);

View File

@ -20,36 +20,17 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
[controlNetId, dispatch]
);
const handleWeightReset = () => {
dispatch(controlNetWeightChanged({ controlNetId, weight: 1 }));
};
if (mini) {
return (
<IAISlider
label={'Weight'}
sliderFormLabelProps={{ pb: 1 }}
value={weight}
onChange={handleWeightChanged}
min={0}
max={1}
step={0.01}
/>
);
}
return (
<IAISlider
label="Weight"
label={'Weight'}
sliderFormLabelProps={{ pb: 2 }}
value={weight}
onChange={handleWeightChanged}
withInput
withReset
handleReset={handleWeightReset}
withSliderMarks
min={0}
min={-1}
max={1}
step={0.01}
withSliderMarks={!mini}
sliderMarks={[-1, 0, 1]}
/>
);
};

View File

@ -4,6 +4,7 @@ import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default;
@ -15,6 +16,7 @@ type CannyProcessorProps = {
const CannyProcessor = (props: CannyProcessorProps) => {
const { controlNetId, processorNode } = props;
const { low_threshold, high_threshold } = processorNode;
const isReady = useIsReadyToInvoke();
const processorChanged = useProcessorNodeChanged();
const handleLowThresholdChanged = useCallback(
@ -46,6 +48,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
return (
<ProcessorWrapper>
<IAISlider
isDisabled={!isReady}
label="Low Threshold"
value={low_threshold}
onChange={handleLowThresholdChanged}
@ -54,8 +57,10 @@ const CannyProcessor = (props: CannyProcessorProps) => {
min={0}
max={255}
withInput
withSliderMarks
/>
<IAISlider
isDisabled={!isReady}
label="High Threshold"
value={high_threshold}
onChange={handleHighThresholdChanged}
@ -64,6 +69,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
min={0}
max={255}
withInput
withSliderMarks
/>
</ProcessorWrapper>
);

View File

@ -4,6 +4,7 @@ import { RequiredContentShuffleImageProcessorInvocation } from 'features/control
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default;
@ -16,6 +17,7 @@ const ContentShuffleProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, w, h, f } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -93,6 +95,8 @@ const ContentShuffleProcessor = (props: Props) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISlider
label="Image Resolution"
@ -103,6 +107,8 @@ const ContentShuffleProcessor = (props: Props) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISlider
label="W"
@ -113,6 +119,8 @@ const ContentShuffleProcessor = (props: Props) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISlider
label="H"
@ -123,6 +131,8 @@ const ContentShuffleProcessor = (props: Props) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISlider
label="F"
@ -133,6 +143,8 @@ const ContentShuffleProcessor = (props: Props) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
</ProcessorWrapper>
);

View File

@ -5,6 +5,7 @@ import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/t
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default;
@ -18,7 +19,7 @@ const HedPreprocessor = (props: HedProcessorProps) => {
controlNetId,
processorNode: { detect_resolution, image_resolution, scribble },
} = props;
const isReady = useIsReadyToInvoke();
const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback(
@ -65,6 +66,8 @@ const HedPreprocessor = (props: HedProcessorProps) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISlider
label="Image Resolution"
@ -75,11 +78,14 @@ const HedPreprocessor = (props: HedProcessorProps) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISwitch
label="Scribble"
isChecked={scribble}
onChange={handleScribbleChanged}
isDisabled={!isReady}
/>
</ProcessorWrapper>
);

View File

@ -4,6 +4,7 @@ import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNe
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default;
@ -16,6 +17,7 @@ const LineartAnimeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -54,6 +56,8 @@ const LineartAnimeProcessor = (props: Props) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISlider
label="Image Resolution"
@ -64,6 +68,8 @@ const LineartAnimeProcessor = (props: Props) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
</ProcessorWrapper>
);

View File

@ -5,6 +5,7 @@ import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/sto
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default;
@ -17,6 +18,7 @@ const LineartProcessor = (props: LineartProcessorProps) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, coarse } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -62,6 +64,8 @@ const LineartProcessor = (props: LineartProcessorProps) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISlider
label="Image Resolution"
@ -72,11 +76,14 @@ const LineartProcessor = (props: LineartProcessorProps) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISwitch
label="Coarse"
isChecked={coarse}
onChange={handleCoarseChanged}
isDisabled={!isReady}
/>
</ProcessorWrapper>
);

View File

@ -4,6 +4,7 @@ import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/st
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default;
@ -16,6 +17,7 @@ const MediapipeFaceProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { max_faces, min_confidence } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleMaxFacesChanged = useCallback(
(v: number) => {
@ -50,6 +52,8 @@ const MediapipeFaceProcessor = (props: Props) => {
min={1}
max={20}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISlider
label="Min Confidence"
@ -61,6 +65,8 @@ const MediapipeFaceProcessor = (props: Props) => {
max={1}
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
/>
</ProcessorWrapper>
);

View File

@ -4,6 +4,7 @@ import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default;
@ -16,6 +17,7 @@ const MidasDepthProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { a_mult, bg_th } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleAMultChanged = useCallback(
(v: number) => {
@ -51,6 +53,8 @@ const MidasDepthProcessor = (props: Props) => {
max={20}
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISlider
label="bg_th"
@ -62,6 +66,8 @@ const MidasDepthProcessor = (props: Props) => {
max={20}
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
/>
</ProcessorWrapper>
);

View File

@ -4,6 +4,7 @@ import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default;
@ -16,6 +17,7 @@ const MlsdImageProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -76,6 +78,8 @@ const MlsdImageProcessor = (props: Props) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISlider
label="Image Resolution"
@ -86,6 +90,8 @@ const MlsdImageProcessor = (props: Props) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISlider
label="W"
@ -97,6 +103,8 @@ const MlsdImageProcessor = (props: Props) => {
max={1}
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISlider
label="H"
@ -108,6 +116,8 @@ const MlsdImageProcessor = (props: Props) => {
max={1}
step={0.01}
withInput
withSliderMarks
isDisabled={!isReady}
/>
</ProcessorWrapper>
);

View File

@ -4,6 +4,7 @@ import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/s
import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default;
@ -16,6 +17,7 @@ const NormalBaeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -54,6 +56,8 @@ const NormalBaeProcessor = (props: Props) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISlider
label="Image Resolution"
@ -64,6 +68,8 @@ const NormalBaeProcessor = (props: Props) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
</ProcessorWrapper>
);

View File

@ -5,6 +5,7 @@ import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/st
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default;
@ -17,6 +18,7 @@ const OpenposeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, hand_and_face } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -62,6 +64,8 @@ const OpenposeProcessor = (props: Props) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISlider
label="Image Resolution"
@ -72,11 +76,14 @@ const OpenposeProcessor = (props: Props) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISwitch
label="Hand and Face"
isChecked={hand_and_face}
onChange={handleHandAndFaceChanged}
isDisabled={!isReady}
/>
</ProcessorWrapper>
);

View File

@ -5,6 +5,7 @@ import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/
import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default;
@ -17,6 +18,7 @@ const PidiProcessor = (props: Props) => {
const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, scribble, safe } = processorNode;
const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback(
(v: number) => {
@ -69,6 +71,8 @@ const PidiProcessor = (props: Props) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISlider
label="Image Resolution"
@ -79,13 +83,20 @@ const PidiProcessor = (props: Props) => {
min={0}
max={4096}
withInput
withSliderMarks
isDisabled={!isReady}
/>
<IAISwitch
label="Scribble"
isChecked={scribble}
onChange={handleScribbleChanged}
/>
<IAISwitch label="Safe" isChecked={safe} onChange={handleSafeChanged} />
<IAISwitch
label="Safe"
isChecked={safe}
onChange={handleSafeChanged}
isDisabled={!isReady}
/>
</ProcessorWrapper>
);
};

View File

@ -23,7 +23,7 @@ type ControlNetProcessorsDict = Record<
*
* TODO: Generate from the OpenAPI schema
*/
export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
export const CONTROLNET_PROCESSORS = {
none: {
type: 'none',
label: 'none',
@ -129,7 +129,7 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
},
normalbae_image_processor: {
type: 'normalbae_image_processor',
label: 'NormalBae',
label: 'Normal BAE',
description: '',
default: {
id: 'normalbae_image_processor',
@ -181,7 +181,7 @@ type ControlNetModel = {
defaultProcessor?: ControlNetProcessorType;
};
export const CONTROLNET_MODELS: Record<string, ControlNetModel> = {
export const CONTROLNET_MODELS = {
'lllyasviel/control_v11p_sd15_canny': {
type: 'lllyasviel/control_v11p_sd15_canny',
label: 'Canny',
@ -208,7 +208,7 @@ export const CONTROLNET_MODELS: Record<string, ControlNetModel> = {
},
'lllyasviel/control_v11p_sd15_seg': {
type: 'lllyasviel/control_v11p_sd15_seg',
label: 'Segment Anything',
label: 'Segmentation',
},
'lllyasviel/control_v11p_sd15_lineart': {
type: 'lllyasviel/control_v11p_sd15_lineart',

View File

@ -21,6 +21,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
ColorField: 'color',
ControlField: 'control',
control: 'control',
cfg_scale: 'float',
control_weight: 'float',
};
const COLOR_TOKEN_VALUE = 500;

View File

@ -1,5 +1,5 @@
import { RootState } from 'app/store/store';
import { forEach, size } from 'lodash-es';
import { filter, forEach, size } from 'lodash-es';
import { CollectInvocation, ControlNetInvocation } from 'services/api';
import { NonNullableGraph } from '../types/types';
@ -12,8 +12,16 @@ export const addControlNetToLinearGraph = (
): void => {
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
const validControlNets = filter(
controlNets,
(c) =>
c.isEnabled &&
(Boolean(c.processedControlImage) ||
(c.processorType === 'none' && Boolean(c.controlImage)))
);
// Add ControlNet
if (isControlNetEnabled) {
if (isControlNetEnabled && validControlNets.length > 0) {
if (size(controlNets) > 1) {
const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT,

View File

@ -3,10 +3,11 @@ import { Scheduler } from 'app/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICustomSelect from 'common/components/IAICustomSelect';
import IAISelect from 'common/components/IAISelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setScheduler } from 'features/parameters/store/generationSlice';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react';
import { ChangeEvent, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
@ -35,24 +36,39 @@ const ParamScheduler = () => {
const { t } = useTranslation();
const handleChange = useCallback(
(v: string | null | undefined) => {
if (!v) {
return;
}
dispatch(setScheduler(v as Scheduler));
(e: ChangeEvent<HTMLSelectElement>) => {
dispatch(setScheduler(e.target.value as Scheduler));
},
[dispatch]
);
// const handleChange = useCallback(
// (v: string | null | undefined) => {
// if (!v) {
// return;
// }
// dispatch(setScheduler(v as Scheduler));
// },
// [dispatch]
// );
return (
<IAICustomSelect
<IAISelect
label={t('parameters.scheduler')}
value={scheduler}
data={allSchedulers}
validValues={allSchedulers}
onChange={handleChange}
withCheckIcon
/>
);
// return (
// <IAICustomSelect
// label={t('parameters.scheduler')}
// value={scheduler}
// data={allSchedulers}
// onChange={handleChange}
// withCheckIcon
// />
// );
};
export default memo(ParamScheduler);

View File

@ -1,5 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
import { memo, useCallback } from 'react';
import { ChangeEvent, memo, useCallback } from 'react';
import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next';
@ -11,6 +11,7 @@ import { generationSelector } from 'features/parameters/store/generationSelector
import IAICustomSelect, {
IAICustomSelectOption,
} from 'common/components/IAICustomSelect';
import IAISelect from 'common/components/IAISelect';
const selector = createSelector(
[(state: RootState) => state, generationSelector],
@ -18,12 +19,18 @@ const selector = createSelector(
const selectedModel = selectModelsById(state, generation.model);
const modelData = selectModelsAll(state)
.map<IAICustomSelectOption>((m) => ({
.map((m) => ({
value: m.name,
label: m.name,
tooltip: m.description,
key: m.name,
}))
.sort((a, b) => a.label.localeCompare(b.label));
.sort((a, b) => a.key.localeCompare(b.key));
// const modelData = selectModelsAll(state)
// .map<IAICustomSelectOption>((m) => ({
// value: m.name,
// label: m.name,
// tooltip: m.description,
// }))
// .sort((a, b) => a.label.localeCompare(b.label));
return {
selectedModel,
modelData,
@ -41,26 +48,43 @@ const ModelSelect = () => {
const { t } = useTranslation();
const { selectedModel, modelData } = useAppSelector(selector);
const handleChangeModel = useCallback(
(v: string | null | undefined) => {
if (!v) {
return;
}
dispatch(modelSelected(v));
(e: ChangeEvent<HTMLSelectElement>) => {
dispatch(modelSelected(e.target.value));
},
[dispatch]
);
// const handleChangeModel = useCallback(
// (v: string | null | undefined) => {
// if (!v) {
// return;
// }
// dispatch(modelSelected(v));
// },
// [dispatch]
// );
return (
<IAICustomSelect
<IAISelect
label={t('modelManager.model')}
tooltip={selectedModel?.description}
data={modelData}
validValues={modelData}
value={selectedModel?.name ?? ''}
onChange={handleChangeModel}
withCheckIcon={true}
tooltipProps={{ placement: 'top', hasArrow: true }}
/>
);
// return (
// <IAICustomSelect
// label={t('modelManager.model')}
// tooltip={selectedModel?.description}
// data={modelData}
// value={selectedModel?.name ?? ''}
// onChange={handleChangeModel}
// withCheckIcon={true}
// tooltipProps={{ placement: 'top', hasArrow: true }}
// />
// );
};
export default memo(ModelSelect);

View File

@ -10,6 +10,8 @@ export const initialConfigState: AppConfig = {
disabledSDFeatures: [],
canRestoreDeletedImagesFromBin: true,
sd: {
disabledControlNetModels: [],
disabledControlNetProcessors: [],
iterations: {
initial: 1,
min: 1,

View File

@ -47,3 +47,6 @@ export const languageSelector = createSelector(
(system) => system.language,
defaultSelectorOptions
);
export const isProcessingSelector = (state: RootState) =>
state.system.isProcessing;

View File

@ -7,30 +7,26 @@ import type { ImageField } from './ImageField';
/**
* Applies HED edge detection to image
*/
export type HedImageProcessorInvocation = {
export type HedImageprocessorInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'hed_image_processor';
/**
* The image to process
* image to process
*/
image?: ImageField;
/**
* The pixel resolution for detection
* pixel resolution for edge detection
*/
detect_resolution?: number;
/**
* The pixel resolution for the output image
* pixel resolution for output image
*/
image_resolution?: number;
/**
* Whether to use scribble mode
* whether to use scribble mode
*/
scribble?: boolean;
};

View File

@ -0,0 +1,33 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
/**
* Applies HED edge detection to image
*/
export type HedImageprocessorInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
type?: 'hed_image_processor';
/**
* image to process
*/
image?: ImageField;
/**
* pixel resolution for edge detection
*/
detect_resolution?: number;
/**
* pixel resolution for output image
*/
image_resolution?: number;
/**
* whether to use scribble mode
*/
scribble?: boolean;
};

View File

@ -30,7 +30,7 @@ const invokeAIMark = defineStyle((_props) => {
return {
fontSize: 'xs',
fontWeight: '500',
color: 'base.200',
color: 'base.400',
mt: 2,
insetInlineStart: 'unset',
};

View File

@ -42,8 +42,9 @@ dependencies = [
"controlnet-aux>=0.0.4",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"datasets",
"diffusers[torch]~=0.16.1",
"diffusers[torch]~=0.17.0",
"dnspython==2.2.1",
"easing-functions",
"einops",
"eventlet",
"facexlib",
@ -56,6 +57,7 @@ dependencies = [
"flaskwebgui==1.0.3",
"gfpgan==1.3.8",
"huggingface-hub>=0.11.1",
"matplotlib", # needed for plotting of Penner easing functions
"mediapipe", # needed for "mediapipeface" controlnet model
"npyscreen",
"numpy<1.24",