mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into diffusers-upgrade
This commit is contained in:
commit
2a814d886b
@ -4,7 +4,6 @@ from inspect import signature
|
|||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||||
@ -15,15 +14,19 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydantic.schema import schema
|
from pydantic.schema import schema
|
||||||
|
|
||||||
|
#This should come early so that modules can log their initialization properly
|
||||||
|
from .services.config import InvokeAIAppConfig
|
||||||
|
from ..backend.util.logging import InvokeAILogger
|
||||||
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
app_config.parse_args()
|
||||||
|
logger = InvokeAILogger.getLogger(config=app_config)
|
||||||
|
|
||||||
import invokeai.frontend.web as web_dir
|
import invokeai.frontend.web as web_dir
|
||||||
|
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import sessions, models, images
|
from .api.routers import sessions, models, images
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.config import InvokeAIAppConfig
|
|
||||||
|
|
||||||
logger = InvokeAILogger.getLogger()
|
|
||||||
|
|
||||||
# Create the app
|
# Create the app
|
||||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||||
@ -41,11 +44,6 @@ app.add_middleware(
|
|||||||
|
|
||||||
socket_io = SocketIO(app)
|
socket_io = SocketIO(app)
|
||||||
|
|
||||||
# initialize config
|
|
||||||
# this is a module global
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
|
||||||
app_config.parse_args()
|
|
||||||
|
|
||||||
# Add startup event to load dependencies
|
# Add startup event to load dependencies
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
|
@ -13,14 +13,20 @@ from typing import (
|
|||||||
|
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
|
# This should come early so that the logger can pick up its configuration options
|
||||||
|
from .services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
config.parse_args()
|
||||||
|
logger = InvokeAILogger().getLogger(config=config)
|
||||||
|
|
||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
from invokeai.app.services.images import ImageService
|
from invokeai.app.services.images import ImageService
|
||||||
from invokeai.app.services.metadata import CoreMetadataService
|
from invokeai.app.services.metadata import CoreMetadataService
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from .services.default_graphs import create_system_graphs
|
from .services.default_graphs import create_system_graphs
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
@ -38,7 +44,7 @@ from .services.invocation_services import InvocationServices
|
|||||||
from .services.invoker import Invoker
|
from .services.invoker import Invoker
|
||||||
from .services.processor import DefaultInvocationProcessor
|
from .services.processor import DefaultInvocationProcessor
|
||||||
from .services.sqlite import SqliteItemStorage
|
from .services.sqlite import SqliteItemStorage
|
||||||
from .services.config import InvokeAIAppConfig
|
|
||||||
|
|
||||||
class CliCommand(BaseModel):
|
class CliCommand(BaseModel):
|
||||||
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||||
@ -47,7 +53,6 @@ class CliCommand(BaseModel):
|
|||||||
class InvalidArgs(Exception):
|
class InvalidArgs(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def add_invocation_args(command_parser):
|
def add_invocation_args(command_parser):
|
||||||
# Add linking capability
|
# Add linking capability
|
||||||
command_parser.add_argument(
|
command_parser.add_argument(
|
||||||
@ -191,14 +196,7 @@ def invoke_all(context: CliContext):
|
|||||||
|
|
||||||
raise SessionError()
|
raise SessionError()
|
||||||
|
|
||||||
|
|
||||||
logger = logger.InvokeAILogger.getLogger()
|
|
||||||
|
|
||||||
|
|
||||||
def invoke_cli():
|
def invoke_cli():
|
||||||
# this gets the basic configuration
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
|
||||||
config.parse_args()
|
|
||||||
|
|
||||||
# get the optional list of invocations to execute on the command line
|
# get the optional list of invocations to execute on the command line
|
||||||
parser = config.get_parser()
|
parser = config.get_parser()
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
# InvokeAI nodes for ControlNet image preprocessors
|
# InvokeAI nodes for ControlNet image preprocessors
|
||||||
# initial implementation by Gregg Helt, 2023
|
# initial implementation by Gregg Helt, 2023
|
||||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||||
|
from builtins import float
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Literal, Optional, Union, List
|
from typing import Literal, Optional, Union, List
|
||||||
from PIL import Image, ImageFilter, ImageOps
|
from PIL import Image, ImageFilter, ImageOps
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
from ..models.image import ImageField, ImageCategory, ResourceOrigin
|
from ..models.image import ImageField, ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
@ -14,6 +15,7 @@ from .baseinvocation import (
|
|||||||
InvocationContext,
|
InvocationContext,
|
||||||
InvocationConfig,
|
InvocationConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from controlnet_aux import (
|
from controlnet_aux import (
|
||||||
CannyDetector,
|
CannyDetector,
|
||||||
HEDdetector,
|
HEDdetector,
|
||||||
@ -96,15 +98,32 @@ CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
|||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
image: ImageField = Field(default=None, description="The control image")
|
image: ImageField = Field(default=None, description="The control image")
|
||||||
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
|
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
|
||||||
control_weight: Optional[float] = Field(default=1, description="The weight given to the ControlNet")
|
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||||
|
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||||
description="When the ControlNet is first applied (% of total steps)")
|
description="When the ControlNet is first applied (% of total steps)")
|
||||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||||
description="When the ControlNet is last applied (% of total steps)")
|
description="When the ControlNet is last applied (% of total steps)")
|
||||||
|
@validator("control_weight")
|
||||||
|
def abs_le_one(cls, v):
|
||||||
|
"""validate that all abs(values) are <=1"""
|
||||||
|
if isinstance(v, list):
|
||||||
|
for i in v:
|
||||||
|
if abs(i) > 1:
|
||||||
|
raise ValueError('all abs(control_weight) must be <= 1')
|
||||||
|
else:
|
||||||
|
if abs(v) > 1:
|
||||||
|
raise ValueError('abs(control_weight) must be <= 1')
|
||||||
|
return v
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"]
|
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
|
||||||
|
"ui": {
|
||||||
|
"type_hints": {
|
||||||
|
"control_weight": "float",
|
||||||
|
# "control_weight": "number",
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -112,7 +131,7 @@ class ControlOutput(BaseInvocationOutput):
|
|||||||
"""node output for ControlNet info"""
|
"""node output for ControlNet info"""
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["control_output"] = "control_output"
|
type: Literal["control_output"] = "control_output"
|
||||||
control: ControlField = Field(default=None, description="The output control image")
|
control: ControlField = Field(default=None, description="The control info")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
@ -123,15 +142,28 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The control image")
|
image: ImageField = Field(default=None, description="The control image")
|
||||||
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
||||||
description="The ControlNet model to use")
|
description="control model used")
|
||||||
control_weight: float = Field(default=1.0, ge=0, le=1, description="The weight given to the ControlNet")
|
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
||||||
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
|
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
|
||||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||||
description="When the ControlNet is first applied (% of total steps)")
|
description="When the ControlNet is first applied (% of total steps)")
|
||||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||||
description="When the ControlNet is last applied (% of total steps)")
|
description="When the ControlNet is last applied (% of total steps)")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"tags": ["latents"],
|
||||||
|
"type_hints": {
|
||||||
|
"model": "model",
|
||||||
|
"control": "control",
|
||||||
|
# "cfg_scale": "float",
|
||||||
|
"cfg_scale": "number",
|
||||||
|
"control_weight": "float",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||||
|
|
||||||
@ -161,7 +193,6 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
|
||||||
raw_image = context.services.images.get_pil_image(
|
raw_image = context.services.images.get_pil_image(
|
||||||
self.image.image_origin, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
@ -174,22 +174,36 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
control: Union[ControlField, List[ControlField]] = Field(default=None, description="The control to use")
|
||||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
@validator("cfg_scale")
|
||||||
|
def ge_one(cls, v):
|
||||||
|
"""validate that all cfg_scale values are >= 1"""
|
||||||
|
if isinstance(v, list):
|
||||||
|
for i in v:
|
||||||
|
if i < 1:
|
||||||
|
raise ValueError('cfg_scale must be greater than 1')
|
||||||
|
else:
|
||||||
|
if v < 1:
|
||||||
|
raise ValueError('cfg_scale must be greater than 1')
|
||||||
|
return v
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {
|
||||||
"tags": ["latents", "image"],
|
"tags": ["latents"],
|
||||||
"type_hints": {
|
"type_hints": {
|
||||||
"model": "model",
|
"model": "model",
|
||||||
"control": "control",
|
"control": "control",
|
||||||
|
# "cfg_scale": "float",
|
||||||
|
"cfg_scale": "number"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -244,10 +258,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
uc,
|
unconditioned_embeddings=uc,
|
||||||
c,
|
text_embeddings=c,
|
||||||
self.cfg_scale,
|
guidance_scale=self.cfg_scale,
|
||||||
extra_conditioning_info,
|
extra=extra_conditioning_info,
|
||||||
postprocessing_settings=PostprocessingSettings(
|
postprocessing_settings=PostprocessingSettings(
|
||||||
threshold=0.0,#threshold,
|
threshold=0.0,#threshold,
|
||||||
warmup=0.2,#warmup,
|
warmup=0.2,#warmup,
|
||||||
@ -348,7 +362,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
||||||
latents_shape=noise.shape,
|
latents_shape=noise.shape,
|
||||||
do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
|
do_classifier_free_guidance=True,)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||||
@ -385,6 +400,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
"type_hints": {
|
"type_hints": {
|
||||||
"model": "model",
|
"model": "model",
|
||||||
"control": "control",
|
"control": "control",
|
||||||
|
"cfg_scale": "number",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -403,10 +419,11 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
model = self.get_model(context.services.model_manager)
|
model = self.get_model(context.services.model_manager)
|
||||||
conditioning_data = self.get_conditioning_data(context, model)
|
conditioning_data = self.get_conditioning_data(context, model)
|
||||||
|
|
||||||
print("type of control input: ", type(self.control))
|
|
||||||
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
||||||
latents_shape=noise.shape,
|
latents_shape=noise.shape,
|
||||||
do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
|
|
||||||
|
237
invokeai/app/invocations/param_easing.py
Normal file
237
invokeai/app/invocations/param_easing.py
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
import io
|
||||||
|
from typing import Literal, Optional, Any
|
||||||
|
|
||||||
|
# from PIL.Image import Image
|
||||||
|
import PIL.Image
|
||||||
|
from matplotlib.ticker import MaxNLocator
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from easing_functions import (
|
||||||
|
LinearInOut,
|
||||||
|
QuadEaseInOut, QuadEaseIn, QuadEaseOut,
|
||||||
|
CubicEaseInOut, CubicEaseIn, CubicEaseOut,
|
||||||
|
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
|
||||||
|
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
|
||||||
|
SineEaseInOut, SineEaseIn, SineEaseOut,
|
||||||
|
CircularEaseIn, CircularEaseInOut, CircularEaseOut,
|
||||||
|
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
|
||||||
|
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
|
||||||
|
BackEaseIn, BackEaseInOut, BackEaseOut,
|
||||||
|
BounceEaseIn, BounceEaseInOut, BounceEaseOut)
|
||||||
|
|
||||||
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
InvocationContext,
|
||||||
|
InvocationConfig,
|
||||||
|
)
|
||||||
|
from ...backend.util.logging import InvokeAILogger
|
||||||
|
from .collections import FloatCollectionOutput
|
||||||
|
|
||||||
|
|
||||||
|
class FloatLinearRangeInvocation(BaseInvocation):
|
||||||
|
"""Creates a range"""
|
||||||
|
|
||||||
|
type: Literal["float_range"] = "float_range"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
start: float = Field(default=5, description="The first value of the range")
|
||||||
|
stop: float = Field(default=10, description="The last value of the range")
|
||||||
|
steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||||
|
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||||
|
return FloatCollectionOutput(
|
||||||
|
collection=param_list
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
EASING_FUNCTIONS_MAP = {
|
||||||
|
"Linear": LinearInOut,
|
||||||
|
"QuadIn": QuadEaseIn,
|
||||||
|
"QuadOut": QuadEaseOut,
|
||||||
|
"QuadInOut": QuadEaseInOut,
|
||||||
|
"CubicIn": CubicEaseIn,
|
||||||
|
"CubicOut": CubicEaseOut,
|
||||||
|
"CubicInOut": CubicEaseInOut,
|
||||||
|
"QuarticIn": QuarticEaseIn,
|
||||||
|
"QuarticOut": QuarticEaseOut,
|
||||||
|
"QuarticInOut": QuarticEaseInOut,
|
||||||
|
"QuinticIn": QuinticEaseIn,
|
||||||
|
"QuinticOut": QuinticEaseOut,
|
||||||
|
"QuinticInOut": QuinticEaseInOut,
|
||||||
|
"SineIn": SineEaseIn,
|
||||||
|
"SineOut": SineEaseOut,
|
||||||
|
"SineInOut": SineEaseInOut,
|
||||||
|
"CircularIn": CircularEaseIn,
|
||||||
|
"CircularOut": CircularEaseOut,
|
||||||
|
"CircularInOut": CircularEaseInOut,
|
||||||
|
"ExponentialIn": ExponentialEaseIn,
|
||||||
|
"ExponentialOut": ExponentialEaseOut,
|
||||||
|
"ExponentialInOut": ExponentialEaseInOut,
|
||||||
|
"ElasticIn": ElasticEaseIn,
|
||||||
|
"ElasticOut": ElasticEaseOut,
|
||||||
|
"ElasticInOut": ElasticEaseInOut,
|
||||||
|
"BackIn": BackEaseIn,
|
||||||
|
"BackOut": BackEaseOut,
|
||||||
|
"BackInOut": BackEaseInOut,
|
||||||
|
"BounceIn": BounceEaseIn,
|
||||||
|
"BounceOut": BounceEaseOut,
|
||||||
|
"BounceInOut": BounceEaseInOut,
|
||||||
|
}
|
||||||
|
|
||||||
|
EASING_FUNCTION_KEYS: Any = Literal[
|
||||||
|
tuple(list(EASING_FUNCTIONS_MAP.keys()))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||||
|
class StepParamEasingInvocation(BaseInvocation):
|
||||||
|
"""Experimental per-step parameter easing for denoising steps"""
|
||||||
|
|
||||||
|
type: Literal["step_param_easing"] = "step_param_easing"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
# fmt: off
|
||||||
|
easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
|
||||||
|
num_steps: int = Field(default=20, description="number of denoising steps")
|
||||||
|
start_value: float = Field(default=0.0, description="easing starting value")
|
||||||
|
end_value: float = Field(default=1.0, description="easing ending value")
|
||||||
|
start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
|
||||||
|
end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
|
||||||
|
# if None, then start_value is used prior to easing start
|
||||||
|
pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
|
||||||
|
# if None, then end value is used prior to easing end
|
||||||
|
post_end_value: Optional[float] = Field(default=None, description="value after easing end")
|
||||||
|
mirror: bool = Field(default=False, description="include mirror of easing function")
|
||||||
|
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||||
|
# alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
|
||||||
|
show_easing_plot: bool = Field(default=False, description="show easing plot")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||||
|
log_diagnostics = False
|
||||||
|
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
||||||
|
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
|
||||||
|
start_step = int(np.round(self.num_steps * self.start_step_percent))
|
||||||
|
# convert from end_step_percent to nearest step >= (steps * end_step_percent)
|
||||||
|
# end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
|
||||||
|
end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))
|
||||||
|
|
||||||
|
# end_step = int(np.ceil(self.num_steps * self.end_step_percent))
|
||||||
|
num_easing_steps = end_step - start_step + 1
|
||||||
|
|
||||||
|
# num_presteps = max(start_step - 1, 0)
|
||||||
|
num_presteps = start_step
|
||||||
|
num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
|
||||||
|
prelist = list(num_presteps * [self.pre_start_value])
|
||||||
|
postlist = list(num_poststeps * [self.post_end_value])
|
||||||
|
|
||||||
|
if log_diagnostics:
|
||||||
|
logger = InvokeAILogger.getLogger(name="StepParamEasing")
|
||||||
|
logger.debug("start_step: " + str(start_step))
|
||||||
|
logger.debug("end_step: " + str(end_step))
|
||||||
|
logger.debug("num_easing_steps: " + str(num_easing_steps))
|
||||||
|
logger.debug("num_presteps: " + str(num_presteps))
|
||||||
|
logger.debug("num_poststeps: " + str(num_poststeps))
|
||||||
|
logger.debug("prelist size: " + str(len(prelist)))
|
||||||
|
logger.debug("postlist size: " + str(len(postlist)))
|
||||||
|
logger.debug("prelist: " + str(prelist))
|
||||||
|
logger.debug("postlist: " + str(postlist))
|
||||||
|
|
||||||
|
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||||
|
if log_diagnostics:
|
||||||
|
logger.debug("easing class: " + str(easing_class))
|
||||||
|
easing_list = list()
|
||||||
|
if self.mirror: # "expected" mirroring
|
||||||
|
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||||
|
# and create reverse copy of list to append
|
||||||
|
# if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
|
||||||
|
# and create reverse copy of list[1:end-1]
|
||||||
|
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
||||||
|
|
||||||
|
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
|
||||||
|
if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration))
|
||||||
|
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
|
||||||
|
easing_function = easing_class(start=self.start_value,
|
||||||
|
end=self.end_value,
|
||||||
|
duration=base_easing_duration - 1)
|
||||||
|
base_easing_vals = list()
|
||||||
|
for step_index in range(base_easing_duration):
|
||||||
|
easing_val = easing_function.ease(step_index)
|
||||||
|
base_easing_vals.append(easing_val)
|
||||||
|
if log_diagnostics:
|
||||||
|
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
||||||
|
if even_num_steps:
|
||||||
|
mirror_easing_vals = list(reversed(base_easing_vals))
|
||||||
|
else:
|
||||||
|
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
||||||
|
if log_diagnostics:
|
||||||
|
logger.debug("base easing vals: " + str(base_easing_vals))
|
||||||
|
logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
||||||
|
easing_list = base_easing_vals + mirror_easing_vals
|
||||||
|
|
||||||
|
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||||
|
# elif self.alt_mirror: # function mirroring (unintuitive behavior (at least to me))
|
||||||
|
# # half_ease_duration = round(num_easing_steps - 1 / 2)
|
||||||
|
# half_ease_duration = round((num_easing_steps - 1) / 2)
|
||||||
|
# easing_function = easing_class(start=self.start_value,
|
||||||
|
# end=self.end_value,
|
||||||
|
# duration=half_ease_duration,
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# mirror_function = easing_class(start=self.end_value,
|
||||||
|
# end=self.start_value,
|
||||||
|
# duration=half_ease_duration,
|
||||||
|
# )
|
||||||
|
# for step_index in range(num_easing_steps):
|
||||||
|
# if step_index <= half_ease_duration:
|
||||||
|
# step_val = easing_function.ease(step_index)
|
||||||
|
# else:
|
||||||
|
# step_val = mirror_function.ease(step_index - half_ease_duration)
|
||||||
|
# easing_list.append(step_val)
|
||||||
|
# if log_diagnostics: logger.debug(step_index, step_val)
|
||||||
|
#
|
||||||
|
|
||||||
|
else: # no mirroring (default)
|
||||||
|
easing_function = easing_class(start=self.start_value,
|
||||||
|
end=self.end_value,
|
||||||
|
duration=num_easing_steps - 1)
|
||||||
|
for step_index in range(num_easing_steps):
|
||||||
|
step_val = easing_function.ease(step_index)
|
||||||
|
easing_list.append(step_val)
|
||||||
|
if log_diagnostics:
|
||||||
|
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
||||||
|
|
||||||
|
if log_diagnostics:
|
||||||
|
logger.debug("prelist size: " + str(len(prelist)))
|
||||||
|
logger.debug("easing_list size: " + str(len(easing_list)))
|
||||||
|
logger.debug("postlist size: " + str(len(postlist)))
|
||||||
|
|
||||||
|
param_list = prelist + easing_list + postlist
|
||||||
|
|
||||||
|
if self.show_easing_plot:
|
||||||
|
plt.figure()
|
||||||
|
plt.xlabel("Step")
|
||||||
|
plt.ylabel("Param Value")
|
||||||
|
plt.title("Per-Step Values Based On Easing: " + self.easing)
|
||||||
|
plt.bar(range(len(param_list)), param_list)
|
||||||
|
# plt.plot(param_list)
|
||||||
|
ax = plt.gca()
|
||||||
|
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||||
|
buf = io.BytesIO()
|
||||||
|
plt.savefig(buf, format='png')
|
||||||
|
buf.seek(0)
|
||||||
|
im = PIL.Image.open(buf)
|
||||||
|
im.show()
|
||||||
|
buf.close()
|
||||||
|
|
||||||
|
# output array of size steps, each entry list[i] is param value for step i
|
||||||
|
return FloatCollectionOutput(
|
||||||
|
collection=param_list
|
||||||
|
)
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional
|
from typing import Optional, Union, List
|
||||||
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
|
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
|
||||||
|
|
||||||
|
|
||||||
@ -47,7 +47,9 @@ class ImageMetadata(BaseModel):
|
|||||||
default=None, description="The seed used for noise generation."
|
default=None, description="The seed used for noise generation."
|
||||||
)
|
)
|
||||||
"""The seed used for noise generation"""
|
"""The seed used for noise generation"""
|
||||||
cfg_scale: Optional[StrictFloat] = Field(
|
# cfg_scale: Optional[StrictFloat] = Field(
|
||||||
|
# cfg_scale: Union[float, list[float]] = Field(
|
||||||
|
cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field(
|
||||||
default=None, description="The classifier-free guidance scale."
|
default=None, description="The classifier-free guidance scale."
|
||||||
)
|
)
|
||||||
"""The classifier-free guidance scale"""
|
"""The classifier-free guidance scale"""
|
||||||
|
@ -65,7 +65,6 @@ from typing import Optional, Union, List, get_args
|
|||||||
def is_union_subtype(t1, t2):
|
def is_union_subtype(t1, t2):
|
||||||
t1_args = get_args(t1)
|
t1_args = get_args(t1)
|
||||||
t2_args = get_args(t2)
|
t2_args = get_args(t2)
|
||||||
|
|
||||||
if not t1_args:
|
if not t1_args:
|
||||||
# t1 is a single type
|
# t1 is a single type
|
||||||
return t1 in t2_args
|
return t1 in t2_args
|
||||||
@ -86,7 +85,6 @@ def is_list_or_contains_list(t):
|
|||||||
for arg in t_args:
|
for arg in t_args:
|
||||||
if get_origin(arg) is list:
|
if get_origin(arg) is list:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -393,7 +391,7 @@ class Graph(BaseModel):
|
|||||||
from_node = self.get_node(edge.source.node_id)
|
from_node = self.get_node(edge.source.node_id)
|
||||||
to_node = self.get_node(edge.destination.node_id)
|
to_node = self.get_node(edge.destination.node_id)
|
||||||
except NodeNotFoundError:
|
except NodeNotFoundError:
|
||||||
raise InvalidEdgeError("One or both nodes don't exist")
|
raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}")
|
||||||
|
|
||||||
# Validate that an edge to this node+field doesn't already exist
|
# Validate that an edge to this node+field doesn't already exist
|
||||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||||
@ -404,41 +402,41 @@ class Graph(BaseModel):
|
|||||||
g = self.nx_graph_flat()
|
g = self.nx_graph_flat()
|
||||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||||
if not nx.is_directed_acyclic_graph(g):
|
if not nx.is_directed_acyclic_graph(g):
|
||||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
|
raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}')
|
||||||
|
|
||||||
# Validate that the field types are compatible
|
# Validate that the field types are compatible
|
||||||
if not are_connections_compatible(
|
if not are_connections_compatible(
|
||||||
from_node, edge.source.field, to_node, edge.destination.field
|
from_node, edge.source.field, to_node, edge.destination.field
|
||||||
):
|
):
|
||||||
raise InvalidEdgeError(f'Fields are incompatible')
|
raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||||
|
|
||||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||||
if not self._is_iterator_connection_valid(
|
if not self._is_iterator_connection_valid(
|
||||||
edge.destination.node_id, new_input=edge.source
|
edge.destination.node_id, new_input=edge.source
|
||||||
):
|
):
|
||||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
|
raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||||
|
|
||||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||||
if not self._is_iterator_connection_valid(
|
if not self._is_iterator_connection_valid(
|
||||||
edge.source.node_id, new_output=edge.destination
|
edge.source.node_id, new_output=edge.destination
|
||||||
):
|
):
|
||||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
|
raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||||
|
|
||||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||||
if not self._is_collector_connection_valid(
|
if not self._is_collector_connection_valid(
|
||||||
edge.destination.node_id, new_input=edge.source
|
edge.destination.node_id, new_input=edge.source
|
||||||
):
|
):
|
||||||
raise InvalidEdgeError(f'Collector output type does not match collector input type')
|
raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||||
|
|
||||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||||
if not self._is_collector_connection_valid(
|
if not self._is_collector_connection_valid(
|
||||||
edge.source.node_id, new_output=edge.destination
|
edge.source.node_id, new_output=edge.destination
|
||||||
):
|
):
|
||||||
raise InvalidEdgeError(f'Collector input type does not match collector output type')
|
raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||||
|
|
||||||
|
|
||||||
def has_node(self, node_path: str) -> bool:
|
def has_node(self, node_path: str) -> bool:
|
||||||
@ -859,11 +857,9 @@ class GraphExecutionState(BaseModel):
|
|||||||
if next_node is None:
|
if next_node is None:
|
||||||
prepared_id = self._prepare()
|
prepared_id = self._prepare()
|
||||||
|
|
||||||
# TODO: prepare multiple nodes at once?
|
# Prepare as many nodes as we can
|
||||||
# while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation):
|
while prepared_id is not None:
|
||||||
# prepared_id = self._prepare()
|
prepared_id = self._prepare()
|
||||||
|
|
||||||
if prepared_id is not None:
|
|
||||||
next_node = self._get_next_node()
|
next_node = self._get_next_node()
|
||||||
|
|
||||||
# Get values from edges
|
# Get values from edges
|
||||||
@ -1010,14 +1006,30 @@ class GraphExecutionState(BaseModel):
|
|||||||
# Get flattened source graph
|
# Get flattened source graph
|
||||||
g = self.graph.nx_graph_flat()
|
g = self.graph.nx_graph_flat()
|
||||||
|
|
||||||
# Find next unprepared node where all source nodes are executed
|
# Find next node that:
|
||||||
|
# - was not already prepared
|
||||||
|
# - is not an iterate node whose inputs have not been executed
|
||||||
|
# - does not have an unexecuted iterate ancestor
|
||||||
sorted_nodes = nx.topological_sort(g)
|
sorted_nodes = nx.topological_sort(g)
|
||||||
next_node_id = next(
|
next_node_id = next(
|
||||||
(
|
(
|
||||||
n
|
n
|
||||||
for n in sorted_nodes
|
for n in sorted_nodes
|
||||||
|
# exclude nodes that have already been prepared
|
||||||
if n not in self.source_prepared_mapping
|
if n not in self.source_prepared_mapping
|
||||||
and all((e[0] in self.executed for e in g.in_edges(n)))
|
# exclude iterate nodes whose inputs have not been executed
|
||||||
|
and not (
|
||||||
|
isinstance(self.graph.get_node(n), IterateInvocation) # `n` is an iterate node...
|
||||||
|
and not all((e[0] in self.executed for e in g.in_edges(n))) # ...that has unexecuted inputs
|
||||||
|
)
|
||||||
|
# exclude nodes who have unexecuted iterate ancestors
|
||||||
|
and not any(
|
||||||
|
(
|
||||||
|
isinstance(self.graph.get_node(a), IterateInvocation) # `a` is an iterate ancestor of `n`...
|
||||||
|
and a not in self.executed # ...that is not executed
|
||||||
|
for a in nx.ancestors(g, n) # for all ancestors `a` of node `n`
|
||||||
|
)
|
||||||
|
)
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@ -1114,9 +1126,22 @@ class GraphExecutionState(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_next_node(self) -> Optional[BaseInvocation]:
|
def _get_next_node(self) -> Optional[BaseInvocation]:
|
||||||
|
"""Gets the deepest node that is ready to be executed"""
|
||||||
g = self.execution_graph.nx_graph()
|
g = self.execution_graph.nx_graph()
|
||||||
sorted_nodes = nx.topological_sort(g)
|
|
||||||
next_node = next((n for n in sorted_nodes if n not in self.executed), None)
|
# Depth-first search with pre-order traversal is a depth-first topological sort
|
||||||
|
sorted_nodes = nx.dfs_preorder_nodes(g)
|
||||||
|
|
||||||
|
next_node = next(
|
||||||
|
(
|
||||||
|
n
|
||||||
|
for n in sorted_nodes
|
||||||
|
if n not in self.executed # the node must not already be executed...
|
||||||
|
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
if next_node is None:
|
if next_node is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -22,7 +22,8 @@ class Invoker:
|
|||||||
def invoke(
|
def invoke(
|
||||||
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
|
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||||
|
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||||
|
|
||||||
# Get the next invocation
|
# Get the next invocation
|
||||||
invocation = graph_execution_state.next()
|
invocation = graph_execution_state.next()
|
||||||
|
@ -40,6 +40,7 @@ import invokeai.configs as configs
|
|||||||
from invokeai.app.services.config import (
|
from invokeai.app.services.config import (
|
||||||
InvokeAIAppConfig,
|
InvokeAIAppConfig,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||||
from invokeai.frontend.install.widgets import (
|
from invokeai.frontend.install.widgets import (
|
||||||
CenteredButtonPress,
|
CenteredButtonPress,
|
||||||
@ -80,6 +81,7 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
|||||||
# or renaming it and then running invokeai-configure again.
|
# or renaming it and then running invokeai-configure again.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
logger=None
|
||||||
|
|
||||||
# --------------------------------------------
|
# --------------------------------------------
|
||||||
def postscript(errors: None):
|
def postscript(errors: None):
|
||||||
@ -824,6 +826,7 @@ def main():
|
|||||||
if opt.full_precision:
|
if opt.full_precision:
|
||||||
invoke_args.extend(['--precision','float32'])
|
invoke_args.extend(['--precision','float32'])
|
||||||
config.parse_args(invoke_args)
|
config.parse_args(invoke_args)
|
||||||
|
logger = InvokeAILogger().getLogger(config=config)
|
||||||
|
|
||||||
errors = set()
|
errors = set()
|
||||||
|
|
||||||
|
@ -784,7 +784,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
self.logger.info(f"Probing {thing} for import")
|
self.logger.info(f"Probing {thing} for import")
|
||||||
|
|
||||||
if thing.startswith(("http:", "https:", "ftp:")):
|
if str(thing).startswith(("http:", "https:", "ftp:")):
|
||||||
self.logger.info(f"{thing} appears to be a URL")
|
self.logger.info(f"{thing} appears to be a URL")
|
||||||
model_path = self._resolve_path(
|
model_path = self._resolve_path(
|
||||||
thing, "models/ldm/stable-diffusion-v1"
|
thing, "models/ldm/stable-diffusion-v1"
|
||||||
|
@ -218,7 +218,7 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
|||||||
class ControlNetData:
|
class ControlNetData:
|
||||||
model: ControlNetModel = Field(default=None)
|
model: ControlNetModel = Field(default=None)
|
||||||
image_tensor: torch.Tensor= Field(default=None)
|
image_tensor: torch.Tensor= Field(default=None)
|
||||||
weight: float = Field(default=1.0)
|
weight: Union[float, List[float]]= Field(default=1.0)
|
||||||
begin_step_percent: float = Field(default=0.0)
|
begin_step_percent: float = Field(default=0.0)
|
||||||
end_step_percent: float = Field(default=1.0)
|
end_step_percent: float = Field(default=1.0)
|
||||||
|
|
||||||
@ -226,7 +226,7 @@ class ControlNetData:
|
|||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
unconditioned_embeddings: torch.Tensor
|
unconditioned_embeddings: torch.Tensor
|
||||||
text_embeddings: torch.Tensor
|
text_embeddings: torch.Tensor
|
||||||
guidance_scale: float
|
guidance_scale: Union[float, List[float]]
|
||||||
"""
|
"""
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||||
@ -662,7 +662,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
down_block_res_samples, mid_block_res_sample = None, None
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
if control_data is not None:
|
if control_data is not None:
|
||||||
if conditioning_data.guidance_scale > 1.0:
|
# FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting
|
||||||
|
# if conditioning_data.guidance_scale > 1.0:
|
||||||
|
if conditioning_data.guidance_scale is not None:
|
||||||
# expand the latents input to control model if doing classifier free guidance
|
# expand the latents input to control model if doing classifier free guidance
|
||||||
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||||
# classifier_free_guidance is <= 1.0 ?)
|
# classifier_free_guidance is <= 1.0 ?)
|
||||||
@ -679,13 +681,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||||
if step_index >= first_control_step and step_index <= last_control_step:
|
if step_index >= first_control_step and step_index <= last_control_step:
|
||||||
# print("running controlnet", i, "for step", step_index)
|
# print("running controlnet", i, "for step", step_index)
|
||||||
|
if isinstance(control_datum.weight, list):
|
||||||
|
# if controlnet has multiple weights, use the weight for the current step
|
||||||
|
controlnet_weight = control_datum.weight[step_index]
|
||||||
|
else:
|
||||||
|
# if controlnet has a single weight, use it for all steps
|
||||||
|
controlnet_weight = control_datum.weight
|
||||||
down_samples, mid_sample = control_datum.model(
|
down_samples, mid_sample = control_datum.model(
|
||||||
sample=latent_control_input,
|
sample=latent_control_input,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
||||||
conditioning_data.text_embeddings]),
|
conditioning_data.text_embeddings]),
|
||||||
controlnet_cond=control_datum.image_tensor,
|
controlnet_cond=control_datum.image_tensor,
|
||||||
conditioning_scale=control_datum.weight,
|
conditioning_scale=controlnet_weight,
|
||||||
# cross_attention_kwargs,
|
# cross_attention_kwargs,
|
||||||
guess_mode=False,
|
guess_mode=False,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import Any, Callable, Dict, Optional, Union
|
from typing import Any, Callable, Dict, Optional, Union, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -180,7 +180,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma: torch.Tensor,
|
sigma: torch.Tensor,
|
||||||
unconditioning: Union[torch.Tensor, dict],
|
unconditioning: Union[torch.Tensor, dict],
|
||||||
conditioning: Union[torch.Tensor, dict],
|
conditioning: Union[torch.Tensor, dict],
|
||||||
unconditional_guidance_scale: float,
|
# unconditional_guidance_scale: float,
|
||||||
|
unconditional_guidance_scale: Union[float, List[float]],
|
||||||
step_index: Optional[int] = None,
|
step_index: Optional[int] = None,
|
||||||
total_step_count: Optional[int] = None,
|
total_step_count: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -195,6 +196,11 @@ class InvokeAIDiffuserComponent:
|
|||||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if isinstance(unconditional_guidance_scale, list):
|
||||||
|
guidance_scale = unconditional_guidance_scale[step_index]
|
||||||
|
else:
|
||||||
|
guidance_scale = unconditional_guidance_scale
|
||||||
|
|
||||||
cross_attention_control_types_to_do = []
|
cross_attention_control_types_to_do = []
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
if self.cross_attention_control_context is not None:
|
if self.cross_attention_control_context is not None:
|
||||||
@ -243,7 +249,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
combined_next_x = self._combine(
|
combined_next_x = self._combine(
|
||||||
unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
|
# unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale
|
||||||
|
unconditioned_next_x, conditioned_next_x, guidance_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
return combined_next_x
|
return combined_next_x
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
|
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
|
||||||
|
|
||||||
"""invokeai.util.logging
|
"""
|
||||||
|
invokeai.util.logging
|
||||||
|
|
||||||
Logging class for InvokeAI that produces console messages
|
Logging class for InvokeAI that produces console messages
|
||||||
|
|
||||||
@ -11,6 +12,7 @@ from invokeai.backend.util.logging import InvokeAILogger
|
|||||||
logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization
|
logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization
|
||||||
(or)
|
(or)
|
||||||
logger = InvokeAILogger.getLogger(__name__) // To use the filename
|
logger = InvokeAILogger.getLogger(__name__) // To use the filename
|
||||||
|
logger.configure()
|
||||||
|
|
||||||
logger.critical('this is critical') // Critical Message
|
logger.critical('this is critical') // Critical Message
|
||||||
logger.error('this is an error') // Error Message
|
logger.error('this is an error') // Error Message
|
||||||
@ -28,6 +30,149 @@ Console messages:
|
|||||||
Alternate Method (in this case the logger name will be set to InvokeAI):
|
Alternate Method (in this case the logger name will be set to InvokeAI):
|
||||||
import invokeai.backend.util.logging as IAILogger
|
import invokeai.backend.util.logging as IAILogger
|
||||||
IAILogger.debug('this is a debugging message')
|
IAILogger.debug('this is a debugging message')
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
The default configuration will print to stderr on the console. To add
|
||||||
|
additional logging handlers, call getLogger with an initialized InvokeAIAppConfig
|
||||||
|
object:
|
||||||
|
|
||||||
|
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
config.parse_args()
|
||||||
|
logger = InvokeAILogger.getLogger(config=config)
|
||||||
|
|
||||||
|
### Three command-line options control logging:
|
||||||
|
|
||||||
|
`--log_handlers <handler1> <handler2> ...`
|
||||||
|
|
||||||
|
This option activates one or more log handlers. Options are "console", "file", "syslog" and "http". To specify more than one, separate them by spaces:
|
||||||
|
|
||||||
|
```
|
||||||
|
invokeai-web --log_handlers console syslog=/dev/log file=C:\\Users\\fred\\invokeai.log
|
||||||
|
```
|
||||||
|
|
||||||
|
The format of these options is described below.
|
||||||
|
|
||||||
|
### `--log_format {plain|color|legacy|syslog}`
|
||||||
|
|
||||||
|
This controls the format of log messages written to the console. Only the "console" log handler is currently affected by this setting.
|
||||||
|
|
||||||
|
* "plain" provides formatted messages like this:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
|
||||||
|
[2023-05-24 23:18:2[2023-05-24 23:18:50,352]::[InvokeAI]::DEBUG --> this is a debug message
|
||||||
|
[2023-05-24 23:18:50,352]::[InvokeAI]::INFO --> this is an informational messages
|
||||||
|
[2023-05-24 23:18:50,352]::[InvokeAI]::WARNING --> this is a warning
|
||||||
|
[2023-05-24 23:18:50,352]::[InvokeAI]::ERROR --> this is an error
|
||||||
|
[2023-05-24 23:18:50,352]::[InvokeAI]::CRITICAL --> this is a critical error
|
||||||
|
```
|
||||||
|
|
||||||
|
* "color" produces similar output, but the text will be color coded to indicate the severity of the message.
|
||||||
|
|
||||||
|
* "legacy" produces output similar to InvokeAI versions 2.3 and earlier:
|
||||||
|
|
||||||
|
```
|
||||||
|
### this is a critical error
|
||||||
|
*** this is an error
|
||||||
|
** this is a warning
|
||||||
|
>> this is an informational messages
|
||||||
|
| this is a debug message
|
||||||
|
```
|
||||||
|
|
||||||
|
* "syslog" produces messages suitable for syslog entries:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
InvokeAI [2691178] <CRITICAL> this is a critical error
|
||||||
|
InvokeAI [2691178] <ERROR> this is an error
|
||||||
|
InvokeAI [2691178] <WARNING> this is a warning
|
||||||
|
InvokeAI [2691178] <INFO> this is an informational messages
|
||||||
|
InvokeAI [2691178] <DEBUG> this is a debug message
|
||||||
|
```
|
||||||
|
|
||||||
|
(note that the date, time and hostname will be added by the syslog system)
|
||||||
|
|
||||||
|
### `--log_level {debug|info|warning|error|critical}`
|
||||||
|
|
||||||
|
Providing this command-line option will cause only messages at the specified level or above to be emitted.
|
||||||
|
|
||||||
|
## Console logging
|
||||||
|
|
||||||
|
When "console" is provided to `--log_handlers`, messages will be written to the command line window in which InvokeAI was launched. By default, the color formatter will be used unless overridden by `--log_format`.
|
||||||
|
|
||||||
|
## File logging
|
||||||
|
|
||||||
|
When "file" is provided to `--log_handlers`, entries will be written to the file indicated in the path argument. By default, the "plain" format will be used:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
invokeai-web --log_handlers file=/var/log/invokeai.log
|
||||||
|
```
|
||||||
|
|
||||||
|
## Syslog logging
|
||||||
|
|
||||||
|
When "syslog" is requested, entries will be sent to the syslog system. There are a variety of ways to control where the log message is sent:
|
||||||
|
|
||||||
|
* Send to the local machine using the `/dev/log` socket:
|
||||||
|
|
||||||
|
```
|
||||||
|
invokeai-web --log_handlers syslog=/dev/log
|
||||||
|
```
|
||||||
|
|
||||||
|
* Send to the local machine using a UDP message:
|
||||||
|
|
||||||
|
```
|
||||||
|
invokeai-web --log_handlers syslog=localhost
|
||||||
|
```
|
||||||
|
|
||||||
|
* Send to the local machine using a UDP message on a nonstandard port:
|
||||||
|
|
||||||
|
```
|
||||||
|
invokeai-web --log_handlers syslog=localhost:512
|
||||||
|
```
|
||||||
|
|
||||||
|
* Send to a remote machine named "loghost" on the local LAN using facility LOG_USER and UDP packets:
|
||||||
|
|
||||||
|
```
|
||||||
|
invokeai-web --log_handlers syslog=loghost,facility=LOG_USER,socktype=SOCK_DGRAM
|
||||||
|
```
|
||||||
|
|
||||||
|
This can be abbreviated `syslog=loghost`, as LOG_USER and SOCK_DGRAM are defaults.
|
||||||
|
|
||||||
|
* Send to a remote machine named "loghost" using the facility LOCAL0 and using a TCP socket:
|
||||||
|
|
||||||
|
```
|
||||||
|
invokeai-web --log_handlers syslog=loghost,facility=LOG_LOCAL0,socktype=SOCK_STREAM
|
||||||
|
```
|
||||||
|
|
||||||
|
If no arguments are specified (just a bare "syslog"), then the logging system will look for a UNIX socket named `/dev/log`, and if not found try to send a UDP message to `localhost`. The Macintosh OS used to support logging to a socket named `/var/run/syslog`, but this feature has since been disabled.
|
||||||
|
|
||||||
|
## Web logging
|
||||||
|
|
||||||
|
If you have access to a web server that is configured to log messages when a particular URL is requested, you can log using the "http" method:
|
||||||
|
|
||||||
|
```
|
||||||
|
invokeai-web --log_handlers http=http://my.server/path/to/logger,method=POST
|
||||||
|
```
|
||||||
|
|
||||||
|
The optional [,method=] part can be used to specify whether the URL accepts GET (default) or POST messages.
|
||||||
|
|
||||||
|
Currently password authentication and SSL are not supported.
|
||||||
|
|
||||||
|
## Using the configuration file
|
||||||
|
|
||||||
|
You can set and forget logging options by adding a "Logging" section to `invokeai.yaml`:
|
||||||
|
|
||||||
|
```
|
||||||
|
InvokeAI:
|
||||||
|
[... other settings...]
|
||||||
|
Logging:
|
||||||
|
log_handlers:
|
||||||
|
- console
|
||||||
|
- syslog=/dev/log
|
||||||
|
log_level: info
|
||||||
|
log_format: color
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging.handlers
|
import logging.handlers
|
||||||
@ -180,14 +325,17 @@ class InvokeAILogger(object):
|
|||||||
loggers = dict()
|
loggers = dict()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def getLogger(cls, name: str = 'InvokeAI') -> logging.Logger:
|
def getLogger(cls,
|
||||||
config = get_invokeai_config()
|
name: str = 'InvokeAI',
|
||||||
|
config: InvokeAIAppConfig=InvokeAIAppConfig.get_config())->logging.Logger:
|
||||||
if name not in cls.loggers:
|
if name in cls.loggers:
|
||||||
|
logger = cls.loggers[name]
|
||||||
|
logger.handlers.clear()
|
||||||
|
else:
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
logger.setLevel(config.log_level.upper()) # yes, strings work here
|
logger.setLevel(config.log_level.upper()) # yes, strings work here
|
||||||
for ch in cls.getLoggers(config):
|
for ch in cls.getLoggers(config):
|
||||||
logger.addHandler(ch)
|
logger.addHandler(ch)
|
||||||
cls.loggers[name] = logger
|
cls.loggers[name] = logger
|
||||||
return cls.loggers[name]
|
return cls.loggers[name]
|
||||||
|
|
||||||
@ -199,9 +347,11 @@ class InvokeAILogger(object):
|
|||||||
handler_name,*args = handler.split('=',2)
|
handler_name,*args = handler.split('=',2)
|
||||||
args = args[0] if len(args) > 0 else None
|
args = args[0] if len(args) > 0 else None
|
||||||
|
|
||||||
# console is the only handler that gets a custom formatter
|
# console and file get the fancy formatter.
|
||||||
|
# syslog gets a simple one
|
||||||
|
# http gets no custom formatter
|
||||||
|
formatter = LOG_FORMATTERS[config.log_format]
|
||||||
if handler_name=='console':
|
if handler_name=='console':
|
||||||
formatter = LOG_FORMATTERS[config.log_format]
|
|
||||||
ch = logging.StreamHandler()
|
ch = logging.StreamHandler()
|
||||||
ch.setFormatter(formatter())
|
ch.setFormatter(formatter())
|
||||||
handlers.append(ch)
|
handlers.append(ch)
|
||||||
@ -212,7 +362,9 @@ class InvokeAILogger(object):
|
|||||||
handlers.append(ch)
|
handlers.append(ch)
|
||||||
|
|
||||||
elif handler_name=='file':
|
elif handler_name=='file':
|
||||||
handlers.append(cls._parse_file_args(args))
|
ch = cls._parse_file_args(args)
|
||||||
|
ch.setFormatter(formatter())
|
||||||
|
handlers.append(ch)
|
||||||
|
|
||||||
elif handler_name=='http':
|
elif handler_name=='http':
|
||||||
handlers.append(cls._parse_http_args(args))
|
handlers.append(cls._parse_http_args(args))
|
||||||
|
@ -28,7 +28,7 @@ import torch
|
|||||||
from npyscreen import widget
|
from npyscreen import widget
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from invokeai.backend.install.model_install_backend import (
|
from invokeai.backend.install.model_install_backend import (
|
||||||
Dataset_path,
|
Dataset_path,
|
||||||
@ -939,6 +939,7 @@ def main():
|
|||||||
if opt.full_precision:
|
if opt.full_precision:
|
||||||
invoke_args.extend(['--precision','float32'])
|
invoke_args.extend(['--precision','float32'])
|
||||||
config.parse_args(invoke_args)
|
config.parse_args(invoke_args)
|
||||||
|
logger = InvokeAILogger().getLogger(config=config)
|
||||||
|
|
||||||
if not (config.root_dir / config.conf_path.parent).exists():
|
if not (config.root_dir / config.conf_path.parent).exists():
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -22,6 +22,7 @@ import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
|||||||
import GlobalHotkeys from './GlobalHotkeys';
|
import GlobalHotkeys from './GlobalHotkeys';
|
||||||
import Toaster from './Toaster';
|
import Toaster from './Toaster';
|
||||||
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
|
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
|
||||||
|
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||||
|
|
||||||
const DEFAULT_CONFIG = {};
|
const DEFAULT_CONFIG = {};
|
||||||
|
|
||||||
@ -66,10 +67,17 @@ const App = ({
|
|||||||
setIsReady(true);
|
setIsReady(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isApplicationReady) {
|
||||||
|
// TODO: This is a jank fix for canvas not filling the screen on first load
|
||||||
|
setTimeout(() => {
|
||||||
|
dispatch(requestCanvasRescale());
|
||||||
|
}, 200);
|
||||||
|
}
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
setIsReady && setIsReady(false);
|
setIsReady && setIsReady(false);
|
||||||
};
|
};
|
||||||
}, [isApplicationReady, setIsReady]);
|
}, [dispatch, isApplicationReady, setIsReady]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
|
@ -40,11 +40,11 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const mouseSensor = useSensor(MouseSensor, {
|
const mouseSensor = useSensor(MouseSensor, {
|
||||||
activationConstraint: { delay: 250, tolerance: 5 },
|
activationConstraint: { delay: 150, tolerance: 5 },
|
||||||
});
|
});
|
||||||
|
|
||||||
const touchSensor = useSensor(TouchSensor, {
|
const touchSensor = useSensor(TouchSensor, {
|
||||||
activationConstraint: { delay: 250, tolerance: 5 },
|
activationConstraint: { delay: 150, tolerance: 5 },
|
||||||
});
|
});
|
||||||
// TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos
|
// TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos
|
||||||
// Alternatively, fix `rectIntersection` collection detection to work with the drag overlay
|
// Alternatively, fix `rectIntersection` collection detection to work with the drag overlay
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
import {
|
||||||
|
CONTROLNET_MODELS,
|
||||||
|
CONTROLNET_PROCESSORS,
|
||||||
|
} from 'features/controlNet/store/constants';
|
||||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
import { O } from 'ts-toolbelt';
|
import { O } from 'ts-toolbelt';
|
||||||
|
|
||||||
@ -117,6 +121,8 @@ export type AppConfig = {
|
|||||||
canRestoreDeletedImagesFromBin: boolean;
|
canRestoreDeletedImagesFromBin: boolean;
|
||||||
sd: {
|
sd: {
|
||||||
defaultModel?: string;
|
defaultModel?: string;
|
||||||
|
disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[];
|
||||||
|
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
|
||||||
iterations: {
|
iterations: {
|
||||||
initial: number;
|
initial: number;
|
||||||
min: number;
|
min: number;
|
||||||
|
@ -2,7 +2,6 @@ import { CheckIcon, ChevronUpIcon } from '@chakra-ui/icons';
|
|||||||
import {
|
import {
|
||||||
Box,
|
Box,
|
||||||
Flex,
|
Flex,
|
||||||
FlexProps,
|
|
||||||
FormControl,
|
FormControl,
|
||||||
FormControlProps,
|
FormControlProps,
|
||||||
FormLabel,
|
FormLabel,
|
||||||
@ -16,42 +15,64 @@ import {
|
|||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom';
|
import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom';
|
||||||
import { useSelect } from 'downshift';
|
import { useSelect } from 'downshift';
|
||||||
|
import { isString } from 'lodash-es';
|
||||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||||
|
|
||||||
import { memo, useMemo } from 'react';
|
import { memo, useLayoutEffect, useMemo } from 'react';
|
||||||
import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles';
|
import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles';
|
||||||
|
|
||||||
export type ItemTooltips = { [key: string]: string };
|
export type ItemTooltips = { [key: string]: string };
|
||||||
|
|
||||||
|
export type IAICustomSelectOption = {
|
||||||
|
value: string;
|
||||||
|
label: string;
|
||||||
|
tooltip?: string;
|
||||||
|
};
|
||||||
|
|
||||||
type IAICustomSelectProps = {
|
type IAICustomSelectProps = {
|
||||||
label?: string;
|
label?: string;
|
||||||
items: string[];
|
value: string;
|
||||||
itemTooltips?: ItemTooltips;
|
data: IAICustomSelectOption[] | string[];
|
||||||
selectedItem: string;
|
onChange: (v: string) => void;
|
||||||
setSelectedItem: (v: string | null | undefined) => void;
|
|
||||||
withCheckIcon?: boolean;
|
withCheckIcon?: boolean;
|
||||||
formControlProps?: FormControlProps;
|
formControlProps?: FormControlProps;
|
||||||
buttonProps?: FlexProps;
|
|
||||||
tooltip?: string;
|
tooltip?: string;
|
||||||
tooltipProps?: Omit<TooltipProps, 'children'>;
|
tooltipProps?: Omit<TooltipProps, 'children'>;
|
||||||
ellipsisPosition?: 'start' | 'end';
|
ellipsisPosition?: 'start' | 'end';
|
||||||
|
isDisabled?: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
const IAICustomSelect = (props: IAICustomSelectProps) => {
|
const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||||
const {
|
const {
|
||||||
label,
|
label,
|
||||||
items,
|
|
||||||
itemTooltips,
|
|
||||||
setSelectedItem,
|
|
||||||
selectedItem,
|
|
||||||
withCheckIcon,
|
withCheckIcon,
|
||||||
formControlProps,
|
formControlProps,
|
||||||
tooltip,
|
tooltip,
|
||||||
buttonProps,
|
|
||||||
tooltipProps,
|
tooltipProps,
|
||||||
ellipsisPosition = 'end',
|
ellipsisPosition = 'end',
|
||||||
|
data,
|
||||||
|
value,
|
||||||
|
onChange,
|
||||||
|
isDisabled = false,
|
||||||
} = props;
|
} = props;
|
||||||
|
|
||||||
|
const values = useMemo(() => {
|
||||||
|
return data.map<IAICustomSelectOption>((v) => {
|
||||||
|
if (isString(v)) {
|
||||||
|
return { value: v, label: v };
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
});
|
||||||
|
}, [data]);
|
||||||
|
|
||||||
|
const stringValues = useMemo(() => {
|
||||||
|
return values.map((v) => v.value);
|
||||||
|
}, [values]);
|
||||||
|
|
||||||
|
const valueData = useMemo(() => {
|
||||||
|
return values.find((v) => v.value === value);
|
||||||
|
}, [values, value]);
|
||||||
|
|
||||||
const {
|
const {
|
||||||
isOpen,
|
isOpen,
|
||||||
getToggleButtonProps,
|
getToggleButtonProps,
|
||||||
@ -60,17 +81,24 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
highlightedIndex,
|
highlightedIndex,
|
||||||
getItemProps,
|
getItemProps,
|
||||||
} = useSelect({
|
} = useSelect({
|
||||||
items,
|
items: stringValues,
|
||||||
selectedItem,
|
selectedItem: value,
|
||||||
onSelectedItemChange: ({ selectedItem: newSelectedItem }) =>
|
onSelectedItemChange: ({ selectedItem: newSelectedItem }) => {
|
||||||
setSelectedItem(newSelectedItem),
|
newSelectedItem && onChange(newSelectedItem);
|
||||||
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const { refs, floatingStyles } = useFloating<HTMLButtonElement>({
|
const { refs, floatingStyles, update } = useFloating<HTMLButtonElement>({
|
||||||
whileElementsMounted: autoUpdate,
|
// whileElementsMounted: autoUpdate,
|
||||||
middleware: [offset(4), shift({ crossAxis: true, padding: 8 })],
|
middleware: [offset(4), shift({ crossAxis: true, padding: 8 })],
|
||||||
});
|
});
|
||||||
|
|
||||||
|
useLayoutEffect(() => {
|
||||||
|
if (isOpen && refs.reference.current && refs.floating.current) {
|
||||||
|
return autoUpdate(refs.reference.current, refs.floating.current, update);
|
||||||
|
}
|
||||||
|
}, [isOpen, update, refs.floating, refs.reference]);
|
||||||
|
|
||||||
const labelTextDirection = useMemo(() => {
|
const labelTextDirection = useMemo(() => {
|
||||||
if (ellipsisPosition === 'start') {
|
if (ellipsisPosition === 'start') {
|
||||||
return document.dir === 'rtl' ? 'ltr' : 'rtl';
|
return document.dir === 'rtl' ? 'ltr' : 'rtl';
|
||||||
@ -93,8 +121,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
)}
|
)}
|
||||||
<Tooltip label={tooltip} {...tooltipProps}>
|
<Tooltip label={tooltip} {...tooltipProps}>
|
||||||
<Flex
|
<Flex
|
||||||
{...getToggleButtonProps({ ref: refs.setReference })}
|
{...getToggleButtonProps({ ref: refs.reference })}
|
||||||
{...buttonProps}
|
|
||||||
sx={{
|
sx={{
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
userSelect: 'none',
|
userSelect: 'none',
|
||||||
@ -105,6 +132,8 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
px: 2,
|
px: 2,
|
||||||
gap: 2,
|
gap: 2,
|
||||||
justifyContent: 'space-between',
|
justifyContent: 'space-between',
|
||||||
|
pointerEvents: isDisabled ? 'none' : undefined,
|
||||||
|
opacity: isDisabled ? 0.5 : undefined,
|
||||||
...getInputOutlineStyles(),
|
...getInputOutlineStyles(),
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
@ -119,7 +148,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
direction: labelTextDirection,
|
direction: labelTextDirection,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{selectedItem}
|
{valueData?.label}
|
||||||
</Text>
|
</Text>
|
||||||
<ChevronUpIcon
|
<ChevronUpIcon
|
||||||
sx={{
|
sx={{
|
||||||
@ -135,7 +164,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
{isOpen && (
|
{isOpen && (
|
||||||
<List
|
<List
|
||||||
as={Flex}
|
as={Flex}
|
||||||
ref={refs.setFloating}
|
ref={refs.floating}
|
||||||
sx={{
|
sx={{
|
||||||
...floatingStyles,
|
...floatingStyles,
|
||||||
top: 0,
|
top: 0,
|
||||||
@ -155,8 +184,8 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<OverlayScrollbarsComponent>
|
<OverlayScrollbarsComponent>
|
||||||
{items.map((item, index) => {
|
{values.map((v, index) => {
|
||||||
const isSelected = selectedItem === item;
|
const isSelected = value === v.value;
|
||||||
const isHighlighted = highlightedIndex === index;
|
const isHighlighted = highlightedIndex === index;
|
||||||
const fontWeight = isSelected ? 700 : 500;
|
const fontWeight = isSelected ? 700 : 500;
|
||||||
const bg = isHighlighted
|
const bg = isHighlighted
|
||||||
@ -166,9 +195,9 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
: undefined;
|
: undefined;
|
||||||
return (
|
return (
|
||||||
<Tooltip
|
<Tooltip
|
||||||
isDisabled={!itemTooltips}
|
isDisabled={!v.tooltip}
|
||||||
key={`${item}${index}`}
|
key={`${v.value}${index}`}
|
||||||
label={itemTooltips?.[item]}
|
label={v.tooltip}
|
||||||
hasArrow
|
hasArrow
|
||||||
placement="right"
|
placement="right"
|
||||||
>
|
>
|
||||||
@ -182,8 +211,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
transitionProperty: 'common',
|
transitionProperty: 'common',
|
||||||
transitionDuration: '0.15s',
|
transitionDuration: '0.15s',
|
||||||
}}
|
}}
|
||||||
key={`${item}${index}`}
|
{...getItemProps({ item: v.value, index })}
|
||||||
{...getItemProps({ item, index })}
|
|
||||||
>
|
>
|
||||||
{withCheckIcon ? (
|
{withCheckIcon ? (
|
||||||
<Grid gridTemplateColumns="1.25rem auto">
|
<Grid gridTemplateColumns="1.25rem auto">
|
||||||
@ -198,7 +226,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
fontWeight,
|
fontWeight,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{item}
|
{v.label}
|
||||||
</Text>
|
</Text>
|
||||||
</GridItem>
|
</GridItem>
|
||||||
</Grid>
|
</Grid>
|
||||||
@ -210,7 +238,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
fontWeight,
|
fontWeight,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{item}
|
{v.label}
|
||||||
</Text>
|
</Text>
|
||||||
)}
|
)}
|
||||||
</ListItem>
|
</ListItem>
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
|
ChakraProps,
|
||||||
FormControl,
|
FormControl,
|
||||||
FormControlProps,
|
FormControlProps,
|
||||||
FormLabel,
|
FormLabel,
|
||||||
@ -39,6 +40,11 @@ import { BiReset } from 'react-icons/bi';
|
|||||||
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
|
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
|
||||||
import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
|
import { roundDownToMultiple } from 'common/util/roundDownToMultiple';
|
||||||
|
|
||||||
|
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
|
||||||
|
mt: 1.5,
|
||||||
|
fontSize: '2xs',
|
||||||
|
};
|
||||||
|
|
||||||
export type IAIFullSliderProps = {
|
export type IAIFullSliderProps = {
|
||||||
label?: string;
|
label?: string;
|
||||||
value: number;
|
value: number;
|
||||||
@ -57,6 +63,7 @@ export type IAIFullSliderProps = {
|
|||||||
hideTooltip?: boolean;
|
hideTooltip?: boolean;
|
||||||
isCompact?: boolean;
|
isCompact?: boolean;
|
||||||
isDisabled?: boolean;
|
isDisabled?: boolean;
|
||||||
|
sliderMarks?: number[];
|
||||||
sliderFormControlProps?: FormControlProps;
|
sliderFormControlProps?: FormControlProps;
|
||||||
sliderFormLabelProps?: FormLabelProps;
|
sliderFormLabelProps?: FormLabelProps;
|
||||||
sliderMarkProps?: Omit<SliderMarkProps, 'value'>;
|
sliderMarkProps?: Omit<SliderMarkProps, 'value'>;
|
||||||
@ -88,6 +95,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
hideTooltip = false,
|
hideTooltip = false,
|
||||||
isCompact = false,
|
isCompact = false,
|
||||||
isDisabled = false,
|
isDisabled = false,
|
||||||
|
sliderMarks,
|
||||||
handleReset,
|
handleReset,
|
||||||
sliderFormControlProps,
|
sliderFormControlProps,
|
||||||
sliderFormLabelProps,
|
sliderFormLabelProps,
|
||||||
@ -198,14 +206,14 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
isDisabled={isDisabled}
|
isDisabled={isDisabled}
|
||||||
{...rest}
|
{...rest}
|
||||||
>
|
>
|
||||||
{withSliderMarks && (
|
{withSliderMarks && !sliderMarks && (
|
||||||
<>
|
<>
|
||||||
<SliderMark
|
<SliderMark
|
||||||
value={min}
|
value={min}
|
||||||
sx={{
|
sx={{
|
||||||
insetInlineStart: '0 !important',
|
insetInlineStart: '0 !important',
|
||||||
insetInlineEnd: 'unset !important',
|
insetInlineEnd: 'unset !important',
|
||||||
mt: 1.5,
|
...SLIDER_MARK_STYLES,
|
||||||
}}
|
}}
|
||||||
{...sliderMarkProps}
|
{...sliderMarkProps}
|
||||||
>
|
>
|
||||||
@ -216,7 +224,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
sx={{
|
sx={{
|
||||||
insetInlineStart: 'unset !important',
|
insetInlineStart: 'unset !important',
|
||||||
insetInlineEnd: '0 !important',
|
insetInlineEnd: '0 !important',
|
||||||
mt: 1.5,
|
...SLIDER_MARK_STYLES,
|
||||||
}}
|
}}
|
||||||
{...sliderMarkProps}
|
{...sliderMarkProps}
|
||||||
>
|
>
|
||||||
@ -224,6 +232,56 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
</SliderMark>
|
</SliderMark>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
{withSliderMarks && sliderMarks && (
|
||||||
|
<>
|
||||||
|
{sliderMarks.map((m, i) => {
|
||||||
|
if (i === 0) {
|
||||||
|
return (
|
||||||
|
<SliderMark
|
||||||
|
key={m}
|
||||||
|
value={m}
|
||||||
|
sx={{
|
||||||
|
insetInlineStart: '0 !important',
|
||||||
|
insetInlineEnd: 'unset !important',
|
||||||
|
...SLIDER_MARK_STYLES,
|
||||||
|
}}
|
||||||
|
{...sliderMarkProps}
|
||||||
|
>
|
||||||
|
{m}
|
||||||
|
</SliderMark>
|
||||||
|
);
|
||||||
|
} else if (i === sliderMarks.length - 1) {
|
||||||
|
return (
|
||||||
|
<SliderMark
|
||||||
|
key={m}
|
||||||
|
value={m}
|
||||||
|
sx={{
|
||||||
|
insetInlineStart: 'unset !important',
|
||||||
|
insetInlineEnd: '0 !important',
|
||||||
|
...SLIDER_MARK_STYLES,
|
||||||
|
}}
|
||||||
|
{...sliderMarkProps}
|
||||||
|
>
|
||||||
|
{m}
|
||||||
|
</SliderMark>
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
return (
|
||||||
|
<SliderMark
|
||||||
|
key={m}
|
||||||
|
value={m}
|
||||||
|
sx={{
|
||||||
|
...SLIDER_MARK_STYLES,
|
||||||
|
}}
|
||||||
|
{...sliderMarkProps}
|
||||||
|
>
|
||||||
|
{m}
|
||||||
|
</SliderMark>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
})}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
<SliderTrack {...sliderTrackProps}>
|
<SliderTrack {...sliderTrackProps}>
|
||||||
<SliderFilledTrack />
|
<SliderFilledTrack />
|
||||||
|
@ -16,7 +16,6 @@ import {
|
|||||||
setShouldShowIntermediates,
|
setShouldShowIntermediates,
|
||||||
setShouldSnapToGrid,
|
setShouldSnapToGrid,
|
||||||
} from 'features/canvas/store/canvasSlice';
|
} from 'features/canvas/store/canvasSlice';
|
||||||
import EmptyTempFolderButtonModal from 'features/system/components/ClearTempFolderButtonModal';
|
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
|
||||||
import { ChangeEvent } from 'react';
|
import { ChangeEvent } from 'react';
|
||||||
@ -159,7 +158,6 @@ const IAICanvasSettingsButtonPopover = () => {
|
|||||||
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
|
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
|
||||||
/>
|
/>
|
||||||
<ClearCanvasHistoryButtonModal />
|
<ClearCanvasHistoryButtonModal />
|
||||||
<EmptyTempFolderButtonModal />
|
|
||||||
</Flex>
|
</Flex>
|
||||||
</IAIPopover>
|
</IAIPopover>
|
||||||
);
|
);
|
||||||
|
@ -30,7 +30,10 @@ import {
|
|||||||
} from './canvasTypes';
|
} from './canvasTypes';
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
import { sessionCanceled } from 'services/thunks/session';
|
import { sessionCanceled } from 'services/thunks/session';
|
||||||
import { setShouldUseCanvasBetaLayout } from 'features/ui/store/uiSlice';
|
import {
|
||||||
|
setActiveTab,
|
||||||
|
setShouldUseCanvasBetaLayout,
|
||||||
|
} from 'features/ui/store/uiSlice';
|
||||||
import { imageUrlsReceived } from 'services/thunks/image';
|
import { imageUrlsReceived } from 'services/thunks/image';
|
||||||
|
|
||||||
export const initialLayerState: CanvasLayerState = {
|
export const initialLayerState: CanvasLayerState = {
|
||||||
@ -857,6 +860,11 @@ export const canvasSlice = createSlice({
|
|||||||
builder.addCase(setShouldUseCanvasBetaLayout, (state, action) => {
|
builder.addCase(setShouldUseCanvasBetaLayout, (state, action) => {
|
||||||
state.doesCanvasNeedScaling = true;
|
state.doesCanvasNeedScaling = true;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
builder.addCase(setActiveTab, (state, action) => {
|
||||||
|
state.doesCanvasNeedScaling = true;
|
||||||
|
});
|
||||||
|
|
||||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||||
const { image_name, image_origin, image_url, thumbnail_url } =
|
const { image_name, image_origin, image_url, thumbnail_url } =
|
||||||
action.payload;
|
action.payload;
|
||||||
|
@ -143,7 +143,7 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
flexDir: 'column',
|
flexDir: 'column',
|
||||||
gap: 2,
|
gap: 2,
|
||||||
w: 'full',
|
w: 'full',
|
||||||
h: 24,
|
h: isExpanded ? 28 : 24,
|
||||||
paddingInlineStart: 1,
|
paddingInlineStart: 1,
|
||||||
paddingInlineEnd: isExpanded ? 1 : 0,
|
paddingInlineEnd: isExpanded ? 1 : 0,
|
||||||
pb: 2,
|
pb: 2,
|
||||||
@ -153,13 +153,13 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
<ParamControlNetWeight
|
<ParamControlNetWeight
|
||||||
controlNetId={controlNetId}
|
controlNetId={controlNetId}
|
||||||
weight={weight}
|
weight={weight}
|
||||||
mini
|
mini={!isExpanded}
|
||||||
/>
|
/>
|
||||||
<ParamControlNetBeginEnd
|
<ParamControlNetBeginEnd
|
||||||
controlNetId={controlNetId}
|
controlNetId={controlNetId}
|
||||||
beginStepPct={beginStepPct}
|
beginStepPct={beginStepPct}
|
||||||
endStepPct={endStepPct}
|
endStepPct={endStepPct}
|
||||||
mini
|
mini={!isExpanded}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
{!isExpanded && (
|
{!isExpanded && (
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
@ -11,7 +12,7 @@ type Props = {
|
|||||||
const ParamControlNetShouldAutoConfig = (props: Props) => {
|
const ParamControlNetShouldAutoConfig = (props: Props) => {
|
||||||
const { controlNetId, shouldAutoConfig } = props;
|
const { controlNetId, shouldAutoConfig } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const isReady = useIsReadyToInvoke();
|
||||||
const handleShouldAutoConfigChanged = useCallback(() => {
|
const handleShouldAutoConfigChanged = useCallback(() => {
|
||||||
dispatch(controlNetAutoConfigToggled({ controlNetId }));
|
dispatch(controlNetAutoConfigToggled({ controlNetId }));
|
||||||
}, [controlNetId, dispatch]);
|
}, [controlNetId, dispatch]);
|
||||||
@ -22,6 +23,7 @@ const ParamControlNetShouldAutoConfig = (props: Props) => {
|
|||||||
aria-label="Auto configure processor"
|
aria-label="Auto configure processor"
|
||||||
isChecked={shouldAutoConfig}
|
isChecked={shouldAutoConfig}
|
||||||
onChange={handleShouldAutoConfigChanged}
|
onChange={handleShouldAutoConfigChanged}
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
|
ChakraProps,
|
||||||
FormControl,
|
FormControl,
|
||||||
FormLabel,
|
FormLabel,
|
||||||
HStack,
|
HStack,
|
||||||
@ -10,14 +11,19 @@ import {
|
|||||||
Tooltip,
|
Tooltip,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
|
||||||
import {
|
import {
|
||||||
controlNetBeginStepPctChanged,
|
controlNetBeginStepPctChanged,
|
||||||
controlNetEndStepPctChanged,
|
controlNetEndStepPctChanged,
|
||||||
} from 'features/controlNet/store/controlNetSlice';
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { BiReset } from 'react-icons/bi';
|
|
||||||
|
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
|
||||||
|
mt: 1.5,
|
||||||
|
fontSize: '2xs',
|
||||||
|
fontWeight: '500',
|
||||||
|
color: 'base.400',
|
||||||
|
};
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
@ -29,7 +35,7 @@ type Props = {
|
|||||||
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
|
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
|
||||||
|
|
||||||
const ParamControlNetBeginEnd = (props: Props) => {
|
const ParamControlNetBeginEnd = (props: Props) => {
|
||||||
const { controlNetId, beginStepPct, endStepPct, mini = false } = props;
|
const { controlNetId, beginStepPct, mini = false, endStepPct } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
@ -75,12 +81,9 @@ const ParamControlNetBeginEnd = (props: Props) => {
|
|||||||
<RangeSliderMark
|
<RangeSliderMark
|
||||||
value={0}
|
value={0}
|
||||||
sx={{
|
sx={{
|
||||||
fontSize: 'xs',
|
|
||||||
fontWeight: '500',
|
|
||||||
color: 'base.200',
|
|
||||||
insetInlineStart: '0 !important',
|
insetInlineStart: '0 !important',
|
||||||
insetInlineEnd: 'unset !important',
|
insetInlineEnd: 'unset !important',
|
||||||
mt: 1.5,
|
...SLIDER_MARK_STYLES,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
0%
|
0%
|
||||||
@ -88,10 +91,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
|
|||||||
<RangeSliderMark
|
<RangeSliderMark
|
||||||
value={0.5}
|
value={0.5}
|
||||||
sx={{
|
sx={{
|
||||||
fontSize: 'xs',
|
...SLIDER_MARK_STYLES,
|
||||||
fontWeight: '500',
|
|
||||||
color: 'base.200',
|
|
||||||
mt: 1.5,
|
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
50%
|
50%
|
||||||
@ -99,12 +99,9 @@ const ParamControlNetBeginEnd = (props: Props) => {
|
|||||||
<RangeSliderMark
|
<RangeSliderMark
|
||||||
value={1}
|
value={1}
|
||||||
sx={{
|
sx={{
|
||||||
fontSize: 'xs',
|
|
||||||
fontWeight: '500',
|
|
||||||
color: 'base.200',
|
|
||||||
insetInlineStart: 'unset !important',
|
insetInlineStart: 'unset !important',
|
||||||
insetInlineEnd: '0 !important',
|
insetInlineEnd: '0 !important',
|
||||||
mt: 1.5,
|
...SLIDER_MARK_STYLES,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
100%
|
100%
|
||||||
@ -112,16 +109,6 @@ const ParamControlNetBeginEnd = (props: Props) => {
|
|||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
</RangeSlider>
|
</RangeSlider>
|
||||||
|
|
||||||
{!mini && (
|
|
||||||
<IAIIconButton
|
|
||||||
size="sm"
|
|
||||||
aria-label={t('accessibility.reset')}
|
|
||||||
tooltip={t('accessibility.reset')}
|
|
||||||
icon={<BiReset />}
|
|
||||||
onClick={handleStepPctReset}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</HStack>
|
</HStack>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
);
|
);
|
||||||
|
@ -1,41 +1,85 @@
|
|||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import IAICustomSelect from 'common/components/IAICustomSelect';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAICustomSelect, {
|
||||||
|
IAICustomSelectOption,
|
||||||
|
} from 'common/components/IAICustomSelect';
|
||||||
|
import IAISelect from 'common/components/IAISelect';
|
||||||
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
import {
|
import {
|
||||||
CONTROLNET_MODELS,
|
CONTROLNET_MODELS,
|
||||||
ControlNetModel,
|
ControlNetModelName,
|
||||||
} from 'features/controlNet/store/constants';
|
} from 'features/controlNet/store/constants';
|
||||||
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { configSelector } from 'features/system/store/configSelectors';
|
||||||
|
import { map } from 'lodash-es';
|
||||||
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
|
|
||||||
type ParamControlNetModelProps = {
|
type ParamControlNetModelProps = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
model: ControlNetModel;
|
model: ControlNetModelName;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const selector = createSelector(configSelector, (config) => {
|
||||||
|
return map(CONTROLNET_MODELS, (m) => ({
|
||||||
|
key: m.label,
|
||||||
|
value: m.type,
|
||||||
|
})).filter((d) => !config.sd.disabledControlNetModels.includes(d.value));
|
||||||
|
});
|
||||||
|
|
||||||
|
// const DATA: IAICustomSelectOption[] = map(CONTROLNET_MODELS, (m) => ({
|
||||||
|
// value: m.type,
|
||||||
|
// label: m.label,
|
||||||
|
// tooltip: m.type,
|
||||||
|
// }));
|
||||||
|
|
||||||
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||||
const { controlNetId, model } = props;
|
const { controlNetId, model } = props;
|
||||||
|
const controlNetModels = useAppSelector(selector);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const isReady = useIsReadyToInvoke();
|
||||||
|
|
||||||
const handleModelChanged = useCallback(
|
const handleModelChanged = useCallback(
|
||||||
(val: string | null | undefined) => {
|
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||||
// TODO: do not cast
|
// TODO: do not cast
|
||||||
const model = val as ControlNetModel;
|
const model = e.target.value as ControlNetModelName;
|
||||||
dispatch(controlNetModelChanged({ controlNetId, model }));
|
dispatch(controlNetModelChanged({ controlNetId, model }));
|
||||||
},
|
},
|
||||||
[controlNetId, dispatch]
|
[controlNetId, dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// const handleModelChanged = useCallback(
|
||||||
|
// (val: string | null | undefined) => {
|
||||||
|
// // TODO: do not cast
|
||||||
|
// const model = val as ControlNetModelName;
|
||||||
|
// dispatch(controlNetModelChanged({ controlNetId, model }));
|
||||||
|
// },
|
||||||
|
// [controlNetId, dispatch]
|
||||||
|
// );
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAICustomSelect
|
<IAISelect
|
||||||
tooltip={model}
|
tooltip={model}
|
||||||
tooltipProps={{ placement: 'top', hasArrow: true }}
|
tooltipProps={{ placement: 'top', hasArrow: true }}
|
||||||
items={CONTROLNET_MODELS}
|
validValues={controlNetModels}
|
||||||
selectedItem={model}
|
value={model}
|
||||||
setSelectedItem={handleModelChanged}
|
onChange={handleModelChanged}
|
||||||
ellipsisPosition="start"
|
isDisabled={!isReady}
|
||||||
withCheckIcon
|
// ellipsisPosition="start"
|
||||||
|
// withCheckIcon
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
// return (
|
||||||
|
// <IAICustomSelect
|
||||||
|
// tooltip={model}
|
||||||
|
// tooltipProps={{ placement: 'top', hasArrow: true }}
|
||||||
|
// data={DATA}
|
||||||
|
// value={model}
|
||||||
|
// onChange={handleModelChanged}
|
||||||
|
// isDisabled={!isReady}
|
||||||
|
// ellipsisPosition="start"
|
||||||
|
// withCheckIcon
|
||||||
|
// />
|
||||||
|
// );
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(ParamControlNetModel);
|
export default memo(ParamControlNetModel);
|
||||||
|
@ -1,47 +1,115 @@
|
|||||||
import IAICustomSelect from 'common/components/IAICustomSelect';
|
import IAICustomSelect, {
|
||||||
import { memo, useCallback } from 'react';
|
IAICustomSelectOption,
|
||||||
|
} from 'common/components/IAICustomSelect';
|
||||||
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
import {
|
import {
|
||||||
ControlNetProcessorNode,
|
ControlNetProcessorNode,
|
||||||
ControlNetProcessorType,
|
ControlNetProcessorType,
|
||||||
} from '../../store/types';
|
} from '../../store/types';
|
||||||
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
|
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { CONTROLNET_PROCESSORS } from '../../store/constants';
|
import { CONTROLNET_PROCESSORS } from '../../store/constants';
|
||||||
|
import { map } from 'lodash-es';
|
||||||
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
|
import IAISelect from 'common/components/IAISelect';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { configSelector } from 'features/system/store/configSelectors';
|
||||||
|
|
||||||
type ParamControlNetProcessorSelectProps = {
|
type ParamControlNetProcessorSelectProps = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
processorNode: ControlNetProcessorNode;
|
processorNode: ControlNetProcessorNode;
|
||||||
};
|
};
|
||||||
|
|
||||||
const CONTROLNET_PROCESSOR_TYPES = Object.keys(
|
const CONTROLNET_PROCESSOR_TYPES = map(CONTROLNET_PROCESSORS, (p) => ({
|
||||||
CONTROLNET_PROCESSORS
|
value: p.type,
|
||||||
) as ControlNetProcessorType[];
|
key: p.label,
|
||||||
|
})).sort((a, b) =>
|
||||||
|
// sort 'none' to the top
|
||||||
|
a.value === 'none' ? -1 : b.value === 'none' ? 1 : a.key.localeCompare(b.key)
|
||||||
|
);
|
||||||
|
|
||||||
|
const selector = createSelector(configSelector, (config) => {
|
||||||
|
return map(CONTROLNET_PROCESSORS, (p) => ({
|
||||||
|
value: p.type,
|
||||||
|
key: p.label,
|
||||||
|
}))
|
||||||
|
.sort((a, b) =>
|
||||||
|
// sort 'none' to the top
|
||||||
|
a.value === 'none'
|
||||||
|
? -1
|
||||||
|
: b.value === 'none'
|
||||||
|
? 1
|
||||||
|
: a.key.localeCompare(b.key)
|
||||||
|
)
|
||||||
|
.filter((d) => !config.sd.disabledControlNetProcessors.includes(d.value));
|
||||||
|
});
|
||||||
|
|
||||||
|
// const CONTROLNET_PROCESSOR_TYPES: IAICustomSelectOption[] = map(
|
||||||
|
// CONTROLNET_PROCESSORS,
|
||||||
|
// (p) => ({
|
||||||
|
// value: p.type,
|
||||||
|
// label: p.label,
|
||||||
|
// tooltip: p.description,
|
||||||
|
// })
|
||||||
|
// ).sort((a, b) =>
|
||||||
|
// // sort 'none' to the top
|
||||||
|
// a.value === 'none'
|
||||||
|
// ? -1
|
||||||
|
// : b.value === 'none'
|
||||||
|
// ? 1
|
||||||
|
// : a.label.localeCompare(b.label)
|
||||||
|
// );
|
||||||
|
|
||||||
const ParamControlNetProcessorSelect = (
|
const ParamControlNetProcessorSelect = (
|
||||||
props: ParamControlNetProcessorSelectProps
|
props: ParamControlNetProcessorSelectProps
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const isReady = useIsReadyToInvoke();
|
||||||
|
const controlNetProcessors = useAppSelector(selector);
|
||||||
|
|
||||||
const handleProcessorTypeChanged = useCallback(
|
const handleProcessorTypeChanged = useCallback(
|
||||||
(v: string | null | undefined) => {
|
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
controlNetProcessorTypeChanged({
|
controlNetProcessorTypeChanged({
|
||||||
controlNetId,
|
controlNetId,
|
||||||
processorType: v as ControlNetProcessorType,
|
processorType: e.target.value as ControlNetProcessorType,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
[controlNetId, dispatch]
|
[controlNetId, dispatch]
|
||||||
);
|
);
|
||||||
|
// const handleProcessorTypeChanged = useCallback(
|
||||||
|
// (v: string | null | undefined) => {
|
||||||
|
// dispatch(
|
||||||
|
// controlNetProcessorTypeChanged({
|
||||||
|
// controlNetId,
|
||||||
|
// processorType: v as ControlNetProcessorType,
|
||||||
|
// })
|
||||||
|
// );
|
||||||
|
// },
|
||||||
|
// [controlNetId, dispatch]
|
||||||
|
// );
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAICustomSelect
|
<IAISelect
|
||||||
label="Processor"
|
label="Processor"
|
||||||
items={CONTROLNET_PROCESSOR_TYPES}
|
value={processorNode.type ?? 'canny_image_processor'}
|
||||||
selectedItem={processorNode.type ?? 'canny_image_processor'}
|
validValues={controlNetProcessors}
|
||||||
setSelectedItem={handleProcessorTypeChanged}
|
onChange={handleProcessorTypeChanged}
|
||||||
withCheckIcon
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
// return (
|
||||||
|
// <IAICustomSelect
|
||||||
|
// label="Processor"
|
||||||
|
// value={processorNode.type ?? 'canny_image_processor'}
|
||||||
|
// data={CONTROLNET_PROCESSOR_TYPES}
|
||||||
|
// onChange={handleProcessorTypeChanged}
|
||||||
|
// withCheckIcon
|
||||||
|
// isDisabled={!isReady}
|
||||||
|
// />
|
||||||
|
// );
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(ParamControlNetProcessorSelect);
|
export default memo(ParamControlNetProcessorSelect);
|
||||||
|
@ -20,36 +20,17 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
|
|||||||
[controlNetId, dispatch]
|
[controlNetId, dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleWeightReset = () => {
|
|
||||||
dispatch(controlNetWeightChanged({ controlNetId, weight: 1 }));
|
|
||||||
};
|
|
||||||
|
|
||||||
if (mini) {
|
|
||||||
return (
|
|
||||||
<IAISlider
|
|
||||||
label={'Weight'}
|
|
||||||
sliderFormLabelProps={{ pb: 1 }}
|
|
||||||
value={weight}
|
|
||||||
onChange={handleWeightChanged}
|
|
||||||
min={0}
|
|
||||||
max={1}
|
|
||||||
step={0.01}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Weight"
|
label={'Weight'}
|
||||||
|
sliderFormLabelProps={{ pb: 2 }}
|
||||||
value={weight}
|
value={weight}
|
||||||
onChange={handleWeightChanged}
|
onChange={handleWeightChanged}
|
||||||
withInput
|
min={-1}
|
||||||
withReset
|
|
||||||
handleReset={handleWeightReset}
|
|
||||||
withSliderMarks
|
|
||||||
min={0}
|
|
||||||
max={1}
|
max={1}
|
||||||
step={0.01}
|
step={0.01}
|
||||||
|
withSliderMarks={!mini}
|
||||||
|
sliderMarks={[-1, 0, 1]}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -4,6 +4,7 @@ import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store
|
|||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default;
|
||||||
|
|
||||||
@ -15,6 +16,7 @@ type CannyProcessorProps = {
|
|||||||
const CannyProcessor = (props: CannyProcessorProps) => {
|
const CannyProcessor = (props: CannyProcessorProps) => {
|
||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode } = props;
|
||||||
const { low_threshold, high_threshold } = processorNode;
|
const { low_threshold, high_threshold } = processorNode;
|
||||||
|
const isReady = useIsReadyToInvoke();
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
|
||||||
const handleLowThresholdChanged = useCallback(
|
const handleLowThresholdChanged = useCallback(
|
||||||
@ -46,6 +48,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
|
|||||||
return (
|
return (
|
||||||
<ProcessorWrapper>
|
<ProcessorWrapper>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
|
isDisabled={!isReady}
|
||||||
label="Low Threshold"
|
label="Low Threshold"
|
||||||
value={low_threshold}
|
value={low_threshold}
|
||||||
onChange={handleLowThresholdChanged}
|
onChange={handleLowThresholdChanged}
|
||||||
@ -54,8 +57,10 @@ const CannyProcessor = (props: CannyProcessorProps) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={255}
|
max={255}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
|
isDisabled={!isReady}
|
||||||
label="High Threshold"
|
label="High Threshold"
|
||||||
value={high_threshold}
|
value={high_threshold}
|
||||||
onChange={handleHighThresholdChanged}
|
onChange={handleHighThresholdChanged}
|
||||||
@ -64,6 +69,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={255}
|
max={255}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -4,6 +4,7 @@ import { RequiredContentShuffleImageProcessorInvocation } from 'features/control
|
|||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default;
|
||||||
|
|
||||||
@ -16,6 +17,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
|||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode } = props;
|
||||||
const { image_resolution, detect_resolution, w, h, f } = processorNode;
|
const { image_resolution, detect_resolution, w, h, f } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
const isReady = useIsReadyToInvoke();
|
||||||
|
|
||||||
const handleDetectResolutionChanged = useCallback(
|
const handleDetectResolutionChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -93,6 +95,8 @@ const ContentShuffleProcessor = (props: Props) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Image Resolution"
|
label="Image Resolution"
|
||||||
@ -103,6 +107,8 @@ const ContentShuffleProcessor = (props: Props) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="W"
|
label="W"
|
||||||
@ -113,6 +119,8 @@ const ContentShuffleProcessor = (props: Props) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="H"
|
label="H"
|
||||||
@ -123,6 +131,8 @@ const ContentShuffleProcessor = (props: Props) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="F"
|
label="F"
|
||||||
@ -133,6 +143,8 @@ const ContentShuffleProcessor = (props: Props) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -5,6 +5,7 @@ import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/t
|
|||||||
import { ChangeEvent, memo, useCallback } from 'react';
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default;
|
||||||
|
|
||||||
@ -18,7 +19,7 @@ const HedPreprocessor = (props: HedProcessorProps) => {
|
|||||||
controlNetId,
|
controlNetId,
|
||||||
processorNode: { detect_resolution, image_resolution, scribble },
|
processorNode: { detect_resolution, image_resolution, scribble },
|
||||||
} = props;
|
} = props;
|
||||||
|
const isReady = useIsReadyToInvoke();
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
|
||||||
const handleDetectResolutionChanged = useCallback(
|
const handleDetectResolutionChanged = useCallback(
|
||||||
@ -65,6 +66,8 @@ const HedPreprocessor = (props: HedProcessorProps) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Image Resolution"
|
label="Image Resolution"
|
||||||
@ -75,11 +78,14 @@ const HedPreprocessor = (props: HedProcessorProps) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISwitch
|
<IAISwitch
|
||||||
label="Scribble"
|
label="Scribble"
|
||||||
isChecked={scribble}
|
isChecked={scribble}
|
||||||
onChange={handleScribbleChanged}
|
onChange={handleScribbleChanged}
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -4,6 +4,7 @@ import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNe
|
|||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default;
|
||||||
|
|
||||||
@ -16,6 +17,7 @@ const LineartAnimeProcessor = (props: Props) => {
|
|||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode } = props;
|
||||||
const { image_resolution, detect_resolution } = processorNode;
|
const { image_resolution, detect_resolution } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
const isReady = useIsReadyToInvoke();
|
||||||
|
|
||||||
const handleDetectResolutionChanged = useCallback(
|
const handleDetectResolutionChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -54,6 +56,8 @@ const LineartAnimeProcessor = (props: Props) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Image Resolution"
|
label="Image Resolution"
|
||||||
@ -64,6 +68,8 @@ const LineartAnimeProcessor = (props: Props) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -5,6 +5,7 @@ import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/sto
|
|||||||
import { ChangeEvent, memo, useCallback } from 'react';
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default;
|
||||||
|
|
||||||
@ -17,6 +18,7 @@ const LineartProcessor = (props: LineartProcessorProps) => {
|
|||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode } = props;
|
||||||
const { image_resolution, detect_resolution, coarse } = processorNode;
|
const { image_resolution, detect_resolution, coarse } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
const isReady = useIsReadyToInvoke();
|
||||||
|
|
||||||
const handleDetectResolutionChanged = useCallback(
|
const handleDetectResolutionChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -62,6 +64,8 @@ const LineartProcessor = (props: LineartProcessorProps) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Image Resolution"
|
label="Image Resolution"
|
||||||
@ -72,11 +76,14 @@ const LineartProcessor = (props: LineartProcessorProps) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISwitch
|
<IAISwitch
|
||||||
label="Coarse"
|
label="Coarse"
|
||||||
isChecked={coarse}
|
isChecked={coarse}
|
||||||
onChange={handleCoarseChanged}
|
onChange={handleCoarseChanged}
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -4,6 +4,7 @@ import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/st
|
|||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default;
|
||||||
|
|
||||||
@ -16,6 +17,7 @@ const MediapipeFaceProcessor = (props: Props) => {
|
|||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode } = props;
|
||||||
const { max_faces, min_confidence } = processorNode;
|
const { max_faces, min_confidence } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
const isReady = useIsReadyToInvoke();
|
||||||
|
|
||||||
const handleMaxFacesChanged = useCallback(
|
const handleMaxFacesChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -50,6 +52,8 @@ const MediapipeFaceProcessor = (props: Props) => {
|
|||||||
min={1}
|
min={1}
|
||||||
max={20}
|
max={20}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Min Confidence"
|
label="Min Confidence"
|
||||||
@ -61,6 +65,8 @@ const MediapipeFaceProcessor = (props: Props) => {
|
|||||||
max={1}
|
max={1}
|
||||||
step={0.01}
|
step={0.01}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -4,6 +4,7 @@ import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/
|
|||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default;
|
||||||
|
|
||||||
@ -16,6 +17,7 @@ const MidasDepthProcessor = (props: Props) => {
|
|||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode } = props;
|
||||||
const { a_mult, bg_th } = processorNode;
|
const { a_mult, bg_th } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
const isReady = useIsReadyToInvoke();
|
||||||
|
|
||||||
const handleAMultChanged = useCallback(
|
const handleAMultChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -51,6 +53,8 @@ const MidasDepthProcessor = (props: Props) => {
|
|||||||
max={20}
|
max={20}
|
||||||
step={0.01}
|
step={0.01}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="bg_th"
|
label="bg_th"
|
||||||
@ -62,6 +66,8 @@ const MidasDepthProcessor = (props: Props) => {
|
|||||||
max={20}
|
max={20}
|
||||||
step={0.01}
|
step={0.01}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -4,6 +4,7 @@ import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/
|
|||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default;
|
||||||
|
|
||||||
@ -16,6 +17,7 @@ const MlsdImageProcessor = (props: Props) => {
|
|||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode } = props;
|
||||||
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
|
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
const isReady = useIsReadyToInvoke();
|
||||||
|
|
||||||
const handleDetectResolutionChanged = useCallback(
|
const handleDetectResolutionChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -76,6 +78,8 @@ const MlsdImageProcessor = (props: Props) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Image Resolution"
|
label="Image Resolution"
|
||||||
@ -86,6 +90,8 @@ const MlsdImageProcessor = (props: Props) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="W"
|
label="W"
|
||||||
@ -97,6 +103,8 @@ const MlsdImageProcessor = (props: Props) => {
|
|||||||
max={1}
|
max={1}
|
||||||
step={0.01}
|
step={0.01}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="H"
|
label="H"
|
||||||
@ -108,6 +116,8 @@ const MlsdImageProcessor = (props: Props) => {
|
|||||||
max={1}
|
max={1}
|
||||||
step={0.01}
|
step={0.01}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -4,6 +4,7 @@ import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/s
|
|||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default;
|
||||||
|
|
||||||
@ -16,6 +17,7 @@ const NormalBaeProcessor = (props: Props) => {
|
|||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode } = props;
|
||||||
const { image_resolution, detect_resolution } = processorNode;
|
const { image_resolution, detect_resolution } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
const isReady = useIsReadyToInvoke();
|
||||||
|
|
||||||
const handleDetectResolutionChanged = useCallback(
|
const handleDetectResolutionChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -54,6 +56,8 @@ const NormalBaeProcessor = (props: Props) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Image Resolution"
|
label="Image Resolution"
|
||||||
@ -64,6 +68,8 @@ const NormalBaeProcessor = (props: Props) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -5,6 +5,7 @@ import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/st
|
|||||||
import { ChangeEvent, memo, useCallback } from 'react';
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default;
|
||||||
|
|
||||||
@ -17,6 +18,7 @@ const OpenposeProcessor = (props: Props) => {
|
|||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode } = props;
|
||||||
const { image_resolution, detect_resolution, hand_and_face } = processorNode;
|
const { image_resolution, detect_resolution, hand_and_face } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
const isReady = useIsReadyToInvoke();
|
||||||
|
|
||||||
const handleDetectResolutionChanged = useCallback(
|
const handleDetectResolutionChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -62,6 +64,8 @@ const OpenposeProcessor = (props: Props) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Image Resolution"
|
label="Image Resolution"
|
||||||
@ -72,11 +76,14 @@ const OpenposeProcessor = (props: Props) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISwitch
|
<IAISwitch
|
||||||
label="Hand and Face"
|
label="Hand and Face"
|
||||||
isChecked={hand_and_face}
|
isChecked={hand_and_face}
|
||||||
onChange={handleHandAndFaceChanged}
|
onChange={handleHandAndFaceChanged}
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -5,6 +5,7 @@ import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/
|
|||||||
import { ChangeEvent, memo, useCallback } from 'react';
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default;
|
||||||
|
|
||||||
@ -17,6 +18,7 @@ const PidiProcessor = (props: Props) => {
|
|||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode } = props;
|
||||||
const { image_resolution, detect_resolution, scribble, safe } = processorNode;
|
const { image_resolution, detect_resolution, scribble, safe } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
const isReady = useIsReadyToInvoke();
|
||||||
|
|
||||||
const handleDetectResolutionChanged = useCallback(
|
const handleDetectResolutionChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -69,6 +71,8 @@ const PidiProcessor = (props: Props) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Image Resolution"
|
label="Image Resolution"
|
||||||
@ -79,13 +83,20 @@ const PidiProcessor = (props: Props) => {
|
|||||||
min={0}
|
min={0}
|
||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
isDisabled={!isReady}
|
||||||
/>
|
/>
|
||||||
<IAISwitch
|
<IAISwitch
|
||||||
label="Scribble"
|
label="Scribble"
|
||||||
isChecked={scribble}
|
isChecked={scribble}
|
||||||
onChange={handleScribbleChanged}
|
onChange={handleScribbleChanged}
|
||||||
/>
|
/>
|
||||||
<IAISwitch label="Safe" isChecked={safe} onChange={handleSafeChanged} />
|
<IAISwitch
|
||||||
|
label="Safe"
|
||||||
|
isChecked={safe}
|
||||||
|
onChange={handleSafeChanged}
|
||||||
|
isDisabled={!isReady}
|
||||||
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -5,12 +5,12 @@ import {
|
|||||||
} from './types';
|
} from './types';
|
||||||
|
|
||||||
type ControlNetProcessorsDict = Record<
|
type ControlNetProcessorsDict = Record<
|
||||||
ControlNetProcessorType,
|
string,
|
||||||
{
|
{
|
||||||
type: ControlNetProcessorType;
|
type: ControlNetProcessorType | 'none';
|
||||||
label: string;
|
label: string;
|
||||||
description: string;
|
description: string;
|
||||||
default: RequiredControlNetProcessorNode;
|
default: RequiredControlNetProcessorNode | { type: 'none' };
|
||||||
}
|
}
|
||||||
>;
|
>;
|
||||||
|
|
||||||
@ -26,7 +26,7 @@ type ControlNetProcessorsDict = Record<
|
|||||||
export const CONTROLNET_PROCESSORS = {
|
export const CONTROLNET_PROCESSORS = {
|
||||||
none: {
|
none: {
|
||||||
type: 'none',
|
type: 'none',
|
||||||
label: 'None',
|
label: 'none',
|
||||||
description: '',
|
description: '',
|
||||||
default: {
|
default: {
|
||||||
type: 'none',
|
type: 'none',
|
||||||
@ -116,7 +116,7 @@ export const CONTROLNET_PROCESSORS = {
|
|||||||
},
|
},
|
||||||
mlsd_image_processor: {
|
mlsd_image_processor: {
|
||||||
type: 'mlsd_image_processor',
|
type: 'mlsd_image_processor',
|
||||||
label: 'MLSD',
|
label: 'M-LSD',
|
||||||
description: '',
|
description: '',
|
||||||
default: {
|
default: {
|
||||||
id: 'mlsd_image_processor',
|
id: 'mlsd_image_processor',
|
||||||
@ -129,7 +129,7 @@ export const CONTROLNET_PROCESSORS = {
|
|||||||
},
|
},
|
||||||
normalbae_image_processor: {
|
normalbae_image_processor: {
|
||||||
type: 'normalbae_image_processor',
|
type: 'normalbae_image_processor',
|
||||||
label: 'NormalBae',
|
label: 'Normal BAE',
|
||||||
description: '',
|
description: '',
|
||||||
default: {
|
default: {
|
||||||
id: 'normalbae_image_processor',
|
id: 'normalbae_image_processor',
|
||||||
@ -174,39 +174,84 @@ export const CONTROLNET_PROCESSORS = {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
export const CONTROLNET_MODELS = [
|
type ControlNetModel = {
|
||||||
'lllyasviel/control_v11p_sd15_canny',
|
type: string;
|
||||||
'lllyasviel/control_v11p_sd15_inpaint',
|
label: string;
|
||||||
'lllyasviel/control_v11p_sd15_mlsd',
|
description?: string;
|
||||||
'lllyasviel/control_v11f1p_sd15_depth',
|
defaultProcessor?: ControlNetProcessorType;
|
||||||
'lllyasviel/control_v11p_sd15_normalbae',
|
|
||||||
'lllyasviel/control_v11p_sd15_seg',
|
|
||||||
'lllyasviel/control_v11p_sd15_lineart',
|
|
||||||
'lllyasviel/control_v11p_sd15s2_lineart_anime',
|
|
||||||
'lllyasviel/control_v11p_sd15_scribble',
|
|
||||||
'lllyasviel/control_v11p_sd15_softedge',
|
|
||||||
'lllyasviel/control_v11e_sd15_shuffle',
|
|
||||||
'lllyasviel/control_v11p_sd15_openpose',
|
|
||||||
'lllyasviel/control_v11f1e_sd15_tile',
|
|
||||||
'lllyasviel/control_v11e_sd15_ip2p',
|
|
||||||
'CrucibleAI/ControlNetMediaPipeFace',
|
|
||||||
];
|
|
||||||
|
|
||||||
export type ControlNetModel = (typeof CONTROLNET_MODELS)[number];
|
|
||||||
|
|
||||||
export const CONTROLNET_MODEL_MAP: Record<
|
|
||||||
ControlNetModel,
|
|
||||||
ControlNetProcessorType
|
|
||||||
> = {
|
|
||||||
'lllyasviel/control_v11p_sd15_canny': 'canny_image_processor',
|
|
||||||
'lllyasviel/control_v11p_sd15_mlsd': 'mlsd_image_processor',
|
|
||||||
'lllyasviel/control_v11f1p_sd15_depth': 'midas_depth_image_processor',
|
|
||||||
'lllyasviel/control_v11p_sd15_normalbae': 'normalbae_image_processor',
|
|
||||||
'lllyasviel/control_v11p_sd15_lineart': 'lineart_image_processor',
|
|
||||||
'lllyasviel/control_v11p_sd15s2_lineart_anime':
|
|
||||||
'lineart_anime_image_processor',
|
|
||||||
'lllyasviel/control_v11p_sd15_softedge': 'hed_image_processor',
|
|
||||||
'lllyasviel/control_v11e_sd15_shuffle': 'content_shuffle_image_processor',
|
|
||||||
'lllyasviel/control_v11p_sd15_openpose': 'openpose_image_processor',
|
|
||||||
'CrucibleAI/ControlNetMediaPipeFace': 'mediapipe_face_processor',
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const CONTROLNET_MODELS = {
|
||||||
|
'lllyasviel/control_v11p_sd15_canny': {
|
||||||
|
type: 'lllyasviel/control_v11p_sd15_canny',
|
||||||
|
label: 'Canny',
|
||||||
|
defaultProcessor: 'canny_image_processor',
|
||||||
|
},
|
||||||
|
'lllyasviel/control_v11p_sd15_inpaint': {
|
||||||
|
type: 'lllyasviel/control_v11p_sd15_inpaint',
|
||||||
|
label: 'Inpaint',
|
||||||
|
},
|
||||||
|
'lllyasviel/control_v11p_sd15_mlsd': {
|
||||||
|
type: 'lllyasviel/control_v11p_sd15_mlsd',
|
||||||
|
label: 'M-LSD',
|
||||||
|
defaultProcessor: 'mlsd_image_processor',
|
||||||
|
},
|
||||||
|
'lllyasviel/control_v11f1p_sd15_depth': {
|
||||||
|
type: 'lllyasviel/control_v11f1p_sd15_depth',
|
||||||
|
label: 'Depth',
|
||||||
|
defaultProcessor: 'midas_depth_image_processor',
|
||||||
|
},
|
||||||
|
'lllyasviel/control_v11p_sd15_normalbae': {
|
||||||
|
type: 'lllyasviel/control_v11p_sd15_normalbae',
|
||||||
|
label: 'Normal Map (BAE)',
|
||||||
|
defaultProcessor: 'normalbae_image_processor',
|
||||||
|
},
|
||||||
|
'lllyasviel/control_v11p_sd15_seg': {
|
||||||
|
type: 'lllyasviel/control_v11p_sd15_seg',
|
||||||
|
label: 'Segmentation',
|
||||||
|
},
|
||||||
|
'lllyasviel/control_v11p_sd15_lineart': {
|
||||||
|
type: 'lllyasviel/control_v11p_sd15_lineart',
|
||||||
|
label: 'Lineart',
|
||||||
|
defaultProcessor: 'lineart_image_processor',
|
||||||
|
},
|
||||||
|
'lllyasviel/control_v11p_sd15s2_lineart_anime': {
|
||||||
|
type: 'lllyasviel/control_v11p_sd15s2_lineart_anime',
|
||||||
|
label: 'Lineart Anime',
|
||||||
|
defaultProcessor: 'lineart_anime_image_processor',
|
||||||
|
},
|
||||||
|
'lllyasviel/control_v11p_sd15_scribble': {
|
||||||
|
type: 'lllyasviel/control_v11p_sd15_scribble',
|
||||||
|
label: 'Scribble',
|
||||||
|
},
|
||||||
|
'lllyasviel/control_v11p_sd15_softedge': {
|
||||||
|
type: 'lllyasviel/control_v11p_sd15_softedge',
|
||||||
|
label: 'Soft Edge',
|
||||||
|
defaultProcessor: 'hed_image_processor',
|
||||||
|
},
|
||||||
|
'lllyasviel/control_v11e_sd15_shuffle': {
|
||||||
|
type: 'lllyasviel/control_v11e_sd15_shuffle',
|
||||||
|
label: 'Content Shuffle',
|
||||||
|
defaultProcessor: 'content_shuffle_image_processor',
|
||||||
|
},
|
||||||
|
'lllyasviel/control_v11p_sd15_openpose': {
|
||||||
|
type: 'lllyasviel/control_v11p_sd15_openpose',
|
||||||
|
label: 'Openpose',
|
||||||
|
defaultProcessor: 'openpose_image_processor',
|
||||||
|
},
|
||||||
|
'lllyasviel/control_v11f1e_sd15_tile': {
|
||||||
|
type: 'lllyasviel/control_v11f1e_sd15_tile',
|
||||||
|
label: 'Tile (experimental)',
|
||||||
|
},
|
||||||
|
'lllyasviel/control_v11e_sd15_ip2p': {
|
||||||
|
type: 'lllyasviel/control_v11e_sd15_ip2p',
|
||||||
|
label: 'Pix2Pix (experimental)',
|
||||||
|
},
|
||||||
|
'CrucibleAI/ControlNetMediaPipeFace': {
|
||||||
|
type: 'CrucibleAI/ControlNetMediaPipeFace',
|
||||||
|
label: 'Mediapipe Face',
|
||||||
|
defaultProcessor: 'mediapipe_face_processor',
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ControlNetModelName = keyof typeof CONTROLNET_MODELS;
|
||||||
|
@ -9,9 +9,8 @@ import {
|
|||||||
} from './types';
|
} from './types';
|
||||||
import {
|
import {
|
||||||
CONTROLNET_MODELS,
|
CONTROLNET_MODELS,
|
||||||
CONTROLNET_MODEL_MAP,
|
|
||||||
CONTROLNET_PROCESSORS,
|
CONTROLNET_PROCESSORS,
|
||||||
ControlNetModel,
|
ControlNetModelName,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { controlNetImageProcessed } from './actions';
|
import { controlNetImageProcessed } from './actions';
|
||||||
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
|
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
|
||||||
@ -21,7 +20,7 @@ import { appSocketInvocationError } from 'services/events/actions';
|
|||||||
|
|
||||||
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
model: CONTROLNET_MODELS[0],
|
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
|
||||||
weight: 1,
|
weight: 1,
|
||||||
beginStepPct: 0,
|
beginStepPct: 0,
|
||||||
endStepPct: 1,
|
endStepPct: 1,
|
||||||
@ -36,7 +35,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
|||||||
export type ControlNetConfig = {
|
export type ControlNetConfig = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
isEnabled: boolean;
|
isEnabled: boolean;
|
||||||
model: ControlNetModel;
|
model: ControlNetModelName;
|
||||||
weight: number;
|
weight: number;
|
||||||
beginStepPct: number;
|
beginStepPct: number;
|
||||||
endStepPct: number;
|
endStepPct: number;
|
||||||
@ -138,14 +137,17 @@ export const controlNetSlice = createSlice({
|
|||||||
},
|
},
|
||||||
controlNetModelChanged: (
|
controlNetModelChanged: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{ controlNetId: string; model: ControlNetModel }>
|
action: PayloadAction<{
|
||||||
|
controlNetId: string;
|
||||||
|
model: ControlNetModelName;
|
||||||
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, model } = action.payload;
|
const { controlNetId, model } = action.payload;
|
||||||
state.controlNets[controlNetId].model = model;
|
state.controlNets[controlNetId].model = model;
|
||||||
state.controlNets[controlNetId].processedControlImage = null;
|
state.controlNets[controlNetId].processedControlImage = null;
|
||||||
|
|
||||||
if (state.controlNets[controlNetId].shouldAutoConfig) {
|
if (state.controlNets[controlNetId].shouldAutoConfig) {
|
||||||
const processorType = CONTROLNET_MODEL_MAP[model];
|
const processorType = CONTROLNET_MODELS[model].defaultProcessor;
|
||||||
if (processorType) {
|
if (processorType) {
|
||||||
state.controlNets[controlNetId].processorType = processorType;
|
state.controlNets[controlNetId].processorType = processorType;
|
||||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
||||||
@ -225,7 +227,8 @@ export const controlNetSlice = createSlice({
|
|||||||
if (newShouldAutoConfig) {
|
if (newShouldAutoConfig) {
|
||||||
// manage the processor for the user
|
// manage the processor for the user
|
||||||
const processorType =
|
const processorType =
|
||||||
CONTROLNET_MODEL_MAP[state.controlNets[controlNetId].model];
|
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
|
||||||
|
.defaultProcessor;
|
||||||
if (processorType) {
|
if (processorType) {
|
||||||
state.controlNets[controlNetId].processorType = processorType;
|
state.controlNets[controlNetId].processorType = processorType;
|
||||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
||||||
|
@ -18,6 +18,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
|||||||
ColorField: 'color',
|
ColorField: 'color',
|
||||||
ControlField: 'control',
|
ControlField: 'control',
|
||||||
control: 'control',
|
control: 'control',
|
||||||
|
cfg_scale: 'float',
|
||||||
|
control_weight: 'float',
|
||||||
};
|
};
|
||||||
|
|
||||||
const COLOR_TOKEN_VALUE = 500;
|
const COLOR_TOKEN_VALUE = 500;
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { forEach, size } from 'lodash-es';
|
import { filter, forEach, size } from 'lodash-es';
|
||||||
import { CollectInvocation, ControlNetInvocation } from 'services/api';
|
import { CollectInvocation, ControlNetInvocation } from 'services/api';
|
||||||
import { NonNullableGraph } from '../types/types';
|
import { NonNullableGraph } from '../types/types';
|
||||||
|
|
||||||
@ -12,8 +12,16 @@ export const addControlNetToLinearGraph = (
|
|||||||
): void => {
|
): void => {
|
||||||
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
|
const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet;
|
||||||
|
|
||||||
|
const validControlNets = filter(
|
||||||
|
controlNets,
|
||||||
|
(c) =>
|
||||||
|
c.isEnabled &&
|
||||||
|
(Boolean(c.processedControlImage) ||
|
||||||
|
(c.processorType === 'none' && Boolean(c.controlImage)))
|
||||||
|
);
|
||||||
|
|
||||||
// Add ControlNet
|
// Add ControlNet
|
||||||
if (isControlNetEnabled) {
|
if (isControlNetEnabled && validControlNets.length > 0) {
|
||||||
if (size(controlNets) > 1) {
|
if (size(controlNets) > 1) {
|
||||||
const controlNetIterateNode: CollectInvocation = {
|
const controlNetIterateNode: CollectInvocation = {
|
||||||
id: CONTROL_NET_COLLECT,
|
id: CONTROL_NET_COLLECT,
|
||||||
|
@ -3,10 +3,11 @@ import { Scheduler } from 'app/constants';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAICustomSelect from 'common/components/IAICustomSelect';
|
import IAICustomSelect from 'common/components/IAICustomSelect';
|
||||||
|
import IAISelect from 'common/components/IAISelect';
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
import { setScheduler } from 'features/parameters/store/generationSlice';
|
import { setScheduler } from 'features/parameters/store/generationSlice';
|
||||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { memo, useCallback } from 'react';
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
@ -14,9 +15,11 @@ const selector = createSelector(
|
|||||||
(ui, generation) => {
|
(ui, generation) => {
|
||||||
// TODO: DPMSolverSinglestepScheduler is fixed in https://github.com/huggingface/diffusers/pull/3413
|
// TODO: DPMSolverSinglestepScheduler is fixed in https://github.com/huggingface/diffusers/pull/3413
|
||||||
// but we need to wait for the next release before removing this special handling.
|
// but we need to wait for the next release before removing this special handling.
|
||||||
const allSchedulers = ui.schedulers.filter((scheduler) => {
|
const allSchedulers = ui.schedulers
|
||||||
return !['dpmpp_2s'].includes(scheduler);
|
.filter((scheduler) => {
|
||||||
});
|
return !['dpmpp_2s'].includes(scheduler);
|
||||||
|
})
|
||||||
|
.sort((a, b) => a.localeCompare(b));
|
||||||
|
|
||||||
return {
|
return {
|
||||||
scheduler: generation.scheduler,
|
scheduler: generation.scheduler,
|
||||||
@ -33,24 +36,39 @@ const ParamScheduler = () => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
(v: string | null | undefined) => {
|
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||||
if (!v) {
|
dispatch(setScheduler(e.target.value as Scheduler));
|
||||||
return;
|
|
||||||
}
|
|
||||||
dispatch(setScheduler(v as Scheduler));
|
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
// const handleChange = useCallback(
|
||||||
|
// (v: string | null | undefined) => {
|
||||||
|
// if (!v) {
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
// dispatch(setScheduler(v as Scheduler));
|
||||||
|
// },
|
||||||
|
// [dispatch]
|
||||||
|
// );
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAICustomSelect
|
<IAISelect
|
||||||
label={t('parameters.scheduler')}
|
label={t('parameters.scheduler')}
|
||||||
selectedItem={scheduler}
|
value={scheduler}
|
||||||
setSelectedItem={handleChange}
|
validValues={allSchedulers}
|
||||||
items={allSchedulers}
|
onChange={handleChange}
|
||||||
withCheckIcon
|
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// return (
|
||||||
|
// <IAICustomSelect
|
||||||
|
// label={t('parameters.scheduler')}
|
||||||
|
// value={scheduler}
|
||||||
|
// data={allSchedulers}
|
||||||
|
// onChange={handleChange}
|
||||||
|
// withCheckIcon
|
||||||
|
// />
|
||||||
|
// );
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(ParamScheduler);
|
export default memo(ParamScheduler);
|
||||||
|
@ -1,41 +0,0 @@
|
|||||||
// import { emptyTempFolder } from 'app/socketio/actions';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import IAIAlertDialog from 'common/components/IAIAlertDialog';
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
|
||||||
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
|
|
||||||
import {
|
|
||||||
clearCanvasHistory,
|
|
||||||
resetCanvas,
|
|
||||||
} from 'features/canvas/store/canvasSlice';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { FaTrash } from 'react-icons/fa';
|
|
||||||
|
|
||||||
const EmptyTempFolderButtonModal = () => {
|
|
||||||
const isStaging = useAppSelector(isStagingSelector);
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const acceptCallback = () => {
|
|
||||||
dispatch(emptyTempFolder());
|
|
||||||
dispatch(resetCanvas());
|
|
||||||
dispatch(clearCanvasHistory());
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<IAIAlertDialog
|
|
||||||
title={t('unifiedCanvas.emptyTempImageFolder')}
|
|
||||||
acceptCallback={acceptCallback}
|
|
||||||
acceptButtonText={t('unifiedCanvas.emptyFolder')}
|
|
||||||
triggerComponent={
|
|
||||||
<IAIButton leftIcon={<FaTrash />} size="sm" isDisabled={isStaging}>
|
|
||||||
{t('unifiedCanvas.emptyTempImageFolder')}
|
|
||||||
</IAIButton>
|
|
||||||
}
|
|
||||||
>
|
|
||||||
<p>{t('unifiedCanvas.emptyTempImagesFolderMessage')}</p>
|
|
||||||
<br />
|
|
||||||
<p>{t('unifiedCanvas.emptyTempImagesFolderConfirm')}</p>
|
|
||||||
</IAIAlertDialog>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
export default EmptyTempFolderButtonModal;
|
|
@ -1,37 +1,39 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { memo, useCallback } from 'react';
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import {
|
import { selectModelsAll, selectModelsById } from '../store/modelSlice';
|
||||||
selectModelsAll,
|
|
||||||
selectModelsById,
|
|
||||||
selectModelsIds,
|
|
||||||
} from '../store/modelSlice';
|
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { modelSelected } from 'features/parameters/store/generationSlice';
|
import { modelSelected } from 'features/parameters/store/generationSlice';
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
import IAICustomSelect, {
|
import IAICustomSelect, {
|
||||||
ItemTooltips,
|
IAICustomSelectOption,
|
||||||
} from 'common/components/IAICustomSelect';
|
} from 'common/components/IAICustomSelect';
|
||||||
|
import IAISelect from 'common/components/IAISelect';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[(state: RootState) => state, generationSelector],
|
[(state: RootState) => state, generationSelector],
|
||||||
(state, generation) => {
|
(state, generation) => {
|
||||||
const selectedModel = selectModelsById(state, generation.model);
|
const selectedModel = selectModelsById(state, generation.model);
|
||||||
const allModelNames = selectModelsIds(state).map((id) => String(id));
|
|
||||||
const allModelTooltips = selectModelsAll(state).reduce(
|
const modelData = selectModelsAll(state)
|
||||||
(allModelTooltips, model) => {
|
.map((m) => ({
|
||||||
allModelTooltips[model.name] = model.description ?? '';
|
value: m.name,
|
||||||
return allModelTooltips;
|
key: m.name,
|
||||||
},
|
}))
|
||||||
{} as ItemTooltips
|
.sort((a, b) => a.key.localeCompare(b.key));
|
||||||
);
|
// const modelData = selectModelsAll(state)
|
||||||
|
// .map<IAICustomSelectOption>((m) => ({
|
||||||
|
// value: m.name,
|
||||||
|
// label: m.name,
|
||||||
|
// tooltip: m.description,
|
||||||
|
// }))
|
||||||
|
// .sort((a, b) => a.label.localeCompare(b.label));
|
||||||
return {
|
return {
|
||||||
allModelNames,
|
|
||||||
allModelTooltips,
|
|
||||||
selectedModel,
|
selectedModel,
|
||||||
|
modelData,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -44,30 +46,45 @@ const selector = createSelector(
|
|||||||
const ModelSelect = () => {
|
const ModelSelect = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { allModelNames, allModelTooltips, selectedModel } =
|
const { selectedModel, modelData } = useAppSelector(selector);
|
||||||
useAppSelector(selector);
|
|
||||||
const handleChangeModel = useCallback(
|
const handleChangeModel = useCallback(
|
||||||
(v: string | null | undefined) => {
|
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||||
if (!v) {
|
dispatch(modelSelected(e.target.value));
|
||||||
return;
|
|
||||||
}
|
|
||||||
dispatch(modelSelected(v));
|
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
// const handleChangeModel = useCallback(
|
||||||
|
// (v: string | null | undefined) => {
|
||||||
|
// if (!v) {
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
// dispatch(modelSelected(v));
|
||||||
|
// },
|
||||||
|
// [dispatch]
|
||||||
|
// );
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAICustomSelect
|
<IAISelect
|
||||||
label={t('modelManager.model')}
|
label={t('modelManager.model')}
|
||||||
tooltip={selectedModel?.description}
|
tooltip={selectedModel?.description}
|
||||||
items={allModelNames}
|
validValues={modelData}
|
||||||
itemTooltips={allModelTooltips}
|
value={selectedModel?.name ?? ''}
|
||||||
selectedItem={selectedModel?.name ?? ''}
|
onChange={handleChangeModel}
|
||||||
setSelectedItem={handleChangeModel}
|
|
||||||
withCheckIcon={true}
|
|
||||||
tooltipProps={{ placement: 'top', hasArrow: true }}
|
tooltipProps={{ placement: 'top', hasArrow: true }}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// return (
|
||||||
|
// <IAICustomSelect
|
||||||
|
// label={t('modelManager.model')}
|
||||||
|
// tooltip={selectedModel?.description}
|
||||||
|
// data={modelData}
|
||||||
|
// value={selectedModel?.name ?? ''}
|
||||||
|
// onChange={handleChangeModel}
|
||||||
|
// withCheckIcon={true}
|
||||||
|
// tooltipProps={{ placement: 'top', hasArrow: true }}
|
||||||
|
// />
|
||||||
|
// );
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(ModelSelect);
|
export default memo(ModelSelect);
|
||||||
|
@ -10,6 +10,8 @@ export const initialConfigState: AppConfig = {
|
|||||||
disabledSDFeatures: [],
|
disabledSDFeatures: [],
|
||||||
canRestoreDeletedImagesFromBin: true,
|
canRestoreDeletedImagesFromBin: true,
|
||||||
sd: {
|
sd: {
|
||||||
|
disabledControlNetModels: [],
|
||||||
|
disabledControlNetProcessors: [],
|
||||||
iterations: {
|
iterations: {
|
||||||
initial: 1,
|
initial: 1,
|
||||||
min: 1,
|
min: 1,
|
||||||
|
@ -47,3 +47,6 @@ export const languageSelector = createSelector(
|
|||||||
(system) => system.language,
|
(system) => system.language,
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
|
export const isProcessingSelector = (state: RootState) =>
|
||||||
|
state.system.isProcessing;
|
||||||
|
@ -14,7 +14,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
||||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
|
import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
|
||||||
import { memo, ReactNode, useCallback, useMemo } from 'react';
|
import { memo, MouseEvent, ReactNode, useCallback, useMemo } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { MdDeviceHub, MdGridOn } from 'react-icons/md';
|
import { MdDeviceHub, MdGridOn } from 'react-icons/md';
|
||||||
import { GoTextSize } from 'react-icons/go';
|
import { GoTextSize } from 'react-icons/go';
|
||||||
@ -47,22 +47,22 @@ export interface InvokeTabInfo {
|
|||||||
const tabs: InvokeTabInfo[] = [
|
const tabs: InvokeTabInfo[] = [
|
||||||
{
|
{
|
||||||
id: 'txt2img',
|
id: 'txt2img',
|
||||||
icon: <Icon as={GoTextSize} sx={{ boxSize: 6 }} />,
|
icon: <Icon as={GoTextSize} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
|
||||||
content: <TextToImageTab />,
|
content: <TextToImageTab />,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'img2img',
|
id: 'img2img',
|
||||||
icon: <Icon as={FaImage} sx={{ boxSize: 6 }} />,
|
icon: <Icon as={FaImage} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
|
||||||
content: <ImageTab />,
|
content: <ImageTab />,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'unifiedCanvas',
|
id: 'unifiedCanvas',
|
||||||
icon: <Icon as={MdGridOn} sx={{ boxSize: 6 }} />,
|
icon: <Icon as={MdGridOn} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
|
||||||
content: <UnifiedCanvasTab />,
|
content: <UnifiedCanvasTab />,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'nodes',
|
id: 'nodes',
|
||||||
icon: <Icon as={MdDeviceHub} sx={{ boxSize: 6 }} />,
|
icon: <Icon as={MdDeviceHub} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
|
||||||
content: <NodesTab />,
|
content: <NodesTab />,
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
@ -119,6 +119,12 @@ const InvokeTabs = () => {
|
|||||||
}
|
}
|
||||||
}, [dispatch, activeTabName]);
|
}, [dispatch, activeTabName]);
|
||||||
|
|
||||||
|
const handleClickTab = useCallback((e: MouseEvent<HTMLElement>) => {
|
||||||
|
if (e.target instanceof HTMLElement) {
|
||||||
|
e.target.blur();
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
const tabs = useMemo(
|
const tabs = useMemo(
|
||||||
() =>
|
() =>
|
||||||
enabledTabs.map((tab) => (
|
enabledTabs.map((tab) => (
|
||||||
@ -128,7 +134,7 @@ const InvokeTabs = () => {
|
|||||||
label={String(t(`common.${tab.id}` as ResourceKey))}
|
label={String(t(`common.${tab.id}` as ResourceKey))}
|
||||||
placement="end"
|
placement="end"
|
||||||
>
|
>
|
||||||
<Tab>
|
<Tab onClick={handleClickTab}>
|
||||||
<VisuallyHidden>
|
<VisuallyHidden>
|
||||||
{String(t(`common.${tab.id}` as ResourceKey))}
|
{String(t(`common.${tab.id}` as ResourceKey))}
|
||||||
</VisuallyHidden>
|
</VisuallyHidden>
|
||||||
@ -136,7 +142,7 @@ const InvokeTabs = () => {
|
|||||||
</Tab>
|
</Tab>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
)),
|
)),
|
||||||
[t, enabledTabs]
|
[enabledTabs, t, handleClickTab]
|
||||||
);
|
);
|
||||||
|
|
||||||
const tabPanels = useMemo(
|
const tabPanels = useMemo(
|
||||||
|
@ -12,7 +12,6 @@ import {
|
|||||||
setShouldShowCanvasDebugInfo,
|
setShouldShowCanvasDebugInfo,
|
||||||
setShouldShowIntermediates,
|
setShouldShowIntermediates,
|
||||||
} from 'features/canvas/store/canvasSlice';
|
} from 'features/canvas/store/canvasSlice';
|
||||||
import EmptyTempFolderButtonModal from 'features/system/components/ClearTempFolderButtonModal';
|
|
||||||
|
|
||||||
import { FaWrench } from 'react-icons/fa';
|
import { FaWrench } from 'react-icons/fa';
|
||||||
|
|
||||||
@ -105,7 +104,6 @@ const UnifiedCanvasSettings = () => {
|
|||||||
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
|
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
|
||||||
/>
|
/>
|
||||||
<ClearCanvasHistoryButtonModal />
|
<ClearCanvasHistoryButtonModal />
|
||||||
<EmptyTempFolderButtonModal />
|
|
||||||
</Flex>
|
</Flex>
|
||||||
</IAIPopover>
|
</IAIPopover>
|
||||||
);
|
);
|
||||||
|
@ -55,8 +55,6 @@ const UnifiedCanvasContent = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
useLayoutEffect(() => {
|
useLayoutEffect(() => {
|
||||||
dispatch(requestCanvasRescale());
|
|
||||||
|
|
||||||
const resizeCallback = () => {
|
const resizeCallback = () => {
|
||||||
dispatch(requestCanvasRescale());
|
dispatch(requestCanvasRescale());
|
||||||
};
|
};
|
||||||
|
@ -7,30 +7,26 @@ import type { ImageField } from './ImageField';
|
|||||||
/**
|
/**
|
||||||
* Applies HED edge detection to image
|
* Applies HED edge detection to image
|
||||||
*/
|
*/
|
||||||
export type HedImageProcessorInvocation = {
|
export type HedImageprocessorInvocation = {
|
||||||
/**
|
/**
|
||||||
* The id of this node. Must be unique among all nodes.
|
* The id of this node. Must be unique among all nodes.
|
||||||
*/
|
*/
|
||||||
id: string;
|
id: string;
|
||||||
/**
|
|
||||||
* Whether or not this node is an intermediate node.
|
|
||||||
*/
|
|
||||||
is_intermediate?: boolean;
|
|
||||||
type?: 'hed_image_processor';
|
type?: 'hed_image_processor';
|
||||||
/**
|
/**
|
||||||
* The image to process
|
* image to process
|
||||||
*/
|
*/
|
||||||
image?: ImageField;
|
image?: ImageField;
|
||||||
/**
|
/**
|
||||||
* The pixel resolution for detection
|
* pixel resolution for edge detection
|
||||||
*/
|
*/
|
||||||
detect_resolution?: number;
|
detect_resolution?: number;
|
||||||
/**
|
/**
|
||||||
* The pixel resolution for the output image
|
* pixel resolution for output image
|
||||||
*/
|
*/
|
||||||
image_resolution?: number;
|
image_resolution?: number;
|
||||||
/**
|
/**
|
||||||
* Whether to use scribble mode
|
* whether to use scribble mode
|
||||||
*/
|
*/
|
||||||
scribble?: boolean;
|
scribble?: boolean;
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ImageField } from './ImageField';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies HED edge detection to image
|
||||||
|
*/
|
||||||
|
export type HedImageprocessorInvocation = {
|
||||||
|
/**
|
||||||
|
* The id of this node. Must be unique among all nodes.
|
||||||
|
*/
|
||||||
|
id: string;
|
||||||
|
type?: 'hed_image_processor';
|
||||||
|
/**
|
||||||
|
* image to process
|
||||||
|
*/
|
||||||
|
image?: ImageField;
|
||||||
|
/**
|
||||||
|
* pixel resolution for edge detection
|
||||||
|
*/
|
||||||
|
detect_resolution?: number;
|
||||||
|
/**
|
||||||
|
* pixel resolution for output image
|
||||||
|
*/
|
||||||
|
image_resolution?: number;
|
||||||
|
/**
|
||||||
|
* whether to use scribble mode
|
||||||
|
*/
|
||||||
|
scribble?: boolean;
|
||||||
|
};
|
||||||
|
|
@ -30,7 +30,7 @@ const invokeAIMark = defineStyle((_props) => {
|
|||||||
return {
|
return {
|
||||||
fontSize: 'xs',
|
fontSize: 'xs',
|
||||||
fontWeight: '500',
|
fontWeight: '500',
|
||||||
color: 'base.200',
|
color: 'base.400',
|
||||||
mt: 2,
|
mt: 2,
|
||||||
insetInlineStart: 'unset',
|
insetInlineStart: 'unset',
|
||||||
};
|
};
|
||||||
|
@ -44,6 +44,7 @@ dependencies = [
|
|||||||
"datasets",
|
"datasets",
|
||||||
"diffusers[torch]~=0.17.0",
|
"diffusers[torch]~=0.17.0",
|
||||||
"dnspython==2.2.1",
|
"dnspython==2.2.1",
|
||||||
|
"easing-functions",
|
||||||
"einops",
|
"einops",
|
||||||
"eventlet",
|
"eventlet",
|
||||||
"facexlib",
|
"facexlib",
|
||||||
@ -56,6 +57,7 @@ dependencies = [
|
|||||||
"flaskwebgui==1.0.3",
|
"flaskwebgui==1.0.3",
|
||||||
"gfpgan==1.3.8",
|
"gfpgan==1.3.8",
|
||||||
"huggingface-hub>=0.11.1",
|
"huggingface-hub>=0.11.1",
|
||||||
|
"matplotlib", # needed for plotting of Penner easing functions
|
||||||
"mediapipe", # needed for "mediapipeface" controlnet model
|
"mediapipe", # needed for "mediapipeface" controlnet model
|
||||||
"npyscreen",
|
"npyscreen",
|
||||||
"numpy<1.24",
|
"numpy<1.24",
|
||||||
|
@ -121,3 +121,78 @@ def test_graph_state_collects(mock_services):
|
|||||||
assert isinstance(n6[0], CollectInvocation)
|
assert isinstance(n6[0], CollectInvocation)
|
||||||
|
|
||||||
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)
|
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_state_prepares_eagerly(mock_services):
|
||||||
|
"""Tests that all prepareable nodes are prepared"""
|
||||||
|
graph = Graph()
|
||||||
|
|
||||||
|
test_prompts = ["Banana sushi", "Cat sushi"]
|
||||||
|
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
|
||||||
|
graph.add_node(IterateInvocation(id="iterate"))
|
||||||
|
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
|
||||||
|
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
|
||||||
|
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
|
||||||
|
|
||||||
|
# separated, fully-preparable chain of nodes
|
||||||
|
graph.add_node(PromptTestInvocation(id="prompt_chain_1", prompt="Dinosaur sushi"))
|
||||||
|
graph.add_node(PromptTestInvocation(id="prompt_chain_2"))
|
||||||
|
graph.add_node(PromptTestInvocation(id="prompt_chain_3"))
|
||||||
|
graph.add_edge(create_edge("prompt_chain_1", "prompt", "prompt_chain_2", "prompt"))
|
||||||
|
graph.add_edge(create_edge("prompt_chain_2", "prompt", "prompt_chain_3", "prompt"))
|
||||||
|
|
||||||
|
g = GraphExecutionState(graph=graph)
|
||||||
|
g.next()
|
||||||
|
|
||||||
|
assert "prompt_collection" in g.source_prepared_mapping
|
||||||
|
assert "prompt_chain_1" in g.source_prepared_mapping
|
||||||
|
assert "prompt_chain_2" in g.source_prepared_mapping
|
||||||
|
assert "prompt_chain_3" in g.source_prepared_mapping
|
||||||
|
assert "iterate" not in g.source_prepared_mapping
|
||||||
|
assert "prompt_iterated" not in g.source_prepared_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_executes_depth_first(mock_services):
|
||||||
|
"""Tests that the graph executes depth-first, executing a branch as far as possible before moving to the next branch"""
|
||||||
|
graph = Graph()
|
||||||
|
|
||||||
|
test_prompts = ["Banana sushi", "Cat sushi"]
|
||||||
|
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
|
||||||
|
graph.add_node(IterateInvocation(id="iterate"))
|
||||||
|
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
|
||||||
|
graph.add_node(PromptTestInvocation(id="prompt_successor"))
|
||||||
|
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
|
||||||
|
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
|
||||||
|
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
|
||||||
|
|
||||||
|
g = GraphExecutionState(graph=graph)
|
||||||
|
n1 = invoke_next(g, mock_services)
|
||||||
|
n2 = invoke_next(g, mock_services)
|
||||||
|
n3 = invoke_next(g, mock_services)
|
||||||
|
n4 = invoke_next(g, mock_services)
|
||||||
|
|
||||||
|
# Because ordering is not guaranteed, we cannot compare results directly.
|
||||||
|
# Instead, we must count the number of results.
|
||||||
|
def get_completed_count(g, id):
|
||||||
|
ids = [i for i in g.source_prepared_mapping[id]]
|
||||||
|
completed_ids = [i for i in g.executed if i in ids]
|
||||||
|
return len(completed_ids)
|
||||||
|
|
||||||
|
# Check at each step that the number of executed nodes matches the expectation for depth-first execution
|
||||||
|
assert get_completed_count(g, "prompt_iterated") == 1
|
||||||
|
assert get_completed_count(g, "prompt_successor") == 0
|
||||||
|
|
||||||
|
n5 = invoke_next(g, mock_services)
|
||||||
|
|
||||||
|
assert get_completed_count(g, "prompt_iterated") == 1
|
||||||
|
assert get_completed_count(g, "prompt_successor") == 1
|
||||||
|
|
||||||
|
n6 = invoke_next(g, mock_services)
|
||||||
|
|
||||||
|
assert get_completed_count(g, "prompt_iterated") == 2
|
||||||
|
assert get_completed_count(g, "prompt_successor") == 1
|
||||||
|
|
||||||
|
n7 = invoke_next(g, mock_services)
|
||||||
|
|
||||||
|
assert get_completed_count(g, "prompt_iterated") == 2
|
||||||
|
assert get_completed_count(g, "prompt_successor") == 2
|
||||||
|
Loading…
Reference in New Issue
Block a user