Merge branch 'main' into diffusers-upgrade

This commit is contained in:
blessedcoolant 2023-06-13 05:29:15 +12:00 committed by GitHub
commit 2a814d886b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 1277 additions and 361 deletions

View File

@ -4,7 +4,6 @@ from inspect import signature
import uvicorn import uvicorn
from invokeai.backend.util.logging import InvokeAILogger
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
@ -15,15 +14,19 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
from pathlib import Path from pathlib import Path
from pydantic.schema import schema from pydantic.schema import schema
#This should come early so that modules can log their initialization properly
from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
logger = InvokeAILogger.getLogger(config=app_config)
import invokeai.frontend.web as web_dir import invokeai.frontend.web as web_dir
from .api.dependencies import ApiDependencies from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images from .api.routers import sessions, models, images
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
from .services.config import InvokeAIAppConfig
logger = InvokeAILogger.getLogger()
# Create the app # Create the app
# TODO: create this all in a method so configuration/etc. can be passed in? # TODO: create this all in a method so configuration/etc. can be passed in?
@ -41,11 +44,6 @@ app.add_middleware(
socket_io = SocketIO(app) socket_io = SocketIO(app)
# initialize config
# this is a module global
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
# Add startup event to load dependencies # Add startup event to load dependencies
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():

View File

@ -13,14 +13,20 @@ from typing import (
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from pydantic.fields import Field from pydantic.fields import Field
# This should come early so that the logger can pick up its configuration options
from .services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger().getLogger(config=config)
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
import invokeai.backend.util.logging as logger
from .services.default_graphs import create_system_graphs from .services.default_graphs import create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@ -38,7 +44,7 @@ from .services.invocation_services import InvocationServices
from .services.invoker import Invoker from .services.invoker import Invoker
from .services.processor import DefaultInvocationProcessor from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage from .services.sqlite import SqliteItemStorage
from .services.config import InvokeAIAppConfig
class CliCommand(BaseModel): class CliCommand(BaseModel):
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
@ -47,7 +53,6 @@ class CliCommand(BaseModel):
class InvalidArgs(Exception): class InvalidArgs(Exception):
pass pass
def add_invocation_args(command_parser): def add_invocation_args(command_parser):
# Add linking capability # Add linking capability
command_parser.add_argument( command_parser.add_argument(
@ -191,14 +196,7 @@ def invoke_all(context: CliContext):
raise SessionError() raise SessionError()
logger = logger.InvokeAILogger.getLogger()
def invoke_cli(): def invoke_cli():
# this gets the basic configuration
config = InvokeAIAppConfig.get_config()
config.parse_args()
# get the optional list of invocations to execute on the command line # get the optional list of invocations to execute on the command line
parser = config.get_parser() parser = config.get_parser()

View File

@ -1,11 +1,12 @@
# InvokeAI nodes for ControlNet image preprocessors # InvokeAI nodes for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023 # initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import float
import numpy as np import numpy as np
from typing import Literal, Optional, Union, List from typing import Literal, Optional, Union, List
from PIL import Image, ImageFilter, ImageOps from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, validator
from ..models.image import ImageField, ImageCategory, ResourceOrigin from ..models.image import ImageField, ImageCategory, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import (
@ -14,6 +15,7 @@ from .baseinvocation import (
InvocationContext, InvocationContext,
InvocationConfig, InvocationConfig,
) )
from controlnet_aux import ( from controlnet_aux import (
CannyDetector, CannyDetector,
HEDdetector, HEDdetector,
@ -96,15 +98,32 @@ CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image") image: ImageField = Field(default=None, description="The control image")
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use") control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
control_weight: Optional[float] = Field(default=1, description="The weight given to the ControlNet") # control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1, begin_step_percent: float = Field(default=0, ge=0, le=1,
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
@validator("control_weight")
def abs_le_one(cls, v):
"""validate that all abs(values) are <=1"""
if isinstance(v, list):
for i in v:
if abs(i) > 1:
raise ValueError('all abs(control_weight) must be <= 1')
else:
if abs(v) > 1:
raise ValueError('abs(control_weight) must be <= 1')
return v
class Config: class Config:
schema_extra = { schema_extra = {
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"] "required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
"ui": {
"type_hints": {
"control_weight": "float",
# "control_weight": "number",
}
}
} }
@ -112,7 +131,7 @@ class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info""" """node output for ControlNet info"""
# fmt: off # fmt: off
type: Literal["control_output"] = "control_output" type: Literal["control_output"] = "control_output"
control: ControlField = Field(default=None, description="The output control image") control: ControlField = Field(default=None, description="The control info")
# fmt: on # fmt: on
@ -123,15 +142,28 @@ class ControlNetInvocation(BaseInvocation):
# Inputs # Inputs
image: ImageField = Field(default=None, description="The control image") image: ImageField = Field(default=None, description="The control image")
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny", control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
description="The ControlNet model to use") description="control model used")
control_weight: float = Field(default=1.0, ge=0, le=1, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode # TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
begin_step_percent: float = Field(default=0, ge=0, le=1, begin_step_percent: float = Field(default=0, ge=0, le=1,
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
# fmt: on # fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number",
"control_weight": "float",
}
},
}
def invoke(self, context: InvocationContext) -> ControlOutput: def invoke(self, context: InvocationContext) -> ControlOutput:
@ -161,7 +193,6 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
return image return image
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get_pil_image( raw_image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name self.image.image_origin, self.image.image_name
) )

View File

@ -174,22 +174,36 @@ class TextToLatentsInvocation(BaseInvocation):
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation") negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use") noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)") model: str = Field(default="", description="The model to use (currently ignored)")
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") control: Union[ControlField, List[ControlField]] = Field(default=None, description="The control to use")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on # fmt: on
@validator("cfg_scale")
def ge_one(cls, v):
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError('cfg_scale must be greater than 1')
else:
if v < 1:
raise ValueError('cfg_scale must be greater than 1')
return v
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {
"tags": ["latents", "image"], "tags": ["latents"],
"type_hints": { "type_hints": {
"model": "model", "model": "model",
"control": "control", "control": "control",
# "cfg_scale": "float",
"cfg_scale": "number"
} }
}, },
} }
@ -244,10 +258,10 @@ class TextToLatentsInvocation(BaseInvocation):
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
uc, unconditioned_embeddings=uc,
c, text_embeddings=c,
self.cfg_scale, guidance_scale=self.cfg_scale,
extra_conditioning_info, extra=extra_conditioning_info,
postprocessing_settings=PostprocessingSettings( postprocessing_settings=PostprocessingSettings(
threshold=0.0,#threshold, threshold=0.0,#threshold,
warmup=0.2,#warmup, warmup=0.2,#warmup,
@ -348,7 +362,8 @@ class TextToLatentsInvocation(BaseInvocation):
control_data = self.prep_control_data(model=model, context=context, control_input=self.control, control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape, latents_shape=noise.shape,
do_classifier_free_guidance=(self.cfg_scale >= 1.0)) # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,)
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings( result_latents, result_attention_map_saver = model.latents_from_embeddings(
@ -385,6 +400,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
"type_hints": { "type_hints": {
"model": "model", "model": "model",
"control": "control", "control": "control",
"cfg_scale": "number",
} }
}, },
} }
@ -403,10 +419,11 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
model = self.get_model(context.services.model_manager) model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model) conditioning_data = self.get_conditioning_data(context, model)
print("type of control input: ", type(self.control))
control_data = self.prep_control_data(model=model, context=context, control_input=self.control, control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape, latents_shape=noise.shape,
do_classifier_free_guidance=(self.cfg_scale >= 1.0)) # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
)
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size

View File

@ -0,0 +1,237 @@
import io
from typing import Literal, Optional, Any
# from PIL.Image import Image
import PIL.Image
from matplotlib.ticker import MaxNLocator
from matplotlib.figure import Figure
from pydantic import BaseModel, Field
import numpy as np
import matplotlib.pyplot as plt
from easing_functions import (
LinearInOut,
QuadEaseInOut, QuadEaseIn, QuadEaseOut,
CubicEaseInOut, CubicEaseIn, CubicEaseOut,
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
SineEaseInOut, SineEaseIn, SineEaseOut,
CircularEaseIn, CircularEaseInOut, CircularEaseOut,
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
BackEaseIn, BackEaseInOut, BackEaseOut,
BounceEaseIn, BounceEaseInOut, BounceEaseOut)
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from ...backend.util.logging import InvokeAILogger
from .collections import FloatCollectionOutput
class FloatLinearRangeInvocation(BaseInvocation):
"""Creates a range"""
type: Literal["float_range"] = "float_range"
# Inputs
start: float = Field(default=5, description="The first value of the range")
stop: float = Field(default=10, description="The last value of the range")
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
param_list = list(np.linspace(self.start, self.stop, self.steps))
return FloatCollectionOutput(
collection=param_list
)
EASING_FUNCTIONS_MAP = {
"Linear": LinearInOut,
"QuadIn": QuadEaseIn,
"QuadOut": QuadEaseOut,
"QuadInOut": QuadEaseInOut,
"CubicIn": CubicEaseIn,
"CubicOut": CubicEaseOut,
"CubicInOut": CubicEaseInOut,
"QuarticIn": QuarticEaseIn,
"QuarticOut": QuarticEaseOut,
"QuarticInOut": QuarticEaseInOut,
"QuinticIn": QuinticEaseIn,
"QuinticOut": QuinticEaseOut,
"QuinticInOut": QuinticEaseInOut,
"SineIn": SineEaseIn,
"SineOut": SineEaseOut,
"SineInOut": SineEaseInOut,
"CircularIn": CircularEaseIn,
"CircularOut": CircularEaseOut,
"CircularInOut": CircularEaseInOut,
"ExponentialIn": ExponentialEaseIn,
"ExponentialOut": ExponentialEaseOut,
"ExponentialInOut": ExponentialEaseInOut,
"ElasticIn": ElasticEaseIn,
"ElasticOut": ElasticEaseOut,
"ElasticInOut": ElasticEaseInOut,
"BackIn": BackEaseIn,
"BackOut": BackEaseOut,
"BackInOut": BackEaseInOut,
"BounceIn": BounceEaseIn,
"BounceOut": BounceEaseOut,
"BounceInOut": BounceEaseInOut,
}
EASING_FUNCTION_KEYS: Any = Literal[
tuple(list(EASING_FUNCTIONS_MAP.keys()))
]
# actually I think for now could just use CollectionOutput (which is list[Any]
class StepParamEasingInvocation(BaseInvocation):
"""Experimental per-step parameter easing for denoising steps"""
type: Literal["step_param_easing"] = "step_param_easing"
# Inputs
# fmt: off
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
num_steps: int = Field(default=20, description="number of denoising steps")
start_value: float = Field(default=0.0, description="easing starting value")
end_value: float = Field(default=1.0, description="easing ending value")
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
# if None, then start_value is used prior to easing start
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
# if None, then end value is used prior to easing end
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
mirror: bool = Field(default=False, description="include mirror of easing function")
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
show_easing_plot: bool = Field(default=False, description="show easing plot")
# fmt: on
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
log_diagnostics = False
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
start_step = int(np.round(self.num_steps * self.start_step_percent))
# convert from end_step_percent to nearest step >= (steps * end_step_percent)
# end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))
# end_step = int(np.ceil(self.num_steps * self.end_step_percent))
num_easing_steps = end_step - start_step + 1
# num_presteps = max(start_step - 1, 0)
num_presteps = start_step
num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
prelist = list(num_presteps * [self.pre_start_value])
postlist = list(num_poststeps * [self.post_end_value])
if log_diagnostics:
logger = InvokeAILogger.getLogger(name="StepParamEasing")
logger.debug("start_step: " + str(start_step))
logger.debug("end_step: " + str(end_step))
logger.debug("num_easing_steps: " + str(num_easing_steps))
logger.debug("num_presteps: " + str(num_presteps))
logger.debug("num_poststeps: " + str(num_poststeps))
logger.debug("prelist size: " + str(len(prelist)))
logger.debug("postlist size: " + str(len(postlist)))
logger.debug("prelist: " + str(prelist))
logger.debug("postlist: " + str(postlist))
easing_class = EASING_FUNCTIONS_MAP[self.easing]
if log_diagnostics:
logger.debug("easing class: " + str(easing_class))
easing_list = list()
if self.mirror: # "expected" mirroring
# if number of steps is even, squeeze duration down to (number_of_steps)/2
# and create reverse copy of list to append
# if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
# and create reverse copy of list[1:end-1]
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration))
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
easing_function = easing_class(start=self.start_value,
end=self.end_value,
duration=base_easing_duration - 1)
base_easing_vals = list()
for step_index in range(base_easing_duration):
easing_val = easing_function.ease(step_index)
base_easing_vals.append(easing_val)
if log_diagnostics:
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
if even_num_steps:
mirror_easing_vals = list(reversed(base_easing_vals))
else:
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
if log_diagnostics:
logger.debug("base easing vals: " + str(base_easing_vals))
logger.debug("mirror easing vals: " + str(mirror_easing_vals))
easing_list = base_easing_vals + mirror_easing_vals
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
# elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me))
# # half_ease_duration = round(num_easing_steps - 1 / 2)
# half_ease_duration = round((num_easing_steps - 1) / 2)
# easing_function = easing_class(start=self.start_value,
# end=self.end_value,
# duration=half_ease_duration,
# )
#
# mirror_function = easing_class(start=self.end_value,
# end=self.start_value,
# duration=half_ease_duration,
# )
# for step_index in range(num_easing_steps):
# if step_index <= half_ease_duration:
# step_val = easing_function.ease(step_index)
# else:
# step_val = mirror_function.ease(step_index - half_ease_duration)
# easing_list.append(step_val)
# if log_diagnostics: logger.debug(step_index, step_val)
#
else: # no mirroring (default)
easing_function = easing_class(start=self.start_value,
end=self.end_value,
duration=num_easing_steps - 1)
for step_index in range(num_easing_steps):
step_val = easing_function.ease(step_index)
easing_list.append(step_val)
if log_diagnostics:
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
if log_diagnostics:
logger.debug("prelist size: " + str(len(prelist)))
logger.debug("easing_list size: " + str(len(easing_list)))
logger.debug("postlist size: " + str(len(postlist)))
param_list = prelist + easing_list + postlist
if self.show_easing_plot:
plt.figure()
plt.xlabel("Step")
plt.ylabel("Param Value")
plt.title("Per-Step Values Based On Easing: " + self.easing)
plt.bar(range(len(param_list)), param_list)
# plt.plot(param_list)
ax = plt.gca()
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
im = PIL.Image.open(buf)
im.show()
buf.close()
# output array of size steps, each entry list[i] is param value for step i
return FloatCollectionOutput(
collection=param_list
)

View File

@ -1,4 +1,4 @@
from typing import Optional from typing import Optional, Union, List
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
@ -47,7 +47,9 @@ class ImageMetadata(BaseModel):
default=None, description="The seed used for noise generation." default=None, description="The seed used for noise generation."
) )
"""The seed used for noise generation""" """The seed used for noise generation"""
cfg_scale: Optional[StrictFloat] = Field( # cfg_scale: Optional[StrictFloat] = Field(
# cfg_scale: Union[float, list[float]] = Field(
cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field(
default=None, description="The classifier-free guidance scale." default=None, description="The classifier-free guidance scale."
) )
"""The classifier-free guidance scale""" """The classifier-free guidance scale"""

View File

@ -65,7 +65,6 @@ from typing import Optional, Union, List, get_args
def is_union_subtype(t1, t2): def is_union_subtype(t1, t2):
t1_args = get_args(t1) t1_args = get_args(t1)
t2_args = get_args(t2) t2_args = get_args(t2)
if not t1_args: if not t1_args:
# t1 is a single type # t1 is a single type
return t1 in t2_args return t1 in t2_args
@ -86,7 +85,6 @@ def is_list_or_contains_list(t):
for arg in t_args: for arg in t_args:
if get_origin(arg) is list: if get_origin(arg) is list:
return True return True
return False return False
@ -393,7 +391,7 @@ class Graph(BaseModel):
from_node = self.get_node(edge.source.node_id) from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id) to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError: except NodeNotFoundError:
raise InvalidEdgeError("One or both nodes don't exist") raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}")
# Validate that an edge to this node+field doesn't already exist # Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field) input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
@ -404,41 +402,41 @@ class Graph(BaseModel):
g = self.nx_graph_flat() g = self.nx_graph_flat()
g.add_edge(edge.source.node_id, edge.destination.node_id) g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g): if not nx.is_directed_acyclic_graph(g):
raise InvalidEdgeError(f'Edge creates a cycle in the graph') raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}')
# Validate that the field types are compatible # Validate that the field types are compatible
if not are_connections_compatible( if not are_connections_compatible(
from_node, edge.source.field, to_node, edge.destination.field from_node, edge.source.field, to_node, edge.destination.field
): ):
raise InvalidEdgeError(f'Fields are incompatible') raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if iterator output type matches iterator input type (if this edge results in both being set) # Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection": if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if not self._is_iterator_connection_valid( if not self._is_iterator_connection_valid(
edge.destination.node_id, new_input=edge.source edge.destination.node_id, new_input=edge.source
): ):
raise InvalidEdgeError(f'Iterator input type does not match iterator output type') raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if iterator input type matches output type (if this edge results in both being set) # Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge.source.field == "item": if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if not self._is_iterator_connection_valid( if not self._is_iterator_connection_valid(
edge.source.node_id, new_output=edge.destination edge.source.node_id, new_output=edge.destination
): ):
raise InvalidEdgeError(f'Iterator output type does not match iterator input type') raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if collector input type matches output type (if this edge results in both being set) # Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item": if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if not self._is_collector_connection_valid( if not self._is_collector_connection_valid(
edge.destination.node_id, new_input=edge.source edge.destination.node_id, new_input=edge.source
): ):
raise InvalidEdgeError(f'Collector output type does not match collector input type') raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if collector output type matches input type (if this edge results in both being set) # Validate if collector output type matches input type (if this edge results in both being set)
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection": if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
if not self._is_collector_connection_valid( if not self._is_collector_connection_valid(
edge.source.node_id, new_output=edge.destination edge.source.node_id, new_output=edge.destination
): ):
raise InvalidEdgeError(f'Collector input type does not match collector output type') raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
def has_node(self, node_path: str) -> bool: def has_node(self, node_path: str) -> bool:
@ -859,11 +857,9 @@ class GraphExecutionState(BaseModel):
if next_node is None: if next_node is None:
prepared_id = self._prepare() prepared_id = self._prepare()
# TODO: prepare multiple nodes at once? # Prepare as many nodes as we can
# while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation): while prepared_id is not None:
# prepared_id = self._prepare() prepared_id = self._prepare()
if prepared_id is not None:
next_node = self._get_next_node() next_node = self._get_next_node()
# Get values from edges # Get values from edges
@ -1010,14 +1006,30 @@ class GraphExecutionState(BaseModel):
# Get flattened source graph # Get flattened source graph
g = self.graph.nx_graph_flat() g = self.graph.nx_graph_flat()
# Find next unprepared node where all source nodes are executed # Find next node that:
# - was not already prepared
# - is not an iterate node whose inputs have not been executed
# - does not have an unexecuted iterate ancestor
sorted_nodes = nx.topological_sort(g) sorted_nodes = nx.topological_sort(g)
next_node_id = next( next_node_id = next(
( (
n n
for n in sorted_nodes for n in sorted_nodes
# exclude nodes that have already been prepared
if n not in self.source_prepared_mapping if n not in self.source_prepared_mapping
and all((e[0] in self.executed for e in g.in_edges(n))) # exclude iterate nodes whose inputs have not been executed
and not (
isinstance(self.graph.get_node(n), IterateInvocation) # `n` is an iterate node...
and not all((e[0] in self.executed for e in g.in_edges(n))) # ...that has unexecuted inputs
)
# exclude nodes who have unexecuted iterate ancestors
and not any(
(
isinstance(self.graph.get_node(a), IterateInvocation) # `a` is an iterate ancestor of `n`...
and a not in self.executed # ...that is not executed
for a in nx.ancestors(g, n) # for all ancestors `a` of node `n`
)
)
), ),
None, None,
) )
@ -1114,9 +1126,22 @@ class GraphExecutionState(BaseModel):
) )
def _get_next_node(self) -> Optional[BaseInvocation]: def _get_next_node(self) -> Optional[BaseInvocation]:
"""Gets the deepest node that is ready to be executed"""
g = self.execution_graph.nx_graph() g = self.execution_graph.nx_graph()
sorted_nodes = nx.topological_sort(g)
next_node = next((n for n in sorted_nodes if n not in self.executed), None) # Depth-first search with pre-order traversal is a depth-first topological sort
sorted_nodes = nx.dfs_preorder_nodes(g)
next_node = next(
(
n
for n in sorted_nodes
if n not in self.executed # the node must not already be executed...
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
),
None,
)
if next_node is None: if next_node is None:
return None return None

View File

@ -22,7 +22,8 @@ class Invoker:
def invoke( def invoke(
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
) -> str | None: ) -> str | None:
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute""" """Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
# Get the next invocation # Get the next invocation
invocation = graph_execution_state.next() invocation = graph_execution_state.next()

View File

@ -40,6 +40,7 @@ import invokeai.configs as configs
from invokeai.app.services.config import ( from invokeai.app.services.config import (
InvokeAIAppConfig, InvokeAIAppConfig,
) )
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
from invokeai.frontend.install.widgets import ( from invokeai.frontend.install.widgets import (
CenteredButtonPress, CenteredButtonPress,
@ -80,6 +81,7 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# or renaming it and then running invokeai-configure again. # or renaming it and then running invokeai-configure again.
""" """
logger=None
# -------------------------------------------- # --------------------------------------------
def postscript(errors: None): def postscript(errors: None):
@ -824,6 +826,7 @@ def main():
if opt.full_precision: if opt.full_precision:
invoke_args.extend(['--precision','float32']) invoke_args.extend(['--precision','float32'])
config.parse_args(invoke_args) config.parse_args(invoke_args)
logger = InvokeAILogger().getLogger(config=config)
errors = set() errors = set()

View File

@ -784,7 +784,7 @@ class ModelManager(object):
self.logger.info(f"Probing {thing} for import") self.logger.info(f"Probing {thing} for import")
if thing.startswith(("http:", "https:", "ftp:")): if str(thing).startswith(("http:", "https:", "ftp:")):
self.logger.info(f"{thing} appears to be a URL") self.logger.info(f"{thing} appears to be a URL")
model_path = self._resolve_path( model_path = self._resolve_path(
thing, "models/ldm/stable-diffusion-v1" thing, "models/ldm/stable-diffusion-v1"

View File

@ -218,7 +218,7 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
class ControlNetData: class ControlNetData:
model: ControlNetModel = Field(default=None) model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor= Field(default=None) image_tensor: torch.Tensor= Field(default=None)
weight: float = Field(default=1.0) weight: Union[float, List[float]]= Field(default=1.0)
begin_step_percent: float = Field(default=0.0) begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0) end_step_percent: float = Field(default=1.0)
@ -226,7 +226,7 @@ class ControlNetData:
class ConditioningData: class ConditioningData:
unconditioned_embeddings: torch.Tensor unconditioned_embeddings: torch.Tensor
text_embeddings: torch.Tensor text_embeddings: torch.Tensor
guidance_scale: float guidance_scale: Union[float, List[float]]
""" """
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
@ -662,7 +662,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
down_block_res_samples, mid_block_res_sample = None, None down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None: if control_data is not None:
if conditioning_data.guidance_scale > 1.0: # FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting
# if conditioning_data.guidance_scale > 1.0:
if conditioning_data.guidance_scale is not None:
# expand the latents input to control model if doing classifier free guidance # expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if # (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?) # classifier_free_guidance is <= 1.0 ?)
@ -679,13 +681,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# only apply controlnet if current step is within the controlnet's begin/end step range # only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step: if step_index >= first_control_step and step_index <= last_control_step:
# print("running controlnet", i, "for step", step_index) # print("running controlnet", i, "for step", step_index)
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index]
else:
# if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight
down_samples, mid_sample = control_datum.model( down_samples, mid_sample = control_datum.model(
sample=latent_control_input, sample=latent_control_input,
timestep=timestep, timestep=timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]), conditioning_data.text_embeddings]),
controlnet_cond=control_datum.image_tensor, controlnet_cond=control_datum.image_tensor,
conditioning_scale=control_datum.weight, conditioning_scale=controlnet_weight,
# cross_attention_kwargs, # cross_attention_kwargs,
guess_mode=False, guess_mode=False,
return_dict=False, return_dict=False,

View File

@ -1,7 +1,7 @@
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from math import ceil from math import ceil
from typing import Any, Callable, Dict, Optional, Union from typing import Any, Callable, Dict, Optional, Union, List
import numpy as np import numpy as np
import torch import torch
@ -180,7 +180,8 @@ class InvokeAIDiffuserComponent:
sigma: torch.Tensor, sigma: torch.Tensor,
unconditioning: Union[torch.Tensor, dict], unconditioning: Union[torch.Tensor, dict],
conditioning: Union[torch.Tensor, dict], conditioning: Union[torch.Tensor, dict],
unconditional_guidance_scale: float, # unconditional_guidance_scale: float,
unconditional_guidance_scale: Union[float, List[float]],
step_index: Optional[int] = None, step_index: Optional[int] = None,
total_step_count: Optional[int] = None, total_step_count: Optional[int] = None,
**kwargs, **kwargs,
@ -195,6 +196,11 @@ class InvokeAIDiffuserComponent:
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning. :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
""" """
if isinstance(unconditional_guidance_scale, list):
guidance_scale = unconditional_guidance_scale[step_index]
else:
guidance_scale = unconditional_guidance_scale
cross_attention_control_types_to_do = [] cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None: if self.cross_attention_control_context is not None:
@ -243,7 +249,8 @@ class InvokeAIDiffuserComponent:
) )
combined_next_x = self._combine( combined_next_x = self._combine(
unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale # unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
unconditioned_next_x, conditioned_next_x, guidance_scale
) )
return combined_next_x return combined_next_x

View File

@ -1,6 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team # Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
"""invokeai.util.logging """
invokeai.util.logging
Logging class for InvokeAI that produces console messages Logging class for InvokeAI that produces console messages
@ -11,6 +12,7 @@ from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization
(or) (or)
logger = InvokeAILogger.getLogger(__name__) // To use the filename logger = InvokeAILogger.getLogger(__name__) // To use the filename
logger.configure()
logger.critical('this is critical') // Critical Message logger.critical('this is critical') // Critical Message
logger.error('this is an error') // Error Message logger.error('this is an error') // Error Message
@ -28,6 +30,149 @@ Console messages:
Alternate Method (in this case the logger name will be set to InvokeAI): Alternate Method (in this case the logger name will be set to InvokeAI):
import invokeai.backend.util.logging as IAILogger import invokeai.backend.util.logging as IAILogger
IAILogger.debug('this is a debugging message') IAILogger.debug('this is a debugging message')
## Configuration
The default configuration will print to stderr on the console. To add
additional logging handlers, call getLogger with an initialized InvokeAIAppConfig
object:
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger.getLogger(config=config)
### Three command-line options control logging:
`--log_handlers <handler1> <handler2> ...`
This option activates one or more log handlers. Options are "console", "file", "syslog" and "http". To specify more than one, separate them by spaces:
```
invokeai-web --log_handlers console syslog=/dev/log file=C:\\Users\\fred\\invokeai.log
```
The format of these options is described below.
### `--log_format {plain|color|legacy|syslog}`
This controls the format of log messages written to the console. Only the "console" log handler is currently affected by this setting.
* "plain" provides formatted messages like this:
```bash
[2023-05-24 23:18:2[2023-05-24 23:18:50,352]::[InvokeAI]::DEBUG --> this is a debug message
[2023-05-24 23:18:50,352]::[InvokeAI]::INFO --> this is an informational messages
[2023-05-24 23:18:50,352]::[InvokeAI]::WARNING --> this is a warning
[2023-05-24 23:18:50,352]::[InvokeAI]::ERROR --> this is an error
[2023-05-24 23:18:50,352]::[InvokeAI]::CRITICAL --> this is a critical error
```
* "color" produces similar output, but the text will be color coded to indicate the severity of the message.
* "legacy" produces output similar to InvokeAI versions 2.3 and earlier:
```
### this is a critical error
*** this is an error
** this is a warning
>> this is an informational messages
| this is a debug message
```
* "syslog" produces messages suitable for syslog entries:
```bash
InvokeAI [2691178] <CRITICAL> this is a critical error
InvokeAI [2691178] <ERROR> this is an error
InvokeAI [2691178] <WARNING> this is a warning
InvokeAI [2691178] <INFO> this is an informational messages
InvokeAI [2691178] <DEBUG> this is a debug message
```
(note that the date, time and hostname will be added by the syslog system)
### `--log_level {debug|info|warning|error|critical}`
Providing this command-line option will cause only messages at the specified level or above to be emitted.
## Console logging
When "console" is provided to `--log_handlers`, messages will be written to the command line window in which InvokeAI was launched. By default, the color formatter will be used unless overridden by `--log_format`.
## File logging
When "file" is provided to `--log_handlers`, entries will be written to the file indicated in the path argument. By default, the "plain" format will be used:
```bash
invokeai-web --log_handlers file=/var/log/invokeai.log
```
## Syslog logging
When "syslog" is requested, entries will be sent to the syslog system. There are a variety of ways to control where the log message is sent:
* Send to the local machine using the `/dev/log` socket:
```
invokeai-web --log_handlers syslog=/dev/log
```
* Send to the local machine using a UDP message:
```
invokeai-web --log_handlers syslog=localhost
```
* Send to the local machine using a UDP message on a nonstandard port:
```
invokeai-web --log_handlers syslog=localhost:512
```
* Send to a remote machine named "loghost" on the local LAN using facility LOG_USER and UDP packets:
```
invokeai-web --log_handlers syslog=loghost,facility=LOG_USER,socktype=SOCK_DGRAM
```
This can be abbreviated `syslog=loghost`, as LOG_USER and SOCK_DGRAM are defaults.
* Send to a remote machine named "loghost" using the facility LOCAL0 and using a TCP socket:
```
invokeai-web --log_handlers syslog=loghost,facility=LOG_LOCAL0,socktype=SOCK_STREAM
```
If no arguments are specified (just a bare "syslog"), then the logging system will look for a UNIX socket named `/dev/log`, and if not found try to send a UDP message to `localhost`. The Macintosh OS used to support logging to a socket named `/var/run/syslog`, but this feature has since been disabled.
## Web logging
If you have access to a web server that is configured to log messages when a particular URL is requested, you can log using the "http" method:
```
invokeai-web --log_handlers http=http://my.server/path/to/logger,method=POST
```
The optional [,method=] part can be used to specify whether the URL accepts GET (default) or POST messages.
Currently password authentication and SSL are not supported.
## Using the configuration file
You can set and forget logging options by adding a "Logging" section to `invokeai.yaml`:
```
InvokeAI:
[... other settings...]
Logging:
log_handlers:
- console
- syslog=/dev/log
log_level: info
log_format: color
```
""" """
import logging.handlers import logging.handlers
@ -180,14 +325,17 @@ class InvokeAILogger(object):
loggers = dict() loggers = dict()
@classmethod @classmethod
def getLogger(cls, name: str = 'InvokeAI') -> logging.Logger: def getLogger(cls,
config = get_invokeai_config() name: str = 'InvokeAI',
config: InvokeAIAppConfig=InvokeAIAppConfig.get_config())->logging.Logger:
if name not in cls.loggers: if name in cls.loggers:
logger = cls.loggers[name]
logger.handlers.clear()
else:
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(config.log_level.upper()) # yes, strings work here logger.setLevel(config.log_level.upper()) # yes, strings work here
for ch in cls.getLoggers(config): for ch in cls.getLoggers(config):
logger.addHandler(ch) logger.addHandler(ch)
cls.loggers[name] = logger cls.loggers[name] = logger
return cls.loggers[name] return cls.loggers[name]
@ -199,9 +347,11 @@ class InvokeAILogger(object):
handler_name,*args = handler.split('=',2) handler_name,*args = handler.split('=',2)
args = args[0] if len(args) > 0 else None args = args[0] if len(args) > 0 else None
# console is the only handler that gets a custom formatter # console and file get the fancy formatter.
# syslog gets a simple one
# http gets no custom formatter
formatter = LOG_FORMATTERS[config.log_format]
if handler_name=='console': if handler_name=='console':
formatter = LOG_FORMATTERS[config.log_format]
ch = logging.StreamHandler() ch = logging.StreamHandler()
ch.setFormatter(formatter()) ch.setFormatter(formatter())
handlers.append(ch) handlers.append(ch)
@ -212,7 +362,9 @@ class InvokeAILogger(object):
handlers.append(ch) handlers.append(ch)
elif handler_name=='file': elif handler_name=='file':
handlers.append(cls._parse_file_args(args)) ch = cls._parse_file_args(args)
ch.setFormatter(formatter())
handlers.append(ch)
elif handler_name=='http': elif handler_name=='http':
handlers.append(cls._parse_http_args(args)) handlers.append(cls._parse_http_args(args))

View File

@ -28,7 +28,7 @@ import torch
from npyscreen import widget from npyscreen import widget
from omegaconf import OmegaConf from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.install.model_install_backend import ( from invokeai.backend.install.model_install_backend import (
Dataset_path, Dataset_path,
@ -939,6 +939,7 @@ def main():
if opt.full_precision: if opt.full_precision:
invoke_args.extend(['--precision','float32']) invoke_args.extend(['--precision','float32'])
config.parse_args(invoke_args) config.parse_args(invoke_args)
logger = InvokeAILogger().getLogger(config=config)
if not (config.root_dir / config.conf_path.parent).exists(): if not (config.root_dir / config.conf_path.parent).exists():
logger.info( logger.info(

View File

@ -22,6 +22,7 @@ import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import GlobalHotkeys from './GlobalHotkeys'; import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster'; import Toaster from './Toaster';
import DeleteImageModal from 'features/gallery/components/DeleteImageModal'; import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
const DEFAULT_CONFIG = {}; const DEFAULT_CONFIG = {};
@ -66,10 +67,17 @@ const App = ({
setIsReady(true); setIsReady(true);
} }
if (isApplicationReady) {
// TODO: This is a jank fix for canvas not filling the screen on first load
setTimeout(() => {
dispatch(requestCanvasRescale());
}, 200);
}
return () => { return () => {
setIsReady && setIsReady(false); setIsReady && setIsReady(false);
}; };
}, [isApplicationReady, setIsReady]); }, [dispatch, isApplicationReady, setIsReady]);
return ( return (
<> <>

View File

@ -40,11 +40,11 @@ const ImageDndContext = (props: ImageDndContextProps) => {
); );
const mouseSensor = useSensor(MouseSensor, { const mouseSensor = useSensor(MouseSensor, {
activationConstraint: { delay: 250, tolerance: 5 }, activationConstraint: { delay: 150, tolerance: 5 },
}); });
const touchSensor = useSensor(TouchSensor, { const touchSensor = useSensor(TouchSensor, {
activationConstraint: { delay: 250, tolerance: 5 }, activationConstraint: { delay: 150, tolerance: 5 },
}); });
// TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos // TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos
// Alternatively, fix `rectIntersection` collection detection to work with the drag overlay // Alternatively, fix `rectIntersection` collection detection to work with the drag overlay

View File

@ -1,3 +1,7 @@
import {
CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
} from 'features/controlNet/store/constants';
import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName } from 'features/ui/store/tabMap';
import { O } from 'ts-toolbelt'; import { O } from 'ts-toolbelt';
@ -117,6 +121,8 @@ export type AppConfig = {
canRestoreDeletedImagesFromBin: boolean; canRestoreDeletedImagesFromBin: boolean;
sd: { sd: {
defaultModel?: string; defaultModel?: string;
disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[];
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
iterations: { iterations: {
initial: number; initial: number;
min: number; min: number;

View File

@ -2,7 +2,6 @@ import { CheckIcon, ChevronUpIcon } from '@chakra-ui/icons';
import { import {
Box, Box,
Flex, Flex,
FlexProps,
FormControl, FormControl,
FormControlProps, FormControlProps,
FormLabel, FormLabel,
@ -16,42 +15,64 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom'; import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom';
import { useSelect } from 'downshift'; import { useSelect } from 'downshift';
import { isString } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useMemo } from 'react'; import { memo, useLayoutEffect, useMemo } from 'react';
import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles'; import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles';
export type ItemTooltips = { [key: string]: string }; export type ItemTooltips = { [key: string]: string };
export type IAICustomSelectOption = {
value: string;
label: string;
tooltip?: string;
};
type IAICustomSelectProps = { type IAICustomSelectProps = {
label?: string; label?: string;
items: string[]; value: string;
itemTooltips?: ItemTooltips; data: IAICustomSelectOption[] | string[];
selectedItem: string; onChange: (v: string) => void;
setSelectedItem: (v: string | null | undefined) => void;
withCheckIcon?: boolean; withCheckIcon?: boolean;
formControlProps?: FormControlProps; formControlProps?: FormControlProps;
buttonProps?: FlexProps;
tooltip?: string; tooltip?: string;
tooltipProps?: Omit<TooltipProps, 'children'>; tooltipProps?: Omit<TooltipProps, 'children'>;
ellipsisPosition?: 'start' | 'end'; ellipsisPosition?: 'start' | 'end';
isDisabled?: boolean;
}; };
const IAICustomSelect = (props: IAICustomSelectProps) => { const IAICustomSelect = (props: IAICustomSelectProps) => {
const { const {
label, label,
items,
itemTooltips,
setSelectedItem,
selectedItem,
withCheckIcon, withCheckIcon,
formControlProps, formControlProps,
tooltip, tooltip,
buttonProps,
tooltipProps, tooltipProps,
ellipsisPosition = 'end', ellipsisPosition = 'end',
data,
value,
onChange,
isDisabled = false,
} = props; } = props;
const values = useMemo(() => {
return data.map<IAICustomSelectOption>((v) => {
if (isString(v)) {
return { value: v, label: v };
}
return v;
});
}, [data]);
const stringValues = useMemo(() => {
return values.map((v) => v.value);
}, [values]);
const valueData = useMemo(() => {
return values.find((v) => v.value === value);
}, [values, value]);
const { const {
isOpen, isOpen,
getToggleButtonProps, getToggleButtonProps,
@ -60,17 +81,24 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
highlightedIndex, highlightedIndex,
getItemProps, getItemProps,
} = useSelect({ } = useSelect({
items, items: stringValues,
selectedItem, selectedItem: value,
onSelectedItemChange: ({ selectedItem: newSelectedItem }) => onSelectedItemChange: ({ selectedItem: newSelectedItem }) => {
setSelectedItem(newSelectedItem), newSelectedItem && onChange(newSelectedItem);
},
}); });
const { refs, floatingStyles } = useFloating<HTMLButtonElement>({ const { refs, floatingStyles, update } = useFloating<HTMLButtonElement>({
whileElementsMounted: autoUpdate, // whileElementsMounted: autoUpdate,
middleware: [offset(4), shift({ crossAxis: true, padding: 8 })], middleware: [offset(4), shift({ crossAxis: true, padding: 8 })],
}); });
useLayoutEffect(() => {
if (isOpen && refs.reference.current && refs.floating.current) {
return autoUpdate(refs.reference.current, refs.floating.current, update);
}
}, [isOpen, update, refs.floating, refs.reference]);
const labelTextDirection = useMemo(() => { const labelTextDirection = useMemo(() => {
if (ellipsisPosition === 'start') { if (ellipsisPosition === 'start') {
return document.dir === 'rtl' ? 'ltr' : 'rtl'; return document.dir === 'rtl' ? 'ltr' : 'rtl';
@ -93,8 +121,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
)} )}
<Tooltip label={tooltip} {...tooltipProps}> <Tooltip label={tooltip} {...tooltipProps}>
<Flex <Flex
{...getToggleButtonProps({ ref: refs.setReference })} {...getToggleButtonProps({ ref: refs.reference })}
{...buttonProps}
sx={{ sx={{
alignItems: 'center', alignItems: 'center',
userSelect: 'none', userSelect: 'none',
@ -105,6 +132,8 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
px: 2, px: 2,
gap: 2, gap: 2,
justifyContent: 'space-between', justifyContent: 'space-between',
pointerEvents: isDisabled ? 'none' : undefined,
opacity: isDisabled ? 0.5 : undefined,
...getInputOutlineStyles(), ...getInputOutlineStyles(),
}} }}
> >
@ -119,7 +148,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
direction: labelTextDirection, direction: labelTextDirection,
}} }}
> >
{selectedItem} {valueData?.label}
</Text> </Text>
<ChevronUpIcon <ChevronUpIcon
sx={{ sx={{
@ -135,7 +164,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
{isOpen && ( {isOpen && (
<List <List
as={Flex} as={Flex}
ref={refs.setFloating} ref={refs.floating}
sx={{ sx={{
...floatingStyles, ...floatingStyles,
top: 0, top: 0,
@ -155,8 +184,8 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
}} }}
> >
<OverlayScrollbarsComponent> <OverlayScrollbarsComponent>
{items.map((item, index) => { {values.map((v, index) => {
const isSelected = selectedItem === item; const isSelected = value === v.value;
const isHighlighted = highlightedIndex === index; const isHighlighted = highlightedIndex === index;
const fontWeight = isSelected ? 700 : 500; const fontWeight = isSelected ? 700 : 500;
const bg = isHighlighted const bg = isHighlighted
@ -166,9 +195,9 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
: undefined; : undefined;
return ( return (
<Tooltip <Tooltip
isDisabled={!itemTooltips} isDisabled={!v.tooltip}
key={`${item}${index}`} key={`${v.value}${index}`}
label={itemTooltips?.[item]} label={v.tooltip}
hasArrow hasArrow
placement="right" placement="right"
> >
@ -182,8 +211,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
transitionProperty: 'common', transitionProperty: 'common',
transitionDuration: '0.15s', transitionDuration: '0.15s',
}} }}
key={`${item}${index}`} {...getItemProps({ item: v.value, index })}
{...getItemProps({ item, index })}
> >
{withCheckIcon ? ( {withCheckIcon ? (
<Grid gridTemplateColumns="1.25rem auto"> <Grid gridTemplateColumns="1.25rem auto">
@ -198,7 +226,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
fontWeight, fontWeight,
}} }}
> >
{item} {v.label}
</Text> </Text>
</GridItem> </GridItem>
</Grid> </Grid>
@ -210,7 +238,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
fontWeight, fontWeight,
}} }}
> >
{item} {v.label}
</Text> </Text>
)} )}
</ListItem> </ListItem>

View File

@ -1,4 +1,5 @@
import { import {
ChakraProps,
FormControl, FormControl,
FormControlProps, FormControlProps,
FormLabel, FormLabel,
@ -39,6 +40,11 @@ import { BiReset } from 'react-icons/bi';
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton'; import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
import { roundDownToMultiple } from 'common/util/roundDownToMultiple'; import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
};
export type IAIFullSliderProps = { export type IAIFullSliderProps = {
label?: string; label?: string;
value: number; value: number;
@ -57,6 +63,7 @@ export type IAIFullSliderProps = {
hideTooltip?: boolean; hideTooltip?: boolean;
isCompact?: boolean; isCompact?: boolean;
isDisabled?: boolean; isDisabled?: boolean;
sliderMarks?: number[];
sliderFormControlProps?: FormControlProps; sliderFormControlProps?: FormControlProps;
sliderFormLabelProps?: FormLabelProps; sliderFormLabelProps?: FormLabelProps;
sliderMarkProps?: Omit<SliderMarkProps, 'value'>; sliderMarkProps?: Omit<SliderMarkProps, 'value'>;
@ -88,6 +95,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
hideTooltip = false, hideTooltip = false,
isCompact = false, isCompact = false,
isDisabled = false, isDisabled = false,
sliderMarks,
handleReset, handleReset,
sliderFormControlProps, sliderFormControlProps,
sliderFormLabelProps, sliderFormLabelProps,
@ -198,14 +206,14 @@ const IAISlider = (props: IAIFullSliderProps) => {
isDisabled={isDisabled} isDisabled={isDisabled}
{...rest} {...rest}
> >
{withSliderMarks && ( {withSliderMarks && !sliderMarks && (
<> <>
<SliderMark <SliderMark
value={min} value={min}
sx={{ sx={{
insetInlineStart: '0 !important', insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important', insetInlineEnd: 'unset !important',
mt: 1.5, ...SLIDER_MARK_STYLES,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >
@ -216,7 +224,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{ sx={{
insetInlineStart: 'unset !important', insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important', insetInlineEnd: '0 !important',
mt: 1.5, ...SLIDER_MARK_STYLES,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >
@ -224,6 +232,56 @@ const IAISlider = (props: IAIFullSliderProps) => {
</SliderMark> </SliderMark>
</> </>
)} )}
{withSliderMarks && sliderMarks && (
<>
{sliderMarks.map((m, i) => {
if (i === 0) {
return (
<SliderMark
key={m}
value={m}
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
{m}
</SliderMark>
);
} else if (i === sliderMarks.length - 1) {
return (
<SliderMark
key={m}
value={m}
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
{m}
</SliderMark>
);
} else {
return (
<SliderMark
key={m}
value={m}
sx={{
...SLIDER_MARK_STYLES,
}}
{...sliderMarkProps}
>
{m}
</SliderMark>
);
}
})}
</>
)}
<SliderTrack {...sliderTrackProps}> <SliderTrack {...sliderTrackProps}>
<SliderFilledTrack /> <SliderFilledTrack />

View File

@ -16,7 +16,6 @@ import {
setShouldShowIntermediates, setShouldShowIntermediates,
setShouldSnapToGrid, setShouldSnapToGrid,
} from 'features/canvas/store/canvasSlice'; } from 'features/canvas/store/canvasSlice';
import EmptyTempFolderButtonModal from 'features/system/components/ClearTempFolderButtonModal';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { ChangeEvent } from 'react'; import { ChangeEvent } from 'react';
@ -159,7 +158,6 @@ const IAICanvasSettingsButtonPopover = () => {
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))} onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
/> />
<ClearCanvasHistoryButtonModal /> <ClearCanvasHistoryButtonModal />
<EmptyTempFolderButtonModal />
</Flex> </Flex>
</IAIPopover> </IAIPopover>
); );

View File

@ -30,7 +30,10 @@ import {
} from './canvasTypes'; } from './canvasTypes';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { sessionCanceled } from 'services/thunks/session'; import { sessionCanceled } from 'services/thunks/session';
import { setShouldUseCanvasBetaLayout } from 'features/ui/store/uiSlice'; import {
setActiveTab,
setShouldUseCanvasBetaLayout,
} from 'features/ui/store/uiSlice';
import { imageUrlsReceived } from 'services/thunks/image'; import { imageUrlsReceived } from 'services/thunks/image';
export const initialLayerState: CanvasLayerState = { export const initialLayerState: CanvasLayerState = {
@ -857,6 +860,11 @@ export const canvasSlice = createSlice({
builder.addCase(setShouldUseCanvasBetaLayout, (state, action) => { builder.addCase(setShouldUseCanvasBetaLayout, (state, action) => {
state.doesCanvasNeedScaling = true; state.doesCanvasNeedScaling = true;
}); });
builder.addCase(setActiveTab, (state, action) => {
state.doesCanvasNeedScaling = true;
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } = const { image_name, image_origin, image_url, thumbnail_url } =
action.payload; action.payload;

View File

@ -143,7 +143,7 @@ const ControlNet = (props: ControlNetProps) => {
flexDir: 'column', flexDir: 'column',
gap: 2, gap: 2,
w: 'full', w: 'full',
h: 24, h: isExpanded ? 28 : 24,
paddingInlineStart: 1, paddingInlineStart: 1,
paddingInlineEnd: isExpanded ? 1 : 0, paddingInlineEnd: isExpanded ? 1 : 0,
pb: 2, pb: 2,
@ -153,13 +153,13 @@ const ControlNet = (props: ControlNetProps) => {
<ParamControlNetWeight <ParamControlNetWeight
controlNetId={controlNetId} controlNetId={controlNetId}
weight={weight} weight={weight}
mini mini={!isExpanded}
/> />
<ParamControlNetBeginEnd <ParamControlNetBeginEnd
controlNetId={controlNetId} controlNetId={controlNetId}
beginStepPct={beginStepPct} beginStepPct={beginStepPct}
endStepPct={endStepPct} endStepPct={endStepPct}
mini mini={!isExpanded}
/> />
</Flex> </Flex>
{!isExpanded && ( {!isExpanded && (

View File

@ -1,5 +1,6 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice'; import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
@ -11,7 +12,7 @@ type Props = {
const ParamControlNetShouldAutoConfig = (props: Props) => { const ParamControlNetShouldAutoConfig = (props: Props) => {
const { controlNetId, shouldAutoConfig } = props; const { controlNetId, shouldAutoConfig } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const handleShouldAutoConfigChanged = useCallback(() => { const handleShouldAutoConfigChanged = useCallback(() => {
dispatch(controlNetAutoConfigToggled({ controlNetId })); dispatch(controlNetAutoConfigToggled({ controlNetId }));
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
@ -22,6 +23,7 @@ const ParamControlNetShouldAutoConfig = (props: Props) => {
aria-label="Auto configure processor" aria-label="Auto configure processor"
isChecked={shouldAutoConfig} isChecked={shouldAutoConfig}
onChange={handleShouldAutoConfigChanged} onChange={handleShouldAutoConfigChanged}
isDisabled={!isReady}
/> />
); );
}; };

View File

@ -1,4 +1,5 @@
import { import {
ChakraProps,
FormControl, FormControl,
FormLabel, FormLabel,
HStack, HStack,
@ -10,14 +11,19 @@ import {
Tooltip, Tooltip,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { import {
controlNetBeginStepPctChanged, controlNetBeginStepPctChanged,
controlNetEndStepPctChanged, controlNetEndStepPctChanged,
} from 'features/controlNet/store/controlNetSlice'; } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { BiReset } from 'react-icons/bi';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
fontWeight: '500',
color: 'base.400',
};
type Props = { type Props = {
controlNetId: string; controlNetId: string;
@ -29,7 +35,7 @@ type Props = {
const formatPct = (v: number) => `${Math.round(v * 100)}%`; const formatPct = (v: number) => `${Math.round(v * 100)}%`;
const ParamControlNetBeginEnd = (props: Props) => { const ParamControlNetBeginEnd = (props: Props) => {
const { controlNetId, beginStepPct, endStepPct, mini = false } = props; const { controlNetId, beginStepPct, mini = false, endStepPct } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -75,12 +81,9 @@ const ParamControlNetBeginEnd = (props: Props) => {
<RangeSliderMark <RangeSliderMark
value={0} value={0}
sx={{ sx={{
fontSize: 'xs',
fontWeight: '500',
color: 'base.200',
insetInlineStart: '0 !important', insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important', insetInlineEnd: 'unset !important',
mt: 1.5, ...SLIDER_MARK_STYLES,
}} }}
> >
0% 0%
@ -88,10 +91,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
<RangeSliderMark <RangeSliderMark
value={0.5} value={0.5}
sx={{ sx={{
fontSize: 'xs', ...SLIDER_MARK_STYLES,
fontWeight: '500',
color: 'base.200',
mt: 1.5,
}} }}
> >
50% 50%
@ -99,12 +99,9 @@ const ParamControlNetBeginEnd = (props: Props) => {
<RangeSliderMark <RangeSliderMark
value={1} value={1}
sx={{ sx={{
fontSize: 'xs',
fontWeight: '500',
color: 'base.200',
insetInlineStart: 'unset !important', insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important', insetInlineEnd: '0 !important',
mt: 1.5, ...SLIDER_MARK_STYLES,
}} }}
> >
100% 100%
@ -112,16 +109,6 @@ const ParamControlNetBeginEnd = (props: Props) => {
</> </>
)} )}
</RangeSlider> </RangeSlider>
{!mini && (
<IAIIconButton
size="sm"
aria-label={t('accessibility.reset')}
tooltip={t('accessibility.reset')}
icon={<BiReset />}
onClick={handleStepPctReset}
/>
)}
</HStack> </HStack>
</FormControl> </FormControl>
); );

View File

@ -1,41 +1,85 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit';
import IAICustomSelect from 'common/components/IAICustomSelect'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAICustomSelect, {
IAICustomSelectOption,
} from 'common/components/IAICustomSelect';
import IAISelect from 'common/components/IAISelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { import {
CONTROLNET_MODELS, CONTROLNET_MODELS,
ControlNetModel, ControlNetModelName,
} from 'features/controlNet/store/constants'; } from 'features/controlNet/store/constants';
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { configSelector } from 'features/system/store/configSelectors';
import { map } from 'lodash-es';
import { ChangeEvent, memo, useCallback } from 'react';
type ParamControlNetModelProps = { type ParamControlNetModelProps = {
controlNetId: string; controlNetId: string;
model: ControlNetModel; model: ControlNetModelName;
}; };
const selector = createSelector(configSelector, (config) => {
return map(CONTROLNET_MODELS, (m) => ({
key: m.label,
value: m.type,
})).filter((d) => !config.sd.disabledControlNetModels.includes(d.value));
});
// const DATA: IAICustomSelectOption[] = map(CONTROLNET_MODELS, (m) => ({
// value: m.type,
// label: m.label,
// tooltip: m.type,
// }));
const ParamControlNetModel = (props: ParamControlNetModelProps) => { const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId, model } = props; const { controlNetId, model } = props;
const controlNetModels = useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const handleModelChanged = useCallback( const handleModelChanged = useCallback(
(val: string | null | undefined) => { (e: ChangeEvent<HTMLSelectElement>) => {
// TODO: do not cast // TODO: do not cast
const model = val as ControlNetModel; const model = e.target.value as ControlNetModelName;
dispatch(controlNetModelChanged({ controlNetId, model })); dispatch(controlNetModelChanged({ controlNetId, model }));
}, },
[controlNetId, dispatch] [controlNetId, dispatch]
); );
// const handleModelChanged = useCallback(
// (val: string | null | undefined) => {
// // TODO: do not cast
// const model = val as ControlNetModelName;
// dispatch(controlNetModelChanged({ controlNetId, model }));
// },
// [controlNetId, dispatch]
// );
return ( return (
<IAICustomSelect <IAISelect
tooltip={model} tooltip={model}
tooltipProps={{ placement: 'top', hasArrow: true }} tooltipProps={{ placement: 'top', hasArrow: true }}
items={CONTROLNET_MODELS} validValues={controlNetModels}
selectedItem={model} value={model}
setSelectedItem={handleModelChanged} onChange={handleModelChanged}
ellipsisPosition="start" isDisabled={!isReady}
withCheckIcon // ellipsisPosition="start"
// withCheckIcon
/> />
); );
// return (
// <IAICustomSelect
// tooltip={model}
// tooltipProps={{ placement: 'top', hasArrow: true }}
// data={DATA}
// value={model}
// onChange={handleModelChanged}
// isDisabled={!isReady}
// ellipsisPosition="start"
// withCheckIcon
// />
// );
}; };
export default memo(ParamControlNetModel); export default memo(ParamControlNetModel);

View File

@ -1,47 +1,115 @@
import IAICustomSelect from 'common/components/IAICustomSelect'; import IAICustomSelect, {
import { memo, useCallback } from 'react'; IAICustomSelectOption,
} from 'common/components/IAICustomSelect';
import { ChangeEvent, memo, useCallback } from 'react';
import { import {
ControlNetProcessorNode, ControlNetProcessorNode,
ControlNetProcessorType, ControlNetProcessorType,
} from '../../store/types'; } from '../../store/types';
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice'; import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { CONTROLNET_PROCESSORS } from '../../store/constants'; import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { map } from 'lodash-es';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import IAISelect from 'common/components/IAISelect';
import { createSelector } from '@reduxjs/toolkit';
import { configSelector } from 'features/system/store/configSelectors';
type ParamControlNetProcessorSelectProps = { type ParamControlNetProcessorSelectProps = {
controlNetId: string; controlNetId: string;
processorNode: ControlNetProcessorNode; processorNode: ControlNetProcessorNode;
}; };
const CONTROLNET_PROCESSOR_TYPES = Object.keys( const CONTROLNET_PROCESSOR_TYPES = map(CONTROLNET_PROCESSORS, (p) => ({
CONTROLNET_PROCESSORS value: p.type,
) as ControlNetProcessorType[]; key: p.label,
})).sort((a, b) =>
// sort 'none' to the top
a.value === 'none' ? -1 : b.value === 'none' ? 1 : a.key.localeCompare(b.key)
);
const selector = createSelector(configSelector, (config) => {
return map(CONTROLNET_PROCESSORS, (p) => ({
value: p.type,
key: p.label,
}))
.sort((a, b) =>
// sort 'none' to the top
a.value === 'none'
? -1
: b.value === 'none'
? 1
: a.key.localeCompare(b.key)
)
.filter((d) => !config.sd.disabledControlNetProcessors.includes(d.value));
});
// const CONTROLNET_PROCESSOR_TYPES: IAICustomSelectOption[] = map(
// CONTROLNET_PROCESSORS,
// (p) => ({
// value: p.type,
// label: p.label,
// tooltip: p.description,
// })
// ).sort((a, b) =>
// // sort 'none' to the top
// a.value === 'none'
// ? -1
// : b.value === 'none'
// ? 1
// : a.label.localeCompare(b.label)
// );
const ParamControlNetProcessorSelect = ( const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps props: ParamControlNetProcessorSelectProps
) => { ) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke();
const controlNetProcessors = useAppSelector(selector);
const handleProcessorTypeChanged = useCallback( const handleProcessorTypeChanged = useCallback(
(v: string | null | undefined) => { (e: ChangeEvent<HTMLSelectElement>) => {
dispatch( dispatch(
controlNetProcessorTypeChanged({ controlNetProcessorTypeChanged({
controlNetId, controlNetId,
processorType: v as ControlNetProcessorType, processorType: e.target.value as ControlNetProcessorType,
}) })
); );
}, },
[controlNetId, dispatch] [controlNetId, dispatch]
); );
// const handleProcessorTypeChanged = useCallback(
// (v: string | null | undefined) => {
// dispatch(
// controlNetProcessorTypeChanged({
// controlNetId,
// processorType: v as ControlNetProcessorType,
// })
// );
// },
// [controlNetId, dispatch]
// );
return ( return (
<IAICustomSelect <IAISelect
label="Processor" label="Processor"
items={CONTROLNET_PROCESSOR_TYPES} value={processorNode.type ?? 'canny_image_processor'}
selectedItem={processorNode.type ?? 'canny_image_processor'} validValues={controlNetProcessors}
setSelectedItem={handleProcessorTypeChanged} onChange={handleProcessorTypeChanged}
withCheckIcon isDisabled={!isReady}
/> />
); );
// return (
// <IAICustomSelect
// label="Processor"
// value={processorNode.type ?? 'canny_image_processor'}
// data={CONTROLNET_PROCESSOR_TYPES}
// onChange={handleProcessorTypeChanged}
// withCheckIcon
// isDisabled={!isReady}
// />
// );
}; };
export default memo(ParamControlNetProcessorSelect); export default memo(ParamControlNetProcessorSelect);

View File

@ -20,36 +20,17 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
[controlNetId, dispatch] [controlNetId, dispatch]
); );
const handleWeightReset = () => {
dispatch(controlNetWeightChanged({ controlNetId, weight: 1 }));
};
if (mini) {
return (
<IAISlider
label={'Weight'}
sliderFormLabelProps={{ pb: 1 }}
value={weight}
onChange={handleWeightChanged}
min={0}
max={1}
step={0.01}
/>
);
}
return ( return (
<IAISlider <IAISlider
label="Weight" label={'Weight'}
sliderFormLabelProps={{ pb: 2 }}
value={weight} value={weight}
onChange={handleWeightChanged} onChange={handleWeightChanged}
withInput min={-1}
withReset
handleReset={handleWeightReset}
withSliderMarks
min={0}
max={1} max={1}
step={0.01} step={0.01}
withSliderMarks={!mini}
sliderMarks={[-1, 0, 1]}
/> />
); );
}; };

View File

@ -4,6 +4,7 @@ import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default;
@ -15,6 +16,7 @@ type CannyProcessorProps = {
const CannyProcessor = (props: CannyProcessorProps) => { const CannyProcessor = (props: CannyProcessorProps) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { low_threshold, high_threshold } = processorNode; const { low_threshold, high_threshold } = processorNode;
const isReady = useIsReadyToInvoke();
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const handleLowThresholdChanged = useCallback( const handleLowThresholdChanged = useCallback(
@ -46,6 +48,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
return ( return (
<ProcessorWrapper> <ProcessorWrapper>
<IAISlider <IAISlider
isDisabled={!isReady}
label="Low Threshold" label="Low Threshold"
value={low_threshold} value={low_threshold}
onChange={handleLowThresholdChanged} onChange={handleLowThresholdChanged}
@ -54,8 +57,10 @@ const CannyProcessor = (props: CannyProcessorProps) => {
min={0} min={0}
max={255} max={255}
withInput withInput
withSliderMarks
/> />
<IAISlider <IAISlider
isDisabled={!isReady}
label="High Threshold" label="High Threshold"
value={high_threshold} value={high_threshold}
onChange={handleHighThresholdChanged} onChange={handleHighThresholdChanged}
@ -64,6 +69,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
min={0} min={0}
max={255} max={255}
withInput withInput
withSliderMarks
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -4,6 +4,7 @@ import { RequiredContentShuffleImageProcessorInvocation } from 'features/control
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default;
@ -16,6 +17,7 @@ const ContentShuffleProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, w, h, f } = processorNode; const { image_resolution, detect_resolution, w, h, f } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -93,6 +95,8 @@ const ContentShuffleProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -103,6 +107,8 @@ const ContentShuffleProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="W" label="W"
@ -113,6 +119,8 @@ const ContentShuffleProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="H" label="H"
@ -123,6 +131,8 @@ const ContentShuffleProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="F" label="F"
@ -133,6 +143,8 @@ const ContentShuffleProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -5,6 +5,7 @@ import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/t
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default;
@ -18,7 +19,7 @@ const HedPreprocessor = (props: HedProcessorProps) => {
controlNetId, controlNetId,
processorNode: { detect_resolution, image_resolution, scribble }, processorNode: { detect_resolution, image_resolution, scribble },
} = props; } = props;
const isReady = useIsReadyToInvoke();
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
@ -65,6 +66,8 @@ const HedPreprocessor = (props: HedProcessorProps) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -75,11 +78,14 @@ const HedPreprocessor = (props: HedProcessorProps) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISwitch <IAISwitch
label="Scribble" label="Scribble"
isChecked={scribble} isChecked={scribble}
onChange={handleScribbleChanged} onChange={handleScribbleChanged}
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -4,6 +4,7 @@ import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNe
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default;
@ -16,6 +17,7 @@ const LineartAnimeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution } = processorNode; const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -54,6 +56,8 @@ const LineartAnimeProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -64,6 +68,8 @@ const LineartAnimeProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -5,6 +5,7 @@ import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/sto
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default;
@ -17,6 +18,7 @@ const LineartProcessor = (props: LineartProcessorProps) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, coarse } = processorNode; const { image_resolution, detect_resolution, coarse } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -62,6 +64,8 @@ const LineartProcessor = (props: LineartProcessorProps) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -72,11 +76,14 @@ const LineartProcessor = (props: LineartProcessorProps) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISwitch <IAISwitch
label="Coarse" label="Coarse"
isChecked={coarse} isChecked={coarse}
onChange={handleCoarseChanged} onChange={handleCoarseChanged}
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -4,6 +4,7 @@ import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/st
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default;
@ -16,6 +17,7 @@ const MediapipeFaceProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { max_faces, min_confidence } = processorNode; const { max_faces, min_confidence } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleMaxFacesChanged = useCallback( const handleMaxFacesChanged = useCallback(
(v: number) => { (v: number) => {
@ -50,6 +52,8 @@ const MediapipeFaceProcessor = (props: Props) => {
min={1} min={1}
max={20} max={20}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Min Confidence" label="Min Confidence"
@ -61,6 +65,8 @@ const MediapipeFaceProcessor = (props: Props) => {
max={1} max={1}
step={0.01} step={0.01}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -4,6 +4,7 @@ import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default;
@ -16,6 +17,7 @@ const MidasDepthProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { a_mult, bg_th } = processorNode; const { a_mult, bg_th } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleAMultChanged = useCallback( const handleAMultChanged = useCallback(
(v: number) => { (v: number) => {
@ -51,6 +53,8 @@ const MidasDepthProcessor = (props: Props) => {
max={20} max={20}
step={0.01} step={0.01}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="bg_th" label="bg_th"
@ -62,6 +66,8 @@ const MidasDepthProcessor = (props: Props) => {
max={20} max={20}
step={0.01} step={0.01}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -4,6 +4,7 @@ import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default;
@ -16,6 +17,7 @@ const MlsdImageProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode; const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -76,6 +78,8 @@ const MlsdImageProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -86,6 +90,8 @@ const MlsdImageProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="W" label="W"
@ -97,6 +103,8 @@ const MlsdImageProcessor = (props: Props) => {
max={1} max={1}
step={0.01} step={0.01}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="H" label="H"
@ -108,6 +116,8 @@ const MlsdImageProcessor = (props: Props) => {
max={1} max={1}
step={0.01} step={0.01}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -4,6 +4,7 @@ import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/s
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default;
@ -16,6 +17,7 @@ const NormalBaeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution } = processorNode; const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -54,6 +56,8 @@ const NormalBaeProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -64,6 +68,8 @@ const NormalBaeProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -5,6 +5,7 @@ import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/st
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default;
@ -17,6 +18,7 @@ const OpenposeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, hand_and_face } = processorNode; const { image_resolution, detect_resolution, hand_and_face } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -62,6 +64,8 @@ const OpenposeProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -72,11 +76,14 @@ const OpenposeProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISwitch <IAISwitch
label="Hand and Face" label="Hand and Face"
isChecked={hand_and_face} isChecked={hand_and_face}
onChange={handleHandAndFaceChanged} onChange={handleHandAndFaceChanged}
isDisabled={!isReady}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -5,6 +5,7 @@ import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default;
@ -17,6 +18,7 @@ const PidiProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode } = props;
const { image_resolution, detect_resolution, scribble, safe } = processorNode; const { image_resolution, detect_resolution, scribble, safe } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -69,6 +71,8 @@ const PidiProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -79,13 +83,20 @@ const PidiProcessor = (props: Props) => {
min={0} min={0}
max={4096} max={4096}
withInput withInput
withSliderMarks
isDisabled={!isReady}
/> />
<IAISwitch <IAISwitch
label="Scribble" label="Scribble"
isChecked={scribble} isChecked={scribble}
onChange={handleScribbleChanged} onChange={handleScribbleChanged}
/> />
<IAISwitch label="Safe" isChecked={safe} onChange={handleSafeChanged} /> <IAISwitch
label="Safe"
isChecked={safe}
onChange={handleSafeChanged}
isDisabled={!isReady}
/>
</ProcessorWrapper> </ProcessorWrapper>
); );
}; };

View File

@ -5,12 +5,12 @@ import {
} from './types'; } from './types';
type ControlNetProcessorsDict = Record< type ControlNetProcessorsDict = Record<
ControlNetProcessorType, string,
{ {
type: ControlNetProcessorType; type: ControlNetProcessorType | 'none';
label: string; label: string;
description: string; description: string;
default: RequiredControlNetProcessorNode; default: RequiredControlNetProcessorNode | { type: 'none' };
} }
>; >;
@ -26,7 +26,7 @@ type ControlNetProcessorsDict = Record<
export const CONTROLNET_PROCESSORS = { export const CONTROLNET_PROCESSORS = {
none: { none: {
type: 'none', type: 'none',
label: 'None', label: 'none',
description: '', description: '',
default: { default: {
type: 'none', type: 'none',
@ -116,7 +116,7 @@ export const CONTROLNET_PROCESSORS = {
}, },
mlsd_image_processor: { mlsd_image_processor: {
type: 'mlsd_image_processor', type: 'mlsd_image_processor',
label: 'MLSD', label: 'M-LSD',
description: '', description: '',
default: { default: {
id: 'mlsd_image_processor', id: 'mlsd_image_processor',
@ -129,7 +129,7 @@ export const CONTROLNET_PROCESSORS = {
}, },
normalbae_image_processor: { normalbae_image_processor: {
type: 'normalbae_image_processor', type: 'normalbae_image_processor',
label: 'NormalBae', label: 'Normal BAE',
description: '', description: '',
default: { default: {
id: 'normalbae_image_processor', id: 'normalbae_image_processor',
@ -174,39 +174,84 @@ export const CONTROLNET_PROCESSORS = {
}, },
}; };
export const CONTROLNET_MODELS = [ type ControlNetModel = {
'lllyasviel/control_v11p_sd15_canny', type: string;
'lllyasviel/control_v11p_sd15_inpaint', label: string;
'lllyasviel/control_v11p_sd15_mlsd', description?: string;
'lllyasviel/control_v11f1p_sd15_depth', defaultProcessor?: ControlNetProcessorType;
'lllyasviel/control_v11p_sd15_normalbae',
'lllyasviel/control_v11p_sd15_seg',
'lllyasviel/control_v11p_sd15_lineart',
'lllyasviel/control_v11p_sd15s2_lineart_anime',
'lllyasviel/control_v11p_sd15_scribble',
'lllyasviel/control_v11p_sd15_softedge',
'lllyasviel/control_v11e_sd15_shuffle',
'lllyasviel/control_v11p_sd15_openpose',
'lllyasviel/control_v11f1e_sd15_tile',
'lllyasviel/control_v11e_sd15_ip2p',
'CrucibleAI/ControlNetMediaPipeFace',
];
export type ControlNetModel = (typeof CONTROLNET_MODELS)[number];
export const CONTROLNET_MODEL_MAP: Record<
ControlNetModel,
ControlNetProcessorType
> = {
'lllyasviel/control_v11p_sd15_canny': 'canny_image_processor',
'lllyasviel/control_v11p_sd15_mlsd': 'mlsd_image_processor',
'lllyasviel/control_v11f1p_sd15_depth': 'midas_depth_image_processor',
'lllyasviel/control_v11p_sd15_normalbae': 'normalbae_image_processor',
'lllyasviel/control_v11p_sd15_lineart': 'lineart_image_processor',
'lllyasviel/control_v11p_sd15s2_lineart_anime':
'lineart_anime_image_processor',
'lllyasviel/control_v11p_sd15_softedge': 'hed_image_processor',
'lllyasviel/control_v11e_sd15_shuffle': 'content_shuffle_image_processor',
'lllyasviel/control_v11p_sd15_openpose': 'openpose_image_processor',
'CrucibleAI/ControlNetMediaPipeFace': 'mediapipe_face_processor',
}; };
export const CONTROLNET_MODELS = {
'lllyasviel/control_v11p_sd15_canny': {
type: 'lllyasviel/control_v11p_sd15_canny',
label: 'Canny',
defaultProcessor: 'canny_image_processor',
},
'lllyasviel/control_v11p_sd15_inpaint': {
type: 'lllyasviel/control_v11p_sd15_inpaint',
label: 'Inpaint',
},
'lllyasviel/control_v11p_sd15_mlsd': {
type: 'lllyasviel/control_v11p_sd15_mlsd',
label: 'M-LSD',
defaultProcessor: 'mlsd_image_processor',
},
'lllyasviel/control_v11f1p_sd15_depth': {
type: 'lllyasviel/control_v11f1p_sd15_depth',
label: 'Depth',
defaultProcessor: 'midas_depth_image_processor',
},
'lllyasviel/control_v11p_sd15_normalbae': {
type: 'lllyasviel/control_v11p_sd15_normalbae',
label: 'Normal Map (BAE)',
defaultProcessor: 'normalbae_image_processor',
},
'lllyasviel/control_v11p_sd15_seg': {
type: 'lllyasviel/control_v11p_sd15_seg',
label: 'Segmentation',
},
'lllyasviel/control_v11p_sd15_lineart': {
type: 'lllyasviel/control_v11p_sd15_lineart',
label: 'Lineart',
defaultProcessor: 'lineart_image_processor',
},
'lllyasviel/control_v11p_sd15s2_lineart_anime': {
type: 'lllyasviel/control_v11p_sd15s2_lineart_anime',
label: 'Lineart Anime',
defaultProcessor: 'lineart_anime_image_processor',
},
'lllyasviel/control_v11p_sd15_scribble': {
type: 'lllyasviel/control_v11p_sd15_scribble',
label: 'Scribble',
},
'lllyasviel/control_v11p_sd15_softedge': {
type: 'lllyasviel/control_v11p_sd15_softedge',
label: 'Soft Edge',
defaultProcessor: 'hed_image_processor',
},
'lllyasviel/control_v11e_sd15_shuffle': {
type: 'lllyasviel/control_v11e_sd15_shuffle',
label: 'Content Shuffle',
defaultProcessor: 'content_shuffle_image_processor',
},
'lllyasviel/control_v11p_sd15_openpose': {
type: 'lllyasviel/control_v11p_sd15_openpose',
label: 'Openpose',
defaultProcessor: 'openpose_image_processor',
},
'lllyasviel/control_v11f1e_sd15_tile': {
type: 'lllyasviel/control_v11f1e_sd15_tile',
label: 'Tile (experimental)',
},
'lllyasviel/control_v11e_sd15_ip2p': {
type: 'lllyasviel/control_v11e_sd15_ip2p',
label: 'Pix2Pix (experimental)',
},
'CrucibleAI/ControlNetMediaPipeFace': {
type: 'CrucibleAI/ControlNetMediaPipeFace',
label: 'Mediapipe Face',
defaultProcessor: 'mediapipe_face_processor',
},
};
export type ControlNetModelName = keyof typeof CONTROLNET_MODELS;

View File

@ -9,9 +9,8 @@ import {
} from './types'; } from './types';
import { import {
CONTROLNET_MODELS, CONTROLNET_MODELS,
CONTROLNET_MODEL_MAP,
CONTROLNET_PROCESSORS, CONTROLNET_PROCESSORS,
ControlNetModel, ControlNetModelName,
} from './constants'; } from './constants';
import { controlNetImageProcessed } from './actions'; import { controlNetImageProcessed } from './actions';
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image'; import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
@ -21,7 +20,7 @@ import { appSocketInvocationError } from 'services/events/actions';
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = { export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true, isEnabled: true,
model: CONTROLNET_MODELS[0], model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
weight: 1, weight: 1,
beginStepPct: 0, beginStepPct: 0,
endStepPct: 1, endStepPct: 1,
@ -36,7 +35,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
export type ControlNetConfig = { export type ControlNetConfig = {
controlNetId: string; controlNetId: string;
isEnabled: boolean; isEnabled: boolean;
model: ControlNetModel; model: ControlNetModelName;
weight: number; weight: number;
beginStepPct: number; beginStepPct: number;
endStepPct: number; endStepPct: number;
@ -138,14 +137,17 @@ export const controlNetSlice = createSlice({
}, },
controlNetModelChanged: ( controlNetModelChanged: (
state, state,
action: PayloadAction<{ controlNetId: string; model: ControlNetModel }> action: PayloadAction<{
controlNetId: string;
model: ControlNetModelName;
}>
) => { ) => {
const { controlNetId, model } = action.payload; const { controlNetId, model } = action.payload;
state.controlNets[controlNetId].model = model; state.controlNets[controlNetId].model = model;
state.controlNets[controlNetId].processedControlImage = null; state.controlNets[controlNetId].processedControlImage = null;
if (state.controlNets[controlNetId].shouldAutoConfig) { if (state.controlNets[controlNetId].shouldAutoConfig) {
const processorType = CONTROLNET_MODEL_MAP[model]; const processorType = CONTROLNET_MODELS[model].defaultProcessor;
if (processorType) { if (processorType) {
state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
@ -225,7 +227,8 @@ export const controlNetSlice = createSlice({
if (newShouldAutoConfig) { if (newShouldAutoConfig) {
// manage the processor for the user // manage the processor for the user
const processorType = const processorType =
CONTROLNET_MODEL_MAP[state.controlNets[controlNetId].model]; CONTROLNET_MODELS[state.controlNets[controlNetId].model]
.defaultProcessor;
if (processorType) { if (processorType) {
state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[

View File

@ -18,6 +18,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
ColorField: 'color', ColorField: 'color',
ControlField: 'control', ControlField: 'control',
control: 'control', control: 'control',
cfg_scale: 'float',
control_weight: 'float',
}; };
const COLOR_TOKEN_VALUE = 500; const COLOR_TOKEN_VALUE = 500;

View File

@ -1,5 +1,5 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { forEach, size } from 'lodash-es'; import { filter, forEach, size } from 'lodash-es';
import { CollectInvocation, ControlNetInvocation } from 'services/api'; import { CollectInvocation, ControlNetInvocation } from 'services/api';
import { NonNullableGraph } from '../types/types'; import { NonNullableGraph } from '../types/types';
@ -12,8 +12,16 @@ export const addControlNetToLinearGraph = (
): void => { ): void => {
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet; const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
const validControlNets = filter(
controlNets,
(c) =>
c.isEnabled &&
(Boolean(c.processedControlImage) ||
(c.processorType === 'none' && Boolean(c.controlImage)))
);
// Add ControlNet // Add ControlNet
if (isControlNetEnabled) { if (isControlNetEnabled && validControlNets.length > 0) {
if (size(controlNets) > 1) { if (size(controlNets) > 1) {
const controlNetIterateNode: CollectInvocation = { const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT, id: CONTROL_NET_COLLECT,

View File

@ -3,10 +3,11 @@ import { Scheduler } from 'app/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICustomSelect from 'common/components/IAICustomSelect'; import IAICustomSelect from 'common/components/IAICustomSelect';
import IAISelect from 'common/components/IAISelect';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setScheduler } from 'features/parameters/store/generationSlice'; import { setScheduler } from 'features/parameters/store/generationSlice';
import { uiSelector } from 'features/ui/store/uiSelectors'; import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const selector = createSelector( const selector = createSelector(
@ -14,9 +15,11 @@ const selector = createSelector(
(ui, generation) => { (ui, generation) => {
// TODO: DPMSolverSinglestepScheduler is fixed in https://github.com/huggingface/diffusers/pull/3413 // TODO: DPMSolverSinglestepScheduler is fixed in https://github.com/huggingface/diffusers/pull/3413
// but we need to wait for the next release before removing this special handling. // but we need to wait for the next release before removing this special handling.
const allSchedulers = ui.schedulers.filter((scheduler) => { const allSchedulers = ui.schedulers
return !['dpmpp_2s'].includes(scheduler); .filter((scheduler) => {
}); return !['dpmpp_2s'].includes(scheduler);
})
.sort((a, b) => a.localeCompare(b));
return { return {
scheduler: generation.scheduler, scheduler: generation.scheduler,
@ -33,24 +36,39 @@ const ParamScheduler = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const handleChange = useCallback( const handleChange = useCallback(
(v: string | null | undefined) => { (e: ChangeEvent<HTMLSelectElement>) => {
if (!v) { dispatch(setScheduler(e.target.value as Scheduler));
return;
}
dispatch(setScheduler(v as Scheduler));
}, },
[dispatch] [dispatch]
); );
// const handleChange = useCallback(
// (v: string | null | undefined) => {
// if (!v) {
// return;
// }
// dispatch(setScheduler(v as Scheduler));
// },
// [dispatch]
// );
return ( return (
<IAICustomSelect <IAISelect
label={t('parameters.scheduler')} label={t('parameters.scheduler')}
selectedItem={scheduler} value={scheduler}
setSelectedItem={handleChange} validValues={allSchedulers}
items={allSchedulers} onChange={handleChange}
withCheckIcon
/> />
); );
// return (
// <IAICustomSelect
// label={t('parameters.scheduler')}
// value={scheduler}
// data={allSchedulers}
// onChange={handleChange}
// withCheckIcon
// />
// );
}; };
export default memo(ParamScheduler); export default memo(ParamScheduler);

View File

@ -1,41 +0,0 @@
// import { emptyTempFolder } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
import {
clearCanvasHistory,
resetCanvas,
} from 'features/canvas/store/canvasSlice';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
const EmptyTempFolderButtonModal = () => {
const isStaging = useAppSelector(isStagingSelector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const acceptCallback = () => {
dispatch(emptyTempFolder());
dispatch(resetCanvas());
dispatch(clearCanvasHistory());
};
return (
<IAIAlertDialog
title={t('unifiedCanvas.emptyTempImageFolder')}
acceptCallback={acceptCallback}
acceptButtonText={t('unifiedCanvas.emptyFolder')}
triggerComponent={
<IAIButton leftIcon={<FaTrash />} size="sm" isDisabled={isStaging}>
{t('unifiedCanvas.emptyTempImageFolder')}
</IAIButton>
}
>
<p>{t('unifiedCanvas.emptyTempImagesFolderMessage')}</p>
<br />
<p>{t('unifiedCanvas.emptyTempImagesFolderConfirm')}</p>
</IAIAlertDialog>
);
};
export default EmptyTempFolderButtonModal;

View File

@ -1,37 +1,39 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { import { selectModelsAll, selectModelsById } from '../store/modelSlice';
selectModelsAll,
selectModelsById,
selectModelsIds,
} from '../store/modelSlice';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { modelSelected } from 'features/parameters/store/generationSlice'; import { modelSelected } from 'features/parameters/store/generationSlice';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import IAICustomSelect, { import IAICustomSelect, {
ItemTooltips, IAICustomSelectOption,
} from 'common/components/IAICustomSelect'; } from 'common/components/IAICustomSelect';
import IAISelect from 'common/components/IAISelect';
const selector = createSelector( const selector = createSelector(
[(state: RootState) => state, generationSelector], [(state: RootState) => state, generationSelector],
(state, generation) => { (state, generation) => {
const selectedModel = selectModelsById(state, generation.model); const selectedModel = selectModelsById(state, generation.model);
const allModelNames = selectModelsIds(state).map((id) => String(id));
const allModelTooltips = selectModelsAll(state).reduce( const modelData = selectModelsAll(state)
(allModelTooltips, model) => { .map((m) => ({
allModelTooltips[model.name] = model.description ?? ''; value: m.name,
return allModelTooltips; key: m.name,
}, }))
{} as ItemTooltips .sort((a, b) => a.key.localeCompare(b.key));
); // const modelData = selectModelsAll(state)
// .map<IAICustomSelectOption>((m) => ({
// value: m.name,
// label: m.name,
// tooltip: m.description,
// }))
// .sort((a, b) => a.label.localeCompare(b.label));
return { return {
allModelNames,
allModelTooltips,
selectedModel, selectedModel,
modelData,
}; };
}, },
{ {
@ -44,30 +46,45 @@ const selector = createSelector(
const ModelSelect = () => { const ModelSelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { allModelNames, allModelTooltips, selectedModel } = const { selectedModel, modelData } = useAppSelector(selector);
useAppSelector(selector);
const handleChangeModel = useCallback( const handleChangeModel = useCallback(
(v: string | null | undefined) => { (e: ChangeEvent<HTMLSelectElement>) => {
if (!v) { dispatch(modelSelected(e.target.value));
return;
}
dispatch(modelSelected(v));
}, },
[dispatch] [dispatch]
); );
// const handleChangeModel = useCallback(
// (v: string | null | undefined) => {
// if (!v) {
// return;
// }
// dispatch(modelSelected(v));
// },
// [dispatch]
// );
return ( return (
<IAICustomSelect <IAISelect
label={t('modelManager.model')} label={t('modelManager.model')}
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
items={allModelNames} validValues={modelData}
itemTooltips={allModelTooltips} value={selectedModel?.name ?? ''}
selectedItem={selectedModel?.name ?? ''} onChange={handleChangeModel}
setSelectedItem={handleChangeModel}
withCheckIcon={true}
tooltipProps={{ placement: 'top', hasArrow: true }} tooltipProps={{ placement: 'top', hasArrow: true }}
/> />
); );
// return (
// <IAICustomSelect
// label={t('modelManager.model')}
// tooltip={selectedModel?.description}
// data={modelData}
// value={selectedModel?.name ?? ''}
// onChange={handleChangeModel}
// withCheckIcon={true}
// tooltipProps={{ placement: 'top', hasArrow: true }}
// />
// );
}; };
export default memo(ModelSelect); export default memo(ModelSelect);

View File

@ -10,6 +10,8 @@ export const initialConfigState: AppConfig = {
disabledSDFeatures: [], disabledSDFeatures: [],
canRestoreDeletedImagesFromBin: true, canRestoreDeletedImagesFromBin: true,
sd: { sd: {
disabledControlNetModels: [],
disabledControlNetProcessors: [],
iterations: { iterations: {
initial: 1, initial: 1,
min: 1, min: 1,

View File

@ -47,3 +47,6 @@ export const languageSelector = createSelector(
(system) => system.language, (system) => system.language,
defaultSelectorOptions defaultSelectorOptions
); );
export const isProcessingSelector = (state: RootState) =>
state.system.isProcessing;

View File

@ -14,7 +14,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice'; import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName } from 'features/ui/store/tabMap';
import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice'; import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
import { memo, ReactNode, useCallback, useMemo } from 'react'; import { memo, MouseEvent, ReactNode, useCallback, useMemo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { MdDeviceHub, MdGridOn } from 'react-icons/md'; import { MdDeviceHub, MdGridOn } from 'react-icons/md';
import { GoTextSize } from 'react-icons/go'; import { GoTextSize } from 'react-icons/go';
@ -47,22 +47,22 @@ export interface InvokeTabInfo {
const tabs: InvokeTabInfo[] = [ const tabs: InvokeTabInfo[] = [
{ {
id: 'txt2img', id: 'txt2img',
icon: <Icon as={GoTextSize} sx={{ boxSize: 6 }} />, icon: <Icon as={GoTextSize} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
content: <TextToImageTab />, content: <TextToImageTab />,
}, },
{ {
id: 'img2img', id: 'img2img',
icon: <Icon as={FaImage} sx={{ boxSize: 6 }} />, icon: <Icon as={FaImage} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
content: <ImageTab />, content: <ImageTab />,
}, },
{ {
id: 'unifiedCanvas', id: 'unifiedCanvas',
icon: <Icon as={MdGridOn} sx={{ boxSize: 6 }} />, icon: <Icon as={MdGridOn} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
content: <UnifiedCanvasTab />, content: <UnifiedCanvasTab />,
}, },
{ {
id: 'nodes', id: 'nodes',
icon: <Icon as={MdDeviceHub} sx={{ boxSize: 6 }} />, icon: <Icon as={MdDeviceHub} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
content: <NodesTab />, content: <NodesTab />,
}, },
]; ];
@ -119,6 +119,12 @@ const InvokeTabs = () => {
} }
}, [dispatch, activeTabName]); }, [dispatch, activeTabName]);
const handleClickTab = useCallback((e: MouseEvent<HTMLElement>) => {
if (e.target instanceof HTMLElement) {
e.target.blur();
}
}, []);
const tabs = useMemo( const tabs = useMemo(
() => () =>
enabledTabs.map((tab) => ( enabledTabs.map((tab) => (
@ -128,7 +134,7 @@ const InvokeTabs = () => {
label={String(t(`common.${tab.id}` as ResourceKey))} label={String(t(`common.${tab.id}` as ResourceKey))}
placement="end" placement="end"
> >
<Tab> <Tab onClick={handleClickTab}>
<VisuallyHidden> <VisuallyHidden>
{String(t(`common.${tab.id}` as ResourceKey))} {String(t(`common.${tab.id}` as ResourceKey))}
</VisuallyHidden> </VisuallyHidden>
@ -136,7 +142,7 @@ const InvokeTabs = () => {
</Tab> </Tab>
</Tooltip> </Tooltip>
)), )),
[t, enabledTabs] [enabledTabs, t, handleClickTab]
); );
const tabPanels = useMemo( const tabPanels = useMemo(

View File

@ -12,7 +12,6 @@ import {
setShouldShowCanvasDebugInfo, setShouldShowCanvasDebugInfo,
setShouldShowIntermediates, setShouldShowIntermediates,
} from 'features/canvas/store/canvasSlice'; } from 'features/canvas/store/canvasSlice';
import EmptyTempFolderButtonModal from 'features/system/components/ClearTempFolderButtonModal';
import { FaWrench } from 'react-icons/fa'; import { FaWrench } from 'react-icons/fa';
@ -105,7 +104,6 @@ const UnifiedCanvasSettings = () => {
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))} onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
/> />
<ClearCanvasHistoryButtonModal /> <ClearCanvasHistoryButtonModal />
<EmptyTempFolderButtonModal />
</Flex> </Flex>
</IAIPopover> </IAIPopover>
); );

View File

@ -55,8 +55,6 @@ const UnifiedCanvasContent = () => {
}); });
useLayoutEffect(() => { useLayoutEffect(() => {
dispatch(requestCanvasRescale());
const resizeCallback = () => { const resizeCallback = () => {
dispatch(requestCanvasRescale()); dispatch(requestCanvasRescale());
}; };

View File

@ -7,30 +7,26 @@ import type { ImageField } from './ImageField';
/** /**
* Applies HED edge detection to image * Applies HED edge detection to image
*/ */
export type HedImageProcessorInvocation = { export type HedImageprocessorInvocation = {
/** /**
* The id of this node. Must be unique among all nodes. * The id of this node. Must be unique among all nodes.
*/ */
id: string; id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'hed_image_processor'; type?: 'hed_image_processor';
/** /**
* The image to process * image to process
*/ */
image?: ImageField; image?: ImageField;
/** /**
* The pixel resolution for detection * pixel resolution for edge detection
*/ */
detect_resolution?: number; detect_resolution?: number;
/** /**
* The pixel resolution for the output image * pixel resolution for output image
*/ */
image_resolution?: number; image_resolution?: number;
/** /**
* Whether to use scribble mode * whether to use scribble mode
*/ */
scribble?: boolean; scribble?: boolean;
}; };

View File

@ -0,0 +1,33 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
/**
* Applies HED edge detection to image
*/
export type HedImageprocessorInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
type?: 'hed_image_processor';
/**
* image to process
*/
image?: ImageField;
/**
* pixel resolution for edge detection
*/
detect_resolution?: number;
/**
* pixel resolution for output image
*/
image_resolution?: number;
/**
* whether to use scribble mode
*/
scribble?: boolean;
};

View File

@ -30,7 +30,7 @@ const invokeAIMark = defineStyle((_props) => {
return { return {
fontSize: 'xs', fontSize: 'xs',
fontWeight: '500', fontWeight: '500',
color: 'base.200', color: 'base.400',
mt: 2, mt: 2,
insetInlineStart: 'unset', insetInlineStart: 'unset',
}; };

View File

@ -44,6 +44,7 @@ dependencies = [
"datasets", "datasets",
"diffusers[torch]~=0.17.0", "diffusers[torch]~=0.17.0",
"dnspython==2.2.1", "dnspython==2.2.1",
"easing-functions",
"einops", "einops",
"eventlet", "eventlet",
"facexlib", "facexlib",
@ -56,6 +57,7 @@ dependencies = [
"flaskwebgui==1.0.3", "flaskwebgui==1.0.3",
"gfpgan==1.3.8", "gfpgan==1.3.8",
"huggingface-hub>=0.11.1", "huggingface-hub>=0.11.1",
"matplotlib", # needed for plotting of Penner easing functions
"mediapipe", # needed for "mediapipeface" controlnet model "mediapipe", # needed for "mediapipeface" controlnet model
"npyscreen", "npyscreen",
"numpy<1.24", "numpy<1.24",

View File

@ -121,3 +121,78 @@ def test_graph_state_collects(mock_services):
assert isinstance(n6[0], CollectInvocation) assert isinstance(n6[0], CollectInvocation)
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts) assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)
def test_graph_state_prepares_eagerly(mock_services):
"""Tests that all prepareable nodes are prepared"""
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
graph.add_node(IterateInvocation(id="iterate"))
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
# separated, fully-preparable chain of nodes
graph.add_node(PromptTestInvocation(id="prompt_chain_1", prompt="Dinosaur sushi"))
graph.add_node(PromptTestInvocation(id="prompt_chain_2"))
graph.add_node(PromptTestInvocation(id="prompt_chain_3"))
graph.add_edge(create_edge("prompt_chain_1", "prompt", "prompt_chain_2", "prompt"))
graph.add_edge(create_edge("prompt_chain_2", "prompt", "prompt_chain_3", "prompt"))
g = GraphExecutionState(graph=graph)
g.next()
assert "prompt_collection" in g.source_prepared_mapping
assert "prompt_chain_1" in g.source_prepared_mapping
assert "prompt_chain_2" in g.source_prepared_mapping
assert "prompt_chain_3" in g.source_prepared_mapping
assert "iterate" not in g.source_prepared_mapping
assert "prompt_iterated" not in g.source_prepared_mapping
def test_graph_executes_depth_first(mock_services):
"""Tests that the graph executes depth-first, executing a branch as far as possible before moving to the next branch"""
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
graph.add_node(IterateInvocation(id="iterate"))
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
graph.add_node(PromptTestInvocation(id="prompt_successor"))
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
g = GraphExecutionState(graph=graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n3 = invoke_next(g, mock_services)
n4 = invoke_next(g, mock_services)
# Because ordering is not guaranteed, we cannot compare results directly.
# Instead, we must count the number of results.
def get_completed_count(g, id):
ids = [i for i in g.source_prepared_mapping[id]]
completed_ids = [i for i in g.executed if i in ids]
return len(completed_ids)
# Check at each step that the number of executed nodes matches the expectation for depth-first execution
assert get_completed_count(g, "prompt_iterated") == 1
assert get_completed_count(g, "prompt_successor") == 0
n5 = invoke_next(g, mock_services)
assert get_completed_count(g, "prompt_iterated") == 1
assert get_completed_count(g, "prompt_successor") == 1
n6 = invoke_next(g, mock_services)
assert get_completed_count(g, "prompt_iterated") == 2
assert get_completed_count(g, "prompt_successor") == 1
n7 = invoke_next(g, mock_services)
assert get_completed_count(g, "prompt_iterated") == 2
assert get_completed_count(g, "prompt_successor") == 2