Merge branch 'main' into lstein/default-model-install

This commit is contained in:
Lincoln Stein 2023-07-15 08:30:22 -04:00 committed by GitHub
commit 32e7e52d69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
81 changed files with 1725 additions and 747 deletions

View File

@ -1,6 +1,7 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
import pathlib
from typing import Literal, List, Optional, Union from typing import Literal, List, Optional, Union
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
@ -22,6 +23,7 @@ UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@ -78,7 +80,7 @@ async def update_model(
return model_response return model_response
@models_router.post( @models_router.post(
"/", "/import",
operation_id="import_model", operation_id="import_model",
responses= { responses= {
201: {"description" : "The model imported successfully"}, 201: {"description" : "The model imported successfully"},
@ -94,7 +96,7 @@ async def import_model(
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse: ) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL """ """ Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
items_to_import = {location} items_to_import = {location}
prediction_types = { x.value: x for x in SchedulerPredictionType } prediction_types = { x.value: x for x in SchedulerPredictionType }
@ -126,18 +128,100 @@ async def import_model(
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/add",
operation_id="add_model",
responses= {
201: {"description" : "The model added successfully"},
404: {"description" : "The model could not be found"},
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse
)
async def add_model(
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> ImportModelResponse:
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
logger = ApiDependencies.invoker.services.logger
try:
ApiDependencies.invoker.services.model_manager.add_model(
info.model_name,
info.base_model,
info.model_type,
model_attributes = info.dict()
)
logger.info(f'Successfully added {info.model_name}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.model_name,
base_model=info.base_model,
model_type=info.model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/rename/{base_model}/{model_type}/{model_name}",
operation_id="rename_model",
responses= {
201: {"description" : "The model was renamed successfully"},
404: {"description" : "The model could not be found"},
409: {"description" : "There is already a model corresponding to the new name"},
},
status_code=201,
response_model=ImportModelResponse
)
async def rename_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="current model name"),
new_name: Optional[str] = Query(description="new model name", default=None),
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
) -> ImportModelResponse:
""" Rename a model"""
logger = ApiDependencies.invoker.services.logger
try:
result = ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = new_name,
new_base = new_base,
)
logger.debug(result)
logger.info(f'Successfully renamed {model_name}=>{new_name}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=new_name or model_name,
base_model=new_base or base_model,
model_type=model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete( @models_router.delete(
"/{base_model}/{model_type}/{model_name}", "/{base_model}/{model_type}/{model_name}",
operation_id="del_model", operation_id="del_model",
responses={ responses={
204: { 204: { "description": "Model deleted successfully" },
"description": "Model deleted successfully" 404: { "description": "Model not found" }
},
404: {
"description": "Model not found"
}
}, },
status_code = 204,
response_model = None,
) )
async def delete_model( async def delete_model(
base_model: BaseModelType = Path(description="Base model"), base_model: BaseModelType = Path(description="Base model"),
@ -173,14 +257,17 @@ async def convert_model(
base_model: BaseModelType = Path(description="Base model"), base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"), model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"), model_name: str = Path(description="model name"),
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
) -> ConvertModelResponse: ) -> ConvertModelResponse:
"""Convert a checkpoint model into a diffusers model""" """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
logger.info(f"Converting model: {model_name}") logger.info(f"Converting model: {model_name}")
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
ApiDependencies.invoker.services.model_manager.convert_model(model_name, ApiDependencies.invoker.services.model_manager.convert_model(model_name,
base_model = base_model, base_model = base_model,
model_type = model_type model_type = model_type,
convert_dest_directory = dest,
) )
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name, model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
base_model = base_model, base_model = base_model,
@ -192,6 +279,53 @@ async def convert_model(
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return response return response
@models_router.get(
"/search",
operation_id="search_for_models",
responses={
200: { "description": "Directory searched successfully" },
404: { "description": "Invalid directory path" },
},
status_code = 200,
response_model = List[pathlib.Path]
)
async def search_for_models(
search_path: pathlib.Path = Query(description="Directory path to search for models")
)->List[pathlib.Path]:
if not search_path.is_dir():
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
@models_router.get(
"/ckpt_confs",
operation_id="list_ckpt_configs",
responses={
200: { "description" : "paths retrieved successfully" },
},
status_code = 200,
response_model = List[pathlib.Path]
)
async def list_ckpt_configs(
)->List[pathlib.Path]:
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
@models_router.get(
"/sync",
operation_id="sync_to_config",
responses={
201: { "description": "synchronization successful" },
},
status_code = 201,
response_model = None
)
async def sync_to_config(
)->None:
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
in-memory data structures with disk data structures."""
return ApiDependencies.invoker.services.model_manager.sync_to_config()
@models_router.put( @models_router.put(
"/merge/{base_model}", "/merge/{base_model}",
operation_id="merge_models", operation_id="merge_models",
@ -210,17 +344,21 @@ async def merge_models(
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"), interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False), force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
) -> MergeModelResponse: ) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model""" """Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
logger.info(f"Merging models: {model_names}") logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names, result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
base_model, base_model,
merged_model_name or "+".join(model_names), merged_model_name=merged_model_name or "+".join(model_names),
alpha, alpha=alpha,
interp, interp=interp,
force) force=force,
merge_dest_directory = dest
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name, model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
base_model = base_model, base_model = base_model,
model_type = ModelType.Main, model_type = ModelType.Main,

View File

@ -100,7 +100,7 @@ class CompelInvocation(BaseInvocation):
text_encoder=text_encoder, text_encoder=text_encoder,
textual_inversion_manager=ti_manager, textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype, dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO: truncate_long_prompts=False,
) )
conjunction = Compel.parse_prompt_string(self.prompt) conjunction = Compel.parse_prompt_string(self.prompt)
@ -112,9 +112,6 @@ class CompelInvocation(BaseInvocation):
c, options = compel.build_conditioning_tensor_for_prompt_object( c, options = compel.build_conditioning_tensor_for_prompt_object(
prompt) prompt)
# TODO: long prompt support
# if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count( tokens_count_including_eos_bos=get_max_token_count(
tokenizer, conjunction), tokenizer, conjunction),

View File

@ -9,6 +9,7 @@ from typing import Literal, Optional, Union, List, Dict
from PIL import Image from PIL import Image
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from ...backend.model_management import BaseModelType, ModelType
from ..models.image import ImageField, ImageCategory, ResourceOrigin from ..models.image import ImageField, ImageCategory, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
@ -105,9 +106,15 @@ CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])] # CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
model_name: str = Field(description="Name of the ControlNet model")
base_model: BaseModelType = Field(description="Base model")
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image") image: ImageField = Field(default=None, description="The control image")
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use") control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet") # control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1, begin_step_percent: float = Field(default=0, ge=0, le=1,
@ -118,15 +125,15 @@ class ControlField(BaseModel):
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") # resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@validator("control_weight") @validator("control_weight")
def abs_le_one(cls, v): def validate_control_weight(cls, v):
"""validate that all abs(values) are <=1""" """Validate that all control weights in the valid range"""
if isinstance(v, list): if isinstance(v, list):
for i in v: for i in v:
if abs(i) > 1: if i < -1 or i > 2:
raise ValueError('all abs(control_weight) must be <= 1') raise ValueError('Control weights must be within -1 to 2 range')
else: else:
if abs(v) > 1: if v < -1 or v > 2:
raise ValueError('abs(control_weight) must be <= 1') raise ValueError('Control weights must be within -1 to 2 range')
return v return v
class Config: class Config:
schema_extra = { schema_extra = {
@ -134,6 +141,7 @@ class ControlField(BaseModel):
"ui": { "ui": {
"type_hints": { "type_hints": {
"control_weight": "float", "control_weight": "float",
"control_model": "controlnet_model",
# "control_weight": "number", # "control_weight": "number",
} }
} }
@ -154,10 +162,10 @@ class ControlNetInvocation(BaseInvocation):
type: Literal["controlnet"] = "controlnet" type: Literal["controlnet"] = "controlnet"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The control image") image: ImageField = Field(default=None, description="The control image")
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny", control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used") description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1, begin_step_percent: float = Field(default=0, ge=-1, le=2,
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")

View File

@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import ExitStack
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
import einops import einops
@ -11,6 +12,7 @@ from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models.base import ModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
@ -71,16 +73,21 @@ def get_scheduler(
scheduler_name: str, scheduler_name: str,
) -> Scheduler: ) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get( scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
scheduler_name, SCHEDULER_MAP['ddim']) scheduler_name, SCHEDULER_MAP['ddim']
)
orig_scheduler_info = context.services.model_manager.get_model( orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict()) **scheduler_info.dict()
)
with orig_scheduler_info as orig_scheduler: with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
if "_backup" in scheduler_config: if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"] scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, ** scheduler_config = {
scheduler_extra_config, "_backup": scheduler_config} **scheduler_config,
**scheduler_extra_config,
"_backup": scheduler_config,
}
scheduler = scheduler_class.from_config(scheduler_config) scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py # hack copied over from generate.py
@ -137,8 +144,11 @@ class TextToLatentsInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching? # TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress( def dispatch_progress(
self, context: InvocationContext, source_node_id: str, self,
intermediate_state: PipelineIntermediateState) -> None: context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback( stable_diffusion_step_callback(
context=context, context=context,
intermediate_state=intermediate_state, intermediate_state=intermediate_state,
@ -147,11 +157,16 @@ class TextToLatentsInvocation(BaseInvocation):
) )
def get_conditioning_data( def get_conditioning_data(
self, context: InvocationContext, scheduler) -> ConditioningData: self,
context: InvocationContext,
scheduler,
) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get( c, extra_conditioning_info = context.services.latents.get(
self.positive_conditioning.conditioning_name) self.positive_conditioning.conditioning_name
)
uc, _ = context.services.latents.get( uc, _ = context.services.latents.get(
self.negative_conditioning.conditioning_name) self.negative_conditioning.conditioning_name
)
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
unconditioned_embeddings=uc, unconditioned_embeddings=uc,
@ -178,7 +193,10 @@ class TextToLatentsInvocation(BaseInvocation):
return conditioning_data return conditioning_data
def create_pipeline( def create_pipeline(
self, unet, scheduler) -> StableDiffusionGeneratorPipeline: self,
unet,
scheduler,
) -> StableDiffusionGeneratorPipeline:
# TODO: # TODO:
# configure_model_padding( # configure_model_padding(
# unet, # unet,
@ -213,6 +231,7 @@ class TextToLatentsInvocation(BaseInvocation):
model: StableDiffusionGeneratorPipeline, model: StableDiffusionGeneratorPipeline,
control_input: List[ControlField], control_input: List[ControlField],
latents_shape: List[int], latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]: ) -> List[ControlNetData]:
@ -238,25 +257,19 @@ class TextToLatentsInvocation(BaseInvocation):
control_data = [] control_data = []
control_models = [] control_models = []
for control_info in control_list: for control_info in control_list:
# handle control models control_model = exit_stack.enter_context(
if ("," in control_info.control_model): context.services.model_manager.get_model(
control_model_split = control_info.control_model.split(",") model_name=control_info.control_model.model_name,
control_name = control_model_split[0] model_type=ModelType.ControlNet,
control_subfolder = control_model_split[1] base_model=control_info.control_model.base_model,
print("Using HF model subfolders") )
print(" control_name: ", control_name) )
print(" control_subfolder: ", control_subfolder)
control_model = ControlNetModel.from_pretrained(
control_name, subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(
model.device)
else:
control_model = ControlNetModel.from_pretrained(
control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model) control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
input_image = context.services.images.get_pil_image( input_image = context.services.images.get_pil_image(
control_image_field.image_name) control_image_field.image_name
)
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes # FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt? # and add in batch_size, num_images_per_prompt?
@ -278,7 +291,8 @@ class TextToLatentsInvocation(BaseInvocation):
weight=control_info.control_weight, weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent, begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent, end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,) control_mode=control_info.control_mode,
)
control_data.append(control_item) control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData] # MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data return control_data
@ -289,7 +303,8 @@ class TextToLatentsInvocation(BaseInvocation):
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id) context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
@ -298,14 +313,17 @@ class TextToLatentsInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"})
)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()) **self.unet.unet.dict()
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ )
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet: unet_info as unet:
scheduler = get_scheduler( scheduler = get_scheduler(
@ -322,6 +340,7 @@ class TextToLatentsInvocation(BaseInvocation):
latents_shape=noise.shape, latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0)) # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
exit_stack=exit_stack,
) )
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
@ -374,7 +393,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id) context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
@ -383,14 +403,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"})
)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()) **self.unet.unet.dict()
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ )
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet: unet_info as unet:
scheduler = get_scheduler( scheduler = get_scheduler(
@ -407,11 +430,13 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
latents_shape=noise.shape, latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0)) # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
exit_stack=exit_stack,
) )
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like( initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=unet.device, dtype=latent.dtype) latent, device=unet.device, dtype=latent.dtype
)
timesteps, _ = pipeline.get_img2img_timesteps( timesteps, _ = pipeline.get_img2img_timesteps(
self.steps, self.steps,
@ -535,7 +560,8 @@ class ResizeLatentsInvocation(BaseInvocation):
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents, size=(self.height // 8, self.width // 8), latents, size=(self.height // 8, self.width // 8),
mode=self.mode, antialias=self.antialias mode=self.mode, antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,) if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -569,7 +595,8 @@ class ScaleLatentsInvocation(BaseInvocation):
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents, scale_factor=self.scale_factor, mode=self.mode, latents, scale_factor=self.scale_factor, mode=self.mode,
antialias=self.antialias antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,) if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -19,7 +19,7 @@ from invokeai.backend.model_management import (
ModelMerger, ModelMerger,
MergeInterpolationMethod, MergeInterpolationMethod,
) )
from invokeai.backend.model_management.model_search import FindModels
import torch import torch
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
@ -167,6 +167,27 @@ class ModelManagerServiceBase(ABC):
""" """
pass pass
@abstractmethod
def rename_model(self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str,
):
"""
Rename the indicated model.
"""
pass
@abstractmethod
def list_checkpoint_configs(
self
)->List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
pass
@abstractmethod @abstractmethod
def convert_model( def convert_model(
self, self,
@ -220,6 +241,7 @@ class ModelManagerServiceBase(ABC):
alpha: Optional[float] = 0.5, alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None, interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False, force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None
) -> AddModelResult: ) -> AddModelResult:
""" """
Merge two to three diffusrs pipeline models and save as a new model. Merge two to three diffusrs pipeline models and save as a new model.
@ -228,6 +250,23 @@ class ModelManagerServiceBase(ABC):
:param merged_model_name: Name of destination merged model :param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model :param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default) :param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
pass
@abstractmethod
def search_for_models(self, directory: Path)->List[Path]:
"""
Return list of all models found in the designated directory.
"""
pass
@abstractmethod
def sync_to_config(self):
"""
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
""" """
pass pass
@ -431,16 +470,18 @@ class ModelManagerService(ModelManagerServiceBase):
""" """
Delete the named model from configuration. If delete_files is true, Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk. as well.
""" """
self.logger.debug(f'delete model {model_name}') self.logger.debug(f'delete model {model_name}')
self.mgr.del_model(model_name, base_model, model_type) self.mgr.del_model(model_name, base_model, model_type)
self.mgr.commit()
def convert_model( def convert_model(
self, self,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae], model_type: Union[ModelType.Main,ModelType.Vae],
convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
) -> AddModelResult: ) -> AddModelResult:
""" """
Convert a checkpoint file into a diffusers folder, deleting the cached Convert a checkpoint file into a diffusers folder, deleting the cached
@ -449,13 +490,14 @@ class ModelManagerService(ModelManagerServiceBase):
:param model_name: Name of the model to convert :param model_name: Name of the model to convert
:param base_model: Base model type :param base_model: Base model type
:param model_type: Type of model ['vae' or 'main'] :param model_type: Type of model ['vae' or 'main']
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
This will raise a ValueError unless the model is not a checkpoint. It will This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place. directory already in place.
""" """
self.logger.debug(f'convert model {model_name}') self.logger.debug(f'convert model {model_name}')
return self.mgr.convert_model(model_name, base_model, model_type) return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
def commit(self, conf_file: Optional[Path]=None): def commit(self, conf_file: Optional[Path]=None):
""" """
@ -536,6 +578,7 @@ class ModelManagerService(ModelManagerServiceBase):
alpha: Optional[float] = 0.5, alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None, interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False, force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
) -> AddModelResult: ) -> AddModelResult:
""" """
Merge two to three diffusrs pipeline models and save as a new model. Merge two to three diffusrs pipeline models and save as a new model.
@ -544,6 +587,7 @@ class ModelManagerService(ModelManagerServiceBase):
:param merged_model_name: Name of destination merged model :param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model :param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default) :param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
""" """
merger = ModelMerger(self.mgr) merger = ModelMerger(self.mgr)
try: try:
@ -554,7 +598,55 @@ class ModelManagerService(ModelManagerServiceBase):
alpha = alpha, alpha = alpha,
interp = interp, interp = interp,
force = force, force = force,
merge_dest_directory=merge_dest_directory,
) )
except AssertionError as e: except AssertionError as e:
raise ValueError(e) raise ValueError(e)
return result return result
def search_for_models(self, directory: Path)->List[Path]:
"""
Return list of all models found in the designated directory.
"""
search = FindModels(directory,self.logger)
return search.list_models()
def sync_to_config(self):
"""
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
"""
return self.mgr.sync_to_config()
def list_checkpoint_configs(self)->List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
config = self.mgr.app_config
conf_path = config.legacy_conf_path
root_path = config.root_path
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')]
def rename_model(self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str = None,
new_base: BaseModelType = None,
):
"""
Rename the indicated model. Can provide a new name and/or a new base.
:param model_name: Current name of the model
:param base_model: Current base of the model
:param model_type: Model type (can't be changed)
:param new_name: New name for the model
:param new_base: New base for the model
"""
self.mgr.rename_model(base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = new_name,
new_base = new_base,
)

View File

@ -71,8 +71,6 @@ class ModelInstallList:
class InstallSelections(): class InstallSelections():
install_models: List[str]= field(default_factory=list) install_models: List[str]= field(default_factory=list)
remove_models: List[str]=field(default_factory=list) remove_models: List[str]=field(default_factory=list)
# scan_directory: Path = None
# autoscan_on_startup: bool=False
@dataclass @dataclass
class ModelLoadInfo(): class ModelLoadInfo():

View File

@ -247,6 +247,7 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, Chdir from invokeai.backend.util import CUDA_DEVICE, Chdir
from .model_cache import ModelCache, ModelLocker from .model_cache import ModelCache, ModelLocker
from .model_search import ModelSearch
from .models import ( from .models import (
BaseModelType, ModelType, SubModelType, BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES, ModelError, SchedulerPredictionType, MODEL_CLASSES,
@ -323,15 +324,6 @@ class ModelManager(object):
# TODO: metadata not found # TODO: metadata not found
# TODO: version check # TODO: version check
self.models = dict()
for model_key, model_config in config.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary
self.app_config = InvokeAIAppConfig.get_config() self.app_config = InvokeAIAppConfig.get_config()
self.logger = logger self.logger = logger
self.cache = ModelCache( self.cache = ModelCache(
@ -342,11 +334,41 @@ class ModelManager(object):
sequential_offload = sequential_offload, sequential_offload = sequential_offload,
logger = logger, logger = logger,
) )
self._read_models(config)
def _read_models(self, config: Optional[DictConfig] = None):
if not config:
if self.config_path:
config = OmegaConf.load(self.config_path)
else:
return
self.models = dict()
for model_key, model_config in config.items():
if model_key.startswith('_'):
continue
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary
self.cache_keys = dict() self.cache_keys = dict()
# add controlnet, lora and textual_inversion models from disk # add controlnet, lora and textual_inversion models from disk
self.scan_models_directory() self.scan_models_directory()
def sync_to_config(self):
"""
Call this when `models.yaml` has been changed externally.
This will reinitialize internal data structures
"""
# Reread models directory; note that this will reinitialize the cache,
# causing otherwise unreferenced models to be removed from memory
self._read_models()
def model_exists( def model_exists(
self, self,
model_name: str, model_name: str,
@ -527,7 +549,10 @@ class ModelManager(object):
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold) model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
models = [] models = []
for model_key in model_keys: for model_key in model_keys:
model_config = self.models[model_key] model_config = self.models.get(model_key)
if not model_config:
self.logger.error(f'Unknown model {model_name}')
raise KeyError(f'Unknown model {model_name}')
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
if base_model is not None and cur_base_model != base_model: if base_model is not None and cur_base_model != base_model:
@ -646,11 +671,61 @@ class ModelManager(object):
config = model_config, config = model_config,
) )
def rename_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str = None,
new_base: BaseModelType = None,
):
'''
Rename or rebase a model.
'''
if new_name is None and new_base is None:
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
return
model_key = self.create_key(model_name, base_model, model_type)
model_cfg = self.models.get(model_key, None)
if not model_cfg:
raise KeyError(f"Unknown model: {model_key}")
old_path = self.app_config.root_path / model_cfg.path
new_name = new_name or model_name
new_base = new_base or base_model
new_key = self.create_key(new_name, new_base, model_type)
if new_key in self.models:
raise ValueError(f'Attempt to overwrite existing model definition "{new_key}"')
# if this is a model file/directory that we manage ourselves, we need to move it
if old_path.is_relative_to(self.app_config.models_path):
new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name
move(old_path, new_path)
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
# clean up caches
old_model_cache = self._get_model_cache_path(old_path)
if old_model_cache.exists():
if old_model_cache.is_dir():
rmtree(str(old_model_cache))
else:
old_model_cache.unlink()
cache_ids = self.cache_keys.pop(model_key, [])
for cache_id in cache_ids:
self.cache.uncache_model(cache_id)
self.models.pop(model_key, None) # delete
self.models[new_key] = model_cfg
self.commit()
def convert_model ( def convert_model (
self, self,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae], model_type: Union[ModelType.Main,ModelType.Vae],
dest_directory: Optional[Path]=None,
) -> AddModelResult: ) -> AddModelResult:
''' '''
Convert a checkpoint file into a diffusers folder, deleting the cached Convert a checkpoint file into a diffusers folder, deleting the cached
@ -677,14 +752,14 @@ class ModelManager(object):
) )
checkpoint_path = self.app_config.root_path / info["path"] checkpoint_path = self.app_config.root_path / info["path"]
old_diffusers_path = self.app_config.models_path / model.location old_diffusers_path = self.app_config.models_path / model.location
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name new_diffusers_path = (dest_directory or self.app_config.models_path / base_model.value / model_type.value) / model_name
if new_diffusers_path.exists(): if new_diffusers_path.exists():
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}") raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
try: try:
move(old_diffusers_path,new_diffusers_path) move(old_diffusers_path,new_diffusers_path)
info["model_format"] = "diffusers" info["model_format"] = "diffusers"
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path)) info["path"] = str(new_diffusers_path) if dest_directory else str(new_diffusers_path.relative_to(self.app_config.root_path))
info.pop('config') info.pop('config')
result = self.add_model(model_name, base_model, model_type, result = self.add_model(model_name, base_model, model_type,
@ -824,6 +899,7 @@ class ModelManager(object):
if (new_models_found or imported_models) and self.config_path: if (new_models_found or imported_models) and self.config_path:
self.commit() self.commit()
def autoimport(self)->Dict[str, AddModelResult]: def autoimport(self)->Dict[str, AddModelResult]:
''' '''
Scan the autoimport directory (if defined) and import new models, delete defunct models. Scan the autoimport directory (if defined) and import new models, delete defunct models.
@ -832,61 +908,40 @@ class ModelManager(object):
from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.frontend.install.model_install import ask_user_for_prediction_type from invokeai.frontend.install.model_install import ask_user_for_prediction_type
class ScanAndImport(ModelSearch):
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
super().__init__(directories, logger)
self.installer = installer
self.ignore = ignore
def on_search_started(self):
self.new_models_found = dict()
def on_model_found(self, model: Path):
if model not in self.ignore:
self.new_models_found.update(self.installer.heuristic_import(model))
def on_search_completed(self):
self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models')
def models_found(self):
return self.new_models_found
installer = ModelInstall(config = self.app_config, installer = ModelInstall(config = self.app_config,
model_manager = self, model_manager = self,
prediction_type_helper = ask_user_for_prediction_type, prediction_type_helper = ask_user_for_prediction_type,
) )
scanned_dirs = set()
config = self.app_config config = self.app_config
known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()} known_paths = {config.root_path / x['path'] for x in self.list_models()}
directories = {config.root_path / x for x in [config.autoimport_dir,
for autodir in [config.autoimport_dir, config.lora_dir,
config.lora_dir, config.embedding_dir,
config.embedding_dir, config.controlnet_dir]
config.controlnet_dir]: }
if autodir is None: scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
continue scanner.search()
return scanner.models_found()
installed = dict()
autodir = self.app_config.root_path / autodir
if not autodir.exists():
continue
items_scanned = 0
new_models_found = dict()
for root, dirs, files in os.walk(autodir):
items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in known_paths or path.parent in scanned_dirs:
scanned_dirs.add(path)
continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
try:
new_models_found.update(installer.heuristic_import(path))
scanned_dirs.add(path)
except ValueError as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path in known_paths or path.parent in scanned_dirs:
continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
try:
import_result = installer.heuristic_import(path)
new_models_found.update(import_result)
except ValueError as e:
self.logger.warning(str(e))
installed.update(new_models_found)
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
return installed
def heuristic_import(self, def heuristic_import(self,
items_to_import: Set[str], items_to_import: Set[str],
@ -924,3 +979,4 @@ class ModelManager(object):
successfully_installed.update(installed) successfully_installed.update(installed)
self.commit() self.commit()
return successfully_installed return successfully_installed

View File

@ -11,7 +11,7 @@ from enum import Enum
from pathlib import Path from pathlib import Path
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers import logging as dlogging from diffusers import logging as dlogging
from typing import List, Union from typing import List, Union, Optional
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
@ -74,6 +74,7 @@ class ModelMerger(object):
alpha: float = 0.5, alpha: float = 0.5,
interp: MergeInterpolationMethod = None, interp: MergeInterpolationMethod = None,
force: bool = False, force: bool = False,
merge_dest_directory: Optional[Path] = None,
**kwargs, **kwargs,
) -> AddModelResult: ) -> AddModelResult:
""" """
@ -85,7 +86,7 @@ class ModelMerger(object):
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None. :param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C). Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
**kwargs - the default DiffusionPipeline.get_config_dict kwargs: **kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
""" """
@ -111,7 +112,7 @@ class ModelMerger(object):
merged_pipe = self.merge_diffusion_models( merged_pipe = self.merge_diffusion_models(
model_paths, alpha, merge_method, force, **kwargs model_paths, alpha, merge_method, force, **kwargs
) )
dump_path = config.models_path / base_model.value / ModelType.Main.value dump_path = Path(merge_dest_directory) if merge_dest_directory else config.models_path / base_model.value / ModelType.Main.value
dump_path.mkdir(parents=True, exist_ok=True) dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name dump_path = dump_path / merged_model_name

View File

@ -0,0 +1,103 @@
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
"""
Abstract base class for recursive directory search for models.
"""
import os
from abc import ABC, abstractmethod
from typing import List, Set, types
from pathlib import Path
import invokeai.backend.util.logging as logger
class ModelSearch(ABC):
def __init__(self, directories: List[Path], logger: types.ModuleType=logger):
"""
Initialize a recursive model directory search.
:param directories: List of directory Paths to recurse through
:param logger: Logger to use
"""
self.directories = directories
self.logger = logger
self._items_scanned = 0
self._models_found = 0
self._scanned_dirs = set()
self._scanned_paths = set()
self._pruned_paths = set()
@abstractmethod
def on_search_started(self):
"""
Called before the scan starts.
"""
pass
@abstractmethod
def on_model_found(self, model: Path):
"""
Process a found model. Raise an exception if something goes wrong.
:param model: Model to process - could be a directory or checkpoint.
"""
pass
@abstractmethod
def on_search_completed(self):
"""
Perform some activity when the scan is completed. May use instance
variables, items_scanned and models_found
"""
pass
def search(self):
self.on_search_started()
for dir in self.directories:
self.walk_directory(dir)
self.on_search_completed()
def walk_directory(self, path: Path):
for root, dirs, files in os.walk(path):
if str(Path(root).name).startswith('.'):
self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
continue
self._items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in self._scanned_paths or path.parent in self._scanned_dirs:
self._scanned_dirs.add(path)
continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
try:
self.on_model_found(path)
self._models_found += 1
self._scanned_dirs.add(path)
except Exception as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path.parent in self._scanned_dirs:
continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
try:
self.on_model_found(path)
self._models_found += 1
except Exception as e:
self.logger.warning(str(e))
class FindModels(ModelSearch):
def on_search_started(self):
self.models_found: Set[Path] = set()
def on_model_found(self,model: Path):
self.models_found.add(model)
def on_search_completed(self):
pass
def list_models(self) -> List[Path]:
self.search()
return self.models_found

View File

@ -48,7 +48,9 @@ for base_model, models in MODEL_CLASSES.items():
model_configs.discard(None) model_configs.discard(None)
MODEL_CONFIGS.extend(model_configs) MODEL_CONFIGS.extend(model_configs)
for cfg in model_configs: # LS: sort to get the checkpoint configs first, which makes
# for a better template in the Swagger docs
for cfg in sorted(model_configs, key=lambda x: str(x)):
model_name, cfg_name = cfg.__qualname__.split('.')[-2:] model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
openapi_cfg_name = model_name + cfg_name openapi_cfg_name = model_name + cfg_name
if openapi_cfg_name in vars(): if openapi_cfg_name in vars():

View File

@ -59,7 +59,6 @@ class ModelConfigBase(BaseModel):
path: str # or Path path: str # or Path
description: Optional[str] = Field(None) description: Optional[str] = Field(None)
model_format: Optional[str] = Field(None) model_format: Optional[str] = Field(None)
# do not save to config
error: Optional[ModelError] = Field(None) error: Optional[ModelError] = Field(None)
class Config: class Config:

View File

@ -38,7 +38,6 @@ class StableDiffusion1Model(DiffusersModel):
config: str config: str
variant: ModelVariantType variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1 assert base_model == BaseModelType.StableDiffusion1
assert model_type == ModelType.Main assert model_type == ModelType.Main

View File

@ -241,11 +241,45 @@ class InvokeAIDiffuserComponent:
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs): def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
# fast batched path # fast batched path
def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
if cond.shape[1] < max_len:
conditioning_attention_mask = torch.cat([
conditioning_attention_mask,
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
], dim=1)
cond = torch.cat([
cond,
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
], dim=1)
if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat([
encoder_attention_mask,
conditioning_attention_mask,
])
return cond, encoder_attention_mask
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2) sigma_twice = torch.cat([sigma] * 2)
encoder_attention_mask = None
if unconditioning.shape[1] != conditioning.shape[1]:
max_len = max(unconditioning.shape[1], conditioning.shape[1])
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
both_conditionings = torch.cat([unconditioning, conditioning]) both_conditionings = torch.cat([unconditioning, conditioning])
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
x_twice, sigma_twice, both_conditionings, **kwargs, x_twice, sigma_twice, both_conditionings,
encoder_attention_mask=encoder_attention_mask,
**kwargs,
) )
unconditioned_next_x, conditioned_next_x = both_results.chunk(2) unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x

View File

@ -13,7 +13,11 @@ import { RootState } from 'app/store/store';
const moduleLog = log.child({ namespace: 'controlNet' }); const moduleLog = log.child({ namespace: 'controlNet' });
const predicate: AnyListenerPredicate<RootState> = (action, state) => { const predicate: AnyListenerPredicate<RootState> = (
action,
state,
prevState
) => {
const isActionMatched = const isActionMatched =
controlNetProcessorParamsChanged.match(action) || controlNetProcessorParamsChanged.match(action) ||
controlNetModelChanged.match(action) || controlNetModelChanged.match(action) ||
@ -25,6 +29,16 @@ const predicate: AnyListenerPredicate<RootState> = (action, state) => {
return false; return false;
} }
if (controlNetAutoConfigToggled.match(action)) {
// do not process if the user just disabled auto-config
if (
prevState.controlNet.controlNets[action.payload.controlNetId]
.shouldAutoConfig === true
) {
return false;
}
}
const { controlImage, processorType, shouldAutoConfig } = const { controlImage, processorType, shouldAutoConfig } =
state.controlNet.controlNets[action.payload.controlNetId]; state.controlNet.controlNets[action.payload.controlNetId];

View File

@ -10,6 +10,7 @@ import { zMainModel } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
const moduleLog = log.child({ module: 'models' }); const moduleLog = log.child({ module: 'models' });
@ -51,7 +52,14 @@ export const addModelSelectedListener = () => {
modelsCleared += 1; modelsCleared += 1;
} }
// TODO: handle incompatible controlnet; pending model manager support const { controlNets } = state.controlNet;
forEach(controlNets, (controlNet, controlNetId) => {
if (controlNet.model?.base_model !== base_model) {
dispatch(controlNetRemoved({ controlNetId }));
modelsCleared += 1;
}
});
if (modelsCleared > 0) { if (modelsCleared > 0) {
dispatch( dispatch(
addToast( addToast(

View File

@ -11,6 +11,7 @@ import {
import { forEach, some } from 'lodash-es'; import { forEach, some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
const moduleLog = log.child({ module: 'models' }); const moduleLog = log.child({ module: 'models' });
@ -127,7 +128,22 @@ export const addModelsLoadedListener = () => {
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled, matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state // ControlNet models loaded - need to remove missing ControlNets from state
// TODO: pending model manager controlnet support const controlNets = getState().controlNet.controlNets;
forEach(controlNets, (controlNet, controlNetId) => {
const isControlNetAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === controlNet?.model?.model_name &&
m?.base_model === controlNet?.model?.base_model
);
if (isControlNetAvailable) {
return;
}
dispatch(controlNetRemoved({ controlNetId }));
});
}, },
}); });
}; };

View File

@ -1,5 +1,5 @@
import { import {
CONTROLNET_MODELS, // CONTROLNET_MODELS,
CONTROLNET_PROCESSORS, CONTROLNET_PROCESSORS,
} from 'features/controlNet/store/constants'; } from 'features/controlNet/store/constants';
import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName } from 'features/ui/store/tabMap';
@ -128,7 +128,7 @@ export type AppConfig = {
canRestoreDeletedImagesFromBin: boolean; canRestoreDeletedImagesFromBin: boolean;
sd: { sd: {
defaultModel?: string; defaultModel?: string;
disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[]; disabledControlNetModels: string[];
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[]; disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
iterations: { iterations: {
initial: number; initial: number;

View File

@ -170,12 +170,14 @@ const IAIDndImage = (props: IAIDndImageProps) => {
</> </>
)} )}
{!imageDTO && isUploadDisabled && noContentFallback} {!imageDTO && isUploadDisabled && noContentFallback}
<IAIDroppable {!isDropDisabled && (
data={droppableData} <IAIDroppable
disabled={isDropDisabled} data={droppableData}
dropLabel={dropLabel} disabled={isDropDisabled}
/> dropLabel={dropLabel}
{imageDTO && ( />
)}
{imageDTO && !isDragDisabled && (
<IAIDraggable <IAIDraggable
data={draggableData} data={draggableData}
disabled={isDragDisabled || !imageDTO} disabled={isDragDisabled || !imageDTO}

View File

@ -1,17 +1,25 @@
import { Tooltip } from '@chakra-ui/react'; import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core'; import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles'; import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles';
import { KeyboardEvent, RefObject, memo, useCallback } from 'react'; import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
type IAIMultiSelectProps = MultiSelectProps & { type IAIMultiSelectProps = Omit<MultiSelectProps, 'label'> & {
tooltip?: string; tooltip?: string;
inputRef?: RefObject<HTMLInputElement>; inputRef?: RefObject<HTMLInputElement>;
label?: string;
}; };
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, inputRef, ...rest } = props; const {
searchable = true,
tooltip,
inputRef,
label,
disabled,
...rest
} = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleKeyDown = useCallback( const handleKeyDown = useCallback(
@ -37,7 +45,15 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
return ( return (
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}> <Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
<MultiSelect <MultiSelect
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
ref={inputRef} ref={inputRef}
disabled={disabled}
onKeyDown={handleKeyDown} onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp} onKeyUp={handleKeyUp}
searchable={searchable} searchable={searchable}

View File

@ -1,4 +1,4 @@
import { Tooltip } from '@chakra-ui/react'; import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core'; import { Select, SelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
@ -11,13 +11,22 @@ export type IAISelectDataType = {
tooltip?: string; tooltip?: string;
}; };
type IAISelectProps = SelectProps & { type IAISelectProps = Omit<SelectProps, 'label'> & {
tooltip?: string; tooltip?: string;
label?: string;
inputRef?: RefObject<HTMLInputElement>; inputRef?: RefObject<HTMLInputElement>;
}; };
const IAIMantineSearchableSelect = (props: IAISelectProps) => { const IAIMantineSearchableSelect = (props: IAISelectProps) => {
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props; const {
searchable = true,
tooltip,
inputRef,
onChange,
label,
disabled,
...rest
} = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [searchValue, setSearchValue] = useState(''); const [searchValue, setSearchValue] = useState('');
@ -61,6 +70,14 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => {
<Tooltip label={tooltip} placement="top" hasArrow> <Tooltip label={tooltip} placement="top" hasArrow>
<Select <Select
ref={inputRef} ref={inputRef}
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
disabled={disabled}
searchValue={searchValue} searchValue={searchValue}
onSearchChange={setSearchValue} onSearchChange={setSearchValue}
onChange={handleChange} onChange={handleChange}

View File

@ -1,4 +1,4 @@
import { Tooltip } from '@chakra-ui/react'; import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core'; import { Select, SelectProps } from '@mantine/core';
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles'; import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
import { RefObject, memo } from 'react'; import { RefObject, memo } from 'react';
@ -9,19 +9,32 @@ export type IAISelectDataType = {
tooltip?: string; tooltip?: string;
}; };
type IAISelectProps = SelectProps & { type IAISelectProps = Omit<SelectProps, 'label'> & {
tooltip?: string; tooltip?: string;
inputRef?: RefObject<HTMLInputElement>; inputRef?: RefObject<HTMLInputElement>;
label?: string;
}; };
const IAIMantineSelect = (props: IAISelectProps) => { const IAIMantineSelect = (props: IAISelectProps) => {
const { tooltip, inputRef, ...rest } = props; const { tooltip, inputRef, label, disabled, ...rest } = props;
const styles = useMantineSelectStyles(); const styles = useMantineSelectStyles();
return ( return (
<Tooltip label={tooltip} placement="top" hasArrow> <Tooltip label={tooltip} placement="top" hasArrow>
<Select ref={inputRef} styles={styles} {...rest} /> <Select
label={
label ? (
<FormControl isDisabled={disabled}>
<FormLabel>{label}</FormLabel>
</FormControl>
) : undefined
}
disabled={disabled}
ref={inputRef}
styles={styles}
{...rest}
/>
</Tooltip> </Tooltip>
); );
}; };

View File

@ -28,7 +28,7 @@ import {
useState, useState,
} from 'react'; } from 'react';
const numberStringRegex = /^-?(0\.)?\.?$/; export const numberStringRegex = /^-?(0\.)?\.?$/;
interface Props extends Omit<NumberInputProps, 'onChange'> { interface Props extends Omit<NumberInputProps, 'onChange'> {
label?: string; label?: string;

View File

@ -43,11 +43,6 @@ import { useTranslation } from 'react-i18next';
import { BiReset } from 'react-icons/bi'; import { BiReset } from 'react-icons/bi';
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton'; import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
};
export type IAIFullSliderProps = { export type IAIFullSliderProps = {
label?: string; label?: string;
value: number; value: number;
@ -207,7 +202,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
{...sliderFormControlProps} {...sliderFormControlProps}
> >
{label && ( {label && (
<FormLabel {...sliderFormLabelProps} mb={-1}> <FormLabel sx={withInput ? { mb: -1.5 } : {}} {...sliderFormLabelProps}>
{label} {label}
</FormLabel> </FormLabel>
)} )}
@ -233,7 +228,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{ sx={{
insetInlineStart: '0 !important', insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important', insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >
@ -244,7 +238,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{ sx={{
insetInlineStart: 'unset !important', insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important', insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >
@ -263,7 +256,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{ sx={{
insetInlineStart: '0 !important', insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important', insetInlineEnd: 'unset !important',
...SLIDER_MARK_STYLES,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >
@ -278,7 +270,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
sx={{ sx={{
insetInlineStart: 'unset !important', insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important', insetInlineEnd: '0 !important',
...SLIDER_MARK_STYLES,
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >
@ -291,7 +282,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
key={m} key={m}
value={m} value={m}
sx={{ sx={{
...SLIDER_MARK_STYLES, transform: 'translateX(-50%)',
}} }}
{...sliderMarkProps} {...sliderMarkProps}
> >

View File

@ -5,6 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { validateSeedWeights } from 'common/util/seedWeightPairs'; import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { modelsApi } from '../../services/api/endpoints/models'; import { modelsApi } from '../../services/api/endpoints/models';
import { forEach } from 'lodash-es';
const readinessSelector = createSelector( const readinessSelector = createSelector(
[stateSelector, activeTabNameSelector], [stateSelector, activeTabNameSelector],
@ -52,6 +53,13 @@ const readinessSelector = createSelector(
reasonsWhyNotReady.push('Seed-Weights badly formatted.'); reasonsWhyNotReady.push('Seed-Weights badly formatted.');
} }
forEach(state.controlNet.controlNets, (controlNet, id) => {
if (!controlNet.model) {
isReady = false;
reasonsWhyNotReady.push('ControlNet ${id} has no model selected.');
}
});
// All good // All good
return { isReady, reasonsWhyNotReady }; return { isReady, reasonsWhyNotReady };
}, },

View File

@ -1,10 +1,9 @@
import { Box, ChakraProps, Flex, useColorMode } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { FaCopy, FaTrash } from 'react-icons/fa'; import { FaCopy, FaTrash } from 'react-icons/fa';
import { import {
ControlNetConfig, controlNetDuplicated,
controlNetAdded,
controlNetRemoved, controlNetRemoved,
controlNetToggled, controlNetToggled,
} from '../store/controlNetSlice'; } from '../store/controlNetSlice';
@ -12,6 +11,9 @@ import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight'; import ParamControlNetWeight from './parameters/ParamControlNetWeight';
import { ChevronUpIcon } from '@chakra-ui/icons'; import { ChevronUpIcon } from '@chakra-ui/icons';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { useToggle } from 'react-use'; import { useToggle } from 'react-use';
@ -22,41 +24,41 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd'; import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode'; import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect'; import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import { mode } from 'theme/util/mode';
const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 };
type ControlNetProps = { type ControlNetProps = {
controlNet: ControlNetConfig; controlNetId: string;
}; };
const ControlNet = (props: ControlNetProps) => { const ControlNet = (props: ControlNetProps) => {
const { const { controlNetId } = props;
controlNetId,
isEnabled,
model,
weight,
beginStepPct,
endStepPct,
controlMode,
controlImage,
processedControlImage,
processorNode,
processorType,
shouldAutoConfig,
} = props.controlNet;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selector = createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, shouldAutoConfig } =
controlNet.controlNets[controlNetId];
return { isEnabled, shouldAutoConfig };
},
defaultSelectorOptions
);
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
const [isExpanded, toggleIsExpanded] = useToggle(false); const [isExpanded, toggleIsExpanded] = useToggle(false);
const { colorMode } = useColorMode();
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {
dispatch(controlNetRemoved({ controlNetId })); dispatch(controlNetRemoved({ controlNetId }));
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
const handleDuplicate = useCallback(() => { const handleDuplicate = useCallback(() => {
dispatch( dispatch(
controlNetAdded({ controlNetId: uuidv4(), controlNet: props.controlNet }) controlNetDuplicated({
sourceControlNetId: controlNetId,
newControlNetId: uuidv4(),
})
); );
}, [dispatch, props.controlNet]); }, [controlNetId, dispatch]);
const handleToggleIsEnabled = useCallback(() => { const handleToggleIsEnabled = useCallback(() => {
dispatch(controlNetToggled({ controlNetId })); dispatch(controlNetToggled({ controlNetId }));
@ -68,15 +70,18 @@ const ControlNet = (props: ControlNetProps) => {
flexDir: 'column', flexDir: 'column',
gap: 2, gap: 2,
p: 3, p: 3,
bg: mode('base.200', 'base.850')(colorMode),
borderRadius: 'base', borderRadius: 'base',
position: 'relative', position: 'relative',
bg: 'base.200',
_dark: {
bg: 'base.850',
},
}} }}
> >
<Flex sx={{ gap: 2 }}> <Flex sx={{ gap: 2, alignItems: 'center' }}>
<IAISwitch <IAISwitch
tooltip="Toggle" tooltip={'Toggle this ControlNet'}
aria-label="Toggle" aria-label={'Toggle this ControlNet'}
isChecked={isEnabled} isChecked={isEnabled}
onChange={handleToggleIsEnabled} onChange={handleToggleIsEnabled}
/> />
@ -90,7 +95,7 @@ const ControlNet = (props: ControlNetProps) => {
transitionDuration: '0.1s', transitionDuration: '0.1s',
}} }}
> >
<ParamControlNetModel controlNetId={controlNetId} model={model} /> <ParamControlNetModel controlNetId={controlNetId} />
</Box> </Box>
<IAIIconButton <IAIIconButton
size="sm" size="sm"
@ -109,21 +114,26 @@ const ControlNet = (props: ControlNetProps) => {
/> />
<IAIIconButton <IAIIconButton
size="sm" size="sm"
aria-label="Show All Options" tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
onClick={toggleIsExpanded} onClick={toggleIsExpanded}
variant="link" variant="link"
icon={ icon={
<ChevronUpIcon <ChevronUpIcon
sx={{ sx={{
boxSize: 4, boxSize: 4,
color: mode('base.700', 'base.300')(colorMode), color: 'base.700',
transform: isExpanded ? 'rotate(0deg)' : 'rotate(180deg)', transform: isExpanded ? 'rotate(0deg)' : 'rotate(180deg)',
transitionProperty: 'common', transitionProperty: 'common',
transitionDuration: 'normal', transitionDuration: 'normal',
_dark: {
color: 'base.300',
},
}} }}
/> />
} }
/> />
{!shouldAutoConfig && ( {!shouldAutoConfig && (
<Box <Box
sx={{ sx={{
@ -131,85 +141,59 @@ const ControlNet = (props: ControlNetProps) => {
w: 1.5, w: 1.5,
h: 1.5, h: 1.5,
borderRadius: 'full', borderRadius: 'full',
bg: mode('error.700', 'error.200')(colorMode),
top: 4, top: 4,
insetInlineEnd: 4, insetInlineEnd: 4,
bg: 'accent.700',
_dark: {
bg: 'accent.400',
},
}} }}
/> />
)} )}
</Flex> </Flex>
{isEnabled && ( <Flex sx={{ w: 'full', flexDirection: 'column' }}>
<> <Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
<Flex sx={{ w: 'full', flexDirection: 'column' }}> <Flex
<Flex sx={{ gap: 4, w: 'full' }}> sx={{
<Flex flexDir: 'column',
sx={{ gap: 3,
flexDir: 'column', h: 28,
gap: 3, w: 'full',
w: 'full', paddingInlineStart: 1,
paddingInlineStart: 1, paddingInlineEnd: isExpanded ? 1 : 0,
paddingInlineEnd: isExpanded ? 1 : 0, pb: 2,
pb: 2, justifyContent: 'space-between',
justifyContent: 'space-between', }}
}} >
> <ParamControlNetWeight controlNetId={controlNetId} />
<ParamControlNetWeight <ParamControlNetBeginEnd controlNetId={controlNetId} />
controlNetId={controlNetId}
weight={weight}
mini={!isExpanded}
/>
<ParamControlNetBeginEnd
controlNetId={controlNetId}
beginStepPct={beginStepPct}
endStepPct={endStepPct}
mini={!isExpanded}
/>
</Flex>
{!isExpanded && (
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
h: 24,
w: 24,
aspectRatio: '1/1',
}}
>
<ControlNetImagePreview
controlNet={props.controlNet}
height={24}
/>
</Flex>
)}
</Flex>
<ParamControlNetControlMode
controlNetId={controlNetId}
controlMode={controlMode}
/>
</Flex> </Flex>
{!isExpanded && (
{isExpanded && ( <Flex
<> sx={{
<Box mt={2}> alignItems: 'center',
<ControlNetImagePreview justifyContent: 'center',
controlNet={props.controlNet} h: 28,
height={96} w: 28,
/> aspectRatio: '1/1',
</Box> mt: 3,
<ParamControlNetProcessorSelect }}
controlNetId={controlNetId} >
processorNode={processorNode} <ControlNetImagePreview controlNetId={controlNetId} height={28} />
/> </Flex>
<ControlNetProcessorComponent
controlNetId={controlNetId}
processorNode={processorNode}
/>
<ParamControlNetShouldAutoConfig
controlNetId={controlNetId}
shouldAutoConfig={shouldAutoConfig}
/>
</>
)} )}
</Flex>
<Box mt={2}>
<ParamControlNetControlMode controlNetId={controlNetId} />
</Box>
<ParamControlNetProcessorSelect controlNetId={controlNetId} />
</Flex>
{isExpanded && (
<>
<ControlNetImagePreview controlNetId={controlNetId} height="392px" />
<ParamControlNetShouldAutoConfig controlNetId={controlNetId} />
<ControlNetProcessorComponent controlNetId={controlNetId} />
</> </>
)} )}
</Flex> </Flex>

View File

@ -5,42 +5,57 @@ import {
TypesafeDraggableData, TypesafeDraggableData,
TypesafeDroppableData, TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd'; } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import { memo, useCallback, useMemo, useState } from 'react'; import { memo, useCallback, useMemo, useState } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/thunks/image'; import { PostUploadAction } from 'services/api/thunks/image';
import { import { controlNetImageChanged } from '../store/controlNetSlice';
ControlNetConfig,
controlNetImageChanged,
controlNetSelector,
} from '../store/controlNetSlice';
const selector = createSelector(
controlNetSelector,
(controlNet) => {
const { pendingControlImages } = controlNet;
return { pendingControlImages };
},
defaultSelectorOptions
);
type Props = { type Props = {
controlNet: ControlNetConfig; controlNetId: string;
height: SystemStyleObject['h']; height: SystemStyleObject['h'];
}; };
const ControlNetImagePreview = (props: Props) => { const ControlNetImagePreview = (props: Props) => {
const { height } = props; const { height, controlNetId } = props;
const {
controlNetId,
controlImage: controlImageName,
processedControlImage: processedControlImageName,
processorType,
} = props.controlNet;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { pendingControlImages } = useAppSelector(selector);
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { pendingControlImages } = controlNet;
const {
controlImage,
processedControlImage,
processorType,
isEnabled,
} = controlNet.controlNets[controlNetId];
return {
controlImageName: controlImage,
processedControlImageName: processedControlImage,
processorType,
isEnabled,
pendingControlImages,
};
},
defaultSelectorOptions
),
[controlNetId]
);
const {
controlImageName,
processedControlImageName,
processorType,
pendingControlImages,
isEnabled,
} = useAppSelector(selector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false); const [isMouseOverImage, setIsMouseOverImage] = useState(false);
@ -110,13 +125,15 @@ const ControlNetImagePreview = (props: Props) => {
h: height, h: height,
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
pointerEvents: isEnabled ? 'auto' : 'none',
opacity: isEnabled ? 1 : 0.5,
}} }}
> >
<IAIDndImage <IAIDndImage
draggableData={draggableData} draggableData={draggableData}
droppableData={droppableData} droppableData={droppableData}
imageDTO={controlImage} imageDTO={controlImage}
isDropDisabled={shouldShowProcessedImage} isDropDisabled={shouldShowProcessedImage || !isEnabled}
onClickReset={handleResetControlImage} onClickReset={handleResetControlImage}
postUploadAction={postUploadAction} postUploadAction={postUploadAction}
resetTooltip="Reset Control Image" resetTooltip="Reset Control Image"
@ -140,6 +157,7 @@ const ControlNetImagePreview = (props: Props) => {
droppableData={droppableData} droppableData={droppableData}
imageDTO={processedControlImage} imageDTO={processedControlImage}
isUploadDisabled={true} isUploadDisabled={true}
isDropDisabled={!isEnabled}
onClickReset={handleResetControlImage} onClickReset={handleResetControlImage}
resetTooltip="Reset Control Image" resetTooltip="Reset Control Image"
withResetIcon={Boolean(controlImage)} withResetIcon={Boolean(controlImage)}

View File

@ -1,10 +1,13 @@
import { memo } from 'react'; import { createSelector } from '@reduxjs/toolkit';
import { RequiredControlNetProcessorNode } from '../store/types'; import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { memo, useMemo } from 'react';
import CannyProcessor from './processors/CannyProcessor'; import CannyProcessor from './processors/CannyProcessor';
import HedProcessor from './processors/HedProcessor';
import LineartProcessor from './processors/LineartProcessor';
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
import ContentShuffleProcessor from './processors/ContentShuffleProcessor'; import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
import HedProcessor from './processors/HedProcessor';
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
import LineartProcessor from './processors/LineartProcessor';
import MediapipeFaceProcessor from './processors/MediapipeFaceProcessor'; import MediapipeFaceProcessor from './processors/MediapipeFaceProcessor';
import MidasDepthProcessor from './processors/MidasDepthProcessor'; import MidasDepthProcessor from './processors/MidasDepthProcessor';
import MlsdImageProcessor from './processors/MlsdImageProcessor'; import MlsdImageProcessor from './processors/MlsdImageProcessor';
@ -15,23 +18,45 @@ import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
export type ControlNetProcessorProps = { export type ControlNetProcessorProps = {
controlNetId: string; controlNetId: string;
processorNode: RequiredControlNetProcessorNode;
}; };
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => { const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
const { controlNetId, processorNode } = props; const { controlNetId } = props;
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, processorNode } =
controlNet.controlNets[controlNetId];
return { isEnabled, processorNode };
},
defaultSelectorOptions
),
[controlNetId]
);
const { isEnabled, processorNode } = useAppSelector(selector);
if (processorNode.type === 'canny_image_processor') { if (processorNode.type === 'canny_image_processor') {
return ( return (
<CannyProcessor <CannyProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
if (processorNode.type === 'hed_image_processor') { if (processorNode.type === 'hed_image_processor') {
return ( return (
<HedProcessor controlNetId={controlNetId} processorNode={processorNode} /> <HedProcessor
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
); );
} }
@ -40,6 +65,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<LineartProcessor <LineartProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -49,6 +75,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<ContentShuffleProcessor <ContentShuffleProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -58,6 +85,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<LineartAnimeProcessor <LineartAnimeProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -67,6 +95,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MediapipeFaceProcessor <MediapipeFaceProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -76,6 +105,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MidasDepthProcessor <MidasDepthProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -85,6 +115,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<MlsdImageProcessor <MlsdImageProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -94,6 +125,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<NormalBaeProcessor <NormalBaeProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -103,6 +135,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<OpenposeProcessor <OpenposeProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -112,6 +145,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<PidiProcessor <PidiProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }
@ -121,6 +155,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
<ZoeDepthProcessor <ZoeDepthProcessor
controlNetId={controlNetId} controlNetId={controlNetId}
processorNode={processorNode} processorNode={processorNode}
isEnabled={isEnabled}
/> />
); );
} }

View File

@ -1,18 +1,36 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice'; import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback, useMemo } from 'react';
type Props = { type Props = {
controlNetId: string; controlNetId: string;
shouldAutoConfig: boolean;
}; };
const ParamControlNetShouldAutoConfig = (props: Props) => { const ParamControlNetShouldAutoConfig = (props: Props) => {
const { controlNetId, shouldAutoConfig } = props; const { controlNetId } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke(); const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, shouldAutoConfig } =
controlNet.controlNets[controlNetId];
return { isEnabled, shouldAutoConfig };
},
defaultSelectorOptions
),
[controlNetId]
);
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
const isBusy = useAppSelector(selectIsBusy);
const handleShouldAutoConfigChanged = useCallback(() => { const handleShouldAutoConfigChanged = useCallback(() => {
dispatch(controlNetAutoConfigToggled({ controlNetId })); dispatch(controlNetAutoConfigToggled({ controlNetId }));
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
@ -23,7 +41,7 @@ const ParamControlNetShouldAutoConfig = (props: Props) => {
aria-label="Auto configure processor" aria-label="Auto configure processor"
isChecked={shouldAutoConfig} isChecked={shouldAutoConfig}
onChange={handleShouldAutoConfigChanged} onChange={handleShouldAutoConfigChanged}
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
); );
}; };

View File

@ -1,5 +1,4 @@
import { import {
ChakraProps,
FormControl, FormControl,
FormLabel, FormLabel,
HStack, HStack,
@ -10,34 +9,41 @@ import {
RangeSliderTrack, RangeSliderTrack,
Tooltip, Tooltip,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { import {
controlNetBeginStepPctChanged, controlNetBeginStepPctChanged,
controlNetEndStepPctChanged, controlNetEndStepPctChanged,
} from 'features/controlNet/store/controlNetSlice'; } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
mt: 1.5,
fontSize: '2xs',
fontWeight: '500',
color: 'base.400',
};
type Props = { type Props = {
controlNetId: string; controlNetId: string;
beginStepPct: number;
endStepPct: number;
mini?: boolean;
}; };
const formatPct = (v: number) => `${Math.round(v * 100)}%`; const formatPct = (v: number) => `${Math.round(v * 100)}%`;
const ParamControlNetBeginEnd = (props: Props) => { const ParamControlNetBeginEnd = (props: Props) => {
const { controlNetId, beginStepPct, mini = false, endStepPct } = props; const { controlNetId } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { beginStepPct, endStepPct, isEnabled } =
controlNet.controlNets[controlNetId];
return { beginStepPct, endStepPct, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { beginStepPct, endStepPct, isEnabled } = useAppSelector(selector);
const handleStepPctChanged = useCallback( const handleStepPctChanged = useCallback(
(v: number[]) => { (v: number[]) => {
@ -55,7 +61,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
return ( return (
<FormControl> <FormControl isDisabled={!isEnabled}>
<FormLabel>Begin / End Step Percentage</FormLabel> <FormLabel>Begin / End Step Percentage</FormLabel>
<HStack w="100%" gap={2} alignItems="center"> <HStack w="100%" gap={2} alignItems="center">
<RangeSlider <RangeSlider
@ -66,6 +72,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
max={1} max={1}
step={0.01} step={0.01}
minStepsBetweenThumbs={5} minStepsBetweenThumbs={5}
isDisabled={!isEnabled}
> >
<RangeSliderTrack> <RangeSliderTrack>
<RangeSliderFilledTrack /> <RangeSliderFilledTrack />
@ -76,38 +83,33 @@ const ParamControlNetBeginEnd = (props: Props) => {
<Tooltip label={formatPct(endStepPct)} placement="top" hasArrow> <Tooltip label={formatPct(endStepPct)} placement="top" hasArrow>
<RangeSliderThumb index={1} /> <RangeSliderThumb index={1} />
</Tooltip> </Tooltip>
{!mini && ( <RangeSliderMark
<> value={0}
<RangeSliderMark sx={{
value={0} insetInlineStart: '0 !important',
sx={{ insetInlineEnd: 'unset !important',
insetInlineStart: '0 !important', }}
insetInlineEnd: 'unset !important', >
...SLIDER_MARK_STYLES, 0%
}} </RangeSliderMark>
> <RangeSliderMark
0% value={0.5}
</RangeSliderMark> sx={{
<RangeSliderMark insetInlineStart: '50% !important',
value={0.5} transform: 'translateX(-50%)',
sx={{ }}
...SLIDER_MARK_STYLES, >
}} 50%
> </RangeSliderMark>
50% <RangeSliderMark
</RangeSliderMark> value={1}
<RangeSliderMark sx={{
value={1} insetInlineStart: 'unset !important',
sx={{ insetInlineEnd: '0 !important',
insetInlineStart: 'unset !important', }}
insetInlineEnd: '0 !important', >
...SLIDER_MARK_STYLES, 100%
}} </RangeSliderMark>
>
100%
</RangeSliderMark>
</>
)}
</RangeSlider> </RangeSlider>
</HStack> </HStack>
</FormControl> </FormControl>

View File

@ -1,15 +1,17 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { import {
ControlModes, ControlModes,
controlNetControlModeChanged, controlNetControlModeChanged,
} from 'features/controlNet/store/controlNetSlice'; } from 'features/controlNet/store/controlNetSlice';
import { useCallback } from 'react'; import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
type ParamControlNetControlModeProps = { type ParamControlNetControlModeProps = {
controlNetId: string; controlNetId: string;
controlMode: string;
}; };
const CONTROL_MODE_DATA = [ const CONTROL_MODE_DATA = [
@ -22,8 +24,23 @@ const CONTROL_MODE_DATA = [
export default function ParamControlNetControlMode( export default function ParamControlNetControlMode(
props: ParamControlNetControlModeProps props: ParamControlNetControlModeProps
) { ) {
const { controlNetId, controlMode = false } = props; const { controlNetId } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { controlMode, isEnabled } =
controlNet.controlNets[controlNetId];
return { controlMode, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { controlMode, isEnabled } = useAppSelector(selector);
const { t } = useTranslation(); const { t } = useTranslation();
@ -36,7 +53,8 @@ export default function ParamControlNetControlMode(
return ( return (
<IAIMantineSelect <IAIMantineSelect
label={t('parameters.controlNetControlMode')} disabled={!isEnabled}
label="Control Mode"
data={CONTROL_MODE_DATA} data={CONTROL_MODE_DATA}
value={String(controlMode)} value={String(controlMode)}
onChange={handleControlModeChange} onChange={handleControlModeChange}

View File

@ -29,6 +29,9 @@ const ParamControlNetFeatureToggle = () => {
label="Enable ControlNet" label="Enable ControlNet"
isChecked={isEnabled} isChecked={isEnabled}
onChange={handleChange} onChange={handleChange}
formControlProps={{
width: '100%',
}}
/> />
); );
}; };

View File

@ -1,28 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { controlNetToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetIsEnabledProps = {
controlNetId: string;
isEnabled: boolean;
};
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
const { controlNetId, isEnabled } = props;
const dispatch = useAppDispatch();
const handleIsEnabledChanged = useCallback(() => {
dispatch(controlNetToggled({ controlNetId }));
}, [dispatch, controlNetId]);
return (
<IAISwitch
label="Enabled"
isChecked={isEnabled}
onChange={handleIsEnabledChanged}
/>
);
};
export default memo(ParamControlNetIsEnabled);

View File

@ -1,36 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIFullCheckbox from 'common/components/IAIFullCheckbox';
import IAISwitch from 'common/components/IAISwitch';
import {
controlNetToggled,
isControlNetImagePreprocessedToggled,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetIsEnabledProps = {
controlNetId: string;
isControlImageProcessed: boolean;
};
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
const { controlNetId, isControlImageProcessed } = props;
const dispatch = useAppDispatch();
const handleIsControlImageProcessedToggled = useCallback(() => {
dispatch(
isControlNetImagePreprocessedToggled({
controlNetId,
})
);
}, [controlNetId, dispatch]);
return (
<IAISwitch
label="Preprocess"
isChecked={isControlImageProcessed}
onChange={handleIsControlImageProcessedToggled}
/>
);
};
export default memo(ParamControlNetIsEnabled);

View File

@ -1,59 +1,118 @@
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSearchableSelect, { import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
IAISelectDataType, import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
} from 'common/components/IAIMantineSearchableSelect'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import {
CONTROLNET_MODELS,
ControlNetModelName,
} from 'features/controlNet/store/constants';
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import { configSelector } from 'features/system/store/configSelectors'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { map } from 'lodash-es'; import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { memo, useCallback } from 'react'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
type ParamControlNetModelProps = { type ParamControlNetModelProps = {
controlNetId: string; controlNetId: string;
model: ControlNetModelName;
}; };
const selector = createSelector(configSelector, (config) => { const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const controlNetModels: IAISelectDataType[] = map(CONTROLNET_MODELS, (m) => ({ const { controlNetId } = props;
label: m.label, const dispatch = useAppDispatch();
value: m.type, const isBusy = useAppSelector(selectIsBusy);
})).filter(
(d) => const selector = useMemo(
!config.sd.disabledControlNetModels.includes( () =>
d.value as ControlNetModelName createSelector(
) stateSelector,
({ generation, controlNet }) => {
const { model } = generation;
const controlNetModel = controlNet.controlNets[controlNetId]?.model;
const isEnabled = controlNet.controlNets[controlNetId]?.isEnabled;
return { mainModel: model, controlNetModel, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
); );
return controlNetModels; const { mainModel, controlNetModel, isEnabled } = useAppSelector(selector);
});
const ParamControlNetModel = (props: ParamControlNetModelProps) => { const { data: controlNetModels } = useGetControlNetModelsQuery();
const { controlNetId, model } = props;
const controlNetModels = useAppSelector(selector); const data = useMemo(() => {
const dispatch = useAppDispatch(); if (!controlNetModels) {
const isReady = useIsReadyToInvoke(); return [];
}
const data: SelectItem[] = [];
forEach(controlNetModels.entities, (model, id) => {
if (!model) {
return;
}
const disabled = model?.base_model !== mainModel?.base_model;
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
disabled,
tooltip: disabled
? `Incompatible base model: ${model.base_model}`
: undefined,
});
});
return data;
}, [controlNetModels, mainModel?.base_model]);
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
controlNetModels?.entities[
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
] ?? null,
[
controlNetModel?.base_model,
controlNetModel?.model_name,
controlNetModels?.entities,
]
);
const handleModelChanged = useCallback( const handleModelChanged = useCallback(
(val: string | null) => { (v: string | null) => {
// TODO: do not cast if (!v) {
const model = val as ControlNetModelName; return;
dispatch(controlNetModelChanged({ controlNetId, model })); }
const newControlNetModel = modelIdToControlNetModelParam(v);
if (!newControlNetModel) {
return;
}
dispatch(
controlNetModelChanged({ controlNetId, model: newControlNetModel })
);
}, },
[controlNetId, dispatch] [controlNetId, dispatch]
); );
return ( return (
<IAIMantineSearchableSelect <IAIMantineSearchableSelect
data={controlNetModels} itemComponent={IAIMantineSelectItemWithTooltip}
value={model} data={data}
error={
!selectedModel || mainModel?.base_model !== selectedModel.base_model
}
placeholder="Select a model"
value={selectedModel?.id ?? null}
onChange={handleModelChanged} onChange={handleModelChanged}
disabled={!isReady} disabled={isBusy || !isEnabled}
tooltip={model} tooltip={selectedModel?.description}
/> />
); );
}; };

View File

@ -1,24 +1,22 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect, { import IAIMantineSearchableSelect, {
IAISelectDataType, IAISelectDataType,
} from 'common/components/IAIMantineSearchableSelect'; } from 'common/components/IAIMantineSearchableSelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { configSelector } from 'features/system/store/configSelectors'; import { configSelector } from 'features/system/store/configSelectors';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { memo, useCallback } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { CONTROLNET_PROCESSORS } from '../../store/constants'; import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice'; import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
import { import { ControlNetProcessorType } from '../../store/types';
ControlNetProcessorNode, import { FormControl, FormLabel } from '@chakra-ui/react';
ControlNetProcessorType,
} from '../../store/types';
type ParamControlNetProcessorSelectProps = { type ParamControlNetProcessorSelectProps = {
controlNetId: string; controlNetId: string;
processorNode: ControlNetProcessorNode;
}; };
const selector = createSelector( const selector = createSelector(
@ -54,10 +52,24 @@ const selector = createSelector(
const ParamControlNetProcessorSelect = ( const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps props: ParamControlNetProcessorSelectProps
) => { ) => {
const { controlNetId, processorNode } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isReady = useIsReadyToInvoke(); const { controlNetId } = props;
const processorNodeSelector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, processorNode } =
controlNet.controlNets[controlNetId];
return { isEnabled, processorNode };
},
defaultSelectorOptions
),
[controlNetId]
);
const isBusy = useAppSelector(selectIsBusy);
const controlNetProcessors = useAppSelector(selector); const controlNetProcessors = useAppSelector(selector);
const { isEnabled, processorNode } = useAppSelector(processorNodeSelector);
const handleProcessorTypeChanged = useCallback( const handleProcessorTypeChanged = useCallback(
(v: string | null) => { (v: string | null) => {
@ -77,7 +89,7 @@ const ParamControlNetProcessorSelect = (
value={processorNode.type ?? 'canny_image_processor'} value={processorNode.type ?? 'canny_image_processor'}
data={controlNetProcessors} data={controlNetProcessors}
onChange={handleProcessorTypeChanged} onChange={handleProcessorTypeChanged}
disabled={!isReady} disabled={isBusy || !isEnabled}
/> />
); );
}; };

View File

@ -1,18 +1,32 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react'; import { memo, useCallback, useMemo } from 'react';
type ParamControlNetWeightProps = { type ParamControlNetWeightProps = {
controlNetId: string; controlNetId: string;
weight: number;
mini?: boolean;
}; };
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => { const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
const { controlNetId, weight, mini = false } = props; const { controlNetId } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { weight, isEnabled } = controlNet.controlNets[controlNetId];
return { weight, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { weight, isEnabled } = useAppSelector(selector);
const handleWeightChanged = useCallback( const handleWeightChanged = useCallback(
(weight: number) => { (weight: number) => {
dispatch(controlNetWeightChanged({ controlNetId, weight })); dispatch(controlNetWeightChanged({ controlNetId, weight }));
@ -22,15 +36,15 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
return ( return (
<IAISlider <IAISlider
isDisabled={!isEnabled}
label={'Weight'} label={'Weight'}
sliderFormLabelProps={{ pb: 2 }}
value={weight} value={weight}
onChange={handleWeightChanged} onChange={handleWeightChanged}
min={-1} min={0}
max={1} max={2}
step={0.01} step={0.01}
withSliderMarks={!mini} withSliderMarks
sliderMarks={[-1, 0, 1]} sliderMarks={[0, 1, 2]}
/> />
); );
}; };

View File

@ -1,22 +1,25 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation;
type CannyProcessorProps = { type CannyProcessorProps = {
controlNetId: string; controlNetId: string;
processorNode: RequiredCannyImageProcessorInvocation; processorNode: RequiredCannyImageProcessorInvocation;
isEnabled: boolean;
}; };
const CannyProcessor = (props: CannyProcessorProps) => { const CannyProcessor = (props: CannyProcessorProps) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { low_threshold, high_threshold } = processorNode; const { low_threshold, high_threshold } = processorNode;
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const handleLowThresholdChanged = useCallback( const handleLowThresholdChanged = useCallback(
@ -48,7 +51,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
return ( return (
<ProcessorWrapper> <ProcessorWrapper>
<IAISlider <IAISlider
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
label="Low Threshold" label="Low Threshold"
value={low_threshold} value={low_threshold}
onChange={handleLowThresholdChanged} onChange={handleLowThresholdChanged}
@ -60,7 +63,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
withSliderMarks withSliderMarks
/> />
<IAISlider <IAISlider
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
label="High Threshold" label="High Threshold"
value={high_threshold} value={high_threshold}
onChange={handleHighThresholdChanged} onChange={handleHighThresholdChanged}

View File

@ -4,20 +4,23 @@ import { RequiredContentShuffleImageProcessorInvocation } from 'features/control
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; import { useAppSelector } from 'app/store/storeHooks';
import { selectIsBusy } from 'features/system/store/systemSelectors';
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor
.default as RequiredContentShuffleImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredContentShuffleImageProcessorInvocation; processorNode: RequiredContentShuffleImageProcessorInvocation;
isEnabled: boolean;
}; };
const ContentShuffleProcessor = (props: Props) => { const ContentShuffleProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, w, h, f } = processorNode; const { image_resolution, detect_resolution, w, h, f } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -96,7 +99,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -108,7 +111,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="W" label="W"
@ -120,7 +123,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="H" label="H"
@ -132,7 +135,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="F" label="F"
@ -144,7 +147,7 @@ const ContentShuffleProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,25 +1,29 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor
.default as RequiredHedImageProcessorInvocation;
type HedProcessorProps = { type HedProcessorProps = {
controlNetId: string; controlNetId: string;
processorNode: RequiredHedImageProcessorInvocation; processorNode: RequiredHedImageProcessorInvocation;
isEnabled: boolean;
}; };
const HedPreprocessor = (props: HedProcessorProps) => { const HedPreprocessor = (props: HedProcessorProps) => {
const { const {
controlNetId, controlNetId,
processorNode: { detect_resolution, image_resolution, scribble }, processorNode: { detect_resolution, image_resolution, scribble },
isEnabled,
} = props; } = props;
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
@ -67,7 +71,7 @@ const HedPreprocessor = (props: HedProcessorProps) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -79,13 +83,13 @@ const HedPreprocessor = (props: HedProcessorProps) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISwitch <IAISwitch
label="Scribble" label="Scribble"
isChecked={scribble} isChecked={scribble}
onChange={handleScribbleChanged} onChange={handleScribbleChanged}
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor
.default as RequiredLineartAnimeImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredLineartAnimeImageProcessorInvocation; processorNode: RequiredLineartAnimeImageProcessorInvocation;
isEnabled: boolean;
}; };
const LineartAnimeProcessor = (props: Props) => { const LineartAnimeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution } = processorNode; const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -57,7 +60,7 @@ const LineartAnimeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -69,7 +72,7 @@ const LineartAnimeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor
.default as RequiredLineartImageProcessorInvocation;
type LineartProcessorProps = { type LineartProcessorProps = {
controlNetId: string; controlNetId: string;
processorNode: RequiredLineartImageProcessorInvocation; processorNode: RequiredLineartImageProcessorInvocation;
isEnabled: boolean;
}; };
const LineartProcessor = (props: LineartProcessorProps) => { const LineartProcessor = (props: LineartProcessorProps) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, coarse } = processorNode; const { image_resolution, detect_resolution, coarse } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -65,7 +68,7 @@ const LineartProcessor = (props: LineartProcessorProps) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -77,13 +80,13 @@ const LineartProcessor = (props: LineartProcessorProps) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISwitch <IAISwitch
label="Coarse" label="Coarse"
isChecked={coarse} isChecked={coarse}
onChange={handleCoarseChanged} onChange={handleCoarseChanged}
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor
.default as RequiredMediapipeFaceProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredMediapipeFaceProcessorInvocation; processorNode: RequiredMediapipeFaceProcessorInvocation;
isEnabled: boolean;
}; };
const MediapipeFaceProcessor = (props: Props) => { const MediapipeFaceProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { max_faces, min_confidence } = processorNode; const { max_faces, min_confidence } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleMaxFacesChanged = useCallback( const handleMaxFacesChanged = useCallback(
(v: number) => { (v: number) => {
@ -53,7 +56,7 @@ const MediapipeFaceProcessor = (props: Props) => {
max={20} max={20}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Min Confidence" label="Min Confidence"
@ -66,7 +69,7 @@ const MediapipeFaceProcessor = (props: Props) => {
step={0.01} step={0.01}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor
.default as RequiredMidasDepthImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredMidasDepthImageProcessorInvocation; processorNode: RequiredMidasDepthImageProcessorInvocation;
isEnabled: boolean;
}; };
const MidasDepthProcessor = (props: Props) => { const MidasDepthProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { a_mult, bg_th } = processorNode; const { a_mult, bg_th } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleAMultChanged = useCallback( const handleAMultChanged = useCallback(
(v: number) => { (v: number) => {
@ -54,7 +57,7 @@ const MidasDepthProcessor = (props: Props) => {
step={0.01} step={0.01}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="bg_th" label="bg_th"
@ -67,7 +70,7 @@ const MidasDepthProcessor = (props: Props) => {
step={0.01} step={0.01}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor
.default as RequiredMlsdImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredMlsdImageProcessorInvocation; processorNode: RequiredMlsdImageProcessorInvocation;
isEnabled: boolean;
}; };
const MlsdImageProcessor = (props: Props) => { const MlsdImageProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode; const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -79,7 +82,7 @@ const MlsdImageProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -91,7 +94,7 @@ const MlsdImageProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="W" label="W"
@ -104,7 +107,7 @@ const MlsdImageProcessor = (props: Props) => {
step={0.01} step={0.01}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="H" label="H"
@ -117,7 +120,7 @@ const MlsdImageProcessor = (props: Props) => {
step={0.01} step={0.01}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,23 +1,26 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor
.default as RequiredNormalbaeImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredNormalbaeImageProcessorInvocation; processorNode: RequiredNormalbaeImageProcessorInvocation;
isEnabled: boolean;
}; };
const NormalBaeProcessor = (props: Props) => { const NormalBaeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution } = processorNode; const { image_resolution, detect_resolution } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -57,7 +60,7 @@ const NormalBaeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -69,7 +72,7 @@ const NormalBaeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor
.default as RequiredOpenposeImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredOpenposeImageProcessorInvocation; processorNode: RequiredOpenposeImageProcessorInvocation;
isEnabled: boolean;
}; };
const OpenposeProcessor = (props: Props) => { const OpenposeProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, hand_and_face } = processorNode; const { image_resolution, detect_resolution, hand_and_face } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -65,7 +68,7 @@ const OpenposeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -77,13 +80,13 @@ const OpenposeProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISwitch <IAISwitch
label="Hand and Face" label="Hand and Face"
isChecked={hand_and_face} isChecked={hand_and_face}
onChange={handleHandAndFaceChanged} onChange={handleHandAndFaceChanged}
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -1,24 +1,27 @@
import { useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch'; import IAISwitch from 'common/components/IAISwitch';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types'; import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { ChangeEvent, memo, useCallback } from 'react'; import { ChangeEvent, memo, useCallback } from 'react';
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged'; import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
import ProcessorWrapper from './common/ProcessorWrapper'; import ProcessorWrapper from './common/ProcessorWrapper';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default; const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor
.default as RequiredPidiImageProcessorInvocation;
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredPidiImageProcessorInvocation; processorNode: RequiredPidiImageProcessorInvocation;
isEnabled: boolean;
}; };
const PidiProcessor = (props: Props) => { const PidiProcessor = (props: Props) => {
const { controlNetId, processorNode } = props; const { controlNetId, processorNode, isEnabled } = props;
const { image_resolution, detect_resolution, scribble, safe } = processorNode; const { image_resolution, detect_resolution, scribble, safe } = processorNode;
const processorChanged = useProcessorNodeChanged(); const processorChanged = useProcessorNodeChanged();
const isReady = useIsReadyToInvoke(); const isBusy = useAppSelector(selectIsBusy);
const handleDetectResolutionChanged = useCallback( const handleDetectResolutionChanged = useCallback(
(v: number) => { (v: number) => {
@ -72,7 +75,7 @@ const PidiProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISlider <IAISlider
label="Image Resolution" label="Image Resolution"
@ -84,7 +87,7 @@ const PidiProcessor = (props: Props) => {
max={4096} max={4096}
withInput withInput
withSliderMarks withSliderMarks
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
<IAISwitch <IAISwitch
label="Scribble" label="Scribble"
@ -95,7 +98,7 @@ const PidiProcessor = (props: Props) => {
label="Safe" label="Safe"
isChecked={safe} isChecked={safe}
onChange={handleSafeChanged} onChange={handleSafeChanged}
isDisabled={!isReady} isDisabled={isBusy || !isEnabled}
/> />
</ProcessorWrapper> </ProcessorWrapper>
); );

View File

@ -4,6 +4,7 @@ import { memo } from 'react';
type Props = { type Props = {
controlNetId: string; controlNetId: string;
processorNode: RequiredZoeDepthImageProcessorInvocation; processorNode: RequiredZoeDepthImageProcessorInvocation;
isEnabled: boolean;
}; };
const ZoeDepthProcessor = (props: Props) => { const ZoeDepthProcessor = (props: Props) => {

View File

@ -173,91 +173,17 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
}, },
}; };
type ControlNetModelsDict = Record<string, ControlNetModel>; export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
[key: string]: ControlNetProcessorType;
type ControlNetModel = { } = {
type: string; canny: 'canny_image_processor',
label: string; mlsd: 'mlsd_image_processor',
description?: string; depth: 'midas_depth_image_processor',
defaultProcessor?: ControlNetProcessorType; bae: 'normalbae_image_processor',
lineart: 'lineart_image_processor',
lineart_anime: 'lineart_anime_image_processor',
softedge: 'hed_image_processor',
shuffle: 'content_shuffle_image_processor',
openpose: 'openpose_image_processor',
mediapipe: 'mediapipe_face_processor',
}; };
export const CONTROLNET_MODELS: ControlNetModelsDict = {
'lllyasviel/control_v11p_sd15_canny': {
type: 'lllyasviel/control_v11p_sd15_canny',
label: 'Canny',
defaultProcessor: 'canny_image_processor',
},
'lllyasviel/control_v11p_sd15_inpaint': {
type: 'lllyasviel/control_v11p_sd15_inpaint',
label: 'Inpaint',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_mlsd': {
type: 'lllyasviel/control_v11p_sd15_mlsd',
label: 'M-LSD',
defaultProcessor: 'mlsd_image_processor',
},
'lllyasviel/control_v11f1p_sd15_depth': {
type: 'lllyasviel/control_v11f1p_sd15_depth',
label: 'Depth',
defaultProcessor: 'midas_depth_image_processor',
},
'lllyasviel/control_v11p_sd15_normalbae': {
type: 'lllyasviel/control_v11p_sd15_normalbae',
label: 'Normal Map (BAE)',
defaultProcessor: 'normalbae_image_processor',
},
'lllyasviel/control_v11p_sd15_seg': {
type: 'lllyasviel/control_v11p_sd15_seg',
label: 'Segmentation',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_lineart': {
type: 'lllyasviel/control_v11p_sd15_lineart',
label: 'Lineart',
defaultProcessor: 'lineart_image_processor',
},
'lllyasviel/control_v11p_sd15s2_lineart_anime': {
type: 'lllyasviel/control_v11p_sd15s2_lineart_anime',
label: 'Lineart Anime',
defaultProcessor: 'lineart_anime_image_processor',
},
'lllyasviel/control_v11p_sd15_scribble': {
type: 'lllyasviel/control_v11p_sd15_scribble',
label: 'Scribble',
defaultProcessor: 'none',
},
'lllyasviel/control_v11p_sd15_softedge': {
type: 'lllyasviel/control_v11p_sd15_softedge',
label: 'Soft Edge',
defaultProcessor: 'hed_image_processor',
},
'lllyasviel/control_v11e_sd15_shuffle': {
type: 'lllyasviel/control_v11e_sd15_shuffle',
label: 'Content Shuffle',
defaultProcessor: 'content_shuffle_image_processor',
},
'lllyasviel/control_v11p_sd15_openpose': {
type: 'lllyasviel/control_v11p_sd15_openpose',
label: 'Openpose',
defaultProcessor: 'openpose_image_processor',
},
'lllyasviel/control_v11f1e_sd15_tile': {
type: 'lllyasviel/control_v11f1e_sd15_tile',
label: 'Tile (experimental)',
defaultProcessor: 'none',
},
'lllyasviel/control_v11e_sd15_ip2p': {
type: 'lllyasviel/control_v11e_sd15_ip2p',
label: 'Pix2Pix (experimental)',
defaultProcessor: 'none',
},
'CrucibleAI/ControlNetMediaPipeFace': {
type: 'CrucibleAI/ControlNetMediaPipeFace',
label: 'Mediapipe Face',
defaultProcessor: 'mediapipe_face_processor',
},
};
export type ControlNetModelName = keyof typeof CONTROLNET_MODELS;

View File

@ -1,22 +1,20 @@
import { PayloadAction } from '@reduxjs/toolkit'; import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { ImageDTO } from 'services/api/types'; import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
import { cloneDeep, forEach } from 'lodash-es';
import { imageDeleted } from 'services/api/thunks/image';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
CONTROLNET_PROCESSORS,
} from './constants';
import { import {
ControlNetProcessorType, ControlNetProcessorType,
RequiredCannyImageProcessorInvocation, RequiredCannyImageProcessorInvocation,
RequiredControlNetProcessorNode, RequiredControlNetProcessorNode,
} from './types'; } from './types';
import {
CONTROLNET_MODELS,
CONTROLNET_PROCESSORS,
ControlNetModelName,
} from './constants';
import { controlNetImageProcessed } from './actions';
import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image';
import { forEach } from 'lodash-es';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions';
export type ControlModes = export type ControlModes =
| 'balanced' | 'balanced'
@ -26,7 +24,7 @@ export type ControlModes =
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = { export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true, isEnabled: true,
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type, model: null,
weight: 1, weight: 1,
beginStepPct: 0, beginStepPct: 0,
endStepPct: 1, endStepPct: 1,
@ -42,7 +40,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
export type ControlNetConfig = { export type ControlNetConfig = {
controlNetId: string; controlNetId: string;
isEnabled: boolean; isEnabled: boolean;
model: ControlNetModelName; model: ControlNetModelParam | null;
weight: number; weight: number;
beginStepPct: number; beginStepPct: number;
endStepPct: number; endStepPct: number;
@ -86,6 +84,19 @@ export const controlNetSlice = createSlice({
controlNetId, controlNetId,
}; };
}, },
controlNetDuplicated: (
state,
action: PayloadAction<{
sourceControlNetId: string;
newControlNetId: string;
}>
) => {
const { sourceControlNetId, newControlNetId } = action.payload;
const newControlnet = cloneDeep(state.controlNets[sourceControlNetId]);
newControlnet.controlNetId = newControlNetId;
state.controlNets[newControlNetId] = newControlnet;
},
controlNetAddedFromImage: ( controlNetAddedFromImage: (
state, state,
action: PayloadAction<{ controlNetId: string; controlImage: string }> action: PayloadAction<{ controlNetId: string; controlImage: string }>
@ -147,7 +158,7 @@ export const controlNetSlice = createSlice({
state, state,
action: PayloadAction<{ action: PayloadAction<{
controlNetId: string; controlNetId: string;
model: ControlNetModelName; model: ControlNetModelParam;
}> }>
) => { ) => {
const { controlNetId, model } = action.payload; const { controlNetId, model } = action.payload;
@ -155,7 +166,15 @@ export const controlNetSlice = createSlice({
state.controlNets[controlNetId].processedControlImage = null; state.controlNets[controlNetId].processedControlImage = null;
if (state.controlNets[controlNetId].shouldAutoConfig) { if (state.controlNets[controlNetId].shouldAutoConfig) {
const processorType = CONTROLNET_MODELS[model].defaultProcessor; let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (model.model_name.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) { if (processorType) {
state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
@ -241,9 +260,19 @@ export const controlNetSlice = createSlice({
if (newShouldAutoConfig) { if (newShouldAutoConfig) {
// manage the processor for the user // manage the processor for the user
const processorType = let processorType: ControlNetProcessorType | undefined = undefined;
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
.defaultProcessor; for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (
state.controlNets[controlNetId].model?.model_name.includes(
modelSubstring
)
) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) { if (processorType) {
state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
@ -272,7 +301,8 @@ export const controlNetSlice = createSlice({
}); });
builder.addCase(imageDeleted.pending, (state, action) => { builder.addCase(imageDeleted.pending, (state, action) => {
// Preemptively remove the image from the gallery // Preemptively remove the image from all controlnets
// TODO: doesn't the imageusage stuff do this for us?
const { image_name } = action.meta.arg; const { image_name } = action.meta.arg;
forEach(state.controlNets, (c) => { forEach(state.controlNets, (c) => {
if (c.controlImage === image_name) { if (c.controlImage === image_name) {
@ -285,21 +315,6 @@ export const controlNetSlice = createSlice({
}); });
}); });
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload;
// forEach(state.controlNets, (c) => {
// if (c.controlImage?.image_name === image_name) {
// c.controlImage.image_url = image_url;
// c.controlImage.thumbnail_url = thumbnail_url;
// }
// if (c.processedControlImage?.image_name === image_name) {
// c.processedControlImage.image_url = image_url;
// c.processedControlImage.thumbnail_url = thumbnail_url;
// }
// });
// });
builder.addCase(appSocketInvocationError, (state, action) => { builder.addCase(appSocketInvocationError, (state, action) => {
state.pendingControlImages = []; state.pendingControlImages = [];
}); });
@ -313,6 +328,7 @@ export const controlNetSlice = createSlice({
export const { export const {
isControlNetEnabledToggled, isControlNetEnabledToggled,
controlNetAdded, controlNetAdded,
controlNetDuplicated,
controlNetAddedFromImage, controlNetAddedFromImage,
controlNetRemoved, controlNetRemoved,
controlNetImageChanged, controlNetImageChanged,

View File

@ -118,6 +118,20 @@ const GalleryImageGrid = () => {
); );
}, [dispatch, imageNames.length, galleryView]); }, [dispatch, imageNames.length, galleryView]);
useEffect(() => {
// Set up gallery scroler
const { current: root } = rootRef;
if (scroller && root) {
initialize({
target: root,
elements: {
viewport: scroller,
},
});
}
return () => osInstance()?.destroy();
}, [scroller, initialize, osInstance]);
const handleEndReached = useMemo(() => { const handleEndReached = useMemo(() => {
if (areMoreAvailable) { if (areMoreAvailable) {
return handleLoadMoreImages; return handleLoadMoreImages;

View File

@ -3,6 +3,7 @@ import { LoRAModelParam } from 'features/parameters/types/parameterSchemas';
import { LoRAModelConfigEntity } from 'services/api/endpoints/models'; import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
export type LoRA = LoRAModelParam & { export type LoRA = LoRAModelParam & {
id: string;
weight: number; weight: number;
}; };
@ -24,7 +25,7 @@ export const loraSlice = createSlice({
reducers: { reducers: {
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => { loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
const { model_name, id, base_model } = action.payload; const { model_name, id, base_model } = action.payload;
state.loras[id] = { model_name, base_model, ...defaultLoRAConfig }; state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
}, },
loraRemoved: (state, action: PayloadAction<string>) => { loraRemoved: (state, action: PayloadAction<string>) => {
const id = action.payload; const id = action.payload;

View File

@ -1,4 +1,5 @@
import { Flex, Heading, Icon, Tooltip } from '@chakra-ui/react'; import { Flex, Heading, Icon, Tooltip } from '@chakra-ui/react';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/hooks/useBuildInvocation';
import { memo } from 'react'; import { memo } from 'react';
import { FaInfoCircle } from 'react-icons/fa'; import { FaInfoCircle } from 'react-icons/fa';
@ -12,6 +13,7 @@ const IAINodeHeader = (props: IAINodeHeaderProps) => {
const { nodeId, title, description } = props; const { nodeId, title, description } = props;
return ( return (
<Flex <Flex
className={DRAG_HANDLE_CLASSNAME}
sx={{ sx={{
borderTopRadius: 'md', borderTopRadius: 'md',
alignItems: 'center', alignItems: 'center',

View File

@ -1,25 +1,25 @@
import {
InputFieldTemplate,
InputFieldValue,
InvocationTemplate,
} from 'features/nodes/types/types';
import { memo, ReactNode, useCallback } from 'react';
import { map } from 'lodash-es';
import { useAppSelector } from 'app/store/storeHooks';
import { RootState } from 'app/store/store';
import { import {
Box, Box,
Divider,
Flex, Flex,
FormControl, FormControl,
FormLabel, FormLabel,
HStack, HStack,
Tooltip, Tooltip,
Divider,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import FieldHandle from '../FieldHandle'; import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection'; import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
import InputFieldComponent from '../InputFieldComponent';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import {
InputFieldTemplate,
InputFieldValue,
InvocationTemplate,
} from 'features/nodes/types/types';
import { map } from 'lodash-es';
import { ReactNode, memo, useCallback } from 'react';
import FieldHandle from '../FieldHandle';
import InputFieldComponent from '../InputFieldComponent';
interface IAINodeInputProps { interface IAINodeInputProps {
nodeId: string; nodeId: string;
@ -35,6 +35,7 @@ function IAINodeInput(props: IAINodeInputProps) {
return ( return (
<Box <Box
className="nopan"
position="relative" position="relative"
borderColor={ borderColor={
!template !template
@ -136,7 +137,7 @@ const IAINodeInputs = (props: IAINodeInputsProps) => {
}); });
return ( return (
<Flex flexDir="column" gap={2} p={2}> <Flex className="nopan" flexDir="column" gap={2} p={2}>
{IAINodeInputsToRender} {IAINodeInputsToRender}
</Flex> </Flex>
); );

View File

@ -7,6 +7,7 @@ import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
import ColorInputFieldComponent from './fields/ColorInputFieldComponent'; import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent'; import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import ControlInputFieldComponent from './fields/ControlInputFieldComponent'; import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent';
import EnumInputFieldComponent from './fields/EnumInputFieldComponent'; import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent'; import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
@ -174,6 +175,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
); );
} }
if (type === 'controlnet_model' && template.type === 'controlnet_model') {
return (
<ControlNetModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'array' && template.type === 'array') { if (type === 'array' && template.type === 'array') {
return ( return (
<ArrayInputFieldComponent <ArrayInputFieldComponent

View File

@ -23,7 +23,14 @@ export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
if (!template) { if (!template) {
return ( return (
<NodeWrapper selected={selected}> <NodeWrapper selected={selected}>
<Flex sx={{ alignItems: 'center', justifyContent: 'center' }}> <Flex
className="nopan"
sx={{
alignItems: 'center',
justifyContent: 'center',
cursor: 'auto',
}}
>
<Icon <Icon
as={FaExclamationCircle} as={FaExclamationCircle}
sx={{ sx={{
@ -46,7 +53,9 @@ export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
description={template.description} description={template.description}
/> />
<Flex <Flex
className={'nopan'}
sx={{ sx={{
cursor: 'auto',
flexDirection: 'column', flexDirection: 'column',
borderBottomRadius: 'md', borderBottomRadius: 'md',
py: 2, py: 2,

View File

@ -2,6 +2,8 @@ import { Box, useToken } from '@chakra-ui/react';
import { NODE_MIN_WIDTH } from 'app/constants'; import { NODE_MIN_WIDTH } from 'app/constants';
import { PropsWithChildren } from 'react'; import { PropsWithChildren } from 'react';
import { DRAG_HANDLE_CLASSNAME } from '../hooks/useBuildInvocation';
import { useAppSelector } from 'app/store/storeHooks';
type NodeWrapperProps = PropsWithChildren & { type NodeWrapperProps = PropsWithChildren & {
selected: boolean; selected: boolean;
@ -13,8 +15,11 @@ const NodeWrapper = (props: NodeWrapperProps) => {
'dark-lg', 'dark-lg',
]); ]);
const shift = useAppSelector((state) => state.hotkeys.shift);
return ( return (
<Box <Box
className={shift ? DRAG_HANDLE_CLASSNAME : 'nopan'}
sx={{ sx={{
position: 'relative', position: 'relative',
borderRadius: 'md', borderRadius: 'md',

View File

@ -21,6 +21,7 @@ const ProgressImageNode = (props: NodeProps<InvocationValue>) => {
/> />
<Flex <Flex
className="nopan"
sx={{ sx={{
flexDirection: 'column', flexDirection: 'column',
borderBottomRadius: 'md', borderBottomRadius: 'md',

View File

@ -0,0 +1,103 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
ControlNetModelInputFieldTemplate,
ControlNetModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const ControlNetModelInputFieldComponent = (
props: FieldComponentProps<
ControlNetModelInputFieldValue,
ControlNetModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const controlNetModel = field.value;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: controlNetModels } = useGetControlNetModelsQuery();
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
controlNetModels?.entities[
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
] ?? null,
[
controlNetModel?.base_model,
controlNetModel?.model_name,
controlNetModels?.entities,
]
);
const data = useMemo(() => {
if (!controlNetModels) {
return [];
}
const data: SelectItem[] = [];
forEach(controlNetModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [controlNetModels]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newControlNetModel = modelIdToControlNetModelParam(v);
if (!newControlNetModel) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: newControlNetModel,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
}
value={selectedModel?.id ?? null}
placeholder="Pick one"
error={!selectedModel}
data={data}
onChange={handleValueChanged}
/>
);
};
export default memo(ControlNetModelInputFieldComponent);

View File

@ -6,6 +6,7 @@ import {
NumberInputStepper, NumberInputStepper,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { numberStringRegex } from 'common/components/IAINumberInput';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
FloatInputFieldTemplate, FloatInputFieldTemplate,
@ -13,7 +14,7 @@ import {
IntegerInputFieldTemplate, IntegerInputFieldTemplate,
IntegerInputFieldValue, IntegerInputFieldValue,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { memo } from 'react'; import { memo, useEffect, useState } from 'react';
import { FieldComponentProps } from './types'; import { FieldComponentProps } from './types';
const NumberInputFieldComponent = ( const NumberInputFieldComponent = (
@ -23,17 +24,42 @@ const NumberInputFieldComponent = (
> >
) => { ) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [valueAsString, setValueAsString] = useState<string>(
String(field.value)
);
const handleValueChanged = (_: string, value: number) => { const handleValueChanged = (v: string) => {
dispatch(fieldValueChanged({ nodeId, fieldName: field.name, value })); setValueAsString(v);
// This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc.
if (!v.match(numberStringRegex)) {
// Cast the value to number. Floor it if it should be an integer.
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value:
props.template.type === 'integer'
? Math.floor(Number(v))
: Number(v),
})
);
}
}; };
useEffect(() => {
if (
!valueAsString.match(numberStringRegex) &&
field.value !== Number(valueAsString)
) {
setValueAsString(String(field.value));
}
}, [field.value, valueAsString]);
return ( return (
<NumberInput <NumberInput
onChange={handleValueChanged} onChange={handleValueChanged}
value={field.value} value={valueAsString}
step={props.template.type === 'integer' ? 1 : 0.1} step={props.template.type === 'integer' ? 1 : 0.1}
precision={props.template.type === 'integer' ? 0 : 3} precision={props.template.type === 'integer' ? 0 : 3}
> >

View File

@ -18,6 +18,12 @@ const templatesSelector = createSelector(
(nodes) => nodes.invocationTemplates (nodes) => nodes.invocationTemplates
); );
export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle';
export const SHARED_NODE_PROPERTIES: Partial<Node> = {
dragHandle: `.${DRAG_HANDLE_CLASSNAME}`,
};
export const useBuildInvocation = () => { export const useBuildInvocation = () => {
const invocationTemplates = useAppSelector(templatesSelector); const invocationTemplates = useAppSelector(templatesSelector);
@ -32,6 +38,7 @@ export const useBuildInvocation = () => {
}); });
const node: Node = { const node: Node = {
...SHARED_NODE_PROPERTIES,
id: 'progress_image', id: 'progress_image',
type: 'progress_image', type: 'progress_image',
position: { x: x, y: y }, position: { x: x, y: y },
@ -91,6 +98,7 @@ export const useBuildInvocation = () => {
}); });
const invocation: Node<InvocationValue> = { const invocation: Node<InvocationValue> = {
...SHARED_NODE_PROPERTIES,
id: nodeId, id: nodeId,
type: 'invocation', type: 'invocation',
position: { x: x, y: y }, position: { x: x, y: y },

View File

@ -1,6 +1,7 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit'; import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { import {
ControlNetModelParam,
LoRAModelParam, LoRAModelParam,
MainModelParam, MainModelParam,
VaeModelParam, VaeModelParam,
@ -81,7 +82,8 @@ const nodesSlice = createSlice({
| ImageField[] | ImageField[]
| MainModelParam | MainModelParam
| VaeModelParam | VaeModelParam
| LoRAModelParam; | LoRAModelParam
| ControlNetModelParam;
}> }>
) => { ) => {
const { nodeId, fieldName, value } = action.payload; const { nodeId, fieldName, value } = action.payload;

View File

@ -19,6 +19,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
model: 'model', model: 'model',
vae_model: 'vae_model', vae_model: 'vae_model',
lora_model: 'lora_model', lora_model: 'lora_model',
controlnet_model: 'controlnet_model',
ControlNetModelField: 'controlnet_model',
array: 'array', array: 'array',
item: 'item', item: 'item',
ColorField: 'color', ColorField: 'color',
@ -130,6 +132,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'LoRA', title: 'LoRA',
description: 'Models are models.', description: 'Models are models.',
}, },
controlnet_model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
title: 'ControlNet',
description: 'Models are models.',
},
array: { array: {
color: 'gray', color: 'gray',
colorCssVar: getColorTokenCssVariable('gray'), colorCssVar: getColorTokenCssVariable('gray'),

View File

@ -1,4 +1,5 @@
import { import {
ControlNetModelParam,
LoRAModelParam, LoRAModelParam,
MainModelParam, MainModelParam,
VaeModelParam, VaeModelParam,
@ -71,6 +72,7 @@ export type FieldType =
| 'model' | 'model'
| 'vae_model' | 'vae_model'
| 'lora_model' | 'lora_model'
| 'controlnet_model'
| 'array' | 'array'
| 'item' | 'item'
| 'color' | 'color'
@ -100,6 +102,7 @@ export type InputFieldValue =
| MainModelInputFieldValue | MainModelInputFieldValue
| VaeModelInputFieldValue | VaeModelInputFieldValue
| LoRAModelInputFieldValue | LoRAModelInputFieldValue
| ControlNetModelInputFieldValue
| ArrayInputFieldValue | ArrayInputFieldValue
| ItemInputFieldValue | ItemInputFieldValue
| ColorInputFieldValue | ColorInputFieldValue
@ -127,6 +130,7 @@ export type InputFieldTemplate =
| ModelInputFieldTemplate | ModelInputFieldTemplate
| VaeModelInputFieldTemplate | VaeModelInputFieldTemplate
| LoRAModelInputFieldTemplate | LoRAModelInputFieldTemplate
| ControlNetModelInputFieldTemplate
| ArrayInputFieldTemplate | ArrayInputFieldTemplate
| ItemInputFieldTemplate | ItemInputFieldTemplate
| ColorInputFieldTemplate | ColorInputFieldTemplate
@ -249,6 +253,11 @@ export type LoRAModelInputFieldValue = FieldValueBase & {
value?: LoRAModelParam; value?: LoRAModelParam;
}; };
export type ControlNetModelInputFieldValue = FieldValueBase & {
type: 'controlnet_model';
value?: ControlNetModelParam;
};
export type ArrayInputFieldValue = FieldValueBase & { export type ArrayInputFieldValue = FieldValueBase & {
type: 'array'; type: 'array';
value?: (string | number)[]; value?: (string | number)[];
@ -368,6 +377,11 @@ export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'lora_model'; type: 'lora_model';
}; };
export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'controlnet_model';
};
export type ArrayInputFieldTemplate = InputFieldTemplateBase & { export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
default: []; default: [];
type: 'array'; type: 'array';

View File

@ -9,6 +9,7 @@ import {
ColorInputFieldTemplate, ColorInputFieldTemplate,
ConditioningInputFieldTemplate, ConditioningInputFieldTemplate,
ControlInputFieldTemplate, ControlInputFieldTemplate,
ControlNetModelInputFieldTemplate,
EnumInputFieldTemplate, EnumInputFieldTemplate,
FieldType, FieldType,
FloatInputFieldTemplate, FloatInputFieldTemplate,
@ -207,6 +208,21 @@ const buildLoRAModelInputFieldTemplate = ({
return template; return template;
}; };
const buildControlNetModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ControlNetModelInputFieldTemplate => {
const template: ControlNetModelInputFieldTemplate = {
...baseField,
type: 'controlnet_model',
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageInputFieldTemplate = ({ const buildImageInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -479,6 +495,9 @@ export const buildInputFieldTemplate = (
if (['lora_model'].includes(fieldType)) { if (['lora_model'].includes(fieldType)) {
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField }); return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
} }
if (['controlnet_model'].includes(fieldType)) {
return buildControlNetModelInputFieldTemplate({ schemaObject, baseField });
}
if (['enum'].includes(fieldType)) { if (['enum'].includes(fieldType)) {
return buildEnumInputFieldTemplate({ schemaObject, baseField }); return buildEnumInputFieldTemplate({ schemaObject, baseField });
} }

View File

@ -83,6 +83,10 @@ export const buildInputFieldValue = (
if (template.type === 'lora_model') { if (template.type === 'lora_model') {
fieldValue.value = undefined; fieldValue.value = undefined;
} }
if (template.type === 'controlnet_model') {
fieldValue.value = undefined;
}
} }
return fieldValue; return fieldValue;

View File

@ -60,7 +60,7 @@ export const addLoRAsToGraph = (
const loraLoaderNode: LoraLoaderInvocation = { const loraLoaderNode: LoraLoaderInvocation = {
type: 'lora_loader', type: 'lora_loader',
id: currentLoraNodeId, id: currentLoraNodeId,
lora, lora: { model_name, base_model },
weight, weight,
}; };

View File

@ -2,12 +2,13 @@ import { Divider, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
import IAIIconButton from 'common/components/IAIIconButton';
import ControlNet from 'features/controlNet/components/ControlNet'; import ControlNet from 'features/controlNet/components/ControlNet';
import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle'; import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle';
import { import {
controlNetAdded, controlNetAdded,
controlNetModelChanged,
controlNetSelector, controlNetSelector,
} from 'features/controlNet/store/controlNetSlice'; } from 'features/controlNet/store/controlNetSlice';
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets'; import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
@ -15,6 +16,8 @@ import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { Fragment, memo, useCallback } from 'react'; import { Fragment, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaPlus } from 'react-icons/fa';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
const selector = createSelector( const selector = createSelector(
@ -39,10 +42,23 @@ const ParamControlNetCollapse = () => {
const { controlNetsArray, activeLabel } = useAppSelector(selector); const { controlNetsArray, activeLabel } = useAppSelector(selector);
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled; const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { firstModel } = useGetControlNetModelsQuery(undefined, {
selectFromResult: (result) => {
const firstModel = result.data?.entities[result.data?.ids[0]];
return {
firstModel,
};
},
});
const handleClickedAddControlNet = useCallback(() => { const handleClickedAddControlNet = useCallback(() => {
dispatch(controlNetAdded({ controlNetId: uuidv4() })); if (!firstModel) {
}, [dispatch]); return;
}
const controlNetId = uuidv4();
dispatch(controlNetAdded({ controlNetId }));
dispatch(controlNetModelChanged({ controlNetId, model: firstModel }));
}, [dispatch, firstModel]);
if (isControlNetDisabled) { if (isControlNetDisabled) {
return null; return null;
@ -51,16 +67,39 @@ const ParamControlNetCollapse = () => {
return ( return (
<IAICollapse label="ControlNet" activeLabel={activeLabel}> <IAICollapse label="ControlNet" activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 3 }}> <Flex sx={{ flexDir: 'column', gap: 3 }}>
<ParamControlNetFeatureToggle /> <Flex gap={2} alignItems="center">
<Flex
sx={{
flexDirection: 'column',
w: '100%',
gap: 2,
px: 4,
py: 2,
borderRadius: 4,
bg: 'base.200',
_dark: {
bg: 'base.850',
},
}}
>
<ParamControlNetFeatureToggle />
</Flex>
<IAIIconButton
tooltip="Add ControlNet"
aria-label="Add ControlNet"
icon={<FaPlus />}
isDisabled={!firstModel}
flexGrow={1}
size="md"
onClick={handleClickedAddControlNet}
/>
</Flex>
{controlNetsArray.map((c, i) => ( {controlNetsArray.map((c, i) => (
<Fragment key={c.controlNetId}> <Fragment key={c.controlNetId}>
{i > 0 && <Divider />} {i > 0 && <Divider />}
<ControlNet controlNet={c} /> <ControlNet controlNetId={c.controlNetId} />
</Fragment> </Fragment>
))} ))}
<IAIButton flexGrow={1} onClick={handleClickedAddControlNet}>
Add ControlNet
</IAIButton>
</Flex> </Flex>
</IAICollapse> </IAICollapse>
); );

View File

@ -37,6 +37,7 @@ const ParamVAEModelSelect = () => {
return []; return [];
} }
// add a "default" option, this means use the main model's included VAE
const data: SelectItem[] = [ const data: SelectItem[] = [
{ {
value: 'default', value: 'default',

View File

@ -17,8 +17,10 @@ import { FaPlay } from 'react-icons/fa';
const IN_PROGRESS_STYLES: ChakraProps['sx'] = { const IN_PROGRESS_STYLES: ChakraProps['sx'] = {
_disabled: { _disabled: {
bg: 'none', bg: 'none',
color: 'base.600',
cursor: 'not-allowed', cursor: 'not-allowed',
_hover: { _hover: {
color: 'base.600',
bg: 'none', bg: 'none',
}, },
}, },

View File

@ -180,6 +180,23 @@ export type LoRAModelParam = z.infer<typeof zLoRAModel>;
*/ */
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam => export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
zLoRAModel.safeParse(val).success; zLoRAModel.safeParse(val).success;
/**
* Zod schema for ControlNet models
*/
export const zControlNetModel = z.object({
model_name: z.string().min(1),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type ControlNetModelParam = z.infer<typeof zLoRAModel>;
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidControlNetModel = (
val: unknown
): val is ControlNetModelParam => zControlNetModel.safeParse(val).success;
/** /**
* Zod schema for l2l strength parameter * Zod schema for l2l strength parameter

View File

@ -0,0 +1,30 @@
import { log } from 'app/logging/useLogger';
import { zControlNetModel } from 'features/parameters/types/parameterSchemas';
import { ControlNetModelField } from 'services/api/types';
const moduleLog = log.child({ module: 'models' });
export const modelIdToControlNetModelParam = (
controlNetModelId: string
): ControlNetModelField | undefined => {
const [base_model, model_type, model_name] = controlNetModelId.split('/');
const result = zControlNetModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
moduleLog.error(
{
controlNetModelId,
errors: result.error.format(),
},
'Failed to parse ControlNet model id'
);
return;
}
return result.data;
};

View File

@ -1,9 +1,12 @@
import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas'; import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToLoRAModelParam = ( export const modelIdToLoRAModelParam = (
loraId: string loraModelId: string
): LoRAModelParam | undefined => { ): LoRAModelParam | undefined => {
const [base_model, model_type, model_name] = loraId.split('/'); const [base_model, model_type, model_name] = loraModelId.split('/');
const result = zLoRAModel.safeParse({ const result = zLoRAModel.safeParse({
base_model, base_model,
@ -11,6 +14,13 @@ export const modelIdToLoRAModelParam = (
}); });
if (!result.success) { if (!result.success) {
moduleLog.error(
{
loraModelId,
errors: result.error.format(),
},
'Failed to parse LoRA model id'
);
return; return;
} }

View File

@ -2,11 +2,14 @@ import {
MainModelParam, MainModelParam,
zMainModel, zMainModel,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToMainModelParam = ( export const modelIdToMainModelParam = (
modelId: string mainModelId: string
): MainModelParam | undefined => { ): MainModelParam | undefined => {
const [base_model, model_type, model_name] = modelId.split('/'); const [base_model, model_type, model_name] = mainModelId.split('/');
const result = zMainModel.safeParse({ const result = zMainModel.safeParse({
base_model, base_model,
@ -14,6 +17,13 @@ export const modelIdToMainModelParam = (
}); });
if (!result.success) { if (!result.success) {
moduleLog.error(
{
mainModelId,
errors: result.error.format(),
},
'Failed to parse main model id'
);
return; return;
} }

View File

@ -1,9 +1,12 @@
import { VaeModelParam, zVaeModel } from '../types/parameterSchemas'; import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ module: 'models' });
export const modelIdToVAEModelParam = ( export const modelIdToVAEModelParam = (
modelId: string vaeModelId: string
): VaeModelParam | undefined => { ): VaeModelParam | undefined => {
const [base_model, model_type, model_name] = modelId.split('/'); const [base_model, model_type, model_name] = vaeModelId.split('/');
const result = zVaeModel.safeParse({ const result = zVaeModel.safeParse({
base_model, base_model,
@ -11,6 +14,13 @@ export const modelIdToVAEModelParam = (
}); });
if (!result.success) { if (!result.success) {
moduleLog.error(
{
vaeModelId,
errors: result.error.format(),
},
'Failed to parse VAE model id'
);
return; return;
} }

View File

@ -19,9 +19,9 @@ const ImageToImageTabParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<ImageToImageTabCoreParameters /> <ImageToImageTabCoreParameters />
<ParamControlNetCollapse />
<ParamLoraCollapse /> <ParamLoraCollapse />
<ParamDynamicPromptsCollapse /> <ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamNoiseCollapse /> <ParamNoiseCollapse />
<ParamSymmetryCollapse /> <ParamSymmetryCollapse />

View File

@ -20,9 +20,9 @@ const TextToImageTabParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<TextToImageTabCoreParameters /> <TextToImageTabCoreParameters />
<ParamControlNetCollapse />
<ParamLoraCollapse /> <ParamLoraCollapse />
<ParamDynamicPromptsCollapse /> <ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamNoiseCollapse /> <ParamNoiseCollapse />
<ParamSymmetryCollapse /> <ParamSymmetryCollapse />

View File

@ -19,9 +19,9 @@ const UnifiedCanvasParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<UnifiedCanvasCoreParameters /> <UnifiedCanvasCoreParameters />
<ParamControlNetCollapse />
<ParamLoraCollapse /> <ParamLoraCollapse />
<ParamDynamicPromptsCollapse /> <ParamDynamicPromptsCollapse />
<ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamSymmetryCollapse /> <ParamSymmetryCollapse />
<ParamSeamCorrectionCollapse /> <ParamSeamCorrectionCollapse />

View File

@ -734,7 +734,7 @@ export type components = {
* Control Model * Control Model
* @description The ControlNet model to use * @description The ControlNet model to use
*/ */
control_model: string; control_model: components["schemas"]["ControlNetModelField"];
/** /**
* Control Weight * Control Weight
* @description The weight given to the ControlNet * @description The weight given to the ControlNet
@ -792,9 +792,8 @@ export type components = {
* Control Model * Control Model
* @description control model used * @description control model used
* @default lllyasviel/sd-controlnet-canny * @default lllyasviel/sd-controlnet-canny
* @enum {string}
*/ */
control_model?: "lllyasviel/sd-controlnet-canny" | "lllyasviel/sd-controlnet-depth" | "lllyasviel/sd-controlnet-hed" | "lllyasviel/sd-controlnet-seg" | "lllyasviel/sd-controlnet-openpose" | "lllyasviel/sd-controlnet-scribble" | "lllyasviel/sd-controlnet-normal" | "lllyasviel/sd-controlnet-mlsd" | "lllyasviel/control_v11p_sd15_canny" | "lllyasviel/control_v11p_sd15_openpose" | "lllyasviel/control_v11p_sd15_seg" | "lllyasviel/control_v11f1p_sd15_depth" | "lllyasviel/control_v11p_sd15_normalbae" | "lllyasviel/control_v11p_sd15_scribble" | "lllyasviel/control_v11p_sd15_mlsd" | "lllyasviel/control_v11p_sd15_softedge" | "lllyasviel/control_v11p_sd15s2_lineart_anime" | "lllyasviel/control_v11p_sd15_lineart" | "lllyasviel/control_v11p_sd15_inpaint" | "lllyasviel/control_v11e_sd15_shuffle" | "lllyasviel/control_v11e_sd15_ip2p" | "lllyasviel/control_v11f1e_sd15_tile" | "thibaud/controlnet-sd21-openpose-diffusers" | "thibaud/controlnet-sd21-canny-diffusers" | "thibaud/controlnet-sd21-depth-diffusers" | "thibaud/controlnet-sd21-scribble-diffusers" | "thibaud/controlnet-sd21-hed-diffusers" | "thibaud/controlnet-sd21-zoedepth-diffusers" | "thibaud/controlnet-sd21-color-diffusers" | "thibaud/controlnet-sd21-openposev2-diffusers" | "thibaud/controlnet-sd21-lineart-diffusers" | "thibaud/controlnet-sd21-normalbae-diffusers" | "thibaud/controlnet-sd21-ade20k-diffusers" | "CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15" | "CrucibleAI/ControlNetMediaPipeFace"; control_model?: components["schemas"]["ControlNetModelField"];
/** /**
* Control Weight * Control Weight
* @description The weight given to the ControlNet * @description The weight given to the ControlNet
@ -838,6 +837,19 @@ export type components = {
model_format: components["schemas"]["ControlNetModelFormat"]; model_format: components["schemas"]["ControlNetModelFormat"];
error?: components["schemas"]["ModelError"]; error?: components["schemas"]["ModelError"];
}; };
/**
* ControlNetModelField
* @description ControlNet model field
*/
ControlNetModelField: {
/**
* Model Name
* @description Name of the ControlNet model
*/
model_name: string;
/** @description Base model */
base_model: components["schemas"]["BaseModelType"];
};
/** /**
* ControlNetModelFormat * ControlNetModelFormat
* @description An enumeration. * @description An enumeration.
@ -1923,12 +1935,12 @@ export type components = {
* Width * Width
* @description The width to resize to (px) * @description The width to resize to (px)
*/ */
width: number; width?: number;
/** /**
* Height * Height
* @description The height to resize to (px) * @description The height to resize to (px)
*/ */
height: number; height?: number;
/** /**
* Resample Mode * Resample Mode
* @description The resampling mode * @description The resampling mode
@ -3911,13 +3923,15 @@ export type components = {
/** /**
* Width * Width
* @description The width to resize to (px) * @description The width to resize to (px)
* @default 512
*/ */
width: number; width?: number;
/** /**
* Height * Height
* @description The height to resize to (px) * @description The height to resize to (px)
* @default 512
*/ */
height: number; height?: number;
/** /**
* Mode * Mode
* @description The interpolation mode * @description The interpolation mode
@ -4605,18 +4619,18 @@ export type components = {
*/ */
image?: components["schemas"]["ImageField"]; image?: components["schemas"]["ImageField"];
}; };
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusion1ModelFormat * StableDiffusion1ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;

View File

@ -32,6 +32,8 @@ export type BaseModelType = components['schemas']['BaseModelType'];
export type MainModelField = components['schemas']['MainModelField']; export type MainModelField = components['schemas']['MainModelField'];
export type VAEModelField = components['schemas']['VAEModelField']; export type VAEModelField = components['schemas']['VAEModelField'];
export type LoRAModelField = components['schemas']['LoRAModelField']; export type LoRAModelField = components['schemas']['LoRAModelField'];
export type ControlNetModelField =
components['schemas']['ControlNetModelField'];
export type ModelsList = components['schemas']['ModelsList']; export type ModelsList = components['schemas']['ModelsList'];
export type ControlField = components['schemas']['ControlField']; export type ControlField = components['schemas']['ControlField'];

View File

@ -30,7 +30,7 @@ const invokeAIThumb = defineStyle((props) => {
const invokeAIMark = defineStyle((props) => { const invokeAIMark = defineStyle((props) => {
return { return {
fontSize: 'xs', fontSize: '2xs',
fontWeight: '500', fontWeight: '500',
color: mode('base.700', 'base.400')(props), color: mode('base.700', 'base.400')(props),
mt: 2, mt: 2,