mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/default-model-install
This commit is contained in:
commit
32e7e52d69
@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
|
||||
|
||||
|
||||
import pathlib
|
||||
from typing import Literal, List, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
@ -22,6 +23,7 @@ UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
@ -78,7 +80,7 @@ async def update_model(
|
||||
return model_response
|
||||
|
||||
@models_router.post(
|
||||
"/",
|
||||
"/import",
|
||||
operation_id="import_model",
|
||||
responses= {
|
||||
201: {"description" : "The model imported successfully"},
|
||||
@ -94,7 +96,7 @@ async def import_model(
|
||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using its local path, repo_id, or remote URL """
|
||||
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
|
||||
|
||||
items_to_import = {location}
|
||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||
@ -126,18 +128,100 @@ async def import_model(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
@models_router.post(
|
||||
"/add",
|
||||
operation_id="add_model",
|
||||
responses= {
|
||||
201: {"description" : "The model added successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
|
||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
)
|
||||
async def add_model(
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
info.model_name,
|
||||
info.base_model,
|
||||
info.model_type,
|
||||
model_attributes = info.dict()
|
||||
)
|
||||
logger.info(f'Successfully added {info.model_name}')
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.model_name,
|
||||
base_model=info.base_model,
|
||||
model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
except KeyError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
@models_router.post(
|
||||
"/rename/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="rename_model",
|
||||
responses= {
|
||||
201: {"description" : "The model was renamed successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
409: {"description" : "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
)
|
||||
async def rename_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="current model name"),
|
||||
new_name: Optional[str] = Query(description="new model name", default=None),
|
||||
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
|
||||
) -> ImportModelResponse:
|
||||
""" Rename a model"""
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.model_manager.rename_model(
|
||||
base_model = base_model,
|
||||
model_type = model_type,
|
||||
model_name = model_name,
|
||||
new_name = new_name,
|
||||
new_base = new_base,
|
||||
)
|
||||
logger.debug(result)
|
||||
logger.info(f'Successfully renamed {model_name}=>{new_name}')
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=new_name or model_name,
|
||||
base_model=new_base or base_model,
|
||||
model_type=model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
except KeyError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
@models_router.delete(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="del_model",
|
||||
responses={
|
||||
204: {
|
||||
"description": "Model deleted successfully"
|
||||
},
|
||||
404: {
|
||||
"description": "Model not found"
|
||||
}
|
||||
204: { "description": "Model deleted successfully" },
|
||||
404: { "description": "Model not found" }
|
||||
},
|
||||
status_code = 204,
|
||||
response_model = None,
|
||||
)
|
||||
async def delete_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
@ -173,14 +257,17 @@ async def convert_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
|
||||
) -> ConvertModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Converting model: {model_name}")
|
||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type
|
||||
model_type = model_type,
|
||||
convert_dest_directory = dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
|
||||
base_model = base_model,
|
||||
@ -191,6 +278,53 @@ async def convert_model(
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
||||
@models_router.get(
|
||||
"/search",
|
||||
operation_id="search_for_models",
|
||||
responses={
|
||||
200: { "description": "Directory searched successfully" },
|
||||
404: { "description": "Invalid directory path" },
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = List[pathlib.Path]
|
||||
)
|
||||
async def search_for_models(
|
||||
search_path: pathlib.Path = Query(description="Directory path to search for models")
|
||||
)->List[pathlib.Path]:
|
||||
if not search_path.is_dir():
|
||||
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
|
||||
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
|
||||
|
||||
@models_router.get(
|
||||
"/ckpt_confs",
|
||||
operation_id="list_ckpt_configs",
|
||||
responses={
|
||||
200: { "description" : "paths retrieved successfully" },
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = List[pathlib.Path]
|
||||
)
|
||||
async def list_ckpt_configs(
|
||||
)->List[pathlib.Path]:
|
||||
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
||||
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/sync",
|
||||
operation_id="sync_to_config",
|
||||
responses={
|
||||
201: { "description": "synchronization successful" },
|
||||
},
|
||||
status_code = 201,
|
||||
response_model = None
|
||||
)
|
||||
async def sync_to_config(
|
||||
)->None:
|
||||
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||
in-memory data structures with disk data structures."""
|
||||
return ApiDependencies.invoker.services.model_manager.sync_to_config()
|
||||
|
||||
@models_router.put(
|
||||
"/merge/{base_model}",
|
||||
@ -210,17 +344,21 @@ async def merge_models(
|
||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
|
||||
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
|
||||
) -> MergeModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {model_names}")
|
||||
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
||||
base_model,
|
||||
merged_model_name or "+".join(model_names),
|
||||
alpha,
|
||||
interp,
|
||||
force)
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory = dest
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
||||
base_model = base_model,
|
||||
model_type = ModelType.Main,
|
||||
|
@ -100,7 +100,7 @@ class CompelInvocation(BaseInvocation):
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
@ -112,9 +112,6 @@ class CompelInvocation(BaseInvocation):
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(
|
||||
prompt)
|
||||
|
||||
# TODO: long prompt support
|
||||
# if not self.truncate_long_prompts:
|
||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(
|
||||
tokenizer, conjunction),
|
||||
|
@ -9,6 +9,7 @@ from typing import Literal, Optional, Union, List, Dict
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType
|
||||
from ..models.image import ImageField, ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -105,9 +106,15 @@ CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control
|
||||
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
|
||||
|
||||
|
||||
class ControlNetModelField(BaseModel):
|
||||
"""ControlNet model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the ControlNet model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
|
||||
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
|
||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
@ -118,15 +125,15 @@ class ControlField(BaseModel):
|
||||
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@validator("control_weight")
|
||||
def abs_le_one(cls, v):
|
||||
"""validate that all abs(values) are <=1"""
|
||||
def validate_control_weight(cls, v):
|
||||
"""Validate that all control weights in the valid range"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if abs(i) > 1:
|
||||
raise ValueError('all abs(control_weight) must be <= 1')
|
||||
if i < -1 or i > 2:
|
||||
raise ValueError('Control weights must be within -1 to 2 range')
|
||||
else:
|
||||
if abs(v) > 1:
|
||||
raise ValueError('abs(control_weight) must be <= 1')
|
||||
if v < -1 or v > 2:
|
||||
raise ValueError('Control weights must be within -1 to 2 range')
|
||||
return v
|
||||
class Config:
|
||||
schema_extra = {
|
||||
@ -134,6 +141,7 @@ class ControlField(BaseModel):
|
||||
"ui": {
|
||||
"type_hints": {
|
||||
"control_weight": "float",
|
||||
"control_model": "controlnet_model",
|
||||
# "control_weight": "number",
|
||||
}
|
||||
}
|
||||
@ -154,10 +162,10 @@ class ControlNetInvocation(BaseInvocation):
|
||||
type: Literal["controlnet"] = "controlnet"
|
||||
# Inputs
|
||||
image: ImageField = Field(default=None, description="The control image")
|
||||
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
||||
control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
|
||||
description="control model used")
|
||||
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||
begin_step_percent: float = Field(default=0, ge=-1, le=2,
|
||||
description="When the ControlNet is first applied (% of total steps)")
|
||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||
description="When the ControlNet is last applied (% of total steps)")
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from contextlib import ExitStack
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
@ -11,6 +12,7 @@ from pydantic import BaseModel, Field, validator
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.models.base import ModelType
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
@ -71,16 +73,21 @@ def get_scheduler(
|
||||
scheduler_name: str,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
|
||||
scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
scheduler_name, SCHEDULER_MAP['ddim']
|
||||
)
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
**scheduler_info.dict())
|
||||
**scheduler_info.dict()
|
||||
)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **
|
||||
scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler_config = {
|
||||
**scheduler_config,
|
||||
**scheduler_extra_config,
|
||||
"_backup": scheduler_config,
|
||||
}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
@ -137,8 +144,11 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState) -> None:
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
@ -147,11 +157,16 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def get_conditioning_data(
|
||||
self, context: InvocationContext, scheduler) -> ConditioningData:
|
||||
self,
|
||||
context: InvocationContext,
|
||||
scheduler,
|
||||
) -> ConditioningData:
|
||||
c, extra_conditioning_info = context.services.latents.get(
|
||||
self.positive_conditioning.conditioning_name)
|
||||
self.positive_conditioning.conditioning_name
|
||||
)
|
||||
uc, _ = context.services.latents.get(
|
||||
self.negative_conditioning.conditioning_name)
|
||||
self.negative_conditioning.conditioning_name
|
||||
)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
unconditioned_embeddings=uc,
|
||||
@ -178,7 +193,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
return conditioning_data
|
||||
|
||||
def create_pipeline(
|
||||
self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
||||
self,
|
||||
unet,
|
||||
scheduler,
|
||||
) -> StableDiffusionGeneratorPipeline:
|
||||
# TODO:
|
||||
# configure_model_padding(
|
||||
# unet,
|
||||
@ -213,6 +231,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
model: StableDiffusionGeneratorPipeline,
|
||||
control_input: List[ControlField],
|
||||
latents_shape: List[int],
|
||||
exit_stack: ExitStack,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[ControlNetData]:
|
||||
|
||||
@ -238,25 +257,19 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
control_data = []
|
||||
control_models = []
|
||||
for control_info in control_list:
|
||||
# handle control models
|
||||
if ("," in control_info.control_model):
|
||||
control_model_split = control_info.control_model.split(",")
|
||||
control_name = control_model_split[0]
|
||||
control_subfolder = control_model_split[1]
|
||||
print("Using HF model subfolders")
|
||||
print(" control_name: ", control_name)
|
||||
print(" control_subfolder: ", control_subfolder)
|
||||
control_model = ControlNetModel.from_pretrained(
|
||||
control_name, subfolder=control_subfolder,
|
||||
torch_dtype=model.unet.dtype).to(
|
||||
model.device)
|
||||
else:
|
||||
control_model = ControlNetModel.from_pretrained(
|
||||
control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
|
||||
control_model = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=control_info.control_model.model_name,
|
||||
model_type=ModelType.ControlNet,
|
||||
base_model=control_info.control_model.base_model,
|
||||
)
|
||||
)
|
||||
|
||||
control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(
|
||||
control_image_field.image_name)
|
||||
control_image_field.image_name
|
||||
)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
@ -278,7 +291,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent,
|
||||
control_mode=control_info.control_mode,)
|
||||
control_mode=control_info.control_mode,
|
||||
)
|
||||
control_data.append(control_item)
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
return control_data
|
||||
@ -289,7 +303,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id)
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
@ -298,14 +313,17 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
**lora.dict(exclude={"weight"})
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict())
|
||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
**self.unet.unet.dict()
|
||||
)
|
||||
with ExitStack() as exit_stack,\
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
|
||||
scheduler = get_scheduler(
|
||||
@ -322,6 +340,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
@ -374,7 +393,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id)
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
@ -383,14 +403,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
**lora.dict(exclude={"weight"})
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict())
|
||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
**self.unet.unet.dict()
|
||||
)
|
||||
with ExitStack() as exit_stack,\
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
|
||||
scheduler = get_scheduler(
|
||||
@ -407,11 +430,13 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
latents_shape=noise.shape,
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
latent, device=unet.device, dtype=latent.dtype)
|
||||
latent, device=unet.device, dtype=latent.dtype
|
||||
)
|
||||
|
||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||
self.steps,
|
||||
@ -535,7 +560,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents, size=(self.height // 8, self.width // 8),
|
||||
mode=self.mode, antialias=self.antialias
|
||||
if self.mode in ["bilinear", "bicubic"] else False,)
|
||||
if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
@ -569,7 +595,8 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents, scale_factor=self.scale_factor, mode=self.mode,
|
||||
antialias=self.antialias
|
||||
if self.mode in ["bilinear", "bicubic"] else False,)
|
||||
if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -19,7 +19,7 @@ from invokeai.backend.model_management import (
|
||||
ModelMerger,
|
||||
MergeInterpolationMethod,
|
||||
)
|
||||
|
||||
from invokeai.backend.model_management.model_search import FindModels
|
||||
|
||||
import torch
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
@ -167,6 +167,27 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rename_model(self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_checkpoint_configs(
|
||||
self
|
||||
)->List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
@ -220,6 +241,7 @@ class ModelManagerServiceBase(ABC):
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
@ -228,9 +250,26 @@ class ModelManagerServiceBase(ABC):
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def search_for_models(self, directory: Path)->List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
@ -431,16 +470,18 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
as well.
|
||||
"""
|
||||
self.logger.debug(f'delete model {model_name}')
|
||||
self.mgr.del_model(model_name, base_model, model_type)
|
||||
self.mgr.commit()
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
@ -449,13 +490,14 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
self.logger.debug(f'convert model {model_name}')
|
||||
return self.mgr.convert_model(model_name, base_model, model_type)
|
||||
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
||||
|
||||
def commit(self, conf_file: Optional[Path]=None):
|
||||
"""
|
||||
@ -536,6 +578,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
@ -544,6 +587,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
merger = ModelMerger(self.mgr)
|
||||
try:
|
||||
@ -554,7 +598,55 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
alpha = alpha,
|
||||
interp = interp,
|
||||
force = force,
|
||||
merge_dest_directory=merge_dest_directory,
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise ValueError(e)
|
||||
return result
|
||||
|
||||
def search_for_models(self, directory: Path)->List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
search = FindModels(directory,self.logger)
|
||||
return search.list_models()
|
||||
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
return self.mgr.sync_to_config()
|
||||
|
||||
def list_checkpoint_configs(self)->List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
config = self.mgr.app_config
|
||||
conf_path = config.legacy_conf_path
|
||||
root_path = config.root_path
|
||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')]
|
||||
|
||||
def rename_model(self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str = None,
|
||||
new_base: BaseModelType = None,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model. Can provide a new name and/or a new base.
|
||||
:param model_name: Current name of the model
|
||||
:param base_model: Current base of the model
|
||||
:param model_type: Model type (can't be changed)
|
||||
:param new_name: New name for the model
|
||||
:param new_base: New base for the model
|
||||
"""
|
||||
self.mgr.rename_model(base_model = base_model,
|
||||
model_type = model_type,
|
||||
model_name = model_name,
|
||||
new_name = new_name,
|
||||
new_base = new_base,
|
||||
)
|
||||
|
||||
|
@ -71,8 +71,6 @@ class ModelInstallList:
|
||||
class InstallSelections():
|
||||
install_models: List[str]= field(default_factory=list)
|
||||
remove_models: List[str]=field(default_factory=list)
|
||||
# scan_directory: Path = None
|
||||
# autoscan_on_startup: bool=False
|
||||
|
||||
@dataclass
|
||||
class ModelLoadInfo():
|
||||
|
@ -247,6 +247,7 @@ import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util import CUDA_DEVICE, Chdir
|
||||
from .model_cache import ModelCache, ModelLocker
|
||||
from .model_search import ModelSearch
|
||||
from .models import (
|
||||
BaseModelType, ModelType, SubModelType,
|
||||
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
||||
@ -322,16 +323,7 @@ class ModelManager(object):
|
||||
self.config_meta = ConfigMeta(**config.pop("__metadata__"))
|
||||
# TODO: metadata not found
|
||||
# TODO: version check
|
||||
|
||||
self.models = dict()
|
||||
for model_key, model_config in config.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
# alias for config file
|
||||
model_config["model_format"] = model_config.pop("format")
|
||||
self.models[model_key] = model_class.create_config(**model_config)
|
||||
|
||||
# check config version number and update on disk/RAM if necessary
|
||||
|
||||
self.app_config = InvokeAIAppConfig.get_config()
|
||||
self.logger = logger
|
||||
self.cache = ModelCache(
|
||||
@ -342,11 +334,41 @@ class ModelManager(object):
|
||||
sequential_offload = sequential_offload,
|
||||
logger = logger,
|
||||
)
|
||||
|
||||
self._read_models(config)
|
||||
|
||||
def _read_models(self, config: Optional[DictConfig] = None):
|
||||
if not config:
|
||||
if self.config_path:
|
||||
config = OmegaConf.load(self.config_path)
|
||||
else:
|
||||
return
|
||||
|
||||
self.models = dict()
|
||||
for model_key, model_config in config.items():
|
||||
if model_key.startswith('_'):
|
||||
continue
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
# alias for config file
|
||||
model_config["model_format"] = model_config.pop("format")
|
||||
self.models[model_key] = model_class.create_config(**model_config)
|
||||
|
||||
# check config version number and update on disk/RAM if necessary
|
||||
self.cache_keys = dict()
|
||||
|
||||
# add controlnet, lora and textual_inversion models from disk
|
||||
self.scan_models_directory()
|
||||
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Call this when `models.yaml` has been changed externally.
|
||||
This will reinitialize internal data structures
|
||||
"""
|
||||
# Reread models directory; note that this will reinitialize the cache,
|
||||
# causing otherwise unreferenced models to be removed from memory
|
||||
self._read_models()
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -527,7 +549,10 @@ class ModelManager(object):
|
||||
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
|
||||
models = []
|
||||
for model_key in model_keys:
|
||||
model_config = self.models[model_key]
|
||||
model_config = self.models.get(model_key)
|
||||
if not model_config:
|
||||
self.logger.error(f'Unknown model {model_name}')
|
||||
raise KeyError(f'Unknown model {model_name}')
|
||||
|
||||
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||
if base_model is not None and cur_base_model != base_model:
|
||||
@ -646,11 +671,61 @@ class ModelManager(object):
|
||||
config = model_config,
|
||||
)
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str = None,
|
||||
new_base: BaseModelType = None,
|
||||
):
|
||||
'''
|
||||
Rename or rebase a model.
|
||||
'''
|
||||
if new_name is None and new_base is None:
|
||||
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
|
||||
return
|
||||
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
model_cfg = self.models.get(model_key, None)
|
||||
if not model_cfg:
|
||||
raise KeyError(f"Unknown model: {model_key}")
|
||||
|
||||
old_path = self.app_config.root_path / model_cfg.path
|
||||
new_name = new_name or model_name
|
||||
new_base = new_base or base_model
|
||||
new_key = self.create_key(new_name, new_base, model_type)
|
||||
if new_key in self.models:
|
||||
raise ValueError(f'Attempt to overwrite existing model definition "{new_key}"')
|
||||
|
||||
# if this is a model file/directory that we manage ourselves, we need to move it
|
||||
if old_path.is_relative_to(self.app_config.models_path):
|
||||
new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name
|
||||
move(old_path, new_path)
|
||||
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
|
||||
|
||||
# clean up caches
|
||||
old_model_cache = self._get_model_cache_path(old_path)
|
||||
if old_model_cache.exists():
|
||||
if old_model_cache.is_dir():
|
||||
rmtree(str(old_model_cache))
|
||||
else:
|
||||
old_model_cache.unlink()
|
||||
|
||||
cache_ids = self.cache_keys.pop(model_key, [])
|
||||
for cache_id in cache_ids:
|
||||
self.cache.uncache_model(cache_id)
|
||||
|
||||
self.models.pop(model_key, None) # delete
|
||||
self.models[new_key] = model_cfg
|
||||
self.commit()
|
||||
|
||||
def convert_model (
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||
dest_directory: Optional[Path]=None,
|
||||
) -> AddModelResult:
|
||||
'''
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
@ -677,14 +752,14 @@ class ModelManager(object):
|
||||
)
|
||||
checkpoint_path = self.app_config.root_path / info["path"]
|
||||
old_diffusers_path = self.app_config.models_path / model.location
|
||||
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
|
||||
new_diffusers_path = (dest_directory or self.app_config.models_path / base_model.value / model_type.value) / model_name
|
||||
if new_diffusers_path.exists():
|
||||
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
||||
|
||||
try:
|
||||
move(old_diffusers_path,new_diffusers_path)
|
||||
info["model_format"] = "diffusers"
|
||||
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path))
|
||||
info["path"] = str(new_diffusers_path) if dest_directory else str(new_diffusers_path.relative_to(self.app_config.root_path))
|
||||
info.pop('config')
|
||||
|
||||
result = self.add_model(model_name, base_model, model_type,
|
||||
@ -824,6 +899,7 @@ class ModelManager(object):
|
||||
if (new_models_found or imported_models) and self.config_path:
|
||||
self.commit()
|
||||
|
||||
|
||||
def autoimport(self)->Dict[str, AddModelResult]:
|
||||
'''
|
||||
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
||||
@ -831,62 +907,41 @@ class ModelManager(object):
|
||||
# avoid circular import
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
||||
|
||||
|
||||
class ScanAndImport(ModelSearch):
|
||||
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
|
||||
super().__init__(directories, logger)
|
||||
self.installer = installer
|
||||
self.ignore = ignore
|
||||
|
||||
def on_search_started(self):
|
||||
self.new_models_found = dict()
|
||||
|
||||
def on_model_found(self, model: Path):
|
||||
if model not in self.ignore:
|
||||
self.new_models_found.update(self.installer.heuristic_import(model))
|
||||
|
||||
def on_search_completed(self):
|
||||
self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models')
|
||||
|
||||
def models_found(self):
|
||||
return self.new_models_found
|
||||
|
||||
|
||||
installer = ModelInstall(config = self.app_config,
|
||||
model_manager = self,
|
||||
prediction_type_helper = ask_user_for_prediction_type,
|
||||
)
|
||||
|
||||
scanned_dirs = set()
|
||||
|
||||
config = self.app_config
|
||||
known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()}
|
||||
|
||||
for autodir in [config.autoimport_dir,
|
||||
config.lora_dir,
|
||||
config.embedding_dir,
|
||||
config.controlnet_dir]:
|
||||
if autodir is None:
|
||||
continue
|
||||
|
||||
installed = dict()
|
||||
|
||||
autodir = self.app_config.root_path / autodir
|
||||
if not autodir.exists():
|
||||
continue
|
||||
|
||||
items_scanned = 0
|
||||
new_models_found = dict()
|
||||
|
||||
for root, dirs, files in os.walk(autodir):
|
||||
items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path in known_paths or path.parent in scanned_dirs:
|
||||
scanned_dirs.add(path)
|
||||
continue
|
||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
||||
try:
|
||||
new_models_found.update(installer.heuristic_import(path))
|
||||
scanned_dirs.add(path)
|
||||
except ValueError as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path in known_paths or path.parent in scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||
try:
|
||||
import_result = installer.heuristic_import(path)
|
||||
new_models_found.update(import_result)
|
||||
except ValueError as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
installed.update(new_models_found)
|
||||
|
||||
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
||||
return installed
|
||||
known_paths = {config.root_path / x['path'] for x in self.list_models()}
|
||||
directories = {config.root_path / x for x in [config.autoimport_dir,
|
||||
config.lora_dir,
|
||||
config.embedding_dir,
|
||||
config.controlnet_dir]
|
||||
}
|
||||
scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
|
||||
scanner.search()
|
||||
return scanner.models_found()
|
||||
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
@ -924,3 +979,4 @@ class ModelManager(object):
|
||||
successfully_installed.update(installed)
|
||||
self.commit()
|
||||
return successfully_installed
|
||||
|
||||
|
@ -11,7 +11,7 @@ from enum import Enum
|
||||
from pathlib import Path
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
@ -74,6 +74,7 @@ class ModelMerger(object):
|
||||
alpha: float = 0.5,
|
||||
interp: MergeInterpolationMethod = None,
|
||||
force: bool = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
**kwargs,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
@ -85,7 +86,7 @@ class ModelMerger(object):
|
||||
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
@ -111,7 +112,7 @@ class ModelMerger(object):
|
||||
merged_pipe = self.merge_diffusion_models(
|
||||
model_paths, alpha, merge_method, force, **kwargs
|
||||
)
|
||||
dump_path = config.models_path / base_model.value / ModelType.Main.value
|
||||
dump_path = Path(merge_dest_directory) if merge_dest_directory else config.models_path / base_model.value / ModelType.Main.value
|
||||
dump_path.mkdir(parents=True, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
|
||||
|
103
invokeai/backend/model_management/model_search.py
Normal file
103
invokeai/backend/model_management/model_search.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
||||
"""
|
||||
Abstract base class for recursive directory search for models.
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, types
|
||||
from pathlib import Path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class ModelSearch(ABC):
|
||||
def __init__(self, directories: List[Path], logger: types.ModuleType=logger):
|
||||
"""
|
||||
Initialize a recursive model directory search.
|
||||
:param directories: List of directory Paths to recurse through
|
||||
:param logger: Logger to use
|
||||
"""
|
||||
self.directories = directories
|
||||
self.logger = logger
|
||||
self._items_scanned = 0
|
||||
self._models_found = 0
|
||||
self._scanned_dirs = set()
|
||||
self._scanned_paths = set()
|
||||
self._pruned_paths = set()
|
||||
|
||||
@abstractmethod
|
||||
def on_search_started(self):
|
||||
"""
|
||||
Called before the scan starts.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_model_found(self, model: Path):
|
||||
"""
|
||||
Process a found model. Raise an exception if something goes wrong.
|
||||
:param model: Model to process - could be a directory or checkpoint.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_search_completed(self):
|
||||
"""
|
||||
Perform some activity when the scan is completed. May use instance
|
||||
variables, items_scanned and models_found
|
||||
"""
|
||||
pass
|
||||
|
||||
def search(self):
|
||||
self.on_search_started()
|
||||
for dir in self.directories:
|
||||
self.walk_directory(dir)
|
||||
self.on_search_completed()
|
||||
|
||||
def walk_directory(self, path: Path):
|
||||
for root, dirs, files in os.walk(path):
|
||||
if str(Path(root).name).startswith('.'):
|
||||
self._pruned_paths.add(root)
|
||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||
continue
|
||||
|
||||
self._items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path in self._scanned_paths or path.parent in self._scanned_dirs:
|
||||
self._scanned_dirs.add(path)
|
||||
continue
|
||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
self._scanned_dirs.add(path)
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path.parent in self._scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
class FindModels(ModelSearch):
|
||||
def on_search_started(self):
|
||||
self.models_found: Set[Path] = set()
|
||||
|
||||
def on_model_found(self,model: Path):
|
||||
self.models_found.add(model)
|
||||
|
||||
def on_search_completed(self):
|
||||
pass
|
||||
|
||||
def list_models(self) -> List[Path]:
|
||||
self.search()
|
||||
return self.models_found
|
||||
|
||||
|
@ -48,7 +48,9 @@ for base_model, models in MODEL_CLASSES.items():
|
||||
model_configs.discard(None)
|
||||
MODEL_CONFIGS.extend(model_configs)
|
||||
|
||||
for cfg in model_configs:
|
||||
# LS: sort to get the checkpoint configs first, which makes
|
||||
# for a better template in the Swagger docs
|
||||
for cfg in sorted(model_configs, key=lambda x: str(x)):
|
||||
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
||||
openapi_cfg_name = model_name + cfg_name
|
||||
if openapi_cfg_name in vars():
|
||||
|
@ -59,7 +59,6 @@ class ModelConfigBase(BaseModel):
|
||||
path: str # or Path
|
||||
description: Optional[str] = Field(None)
|
||||
model_format: Optional[str] = Field(None)
|
||||
# do not save to config
|
||||
error: Optional[ModelError] = Field(None)
|
||||
|
||||
class Config:
|
||||
|
@ -37,8 +37,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
vae: Optional[str] = Field(None)
|
||||
config: str
|
||||
variant: ModelVariantType
|
||||
|
||||
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion1
|
||||
assert model_type == ModelType.Main
|
||||
|
@ -241,11 +241,45 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
# fast batched path
|
||||
|
||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||
|
||||
if cond.shape[1] < max_len:
|
||||
conditioning_attention_mask = torch.cat([
|
||||
conditioning_attention_mask,
|
||||
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||
], dim=1)
|
||||
|
||||
cond = torch.cat([
|
||||
cond,
|
||||
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
||||
], dim=1)
|
||||
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = conditioning_attention_mask
|
||||
else:
|
||||
encoder_attention_mask = torch.cat([
|
||||
encoder_attention_mask,
|
||||
conditioning_attention_mask,
|
||||
])
|
||||
|
||||
return cond, encoder_attention_mask
|
||||
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
encoder_attention_mask = None
|
||||
if unconditioning.shape[1] != conditioning.shape[1]:
|
||||
max_len = max(unconditioning.shape[1], conditioning.shape[1])
|
||||
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
||||
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
||||
|
||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||
x_twice, sigma_twice, both_conditionings,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
@ -13,7 +13,11 @@ import { RootState } from 'app/store/store';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'controlNet' });
|
||||
|
||||
const predicate: AnyListenerPredicate<RootState> = (action, state) => {
|
||||
const predicate: AnyListenerPredicate<RootState> = (
|
||||
action,
|
||||
state,
|
||||
prevState
|
||||
) => {
|
||||
const isActionMatched =
|
||||
controlNetProcessorParamsChanged.match(action) ||
|
||||
controlNetModelChanged.match(action) ||
|
||||
@ -25,6 +29,16 @@ const predicate: AnyListenerPredicate<RootState> = (action, state) => {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (controlNetAutoConfigToggled.match(action)) {
|
||||
// do not process if the user just disabled auto-config
|
||||
if (
|
||||
prevState.controlNet.controlNets[action.payload.controlNetId]
|
||||
.shouldAutoConfig === true
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const { controlImage, processorType, shouldAutoConfig } =
|
||||
state.controlNet.controlNets[action.payload.controlNetId];
|
||||
|
||||
|
@ -10,6 +10,7 @@ import { zMainModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { startAppListening } from '..';
|
||||
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
|
||||
|
||||
const moduleLog = log.child({ module: 'models' });
|
||||
|
||||
@ -51,7 +52,14 @@ export const addModelSelectedListener = () => {
|
||||
modelsCleared += 1;
|
||||
}
|
||||
|
||||
// TODO: handle incompatible controlnet; pending model manager support
|
||||
const { controlNets } = state.controlNet;
|
||||
forEach(controlNets, (controlNet, controlNetId) => {
|
||||
if (controlNet.model?.base_model !== base_model) {
|
||||
dispatch(controlNetRemoved({ controlNetId }));
|
||||
modelsCleared += 1;
|
||||
}
|
||||
});
|
||||
|
||||
if (modelsCleared > 0) {
|
||||
dispatch(
|
||||
addToast(
|
||||
|
@ -11,6 +11,7 @@ import {
|
||||
import { forEach, some } from 'lodash-es';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import { startAppListening } from '..';
|
||||
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
|
||||
|
||||
const moduleLog = log.child({ module: 'models' });
|
||||
|
||||
@ -127,7 +128,22 @@ export const addModelsLoadedListener = () => {
|
||||
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// ControlNet models loaded - need to remove missing ControlNets from state
|
||||
// TODO: pending model manager controlnet support
|
||||
const controlNets = getState().controlNet.controlNets;
|
||||
|
||||
forEach(controlNets, (controlNet, controlNetId) => {
|
||||
const isControlNetAvailable = some(
|
||||
action.payload.entities,
|
||||
(m) =>
|
||||
m?.model_name === controlNet?.model?.model_name &&
|
||||
m?.base_model === controlNet?.model?.base_model
|
||||
);
|
||||
|
||||
if (isControlNetAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(controlNetRemoved({ controlNetId }));
|
||||
});
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,5 +1,5 @@
|
||||
import {
|
||||
CONTROLNET_MODELS,
|
||||
// CONTROLNET_MODELS,
|
||||
CONTROLNET_PROCESSORS,
|
||||
} from 'features/controlNet/store/constants';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
@ -128,7 +128,7 @@ export type AppConfig = {
|
||||
canRestoreDeletedImagesFromBin: boolean;
|
||||
sd: {
|
||||
defaultModel?: string;
|
||||
disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[];
|
||||
disabledControlNetModels: string[];
|
||||
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
|
||||
iterations: {
|
||||
initial: number;
|
||||
|
@ -170,12 +170,14 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
</>
|
||||
)}
|
||||
{!imageDTO && isUploadDisabled && noContentFallback}
|
||||
<IAIDroppable
|
||||
data={droppableData}
|
||||
disabled={isDropDisabled}
|
||||
dropLabel={dropLabel}
|
||||
/>
|
||||
{imageDTO && (
|
||||
{!isDropDisabled && (
|
||||
<IAIDroppable
|
||||
data={droppableData}
|
||||
disabled={isDropDisabled}
|
||||
dropLabel={dropLabel}
|
||||
/>
|
||||
)}
|
||||
{imageDTO && !isDragDisabled && (
|
||||
<IAIDraggable
|
||||
data={draggableData}
|
||||
disabled={isDragDisabled || !imageDTO}
|
||||
|
@ -1,17 +1,25 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
|
||||
import { MultiSelect, MultiSelectProps } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||
import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles';
|
||||
import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
|
||||
|
||||
type IAIMultiSelectProps = MultiSelectProps & {
|
||||
type IAIMultiSelectProps = Omit<MultiSelectProps, 'label'> & {
|
||||
tooltip?: string;
|
||||
inputRef?: RefObject<HTMLInputElement>;
|
||||
label?: string;
|
||||
};
|
||||
|
||||
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
const { searchable = true, tooltip, inputRef, ...rest } = props;
|
||||
const {
|
||||
searchable = true,
|
||||
tooltip,
|
||||
inputRef,
|
||||
label,
|
||||
disabled,
|
||||
...rest
|
||||
} = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
@ -37,7 +45,15 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
|
||||
<MultiSelect
|
||||
label={
|
||||
label ? (
|
||||
<FormControl isDisabled={disabled}>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
</FormControl>
|
||||
) : undefined
|
||||
}
|
||||
ref={inputRef}
|
||||
disabled={disabled}
|
||||
onKeyDown={handleKeyDown}
|
||||
onKeyUp={handleKeyUp}
|
||||
searchable={searchable}
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
|
||||
import { Select, SelectProps } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||
@ -11,13 +11,22 @@ export type IAISelectDataType = {
|
||||
tooltip?: string;
|
||||
};
|
||||
|
||||
type IAISelectProps = SelectProps & {
|
||||
type IAISelectProps = Omit<SelectProps, 'label'> & {
|
||||
tooltip?: string;
|
||||
label?: string;
|
||||
inputRef?: RefObject<HTMLInputElement>;
|
||||
};
|
||||
|
||||
const IAIMantineSearchableSelect = (props: IAISelectProps) => {
|
||||
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props;
|
||||
const {
|
||||
searchable = true,
|
||||
tooltip,
|
||||
inputRef,
|
||||
onChange,
|
||||
label,
|
||||
disabled,
|
||||
...rest
|
||||
} = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [searchValue, setSearchValue] = useState('');
|
||||
@ -61,6 +70,14 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => {
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<Select
|
||||
ref={inputRef}
|
||||
label={
|
||||
label ? (
|
||||
<FormControl isDisabled={disabled}>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
</FormControl>
|
||||
) : undefined
|
||||
}
|
||||
disabled={disabled}
|
||||
searchValue={searchValue}
|
||||
onSearchChange={setSearchValue}
|
||||
onChange={handleChange}
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
|
||||
import { Select, SelectProps } from '@mantine/core';
|
||||
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
|
||||
import { RefObject, memo } from 'react';
|
||||
@ -9,19 +9,32 @@ export type IAISelectDataType = {
|
||||
tooltip?: string;
|
||||
};
|
||||
|
||||
type IAISelectProps = SelectProps & {
|
||||
type IAISelectProps = Omit<SelectProps, 'label'> & {
|
||||
tooltip?: string;
|
||||
inputRef?: RefObject<HTMLInputElement>;
|
||||
label?: string;
|
||||
};
|
||||
|
||||
const IAIMantineSelect = (props: IAISelectProps) => {
|
||||
const { tooltip, inputRef, ...rest } = props;
|
||||
const { tooltip, inputRef, label, disabled, ...rest } = props;
|
||||
|
||||
const styles = useMantineSelectStyles();
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||
<Select ref={inputRef} styles={styles} {...rest} />
|
||||
<Select
|
||||
label={
|
||||
label ? (
|
||||
<FormControl isDisabled={disabled}>
|
||||
<FormLabel>{label}</FormLabel>
|
||||
</FormControl>
|
||||
) : undefined
|
||||
}
|
||||
disabled={disabled}
|
||||
ref={inputRef}
|
||||
styles={styles}
|
||||
{...rest}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
@ -28,7 +28,7 @@ import {
|
||||
useState,
|
||||
} from 'react';
|
||||
|
||||
const numberStringRegex = /^-?(0\.)?\.?$/;
|
||||
export const numberStringRegex = /^-?(0\.)?\.?$/;
|
||||
|
||||
interface Props extends Omit<NumberInputProps, 'onChange'> {
|
||||
label?: string;
|
||||
|
@ -43,11 +43,6 @@ import { useTranslation } from 'react-i18next';
|
||||
import { BiReset } from 'react-icons/bi';
|
||||
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
|
||||
|
||||
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
|
||||
mt: 1.5,
|
||||
fontSize: '2xs',
|
||||
};
|
||||
|
||||
export type IAIFullSliderProps = {
|
||||
label?: string;
|
||||
value: number;
|
||||
@ -207,7 +202,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
{...sliderFormControlProps}
|
||||
>
|
||||
{label && (
|
||||
<FormLabel {...sliderFormLabelProps} mb={-1}>
|
||||
<FormLabel sx={withInput ? { mb: -1.5 } : {}} {...sliderFormLabelProps}>
|
||||
{label}
|
||||
</FormLabel>
|
||||
)}
|
||||
@ -233,7 +228,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
sx={{
|
||||
insetInlineStart: '0 !important',
|
||||
insetInlineEnd: 'unset !important',
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
{...sliderMarkProps}
|
||||
>
|
||||
@ -244,7 +238,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
sx={{
|
||||
insetInlineStart: 'unset !important',
|
||||
insetInlineEnd: '0 !important',
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
{...sliderMarkProps}
|
||||
>
|
||||
@ -263,7 +256,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
sx={{
|
||||
insetInlineStart: '0 !important',
|
||||
insetInlineEnd: 'unset !important',
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
{...sliderMarkProps}
|
||||
>
|
||||
@ -278,7 +270,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
sx={{
|
||||
insetInlineStart: 'unset !important',
|
||||
insetInlineEnd: '0 !important',
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
{...sliderMarkProps}
|
||||
>
|
||||
@ -291,7 +282,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
key={m}
|
||||
value={m}
|
||||
sx={{
|
||||
...SLIDER_MARK_STYLES,
|
||||
transform: 'translateX(-50%)',
|
||||
}}
|
||||
{...sliderMarkProps}
|
||||
>
|
||||
|
@ -5,6 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { modelsApi } from '../../services/api/endpoints/models';
|
||||
import { forEach } from 'lodash-es';
|
||||
|
||||
const readinessSelector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
@ -52,6 +53,13 @@ const readinessSelector = createSelector(
|
||||
reasonsWhyNotReady.push('Seed-Weights badly formatted.');
|
||||
}
|
||||
|
||||
forEach(state.controlNet.controlNets, (controlNet, id) => {
|
||||
if (!controlNet.model) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push('ControlNet ${id} has no model selected.');
|
||||
}
|
||||
});
|
||||
|
||||
// All good
|
||||
return { isReady, reasonsWhyNotReady };
|
||||
},
|
||||
|
@ -1,10 +1,9 @@
|
||||
import { Box, ChakraProps, Flex, useColorMode } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { FaCopy, FaTrash } from 'react-icons/fa';
|
||||
import {
|
||||
ControlNetConfig,
|
||||
controlNetAdded,
|
||||
controlNetDuplicated,
|
||||
controlNetRemoved,
|
||||
controlNetToggled,
|
||||
} from '../store/controlNetSlice';
|
||||
@ -12,6 +11,9 @@ import ParamControlNetModel from './parameters/ParamControlNetModel';
|
||||
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
|
||||
|
||||
import { ChevronUpIcon } from '@chakra-ui/icons';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { useToggle } from 'react-use';
|
||||
@ -22,41 +24,41 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
|
||||
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
|
||||
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
|
||||
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
|
||||
import { mode } from 'theme/util/mode';
|
||||
|
||||
const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 };
|
||||
|
||||
type ControlNetProps = {
|
||||
controlNet: ControlNetConfig;
|
||||
controlNetId: string;
|
||||
};
|
||||
|
||||
const ControlNet = (props: ControlNetProps) => {
|
||||
const {
|
||||
controlNetId,
|
||||
isEnabled,
|
||||
model,
|
||||
weight,
|
||||
beginStepPct,
|
||||
endStepPct,
|
||||
controlMode,
|
||||
controlImage,
|
||||
processedControlImage,
|
||||
processorNode,
|
||||
processorType,
|
||||
shouldAutoConfig,
|
||||
} = props.controlNet;
|
||||
const { controlNetId } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { isEnabled, shouldAutoConfig } =
|
||||
controlNet.controlNets[controlNetId];
|
||||
|
||||
return { isEnabled, shouldAutoConfig };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
|
||||
const [isExpanded, toggleIsExpanded] = useToggle(false);
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const handleDelete = useCallback(() => {
|
||||
dispatch(controlNetRemoved({ controlNetId }));
|
||||
}, [controlNetId, dispatch]);
|
||||
|
||||
const handleDuplicate = useCallback(() => {
|
||||
dispatch(
|
||||
controlNetAdded({ controlNetId: uuidv4(), controlNet: props.controlNet })
|
||||
controlNetDuplicated({
|
||||
sourceControlNetId: controlNetId,
|
||||
newControlNetId: uuidv4(),
|
||||
})
|
||||
);
|
||||
}, [dispatch, props.controlNet]);
|
||||
}, [controlNetId, dispatch]);
|
||||
|
||||
const handleToggleIsEnabled = useCallback(() => {
|
||||
dispatch(controlNetToggled({ controlNetId }));
|
||||
@ -68,15 +70,18 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
flexDir: 'column',
|
||||
gap: 2,
|
||||
p: 3,
|
||||
bg: mode('base.200', 'base.850')(colorMode),
|
||||
borderRadius: 'base',
|
||||
position: 'relative',
|
||||
bg: 'base.200',
|
||||
_dark: {
|
||||
bg: 'base.850',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Flex sx={{ gap: 2 }}>
|
||||
<Flex sx={{ gap: 2, alignItems: 'center' }}>
|
||||
<IAISwitch
|
||||
tooltip="Toggle"
|
||||
aria-label="Toggle"
|
||||
tooltip={'Toggle this ControlNet'}
|
||||
aria-label={'Toggle this ControlNet'}
|
||||
isChecked={isEnabled}
|
||||
onChange={handleToggleIsEnabled}
|
||||
/>
|
||||
@ -90,7 +95,7 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
transitionDuration: '0.1s',
|
||||
}}
|
||||
>
|
||||
<ParamControlNetModel controlNetId={controlNetId} model={model} />
|
||||
<ParamControlNetModel controlNetId={controlNetId} />
|
||||
</Box>
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
@ -109,21 +114,26 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
/>
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
aria-label="Show All Options"
|
||||
tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
|
||||
aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
|
||||
onClick={toggleIsExpanded}
|
||||
variant="link"
|
||||
icon={
|
||||
<ChevronUpIcon
|
||||
sx={{
|
||||
boxSize: 4,
|
||||
color: mode('base.700', 'base.300')(colorMode),
|
||||
color: 'base.700',
|
||||
transform: isExpanded ? 'rotate(0deg)' : 'rotate(180deg)',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
_dark: {
|
||||
color: 'base.300',
|
||||
},
|
||||
}}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
|
||||
{!shouldAutoConfig && (
|
||||
<Box
|
||||
sx={{
|
||||
@ -131,85 +141,59 @@ const ControlNet = (props: ControlNetProps) => {
|
||||
w: 1.5,
|
||||
h: 1.5,
|
||||
borderRadius: 'full',
|
||||
bg: mode('error.700', 'error.200')(colorMode),
|
||||
top: 4,
|
||||
insetInlineEnd: 4,
|
||||
bg: 'accent.700',
|
||||
_dark: {
|
||||
bg: 'accent.400',
|
||||
},
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
{isEnabled && (
|
||||
<>
|
||||
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
|
||||
<Flex sx={{ gap: 4, w: 'full' }}>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
gap: 3,
|
||||
w: 'full',
|
||||
paddingInlineStart: 1,
|
||||
paddingInlineEnd: isExpanded ? 1 : 0,
|
||||
pb: 2,
|
||||
justifyContent: 'space-between',
|
||||
}}
|
||||
>
|
||||
<ParamControlNetWeight
|
||||
controlNetId={controlNetId}
|
||||
weight={weight}
|
||||
mini={!isExpanded}
|
||||
/>
|
||||
<ParamControlNetBeginEnd
|
||||
controlNetId={controlNetId}
|
||||
beginStepPct={beginStepPct}
|
||||
endStepPct={endStepPct}
|
||||
mini={!isExpanded}
|
||||
/>
|
||||
</Flex>
|
||||
{!isExpanded && (
|
||||
<Flex
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
h: 24,
|
||||
w: 24,
|
||||
aspectRatio: '1/1',
|
||||
}}
|
||||
>
|
||||
<ControlNetImagePreview
|
||||
controlNet={props.controlNet}
|
||||
height={24}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
<ParamControlNetControlMode
|
||||
controlNetId={controlNetId}
|
||||
controlMode={controlMode}
|
||||
/>
|
||||
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
|
||||
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
gap: 3,
|
||||
h: 28,
|
||||
w: 'full',
|
||||
paddingInlineStart: 1,
|
||||
paddingInlineEnd: isExpanded ? 1 : 0,
|
||||
pb: 2,
|
||||
justifyContent: 'space-between',
|
||||
}}
|
||||
>
|
||||
<ParamControlNetWeight controlNetId={controlNetId} />
|
||||
<ParamControlNetBeginEnd controlNetId={controlNetId} />
|
||||
</Flex>
|
||||
|
||||
{isExpanded && (
|
||||
<>
|
||||
<Box mt={2}>
|
||||
<ControlNetImagePreview
|
||||
controlNet={props.controlNet}
|
||||
height={96}
|
||||
/>
|
||||
</Box>
|
||||
<ParamControlNetProcessorSelect
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
/>
|
||||
<ControlNetProcessorComponent
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
/>
|
||||
<ParamControlNetShouldAutoConfig
|
||||
controlNetId={controlNetId}
|
||||
shouldAutoConfig={shouldAutoConfig}
|
||||
/>
|
||||
</>
|
||||
{!isExpanded && (
|
||||
<Flex
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
h: 28,
|
||||
w: 28,
|
||||
aspectRatio: '1/1',
|
||||
mt: 3,
|
||||
}}
|
||||
>
|
||||
<ControlNetImagePreview controlNetId={controlNetId} height={28} />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
<Box mt={2}>
|
||||
<ParamControlNetControlMode controlNetId={controlNetId} />
|
||||
</Box>
|
||||
<ParamControlNetProcessorSelect controlNetId={controlNetId} />
|
||||
</Flex>
|
||||
|
||||
{isExpanded && (
|
||||
<>
|
||||
<ControlNetImagePreview controlNetId={controlNetId} height="392px" />
|
||||
<ParamControlNetShouldAutoConfig controlNetId={controlNetId} />
|
||||
<ControlNetProcessorComponent controlNetId={controlNetId} />
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
|
@ -5,42 +5,57 @@ import {
|
||||
TypesafeDraggableData,
|
||||
TypesafeDroppableData,
|
||||
} from 'app/components/ImageDnd/typesafeDnd';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import { PostUploadAction } from 'services/api/thunks/image';
|
||||
import {
|
||||
ControlNetConfig,
|
||||
controlNetImageChanged,
|
||||
controlNetSelector,
|
||||
} from '../store/controlNetSlice';
|
||||
|
||||
const selector = createSelector(
|
||||
controlNetSelector,
|
||||
(controlNet) => {
|
||||
const { pendingControlImages } = controlNet;
|
||||
return { pendingControlImages };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
import { controlNetImageChanged } from '../store/controlNetSlice';
|
||||
|
||||
type Props = {
|
||||
controlNet: ControlNetConfig;
|
||||
controlNetId: string;
|
||||
height: SystemStyleObject['h'];
|
||||
};
|
||||
|
||||
const ControlNetImagePreview = (props: Props) => {
|
||||
const { height } = props;
|
||||
const {
|
||||
controlNetId,
|
||||
controlImage: controlImageName,
|
||||
processedControlImage: processedControlImageName,
|
||||
processorType,
|
||||
} = props.controlNet;
|
||||
const { height, controlNetId } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { pendingControlImages } = useAppSelector(selector);
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { pendingControlImages } = controlNet;
|
||||
const {
|
||||
controlImage,
|
||||
processedControlImage,
|
||||
processorType,
|
||||
isEnabled,
|
||||
} = controlNet.controlNets[controlNetId];
|
||||
|
||||
return {
|
||||
controlImageName: controlImage,
|
||||
processedControlImageName: processedControlImage,
|
||||
processorType,
|
||||
isEnabled,
|
||||
pendingControlImages,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
const {
|
||||
controlImageName,
|
||||
processedControlImageName,
|
||||
processorType,
|
||||
pendingControlImages,
|
||||
isEnabled,
|
||||
} = useAppSelector(selector);
|
||||
|
||||
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
||||
|
||||
@ -110,13 +125,15 @@ const ControlNetImagePreview = (props: Props) => {
|
||||
h: height,
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
pointerEvents: isEnabled ? 'auto' : 'none',
|
||||
opacity: isEnabled ? 1 : 0.5,
|
||||
}}
|
||||
>
|
||||
<IAIDndImage
|
||||
draggableData={draggableData}
|
||||
droppableData={droppableData}
|
||||
imageDTO={controlImage}
|
||||
isDropDisabled={shouldShowProcessedImage}
|
||||
isDropDisabled={shouldShowProcessedImage || !isEnabled}
|
||||
onClickReset={handleResetControlImage}
|
||||
postUploadAction={postUploadAction}
|
||||
resetTooltip="Reset Control Image"
|
||||
@ -140,6 +157,7 @@ const ControlNetImagePreview = (props: Props) => {
|
||||
droppableData={droppableData}
|
||||
imageDTO={processedControlImage}
|
||||
isUploadDisabled={true}
|
||||
isDropDisabled={!isEnabled}
|
||||
onClickReset={handleResetControlImage}
|
||||
resetTooltip="Reset Control Image"
|
||||
withResetIcon={Boolean(controlImage)}
|
||||
|
@ -1,10 +1,13 @@
|
||||
import { memo } from 'react';
|
||||
import { RequiredControlNetProcessorNode } from '../store/types';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { memo, useMemo } from 'react';
|
||||
import CannyProcessor from './processors/CannyProcessor';
|
||||
import HedProcessor from './processors/HedProcessor';
|
||||
import LineartProcessor from './processors/LineartProcessor';
|
||||
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
|
||||
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
|
||||
import HedProcessor from './processors/HedProcessor';
|
||||
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
|
||||
import LineartProcessor from './processors/LineartProcessor';
|
||||
import MediapipeFaceProcessor from './processors/MediapipeFaceProcessor';
|
||||
import MidasDepthProcessor from './processors/MidasDepthProcessor';
|
||||
import MlsdImageProcessor from './processors/MlsdImageProcessor';
|
||||
@ -15,23 +18,45 @@ import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
|
||||
|
||||
export type ControlNetProcessorProps = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredControlNetProcessorNode;
|
||||
};
|
||||
|
||||
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId } = props;
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { isEnabled, processorNode } =
|
||||
controlNet.controlNets[controlNetId];
|
||||
|
||||
return { isEnabled, processorNode };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
const { isEnabled, processorNode } = useAppSelector(selector);
|
||||
|
||||
if (processorNode.type === 'canny_image_processor') {
|
||||
return (
|
||||
<CannyProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (processorNode.type === 'hed_image_processor') {
|
||||
return (
|
||||
<HedProcessor controlNetId={controlNetId} processorNode={processorNode} />
|
||||
<HedProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
@ -40,6 +65,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<LineartProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -49,6 +75,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<ContentShuffleProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -58,6 +85,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<LineartAnimeProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -67,6 +95,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<MediapipeFaceProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -76,6 +105,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<MidasDepthProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -85,6 +115,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<MlsdImageProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -94,6 +125,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<NormalBaeProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -103,6 +135,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<OpenposeProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -112,6 +145,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<PidiProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -121,6 +155,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||
<ZoeDepthProcessor
|
||||
controlNetId={controlNetId}
|
||||
processorNode={processorNode}
|
||||
isEnabled={isEnabled}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
@ -1,18 +1,36 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
shouldAutoConfig: boolean;
|
||||
};
|
||||
|
||||
const ParamControlNetShouldAutoConfig = (props: Props) => {
|
||||
const { controlNetId, shouldAutoConfig } = props;
|
||||
const { controlNetId } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { isEnabled, shouldAutoConfig } =
|
||||
controlNet.controlNets[controlNetId];
|
||||
return { isEnabled, shouldAutoConfig };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleShouldAutoConfigChanged = useCallback(() => {
|
||||
dispatch(controlNetAutoConfigToggled({ controlNetId }));
|
||||
}, [controlNetId, dispatch]);
|
||||
@ -23,7 +41,7 @@ const ParamControlNetShouldAutoConfig = (props: Props) => {
|
||||
aria-label="Auto configure processor"
|
||||
isChecked={shouldAutoConfig}
|
||||
onChange={handleShouldAutoConfigChanged}
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,5 +1,4 @@
|
||||
import {
|
||||
ChakraProps,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
HStack,
|
||||
@ -10,34 +9,41 @@ import {
|
||||
RangeSliderTrack,
|
||||
Tooltip,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import {
|
||||
controlNetBeginStepPctChanged,
|
||||
controlNetEndStepPctChanged,
|
||||
} from 'features/controlNet/store/controlNetSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
|
||||
mt: 1.5,
|
||||
fontSize: '2xs',
|
||||
fontWeight: '500',
|
||||
color: 'base.400',
|
||||
};
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
mini?: boolean;
|
||||
};
|
||||
|
||||
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
|
||||
|
||||
const ParamControlNetBeginEnd = (props: Props) => {
|
||||
const { controlNetId, beginStepPct, mini = false, endStepPct } = props;
|
||||
const { controlNetId } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { beginStepPct, endStepPct, isEnabled } =
|
||||
controlNet.controlNets[controlNetId];
|
||||
return { beginStepPct, endStepPct, isEnabled };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
const { beginStepPct, endStepPct, isEnabled } = useAppSelector(selector);
|
||||
|
||||
const handleStepPctChanged = useCallback(
|
||||
(v: number[]) => {
|
||||
@ -55,7 +61,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
|
||||
}, [controlNetId, dispatch]);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormControl isDisabled={!isEnabled}>
|
||||
<FormLabel>Begin / End Step Percentage</FormLabel>
|
||||
<HStack w="100%" gap={2} alignItems="center">
|
||||
<RangeSlider
|
||||
@ -66,6 +72,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
|
||||
max={1}
|
||||
step={0.01}
|
||||
minStepsBetweenThumbs={5}
|
||||
isDisabled={!isEnabled}
|
||||
>
|
||||
<RangeSliderTrack>
|
||||
<RangeSliderFilledTrack />
|
||||
@ -76,38 +83,33 @@ const ParamControlNetBeginEnd = (props: Props) => {
|
||||
<Tooltip label={formatPct(endStepPct)} placement="top" hasArrow>
|
||||
<RangeSliderThumb index={1} />
|
||||
</Tooltip>
|
||||
{!mini && (
|
||||
<>
|
||||
<RangeSliderMark
|
||||
value={0}
|
||||
sx={{
|
||||
insetInlineStart: '0 !important',
|
||||
insetInlineEnd: 'unset !important',
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
>
|
||||
0%
|
||||
</RangeSliderMark>
|
||||
<RangeSliderMark
|
||||
value={0.5}
|
||||
sx={{
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
>
|
||||
50%
|
||||
</RangeSliderMark>
|
||||
<RangeSliderMark
|
||||
value={1}
|
||||
sx={{
|
||||
insetInlineStart: 'unset !important',
|
||||
insetInlineEnd: '0 !important',
|
||||
...SLIDER_MARK_STYLES,
|
||||
}}
|
||||
>
|
||||
100%
|
||||
</RangeSliderMark>
|
||||
</>
|
||||
)}
|
||||
<RangeSliderMark
|
||||
value={0}
|
||||
sx={{
|
||||
insetInlineStart: '0 !important',
|
||||
insetInlineEnd: 'unset !important',
|
||||
}}
|
||||
>
|
||||
0%
|
||||
</RangeSliderMark>
|
||||
<RangeSliderMark
|
||||
value={0.5}
|
||||
sx={{
|
||||
insetInlineStart: '50% !important',
|
||||
transform: 'translateX(-50%)',
|
||||
}}
|
||||
>
|
||||
50%
|
||||
</RangeSliderMark>
|
||||
<RangeSliderMark
|
||||
value={1}
|
||||
sx={{
|
||||
insetInlineStart: 'unset !important',
|
||||
insetInlineEnd: '0 !important',
|
||||
}}
|
||||
>
|
||||
100%
|
||||
</RangeSliderMark>
|
||||
</RangeSlider>
|
||||
</HStack>
|
||||
</FormControl>
|
||||
|
@ -1,15 +1,17 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import {
|
||||
ControlModes,
|
||||
controlNetControlModeChanged,
|
||||
} from 'features/controlNet/store/controlNetSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type ParamControlNetControlModeProps = {
|
||||
controlNetId: string;
|
||||
controlMode: string;
|
||||
};
|
||||
|
||||
const CONTROL_MODE_DATA = [
|
||||
@ -22,8 +24,23 @@ const CONTROL_MODE_DATA = [
|
||||
export default function ParamControlNetControlMode(
|
||||
props: ParamControlNetControlModeProps
|
||||
) {
|
||||
const { controlNetId, controlMode = false } = props;
|
||||
const { controlNetId } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { controlMode, isEnabled } =
|
||||
controlNet.controlNets[controlNetId];
|
||||
return { controlMode, isEnabled };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
const { controlMode, isEnabled } = useAppSelector(selector);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@ -36,7 +53,8 @@ export default function ParamControlNetControlMode(
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
label={t('parameters.controlNetControlMode')}
|
||||
disabled={!isEnabled}
|
||||
label="Control Mode"
|
||||
data={CONTROL_MODE_DATA}
|
||||
value={String(controlMode)}
|
||||
onChange={handleControlModeChange}
|
||||
|
@ -29,6 +29,9 @@ const ParamControlNetFeatureToggle = () => {
|
||||
label="Enable ControlNet"
|
||||
isChecked={isEnabled}
|
||||
onChange={handleChange}
|
||||
formControlProps={{
|
||||
width: '100%',
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,28 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { controlNetToggled } from 'features/controlNet/store/controlNetSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
type ParamControlNetIsEnabledProps = {
|
||||
controlNetId: string;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
|
||||
const { controlNetId, isEnabled } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleIsEnabledChanged = useCallback(() => {
|
||||
dispatch(controlNetToggled({ controlNetId }));
|
||||
}, [dispatch, controlNetId]);
|
||||
|
||||
return (
|
||||
<IAISwitch
|
||||
label="Enabled"
|
||||
isChecked={isEnabled}
|
||||
onChange={handleIsEnabledChanged}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamControlNetIsEnabled);
|
@ -1,36 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIFullCheckbox from 'common/components/IAIFullCheckbox';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import {
|
||||
controlNetToggled,
|
||||
isControlNetImagePreprocessedToggled,
|
||||
} from 'features/controlNet/store/controlNetSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
type ParamControlNetIsEnabledProps = {
|
||||
controlNetId: string;
|
||||
isControlImageProcessed: boolean;
|
||||
};
|
||||
|
||||
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
|
||||
const { controlNetId, isControlImageProcessed } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleIsControlImageProcessedToggled = useCallback(() => {
|
||||
dispatch(
|
||||
isControlNetImagePreprocessedToggled({
|
||||
controlNetId,
|
||||
})
|
||||
);
|
||||
}, [controlNetId, dispatch]);
|
||||
|
||||
return (
|
||||
<IAISwitch
|
||||
label="Preprocess"
|
||||
isChecked={isControlImageProcessed}
|
||||
onChange={handleIsControlImageProcessedToggled}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamControlNetIsEnabled);
|
@ -1,59 +1,118 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect, {
|
||||
IAISelectDataType,
|
||||
} from 'common/components/IAIMantineSearchableSelect';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import {
|
||||
CONTROLNET_MODELS,
|
||||
ControlNetModelName,
|
||||
} from 'features/controlNet/store/constants';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { configSelector } from 'features/system/store/configSelectors';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
type ParamControlNetModelProps = {
|
||||
controlNetId: string;
|
||||
model: ControlNetModelName;
|
||||
};
|
||||
|
||||
const selector = createSelector(configSelector, (config) => {
|
||||
const controlNetModels: IAISelectDataType[] = map(CONTROLNET_MODELS, (m) => ({
|
||||
label: m.label,
|
||||
value: m.type,
|
||||
})).filter(
|
||||
(d) =>
|
||||
!config.sd.disabledControlNetModels.includes(
|
||||
d.value as ControlNetModelName
|
||||
)
|
||||
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||
const { controlNetId } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ generation, controlNet }) => {
|
||||
const { model } = generation;
|
||||
const controlNetModel = controlNet.controlNets[controlNetId]?.model;
|
||||
const isEnabled = controlNet.controlNets[controlNetId]?.isEnabled;
|
||||
return { mainModel: model, controlNetModel, isEnabled };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
return controlNetModels;
|
||||
});
|
||||
const { mainModel, controlNetModel, isEnabled } = useAppSelector(selector);
|
||||
|
||||
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||
const { controlNetId, model } = props;
|
||||
const controlNetModels = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!controlNetModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(controlNetModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
const disabled = model?.base_model !== mainModel?.base_model;
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
disabled,
|
||||
tooltip: disabled
|
||||
? `Incompatible base model: ${model.base_model}`
|
||||
: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [controlNetModels, mainModel?.base_model]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
controlNetModels?.entities[
|
||||
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
|
||||
] ?? null,
|
||||
[
|
||||
controlNetModel?.base_model,
|
||||
controlNetModel?.model_name,
|
||||
controlNetModels?.entities,
|
||||
]
|
||||
);
|
||||
|
||||
const handleModelChanged = useCallback(
|
||||
(val: string | null) => {
|
||||
// TODO: do not cast
|
||||
const model = val as ControlNetModelName;
|
||||
dispatch(controlNetModelChanged({ controlNetId, model }));
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newControlNetModel = modelIdToControlNetModelParam(v);
|
||||
|
||||
if (!newControlNetModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
controlNetModelChanged({ controlNetId, model: newControlNetModel })
|
||||
);
|
||||
},
|
||||
[controlNetId, dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSearchableSelect
|
||||
data={controlNetModels}
|
||||
value={model}
|
||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||
data={data}
|
||||
error={
|
||||
!selectedModel || mainModel?.base_model !== selectedModel.base_model
|
||||
}
|
||||
placeholder="Select a model"
|
||||
value={selectedModel?.id ?? null}
|
||||
onChange={handleModelChanged}
|
||||
disabled={!isReady}
|
||||
tooltip={model}
|
||||
disabled={isBusy || !isEnabled}
|
||||
tooltip={selectedModel?.description}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,24 +1,22 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSearchableSelect, {
|
||||
IAISelectDataType,
|
||||
} from 'common/components/IAIMantineSearchableSelect';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import { configSelector } from 'features/system/store/configSelectors';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { CONTROLNET_PROCESSORS } from '../../store/constants';
|
||||
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
|
||||
import {
|
||||
ControlNetProcessorNode,
|
||||
ControlNetProcessorType,
|
||||
} from '../../store/types';
|
||||
import { ControlNetProcessorType } from '../../store/types';
|
||||
import { FormControl, FormLabel } from '@chakra-ui/react';
|
||||
|
||||
type ParamControlNetProcessorSelectProps = {
|
||||
controlNetId: string;
|
||||
processorNode: ControlNetProcessorNode;
|
||||
};
|
||||
|
||||
const selector = createSelector(
|
||||
@ -54,10 +52,24 @@ const selector = createSelector(
|
||||
const ParamControlNetProcessorSelect = (
|
||||
props: ParamControlNetProcessorSelectProps
|
||||
) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const { controlNetId } = props;
|
||||
const processorNodeSelector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { isEnabled, processorNode } =
|
||||
controlNet.controlNets[controlNetId];
|
||||
return { isEnabled, processorNode };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
const controlNetProcessors = useAppSelector(selector);
|
||||
const { isEnabled, processorNode } = useAppSelector(processorNodeSelector);
|
||||
|
||||
const handleProcessorTypeChanged = useCallback(
|
||||
(v: string | null) => {
|
||||
@ -77,7 +89,7 @@ const ParamControlNetProcessorSelect = (
|
||||
value={processorNode.type ?? 'canny_image_processor'}
|
||||
data={controlNetProcessors}
|
||||
onChange={handleProcessorTypeChanged}
|
||||
disabled={!isReady}
|
||||
disabled={isBusy || !isEnabled}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,18 +1,32 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
|
||||
type ParamControlNetWeightProps = {
|
||||
controlNetId: string;
|
||||
weight: number;
|
||||
mini?: boolean;
|
||||
};
|
||||
|
||||
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
|
||||
const { controlNetId, weight, mini = false } = props;
|
||||
const { controlNetId } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ controlNet }) => {
|
||||
const { weight, isEnabled } = controlNet.controlNets[controlNetId];
|
||||
return { weight, isEnabled };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[controlNetId]
|
||||
);
|
||||
|
||||
const { weight, isEnabled } = useAppSelector(selector);
|
||||
const handleWeightChanged = useCallback(
|
||||
(weight: number) => {
|
||||
dispatch(controlNetWeightChanged({ controlNetId, weight }));
|
||||
@ -22,15 +36,15 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
isDisabled={!isEnabled}
|
||||
label={'Weight'}
|
||||
sliderFormLabelProps={{ pb: 2 }}
|
||||
value={weight}
|
||||
onChange={handleWeightChanged}
|
||||
min={-1}
|
||||
max={1}
|
||||
min={0}
|
||||
max={2}
|
||||
step={0.01}
|
||||
withSliderMarks={!mini}
|
||||
sliderMarks={[-1, 0, 1]}
|
||||
withSliderMarks
|
||||
sliderMarks={[0, 1, 2]}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
@ -1,22 +1,25 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor
|
||||
.default as RequiredCannyImageProcessorInvocation;
|
||||
|
||||
type CannyProcessorProps = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredCannyImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const CannyProcessor = (props: CannyProcessorProps) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { low_threshold, high_threshold } = processorNode;
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
|
||||
const handleLowThresholdChanged = useCallback(
|
||||
@ -48,7 +51,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
|
||||
return (
|
||||
<ProcessorWrapper>
|
||||
<IAISlider
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
label="Low Threshold"
|
||||
value={low_threshold}
|
||||
onChange={handleLowThresholdChanged}
|
||||
@ -60,7 +63,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
|
||||
withSliderMarks
|
||||
/>
|
||||
<IAISlider
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
label="High Threshold"
|
||||
value={high_threshold}
|
||||
onChange={handleHighThresholdChanged}
|
||||
|
@ -4,20 +4,23 @@ import { RequiredContentShuffleImageProcessorInvocation } from 'features/control
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor
|
||||
.default as RequiredContentShuffleImageProcessorInvocation;
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredContentShuffleImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const ContentShuffleProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { image_resolution, detect_resolution, w, h, f } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -96,7 +99,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -108,7 +111,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="W"
|
||||
@ -120,7 +123,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="H"
|
||||
@ -132,7 +135,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="F"
|
||||
@ -144,7 +147,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,25 +1,29 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor
|
||||
.default as RequiredHedImageProcessorInvocation;
|
||||
|
||||
type HedProcessorProps = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredHedImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const HedPreprocessor = (props: HedProcessorProps) => {
|
||||
const {
|
||||
controlNetId,
|
||||
processorNode: { detect_resolution, image_resolution, scribble },
|
||||
isEnabled,
|
||||
} = props;
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
@ -67,7 +71,7 @@ const HedPreprocessor = (props: HedProcessorProps) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -79,13 +83,13 @@ const HedPreprocessor = (props: HedProcessorProps) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISwitch
|
||||
label="Scribble"
|
||||
isChecked={scribble}
|
||||
onChange={handleScribbleChanged}
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,23 +1,26 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor
|
||||
.default as RequiredLineartAnimeImageProcessorInvocation;
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredLineartAnimeImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const LineartAnimeProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { image_resolution, detect_resolution } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -57,7 +60,7 @@ const LineartAnimeProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -69,7 +72,7 @@ const LineartAnimeProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,24 +1,27 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor
|
||||
.default as RequiredLineartImageProcessorInvocation;
|
||||
|
||||
type LineartProcessorProps = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredLineartImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const LineartProcessor = (props: LineartProcessorProps) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { image_resolution, detect_resolution, coarse } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -65,7 +68,7 @@ const LineartProcessor = (props: LineartProcessorProps) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -77,13 +80,13 @@ const LineartProcessor = (props: LineartProcessorProps) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISwitch
|
||||
label="Coarse"
|
||||
isChecked={coarse}
|
||||
onChange={handleCoarseChanged}
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,23 +1,26 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor
|
||||
.default as RequiredMediapipeFaceProcessorInvocation;
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredMediapipeFaceProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const MediapipeFaceProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { max_faces, min_confidence } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleMaxFacesChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -53,7 +56,7 @@ const MediapipeFaceProcessor = (props: Props) => {
|
||||
max={20}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Min Confidence"
|
||||
@ -66,7 +69,7 @@ const MediapipeFaceProcessor = (props: Props) => {
|
||||
step={0.01}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,23 +1,26 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor
|
||||
.default as RequiredMidasDepthImageProcessorInvocation;
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredMidasDepthImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const MidasDepthProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { a_mult, bg_th } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleAMultChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -54,7 +57,7 @@ const MidasDepthProcessor = (props: Props) => {
|
||||
step={0.01}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="bg_th"
|
||||
@ -67,7 +70,7 @@ const MidasDepthProcessor = (props: Props) => {
|
||||
step={0.01}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,23 +1,26 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor
|
||||
.default as RequiredMlsdImageProcessorInvocation;
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredMlsdImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const MlsdImageProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -79,7 +82,7 @@ const MlsdImageProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -91,7 +94,7 @@ const MlsdImageProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="W"
|
||||
@ -104,7 +107,7 @@ const MlsdImageProcessor = (props: Props) => {
|
||||
step={0.01}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="H"
|
||||
@ -117,7 +120,7 @@ const MlsdImageProcessor = (props: Props) => {
|
||||
step={0.01}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,23 +1,26 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor
|
||||
.default as RequiredNormalbaeImageProcessorInvocation;
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredNormalbaeImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const NormalBaeProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { image_resolution, detect_resolution } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -57,7 +60,7 @@ const NormalBaeProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -69,7 +72,7 @@ const NormalBaeProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,24 +1,27 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor
|
||||
.default as RequiredOpenposeImageProcessorInvocation;
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredOpenposeImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const OpenposeProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { image_resolution, detect_resolution, hand_and_face } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -65,7 +68,7 @@ const OpenposeProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -77,13 +80,13 @@ const OpenposeProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISwitch
|
||||
label="Hand and Face"
|
||||
isChecked={hand_and_face}
|
||||
onChange={handleHandAndFaceChanged}
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -1,24 +1,27 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||
import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { ChangeEvent, memo, useCallback } from 'react';
|
||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default;
|
||||
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor
|
||||
.default as RequiredPidiImageProcessorInvocation;
|
||||
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredPidiImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const PidiProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode } = props;
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { image_resolution, detect_resolution, scribble, safe } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
|
||||
const handleDetectResolutionChanged = useCallback(
|
||||
(v: number) => {
|
||||
@ -72,7 +75,7 @@ const PidiProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISlider
|
||||
label="Image Resolution"
|
||||
@ -84,7 +87,7 @@ const PidiProcessor = (props: Props) => {
|
||||
max={4096}
|
||||
withInput
|
||||
withSliderMarks
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
<IAISwitch
|
||||
label="Scribble"
|
||||
@ -95,7 +98,7 @@ const PidiProcessor = (props: Props) => {
|
||||
label="Safe"
|
||||
isChecked={safe}
|
||||
onChange={handleSafeChanged}
|
||||
isDisabled={!isReady}
|
||||
isDisabled={isBusy || !isEnabled}
|
||||
/>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
|
@ -4,6 +4,7 @@ import { memo } from 'react';
|
||||
type Props = {
|
||||
controlNetId: string;
|
||||
processorNode: RequiredZoeDepthImageProcessorInvocation;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
const ZoeDepthProcessor = (props: Props) => {
|
||||
|
@ -173,91 +173,17 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
|
||||
},
|
||||
};
|
||||
|
||||
type ControlNetModelsDict = Record<string, ControlNetModel>;
|
||||
|
||||
type ControlNetModel = {
|
||||
type: string;
|
||||
label: string;
|
||||
description?: string;
|
||||
defaultProcessor?: ControlNetProcessorType;
|
||||
export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
|
||||
[key: string]: ControlNetProcessorType;
|
||||
} = {
|
||||
canny: 'canny_image_processor',
|
||||
mlsd: 'mlsd_image_processor',
|
||||
depth: 'midas_depth_image_processor',
|
||||
bae: 'normalbae_image_processor',
|
||||
lineart: 'lineart_image_processor',
|
||||
lineart_anime: 'lineart_anime_image_processor',
|
||||
softedge: 'hed_image_processor',
|
||||
shuffle: 'content_shuffle_image_processor',
|
||||
openpose: 'openpose_image_processor',
|
||||
mediapipe: 'mediapipe_face_processor',
|
||||
};
|
||||
|
||||
export const CONTROLNET_MODELS: ControlNetModelsDict = {
|
||||
'lllyasviel/control_v11p_sd15_canny': {
|
||||
type: 'lllyasviel/control_v11p_sd15_canny',
|
||||
label: 'Canny',
|
||||
defaultProcessor: 'canny_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_inpaint': {
|
||||
type: 'lllyasviel/control_v11p_sd15_inpaint',
|
||||
label: 'Inpaint',
|
||||
defaultProcessor: 'none',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_mlsd': {
|
||||
type: 'lllyasviel/control_v11p_sd15_mlsd',
|
||||
label: 'M-LSD',
|
||||
defaultProcessor: 'mlsd_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11f1p_sd15_depth': {
|
||||
type: 'lllyasviel/control_v11f1p_sd15_depth',
|
||||
label: 'Depth',
|
||||
defaultProcessor: 'midas_depth_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_normalbae': {
|
||||
type: 'lllyasviel/control_v11p_sd15_normalbae',
|
||||
label: 'Normal Map (BAE)',
|
||||
defaultProcessor: 'normalbae_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_seg': {
|
||||
type: 'lllyasviel/control_v11p_sd15_seg',
|
||||
label: 'Segmentation',
|
||||
defaultProcessor: 'none',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_lineart': {
|
||||
type: 'lllyasviel/control_v11p_sd15_lineart',
|
||||
label: 'Lineart',
|
||||
defaultProcessor: 'lineart_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15s2_lineart_anime': {
|
||||
type: 'lllyasviel/control_v11p_sd15s2_lineart_anime',
|
||||
label: 'Lineart Anime',
|
||||
defaultProcessor: 'lineart_anime_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_scribble': {
|
||||
type: 'lllyasviel/control_v11p_sd15_scribble',
|
||||
label: 'Scribble',
|
||||
defaultProcessor: 'none',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_softedge': {
|
||||
type: 'lllyasviel/control_v11p_sd15_softedge',
|
||||
label: 'Soft Edge',
|
||||
defaultProcessor: 'hed_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11e_sd15_shuffle': {
|
||||
type: 'lllyasviel/control_v11e_sd15_shuffle',
|
||||
label: 'Content Shuffle',
|
||||
defaultProcessor: 'content_shuffle_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_openpose': {
|
||||
type: 'lllyasviel/control_v11p_sd15_openpose',
|
||||
label: 'Openpose',
|
||||
defaultProcessor: 'openpose_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11f1e_sd15_tile': {
|
||||
type: 'lllyasviel/control_v11f1e_sd15_tile',
|
||||
label: 'Tile (experimental)',
|
||||
defaultProcessor: 'none',
|
||||
},
|
||||
'lllyasviel/control_v11e_sd15_ip2p': {
|
||||
type: 'lllyasviel/control_v11e_sd15_ip2p',
|
||||
label: 'Pix2Pix (experimental)',
|
||||
defaultProcessor: 'none',
|
||||
},
|
||||
'CrucibleAI/ControlNetMediaPipeFace': {
|
||||
type: 'CrucibleAI/ControlNetMediaPipeFace',
|
||||
label: 'Mediapipe Face',
|
||||
defaultProcessor: 'mediapipe_face_processor',
|
||||
},
|
||||
};
|
||||
|
||||
export type ControlNetModelName = keyof typeof CONTROLNET_MODELS;
|
||||
|
@ -1,22 +1,20 @@
|
||||
import { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
|
||||
import { cloneDeep, forEach } from 'lodash-es';
|
||||
import { imageDeleted } from 'services/api/thunks/image';
|
||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
||||
import { appSocketInvocationError } from 'services/events/actions';
|
||||
import { controlNetImageProcessed } from './actions';
|
||||
import {
|
||||
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
|
||||
CONTROLNET_PROCESSORS,
|
||||
} from './constants';
|
||||
import {
|
||||
ControlNetProcessorType,
|
||||
RequiredCannyImageProcessorInvocation,
|
||||
RequiredControlNetProcessorNode,
|
||||
} from './types';
|
||||
import {
|
||||
CONTROLNET_MODELS,
|
||||
CONTROLNET_PROCESSORS,
|
||||
ControlNetModelName,
|
||||
} from './constants';
|
||||
import { controlNetImageProcessed } from './actions';
|
||||
import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
||||
import { appSocketInvocationError } from 'services/events/actions';
|
||||
|
||||
export type ControlModes =
|
||||
| 'balanced'
|
||||
@ -26,7 +24,7 @@ export type ControlModes =
|
||||
|
||||
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||
isEnabled: true,
|
||||
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
|
||||
model: null,
|
||||
weight: 1,
|
||||
beginStepPct: 0,
|
||||
endStepPct: 1,
|
||||
@ -42,7 +40,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||
export type ControlNetConfig = {
|
||||
controlNetId: string;
|
||||
isEnabled: boolean;
|
||||
model: ControlNetModelName;
|
||||
model: ControlNetModelParam | null;
|
||||
weight: number;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
@ -86,6 +84,19 @@ export const controlNetSlice = createSlice({
|
||||
controlNetId,
|
||||
};
|
||||
},
|
||||
controlNetDuplicated: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
sourceControlNetId: string;
|
||||
newControlNetId: string;
|
||||
}>
|
||||
) => {
|
||||
const { sourceControlNetId, newControlNetId } = action.payload;
|
||||
|
||||
const newControlnet = cloneDeep(state.controlNets[sourceControlNetId]);
|
||||
newControlnet.controlNetId = newControlNetId;
|
||||
state.controlNets[newControlNetId] = newControlnet;
|
||||
},
|
||||
controlNetAddedFromImage: (
|
||||
state,
|
||||
action: PayloadAction<{ controlNetId: string; controlImage: string }>
|
||||
@ -147,7 +158,7 @@ export const controlNetSlice = createSlice({
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
controlNetId: string;
|
||||
model: ControlNetModelName;
|
||||
model: ControlNetModelParam;
|
||||
}>
|
||||
) => {
|
||||
const { controlNetId, model } = action.payload;
|
||||
@ -155,7 +166,15 @@ export const controlNetSlice = createSlice({
|
||||
state.controlNets[controlNetId].processedControlImage = null;
|
||||
|
||||
if (state.controlNets[controlNetId].shouldAutoConfig) {
|
||||
const processorType = CONTROLNET_MODELS[model].defaultProcessor;
|
||||
let processorType: ControlNetProcessorType | undefined = undefined;
|
||||
|
||||
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||
if (model.model_name.includes(modelSubstring)) {
|
||||
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (processorType) {
|
||||
state.controlNets[controlNetId].processorType = processorType;
|
||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
||||
@ -241,9 +260,19 @@ export const controlNetSlice = createSlice({
|
||||
|
||||
if (newShouldAutoConfig) {
|
||||
// manage the processor for the user
|
||||
const processorType =
|
||||
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
|
||||
.defaultProcessor;
|
||||
let processorType: ControlNetProcessorType | undefined = undefined;
|
||||
|
||||
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||
if (
|
||||
state.controlNets[controlNetId].model?.model_name.includes(
|
||||
modelSubstring
|
||||
)
|
||||
) {
|
||||
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (processorType) {
|
||||
state.controlNets[controlNetId].processorType = processorType;
|
||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
||||
@ -272,7 +301,8 @@ export const controlNetSlice = createSlice({
|
||||
});
|
||||
|
||||
builder.addCase(imageDeleted.pending, (state, action) => {
|
||||
// Preemptively remove the image from the gallery
|
||||
// Preemptively remove the image from all controlnets
|
||||
// TODO: doesn't the imageusage stuff do this for us?
|
||||
const { image_name } = action.meta.arg;
|
||||
forEach(state.controlNets, (c) => {
|
||||
if (c.controlImage === image_name) {
|
||||
@ -285,21 +315,6 @@ export const controlNetSlice = createSlice({
|
||||
});
|
||||
});
|
||||
|
||||
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
// const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
// forEach(state.controlNets, (c) => {
|
||||
// if (c.controlImage?.image_name === image_name) {
|
||||
// c.controlImage.image_url = image_url;
|
||||
// c.controlImage.thumbnail_url = thumbnail_url;
|
||||
// }
|
||||
// if (c.processedControlImage?.image_name === image_name) {
|
||||
// c.processedControlImage.image_url = image_url;
|
||||
// c.processedControlImage.thumbnail_url = thumbnail_url;
|
||||
// }
|
||||
// });
|
||||
// });
|
||||
|
||||
builder.addCase(appSocketInvocationError, (state, action) => {
|
||||
state.pendingControlImages = [];
|
||||
});
|
||||
@ -313,6 +328,7 @@ export const controlNetSlice = createSlice({
|
||||
export const {
|
||||
isControlNetEnabledToggled,
|
||||
controlNetAdded,
|
||||
controlNetDuplicated,
|
||||
controlNetAddedFromImage,
|
||||
controlNetRemoved,
|
||||
controlNetImageChanged,
|
||||
|
@ -118,6 +118,20 @@ const GalleryImageGrid = () => {
|
||||
);
|
||||
}, [dispatch, imageNames.length, galleryView]);
|
||||
|
||||
useEffect(() => {
|
||||
// Set up gallery scroler
|
||||
const { current: root } = rootRef;
|
||||
if (scroller && root) {
|
||||
initialize({
|
||||
target: root,
|
||||
elements: {
|
||||
viewport: scroller,
|
||||
},
|
||||
});
|
||||
}
|
||||
return () => osInstance()?.destroy();
|
||||
}, [scroller, initialize, osInstance]);
|
||||
|
||||
const handleEndReached = useMemo(() => {
|
||||
if (areMoreAvailable) {
|
||||
return handleLoadMoreImages;
|
||||
|
@ -3,6 +3,7 @@ import { LoRAModelParam } from 'features/parameters/types/parameterSchemas';
|
||||
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
||||
|
||||
export type LoRA = LoRAModelParam & {
|
||||
id: string;
|
||||
weight: number;
|
||||
};
|
||||
|
||||
@ -24,7 +25,7 @@ export const loraSlice = createSlice({
|
||||
reducers: {
|
||||
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
|
||||
const { model_name, id, base_model } = action.payload;
|
||||
state.loras[id] = { model_name, base_model, ...defaultLoRAConfig };
|
||||
state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
|
||||
},
|
||||
loraRemoved: (state, action: PayloadAction<string>) => {
|
||||
const id = action.payload;
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { Flex, Heading, Icon, Tooltip } from '@chakra-ui/react';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/hooks/useBuildInvocation';
|
||||
import { memo } from 'react';
|
||||
import { FaInfoCircle } from 'react-icons/fa';
|
||||
|
||||
@ -12,6 +13,7 @@ const IAINodeHeader = (props: IAINodeHeaderProps) => {
|
||||
const { nodeId, title, description } = props;
|
||||
return (
|
||||
<Flex
|
||||
className={DRAG_HANDLE_CLASSNAME}
|
||||
sx={{
|
||||
borderTopRadius: 'md',
|
||||
alignItems: 'center',
|
||||
|
@ -1,25 +1,25 @@
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
InputFieldValue,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo, ReactNode, useCallback } from 'react';
|
||||
import { map } from 'lodash-es';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
Box,
|
||||
Divider,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
HStack,
|
||||
Tooltip,
|
||||
Divider,
|
||||
} from '@chakra-ui/react';
|
||||
import FieldHandle from '../FieldHandle';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
||||
import InputFieldComponent from '../InputFieldComponent';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
InputFieldValue,
|
||||
InvocationTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { map } from 'lodash-es';
|
||||
import { ReactNode, memo, useCallback } from 'react';
|
||||
import FieldHandle from '../FieldHandle';
|
||||
import InputFieldComponent from '../InputFieldComponent';
|
||||
|
||||
interface IAINodeInputProps {
|
||||
nodeId: string;
|
||||
@ -35,6 +35,7 @@ function IAINodeInput(props: IAINodeInputProps) {
|
||||
|
||||
return (
|
||||
<Box
|
||||
className="nopan"
|
||||
position="relative"
|
||||
borderColor={
|
||||
!template
|
||||
@ -136,7 +137,7 @@ const IAINodeInputs = (props: IAINodeInputsProps) => {
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={2} p={2}>
|
||||
<Flex className="nopan" flexDir="column" gap={2} p={2}>
|
||||
{IAINodeInputsToRender}
|
||||
</Flex>
|
||||
);
|
||||
|
@ -7,6 +7,7 @@ import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
||||
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
||||
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
||||
import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent';
|
||||
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
||||
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||
@ -174,6 +175,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'controlnet_model' && template.type === 'controlnet_model') {
|
||||
return (
|
||||
<ControlNetModelInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'array' && template.type === 'array') {
|
||||
return (
|
||||
<ArrayInputFieldComponent
|
||||
|
@ -23,7 +23,14 @@ export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
|
||||
if (!template) {
|
||||
return (
|
||||
<NodeWrapper selected={selected}>
|
||||
<Flex sx={{ alignItems: 'center', justifyContent: 'center' }}>
|
||||
<Flex
|
||||
className="nopan"
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
cursor: 'auto',
|
||||
}}
|
||||
>
|
||||
<Icon
|
||||
as={FaExclamationCircle}
|
||||
sx={{
|
||||
@ -46,7 +53,9 @@ export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
|
||||
description={template.description}
|
||||
/>
|
||||
<Flex
|
||||
className={'nopan'}
|
||||
sx={{
|
||||
cursor: 'auto',
|
||||
flexDirection: 'column',
|
||||
borderBottomRadius: 'md',
|
||||
py: 2,
|
||||
|
@ -2,6 +2,8 @@ import { Box, useToken } from '@chakra-ui/react';
|
||||
import { NODE_MIN_WIDTH } from 'app/constants';
|
||||
|
||||
import { PropsWithChildren } from 'react';
|
||||
import { DRAG_HANDLE_CLASSNAME } from '../hooks/useBuildInvocation';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
|
||||
type NodeWrapperProps = PropsWithChildren & {
|
||||
selected: boolean;
|
||||
@ -13,8 +15,11 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
'dark-lg',
|
||||
]);
|
||||
|
||||
const shift = useAppSelector((state) => state.hotkeys.shift);
|
||||
|
||||
return (
|
||||
<Box
|
||||
className={shift ? DRAG_HANDLE_CLASSNAME : 'nopan'}
|
||||
sx={{
|
||||
position: 'relative',
|
||||
borderRadius: 'md',
|
||||
|
@ -21,6 +21,7 @@ const ProgressImageNode = (props: NodeProps<InvocationValue>) => {
|
||||
/>
|
||||
|
||||
<Flex
|
||||
className="nopan"
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
borderBottomRadius: 'md',
|
||||
|
@ -0,0 +1,103 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
ControlNetModelInputFieldTemplate,
|
||||
ControlNetModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const ControlNetModelInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
ControlNetModelInputFieldValue,
|
||||
ControlNetModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const controlNetModel = field.value;
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
controlNetModels?.entities[
|
||||
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
|
||||
] ?? null,
|
||||
[
|
||||
controlNetModel?.base_model,
|
||||
controlNetModel?.model_name,
|
||||
controlNetModels?.entities,
|
||||
]
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!controlNetModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(controlNetModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [controlNetModels]);
|
||||
|
||||
const handleValueChanged = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newControlNetModel = modelIdToControlNetModelParam(v);
|
||||
|
||||
if (!newControlNetModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: newControlNetModel,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={
|
||||
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||
}
|
||||
value={selectedModel?.id ?? null}
|
||||
placeholder="Pick one"
|
||||
error={!selectedModel}
|
||||
data={data}
|
||||
onChange={handleValueChanged}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ControlNetModelInputFieldComponent);
|
@ -6,6 +6,7 @@ import {
|
||||
NumberInputStepper,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { numberStringRegex } from 'common/components/IAINumberInput';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FloatInputFieldTemplate,
|
||||
@ -13,7 +14,7 @@ import {
|
||||
IntegerInputFieldTemplate,
|
||||
IntegerInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
import { memo, useEffect, useState } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const NumberInputFieldComponent = (
|
||||
@ -23,17 +24,42 @@ const NumberInputFieldComponent = (
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const [valueAsString, setValueAsString] = useState<string>(
|
||||
String(field.value)
|
||||
);
|
||||
|
||||
const handleValueChanged = (_: string, value: number) => {
|
||||
dispatch(fieldValueChanged({ nodeId, fieldName: field.name, value }));
|
||||
const handleValueChanged = (v: string) => {
|
||||
setValueAsString(v);
|
||||
// This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc.
|
||||
if (!v.match(numberStringRegex)) {
|
||||
// Cast the value to number. Floor it if it should be an integer.
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value:
|
||||
props.template.type === 'integer'
|
||||
? Math.floor(Number(v))
|
||||
: Number(v),
|
||||
})
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (
|
||||
!valueAsString.match(numberStringRegex) &&
|
||||
field.value !== Number(valueAsString)
|
||||
) {
|
||||
setValueAsString(String(field.value));
|
||||
}
|
||||
}, [field.value, valueAsString]);
|
||||
|
||||
return (
|
||||
<NumberInput
|
||||
onChange={handleValueChanged}
|
||||
value={field.value}
|
||||
value={valueAsString}
|
||||
step={props.template.type === 'integer' ? 1 : 0.1}
|
||||
precision={props.template.type === 'integer' ? 0 : 3}
|
||||
>
|
||||
|
@ -18,6 +18,12 @@ const templatesSelector = createSelector(
|
||||
(nodes) => nodes.invocationTemplates
|
||||
);
|
||||
|
||||
export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle';
|
||||
|
||||
export const SHARED_NODE_PROPERTIES: Partial<Node> = {
|
||||
dragHandle: `.${DRAG_HANDLE_CLASSNAME}`,
|
||||
};
|
||||
|
||||
export const useBuildInvocation = () => {
|
||||
const invocationTemplates = useAppSelector(templatesSelector);
|
||||
|
||||
@ -32,6 +38,7 @@ export const useBuildInvocation = () => {
|
||||
});
|
||||
|
||||
const node: Node = {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
id: 'progress_image',
|
||||
type: 'progress_image',
|
||||
position: { x: x, y: y },
|
||||
@ -91,6 +98,7 @@ export const useBuildInvocation = () => {
|
||||
});
|
||||
|
||||
const invocation: Node<InvocationValue> = {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
id: nodeId,
|
||||
type: 'invocation',
|
||||
position: { x: x, y: y },
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
ControlNetModelParam,
|
||||
LoRAModelParam,
|
||||
MainModelParam,
|
||||
VaeModelParam,
|
||||
@ -81,7 +82,8 @@ const nodesSlice = createSlice({
|
||||
| ImageField[]
|
||||
| MainModelParam
|
||||
| VaeModelParam
|
||||
| LoRAModelParam;
|
||||
| LoRAModelParam
|
||||
| ControlNetModelParam;
|
||||
}>
|
||||
) => {
|
||||
const { nodeId, fieldName, value } = action.payload;
|
||||
|
@ -19,6 +19,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
||||
model: 'model',
|
||||
vae_model: 'vae_model',
|
||||
lora_model: 'lora_model',
|
||||
controlnet_model: 'controlnet_model',
|
||||
ControlNetModelField: 'controlnet_model',
|
||||
array: 'array',
|
||||
item: 'item',
|
||||
ColorField: 'color',
|
||||
@ -130,6 +132,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
title: 'LoRA',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
controlnet_model: {
|
||||
color: 'teal',
|
||||
colorCssVar: getColorTokenCssVariable('teal'),
|
||||
title: 'ControlNet',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
array: {
|
||||
color: 'gray',
|
||||
colorCssVar: getColorTokenCssVariable('gray'),
|
||||
|
@ -1,4 +1,5 @@
|
||||
import {
|
||||
ControlNetModelParam,
|
||||
LoRAModelParam,
|
||||
MainModelParam,
|
||||
VaeModelParam,
|
||||
@ -71,6 +72,7 @@ export type FieldType =
|
||||
| 'model'
|
||||
| 'vae_model'
|
||||
| 'lora_model'
|
||||
| 'controlnet_model'
|
||||
| 'array'
|
||||
| 'item'
|
||||
| 'color'
|
||||
@ -100,6 +102,7 @@ export type InputFieldValue =
|
||||
| MainModelInputFieldValue
|
||||
| VaeModelInputFieldValue
|
||||
| LoRAModelInputFieldValue
|
||||
| ControlNetModelInputFieldValue
|
||||
| ArrayInputFieldValue
|
||||
| ItemInputFieldValue
|
||||
| ColorInputFieldValue
|
||||
@ -127,6 +130,7 @@ export type InputFieldTemplate =
|
||||
| ModelInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate
|
||||
| LoRAModelInputFieldTemplate
|
||||
| ControlNetModelInputFieldTemplate
|
||||
| ArrayInputFieldTemplate
|
||||
| ItemInputFieldTemplate
|
||||
| ColorInputFieldTemplate
|
||||
@ -249,6 +253,11 @@ export type LoRAModelInputFieldValue = FieldValueBase & {
|
||||
value?: LoRAModelParam;
|
||||
};
|
||||
|
||||
export type ControlNetModelInputFieldValue = FieldValueBase & {
|
||||
type: 'controlnet_model';
|
||||
value?: ControlNetModelParam;
|
||||
};
|
||||
|
||||
export type ArrayInputFieldValue = FieldValueBase & {
|
||||
type: 'array';
|
||||
value?: (string | number)[];
|
||||
@ -368,6 +377,11 @@ export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'lora_model';
|
||||
};
|
||||
|
||||
export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
type: 'controlnet_model';
|
||||
};
|
||||
|
||||
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: [];
|
||||
type: 'array';
|
||||
|
@ -9,6 +9,7 @@ import {
|
||||
ColorInputFieldTemplate,
|
||||
ConditioningInputFieldTemplate,
|
||||
ControlInputFieldTemplate,
|
||||
ControlNetModelInputFieldTemplate,
|
||||
EnumInputFieldTemplate,
|
||||
FieldType,
|
||||
FloatInputFieldTemplate,
|
||||
@ -207,6 +208,21 @@ const buildLoRAModelInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildControlNetModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ControlNetModelInputFieldTemplate => {
|
||||
const template: ControlNetModelInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'controlnet_model',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'direct',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImageInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -479,6 +495,9 @@ export const buildInputFieldTemplate = (
|
||||
if (['lora_model'].includes(fieldType)) {
|
||||
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['controlnet_model'].includes(fieldType)) {
|
||||
return buildControlNetModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['enum'].includes(fieldType)) {
|
||||
return buildEnumInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
|
@ -83,6 +83,10 @@ export const buildInputFieldValue = (
|
||||
if (template.type === 'lora_model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'controlnet_model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
return fieldValue;
|
||||
|
@ -60,7 +60,7 @@ export const addLoRAsToGraph = (
|
||||
const loraLoaderNode: LoraLoaderInvocation = {
|
||||
type: 'lora_loader',
|
||||
id: currentLoraNodeId,
|
||||
lora,
|
||||
lora: { model_name, base_model },
|
||||
weight,
|
||||
};
|
||||
|
||||
|
@ -2,12 +2,13 @@ import { Divider, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAICollapse from 'common/components/IAICollapse';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import ControlNet from 'features/controlNet/components/ControlNet';
|
||||
import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle';
|
||||
import {
|
||||
controlNetAdded,
|
||||
controlNetModelChanged,
|
||||
controlNetSelector,
|
||||
} from 'features/controlNet/store/controlNetSlice';
|
||||
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
|
||||
@ -15,6 +16,8 @@ import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { map } from 'lodash-es';
|
||||
import { Fragment, memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaPlus } from 'react-icons/fa';
|
||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
const selector = createSelector(
|
||||
@ -39,10 +42,23 @@ const ParamControlNetCollapse = () => {
|
||||
const { controlNetsArray, activeLabel } = useAppSelector(selector);
|
||||
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
|
||||
const dispatch = useAppDispatch();
|
||||
const { firstModel } = useGetControlNetModelsQuery(undefined, {
|
||||
selectFromResult: (result) => {
|
||||
const firstModel = result.data?.entities[result.data?.ids[0]];
|
||||
return {
|
||||
firstModel,
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
const handleClickedAddControlNet = useCallback(() => {
|
||||
dispatch(controlNetAdded({ controlNetId: uuidv4() }));
|
||||
}, [dispatch]);
|
||||
if (!firstModel) {
|
||||
return;
|
||||
}
|
||||
const controlNetId = uuidv4();
|
||||
dispatch(controlNetAdded({ controlNetId }));
|
||||
dispatch(controlNetModelChanged({ controlNetId, model: firstModel }));
|
||||
}, [dispatch, firstModel]);
|
||||
|
||||
if (isControlNetDisabled) {
|
||||
return null;
|
||||
@ -51,16 +67,39 @@ const ParamControlNetCollapse = () => {
|
||||
return (
|
||||
<IAICollapse label="ControlNet" activeLabel={activeLabel}>
|
||||
<Flex sx={{ flexDir: 'column', gap: 3 }}>
|
||||
<ParamControlNetFeatureToggle />
|
||||
<Flex gap={2} alignItems="center">
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
w: '100%',
|
||||
gap: 2,
|
||||
px: 4,
|
||||
py: 2,
|
||||
borderRadius: 4,
|
||||
bg: 'base.200',
|
||||
_dark: {
|
||||
bg: 'base.850',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<ParamControlNetFeatureToggle />
|
||||
</Flex>
|
||||
<IAIIconButton
|
||||
tooltip="Add ControlNet"
|
||||
aria-label="Add ControlNet"
|
||||
icon={<FaPlus />}
|
||||
isDisabled={!firstModel}
|
||||
flexGrow={1}
|
||||
size="md"
|
||||
onClick={handleClickedAddControlNet}
|
||||
/>
|
||||
</Flex>
|
||||
{controlNetsArray.map((c, i) => (
|
||||
<Fragment key={c.controlNetId}>
|
||||
{i > 0 && <Divider />}
|
||||
<ControlNet controlNet={c} />
|
||||
<ControlNet controlNetId={c.controlNetId} />
|
||||
</Fragment>
|
||||
))}
|
||||
<IAIButton flexGrow={1} onClick={handleClickedAddControlNet}>
|
||||
Add ControlNet
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
</IAICollapse>
|
||||
);
|
||||
|
@ -37,6 +37,7 @@ const ParamVAEModelSelect = () => {
|
||||
return [];
|
||||
}
|
||||
|
||||
// add a "default" option, this means use the main model's included VAE
|
||||
const data: SelectItem[] = [
|
||||
{
|
||||
value: 'default',
|
||||
|
@ -17,8 +17,10 @@ import { FaPlay } from 'react-icons/fa';
|
||||
const IN_PROGRESS_STYLES: ChakraProps['sx'] = {
|
||||
_disabled: {
|
||||
bg: 'none',
|
||||
color: 'base.600',
|
||||
cursor: 'not-allowed',
|
||||
_hover: {
|
||||
color: 'base.600',
|
||||
bg: 'none',
|
||||
},
|
||||
},
|
||||
|
@ -180,6 +180,23 @@ export type LoRAModelParam = z.infer<typeof zLoRAModel>;
|
||||
*/
|
||||
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
|
||||
zLoRAModel.safeParse(val).success;
|
||||
/**
|
||||
* Zod schema for ControlNet models
|
||||
*/
|
||||
export const zControlNetModel = z.object({
|
||||
model_name: z.string().min(1),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type ControlNetModelParam = z.infer<typeof zLoRAModel>;
|
||||
/**
|
||||
* Validates/type-guards a value as a model parameter
|
||||
*/
|
||||
export const isValidControlNetModel = (
|
||||
val: unknown
|
||||
): val is ControlNetModelParam => zControlNetModel.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for l2l strength parameter
|
||||
|
@ -0,0 +1,30 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { zControlNetModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { ControlNetModelField } from 'services/api/types';
|
||||
|
||||
const moduleLog = log.child({ module: 'models' });
|
||||
|
||||
export const modelIdToControlNetModelParam = (
|
||||
controlNetModelId: string
|
||||
): ControlNetModelField | undefined => {
|
||||
const [base_model, model_type, model_name] = controlNetModelId.split('/');
|
||||
|
||||
const result = zControlNetModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
moduleLog.error(
|
||||
{
|
||||
controlNetModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse ControlNet model id'
|
||||
);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
return result.data;
|
||||
};
|
@ -1,9 +1,12 @@
|
||||
import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
|
||||
const moduleLog = log.child({ module: 'models' });
|
||||
|
||||
export const modelIdToLoRAModelParam = (
|
||||
loraId: string
|
||||
loraModelId: string
|
||||
): LoRAModelParam | undefined => {
|
||||
const [base_model, model_type, model_name] = loraId.split('/');
|
||||
const [base_model, model_type, model_name] = loraModelId.split('/');
|
||||
|
||||
const result = zLoRAModel.safeParse({
|
||||
base_model,
|
||||
@ -11,6 +14,13 @@ export const modelIdToLoRAModelParam = (
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
moduleLog.error(
|
||||
{
|
||||
loraModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse LoRA model id'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2,11 +2,14 @@ import {
|
||||
MainModelParam,
|
||||
zMainModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
|
||||
const moduleLog = log.child({ module: 'models' });
|
||||
|
||||
export const modelIdToMainModelParam = (
|
||||
modelId: string
|
||||
mainModelId: string
|
||||
): MainModelParam | undefined => {
|
||||
const [base_model, model_type, model_name] = modelId.split('/');
|
||||
const [base_model, model_type, model_name] = mainModelId.split('/');
|
||||
|
||||
const result = zMainModel.safeParse({
|
||||
base_model,
|
||||
@ -14,6 +17,13 @@ export const modelIdToMainModelParam = (
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
moduleLog.error(
|
||||
{
|
||||
mainModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse main model id'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1,9 +1,12 @@
|
||||
import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
|
||||
const moduleLog = log.child({ module: 'models' });
|
||||
|
||||
export const modelIdToVAEModelParam = (
|
||||
modelId: string
|
||||
vaeModelId: string
|
||||
): VaeModelParam | undefined => {
|
||||
const [base_model, model_type, model_name] = modelId.split('/');
|
||||
const [base_model, model_type, model_name] = vaeModelId.split('/');
|
||||
|
||||
const result = zVaeModel.safeParse({
|
||||
base_model,
|
||||
@ -11,6 +14,13 @@ export const modelIdToVAEModelParam = (
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
moduleLog.error(
|
||||
{
|
||||
vaeModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse VAE model id'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -19,9 +19,9 @@ const ImageToImageTabParameters = () => {
|
||||
<ParamNegativeConditioning />
|
||||
<ProcessButtons />
|
||||
<ImageToImageTabCoreParameters />
|
||||
<ParamControlNetCollapse />
|
||||
<ParamLoraCollapse />
|
||||
<ParamDynamicPromptsCollapse />
|
||||
<ParamControlNetCollapse />
|
||||
<ParamVariationCollapse />
|
||||
<ParamNoiseCollapse />
|
||||
<ParamSymmetryCollapse />
|
||||
|
@ -20,9 +20,9 @@ const TextToImageTabParameters = () => {
|
||||
<ParamNegativeConditioning />
|
||||
<ProcessButtons />
|
||||
<TextToImageTabCoreParameters />
|
||||
<ParamControlNetCollapse />
|
||||
<ParamLoraCollapse />
|
||||
<ParamDynamicPromptsCollapse />
|
||||
<ParamControlNetCollapse />
|
||||
<ParamVariationCollapse />
|
||||
<ParamNoiseCollapse />
|
||||
<ParamSymmetryCollapse />
|
||||
|
@ -19,9 +19,9 @@ const UnifiedCanvasParameters = () => {
|
||||
<ParamNegativeConditioning />
|
||||
<ProcessButtons />
|
||||
<UnifiedCanvasCoreParameters />
|
||||
<ParamControlNetCollapse />
|
||||
<ParamLoraCollapse />
|
||||
<ParamDynamicPromptsCollapse />
|
||||
<ParamControlNetCollapse />
|
||||
<ParamVariationCollapse />
|
||||
<ParamSymmetryCollapse />
|
||||
<ParamSeamCorrectionCollapse />
|
||||
|
@ -734,7 +734,7 @@ export type components = {
|
||||
* Control Model
|
||||
* @description The ControlNet model to use
|
||||
*/
|
||||
control_model: string;
|
||||
control_model: components["schemas"]["ControlNetModelField"];
|
||||
/**
|
||||
* Control Weight
|
||||
* @description The weight given to the ControlNet
|
||||
@ -791,10 +791,9 @@ export type components = {
|
||||
/**
|
||||
* Control Model
|
||||
* @description control model used
|
||||
* @default lllyasviel/sd-controlnet-canny
|
||||
* @enum {string}
|
||||
* @default lllyasviel/sd-controlnet-canny
|
||||
*/
|
||||
control_model?: "lllyasviel/sd-controlnet-canny" | "lllyasviel/sd-controlnet-depth" | "lllyasviel/sd-controlnet-hed" | "lllyasviel/sd-controlnet-seg" | "lllyasviel/sd-controlnet-openpose" | "lllyasviel/sd-controlnet-scribble" | "lllyasviel/sd-controlnet-normal" | "lllyasviel/sd-controlnet-mlsd" | "lllyasviel/control_v11p_sd15_canny" | "lllyasviel/control_v11p_sd15_openpose" | "lllyasviel/control_v11p_sd15_seg" | "lllyasviel/control_v11f1p_sd15_depth" | "lllyasviel/control_v11p_sd15_normalbae" | "lllyasviel/control_v11p_sd15_scribble" | "lllyasviel/control_v11p_sd15_mlsd" | "lllyasviel/control_v11p_sd15_softedge" | "lllyasviel/control_v11p_sd15s2_lineart_anime" | "lllyasviel/control_v11p_sd15_lineart" | "lllyasviel/control_v11p_sd15_inpaint" | "lllyasviel/control_v11e_sd15_shuffle" | "lllyasviel/control_v11e_sd15_ip2p" | "lllyasviel/control_v11f1e_sd15_tile" | "thibaud/controlnet-sd21-openpose-diffusers" | "thibaud/controlnet-sd21-canny-diffusers" | "thibaud/controlnet-sd21-depth-diffusers" | "thibaud/controlnet-sd21-scribble-diffusers" | "thibaud/controlnet-sd21-hed-diffusers" | "thibaud/controlnet-sd21-zoedepth-diffusers" | "thibaud/controlnet-sd21-color-diffusers" | "thibaud/controlnet-sd21-openposev2-diffusers" | "thibaud/controlnet-sd21-lineart-diffusers" | "thibaud/controlnet-sd21-normalbae-diffusers" | "thibaud/controlnet-sd21-ade20k-diffusers" | "CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15" | "CrucibleAI/ControlNetMediaPipeFace";
|
||||
control_model?: components["schemas"]["ControlNetModelField"];
|
||||
/**
|
||||
* Control Weight
|
||||
* @description The weight given to the ControlNet
|
||||
@ -838,6 +837,19 @@ export type components = {
|
||||
model_format: components["schemas"]["ControlNetModelFormat"];
|
||||
error?: components["schemas"]["ModelError"];
|
||||
};
|
||||
/**
|
||||
* ControlNetModelField
|
||||
* @description ControlNet model field
|
||||
*/
|
||||
ControlNetModelField: {
|
||||
/**
|
||||
* Model Name
|
||||
* @description Name of the ControlNet model
|
||||
*/
|
||||
model_name: string;
|
||||
/** @description Base model */
|
||||
base_model: components["schemas"]["BaseModelType"];
|
||||
};
|
||||
/**
|
||||
* ControlNetModelFormat
|
||||
* @description An enumeration.
|
||||
@ -1923,12 +1935,12 @@ export type components = {
|
||||
* Width
|
||||
* @description The width to resize to (px)
|
||||
*/
|
||||
width: number;
|
||||
width?: number;
|
||||
/**
|
||||
* Height
|
||||
* @description The height to resize to (px)
|
||||
*/
|
||||
height: number;
|
||||
height?: number;
|
||||
/**
|
||||
* Resample Mode
|
||||
* @description The resampling mode
|
||||
@ -3910,14 +3922,16 @@ export type components = {
|
||||
latents?: components["schemas"]["LatentsField"];
|
||||
/**
|
||||
* Width
|
||||
* @description The width to resize to (px)
|
||||
* @description The width to resize to (px)
|
||||
* @default 512
|
||||
*/
|
||||
width: number;
|
||||
width?: number;
|
||||
/**
|
||||
* Height
|
||||
* @description The height to resize to (px)
|
||||
* @description The height to resize to (px)
|
||||
* @default 512
|
||||
*/
|
||||
height: number;
|
||||
height?: number;
|
||||
/**
|
||||
* Mode
|
||||
* @description The interpolation mode
|
||||
@ -4605,18 +4619,18 @@ export type components = {
|
||||
*/
|
||||
image?: components["schemas"]["ImageField"];
|
||||
};
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
};
|
||||
responses: never;
|
||||
parameters: never;
|
||||
|
@ -32,6 +32,8 @@ export type BaseModelType = components['schemas']['BaseModelType'];
|
||||
export type MainModelField = components['schemas']['MainModelField'];
|
||||
export type VAEModelField = components['schemas']['VAEModelField'];
|
||||
export type LoRAModelField = components['schemas']['LoRAModelField'];
|
||||
export type ControlNetModelField =
|
||||
components['schemas']['ControlNetModelField'];
|
||||
export type ModelsList = components['schemas']['ModelsList'];
|
||||
export type ControlField = components['schemas']['ControlField'];
|
||||
|
||||
|
@ -30,7 +30,7 @@ const invokeAIThumb = defineStyle((props) => {
|
||||
|
||||
const invokeAIMark = defineStyle((props) => {
|
||||
return {
|
||||
fontSize: 'xs',
|
||||
fontSize: '2xs',
|
||||
fontWeight: '500',
|
||||
color: mode('base.700', 'base.400')(props),
|
||||
mt: 2,
|
||||
|
Loading…
Reference in New Issue
Block a user