mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into lstein/default-model-install
This commit is contained in:
commit
32e7e52d69
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
|
||||||
|
|
||||||
|
|
||||||
|
import pathlib
|
||||||
from typing import Literal, List, Optional, Union
|
from typing import Literal, List, Optional, Union
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response
|
||||||
@ -22,6 +23,7 @@ UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
|||||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||||
@ -78,7 +80,7 @@ async def update_model(
|
|||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/",
|
"/import",
|
||||||
operation_id="import_model",
|
operation_id="import_model",
|
||||||
responses= {
|
responses= {
|
||||||
201: {"description" : "The model imported successfully"},
|
201: {"description" : "The model imported successfully"},
|
||||||
@ -94,7 +96,7 @@ async def import_model(
|
|||||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||||
) -> ImportModelResponse:
|
) -> ImportModelResponse:
|
||||||
""" Add a model using its local path, repo_id, or remote URL """
|
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
|
||||||
|
|
||||||
items_to_import = {location}
|
items_to_import = {location}
|
||||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||||
@ -126,18 +128,100 @@ async def import_model(
|
|||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
|
@models_router.post(
|
||||||
|
"/add",
|
||||||
|
operation_id="add_model",
|
||||||
|
responses= {
|
||||||
|
201: {"description" : "The model added successfully"},
|
||||||
|
404: {"description" : "The model could not be found"},
|
||||||
|
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
|
||||||
|
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
response_model=ImportModelResponse
|
||||||
|
)
|
||||||
|
async def add_model(
|
||||||
|
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||||
|
) -> ImportModelResponse:
|
||||||
|
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||||
|
|
||||||
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
ApiDependencies.invoker.services.model_manager.add_model(
|
||||||
|
info.model_name,
|
||||||
|
info.base_model,
|
||||||
|
info.model_type,
|
||||||
|
model_attributes = info.dict()
|
||||||
|
)
|
||||||
|
logger.info(f'Successfully added {info.model_name}')
|
||||||
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
|
model_name=info.model_name,
|
||||||
|
base_model=info.base_model,
|
||||||
|
model_type=info.model_type
|
||||||
|
)
|
||||||
|
return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
|
@models_router.post(
|
||||||
|
"/rename/{base_model}/{model_type}/{model_name}",
|
||||||
|
operation_id="rename_model",
|
||||||
|
responses= {
|
||||||
|
201: {"description" : "The model was renamed successfully"},
|
||||||
|
404: {"description" : "The model could not be found"},
|
||||||
|
409: {"description" : "There is already a model corresponding to the new name"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
response_model=ImportModelResponse
|
||||||
|
)
|
||||||
|
async def rename_model(
|
||||||
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
|
model_name: str = Path(description="current model name"),
|
||||||
|
new_name: Optional[str] = Query(description="new model name", default=None),
|
||||||
|
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
|
||||||
|
) -> ImportModelResponse:
|
||||||
|
""" Rename a model"""
|
||||||
|
|
||||||
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = ApiDependencies.invoker.services.model_manager.rename_model(
|
||||||
|
base_model = base_model,
|
||||||
|
model_type = model_type,
|
||||||
|
model_name = model_name,
|
||||||
|
new_name = new_name,
|
||||||
|
new_base = new_base,
|
||||||
|
)
|
||||||
|
logger.debug(result)
|
||||||
|
logger.info(f'Successfully renamed {model_name}=>{new_name}')
|
||||||
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
|
model_name=new_name or model_name,
|
||||||
|
base_model=new_base or base_model,
|
||||||
|
model_type=model_type
|
||||||
|
)
|
||||||
|
return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
@models_router.delete(
|
@models_router.delete(
|
||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="del_model",
|
operation_id="del_model",
|
||||||
responses={
|
responses={
|
||||||
204: {
|
204: { "description": "Model deleted successfully" },
|
||||||
"description": "Model deleted successfully"
|
404: { "description": "Model not found" }
|
||||||
},
|
|
||||||
404: {
|
|
||||||
"description": "Model not found"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
status_code = 204,
|
||||||
|
response_model = None,
|
||||||
)
|
)
|
||||||
async def delete_model(
|
async def delete_model(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
@ -173,14 +257,17 @@ async def convert_model(
|
|||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_type: ModelType = Path(description="The type of model"),
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
model_name: str = Path(description="model name"),
|
model_name: str = Path(description="model name"),
|
||||||
|
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
|
||||||
) -> ConvertModelResponse:
|
) -> ConvertModelResponse:
|
||||||
"""Convert a checkpoint model into a diffusers model"""
|
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
try:
|
||||||
logger.info(f"Converting model: {model_name}")
|
logger.info(f"Converting model: {model_name}")
|
||||||
|
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||||
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
||||||
base_model = base_model,
|
base_model = base_model,
|
||||||
model_type = model_type
|
model_type = model_type,
|
||||||
|
convert_dest_directory = dest,
|
||||||
)
|
)
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
|
||||||
base_model = base_model,
|
base_model = base_model,
|
||||||
@ -192,6 +279,53 @@ async def convert_model(
|
|||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@models_router.get(
|
||||||
|
"/search",
|
||||||
|
operation_id="search_for_models",
|
||||||
|
responses={
|
||||||
|
200: { "description": "Directory searched successfully" },
|
||||||
|
404: { "description": "Invalid directory path" },
|
||||||
|
},
|
||||||
|
status_code = 200,
|
||||||
|
response_model = List[pathlib.Path]
|
||||||
|
)
|
||||||
|
async def search_for_models(
|
||||||
|
search_path: pathlib.Path = Query(description="Directory path to search for models")
|
||||||
|
)->List[pathlib.Path]:
|
||||||
|
if not search_path.is_dir():
|
||||||
|
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
|
||||||
|
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
|
||||||
|
|
||||||
|
@models_router.get(
|
||||||
|
"/ckpt_confs",
|
||||||
|
operation_id="list_ckpt_configs",
|
||||||
|
responses={
|
||||||
|
200: { "description" : "paths retrieved successfully" },
|
||||||
|
},
|
||||||
|
status_code = 200,
|
||||||
|
response_model = List[pathlib.Path]
|
||||||
|
)
|
||||||
|
async def list_ckpt_configs(
|
||||||
|
)->List[pathlib.Path]:
|
||||||
|
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
||||||
|
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
||||||
|
|
||||||
|
|
||||||
|
@models_router.get(
|
||||||
|
"/sync",
|
||||||
|
operation_id="sync_to_config",
|
||||||
|
responses={
|
||||||
|
201: { "description": "synchronization successful" },
|
||||||
|
},
|
||||||
|
status_code = 201,
|
||||||
|
response_model = None
|
||||||
|
)
|
||||||
|
async def sync_to_config(
|
||||||
|
)->None:
|
||||||
|
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||||
|
in-memory data structures with disk data structures."""
|
||||||
|
return ApiDependencies.invoker.services.model_manager.sync_to_config()
|
||||||
|
|
||||||
@models_router.put(
|
@models_router.put(
|
||||||
"/merge/{base_model}",
|
"/merge/{base_model}",
|
||||||
operation_id="merge_models",
|
operation_id="merge_models",
|
||||||
@ -210,17 +344,21 @@ async def merge_models(
|
|||||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||||
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
|
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
|
||||||
|
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
|
||||||
) -> MergeModelResponse:
|
) -> MergeModelResponse:
|
||||||
"""Convert a checkpoint model into a diffusers model"""
|
"""Convert a checkpoint model into a diffusers model"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
try:
|
||||||
logger.info(f"Merging models: {model_names}")
|
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||||
|
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||||
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
||||||
base_model,
|
base_model,
|
||||||
merged_model_name or "+".join(model_names),
|
merged_model_name=merged_model_name or "+".join(model_names),
|
||||||
alpha,
|
alpha=alpha,
|
||||||
interp,
|
interp=interp,
|
||||||
force)
|
force=force,
|
||||||
|
merge_dest_directory = dest
|
||||||
|
)
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
||||||
base_model = base_model,
|
base_model = base_model,
|
||||||
model_type = ModelType.Main,
|
model_type = ModelType.Main,
|
||||||
|
@ -100,7 +100,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=ti_manager,
|
textual_inversion_manager=ti_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=torch_dtype,
|
||||||
truncate_long_prompts=True, # TODO:
|
truncate_long_prompts=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||||
@ -112,9 +112,6 @@ class CompelInvocation(BaseInvocation):
|
|||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(
|
c, options = compel.build_conditioning_tensor_for_prompt_object(
|
||||||
prompt)
|
prompt)
|
||||||
|
|
||||||
# TODO: long prompt support
|
|
||||||
# if not self.truncate_long_prompts:
|
|
||||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||||
tokens_count_including_eos_bos=get_max_token_count(
|
tokens_count_including_eos_bos=get_max_token_count(
|
||||||
tokenizer, conjunction),
|
tokenizer, conjunction),
|
||||||
|
@ -9,6 +9,7 @@ from typing import Literal, Optional, Union, List, Dict
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
|
from ...backend.model_management import BaseModelType, ModelType
|
||||||
from ..models.image import ImageField, ImageCategory, ResourceOrigin
|
from ..models.image import ImageField, ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -105,9 +106,15 @@ CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control
|
|||||||
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
|
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetModelField(BaseModel):
|
||||||
|
"""ControlNet model field"""
|
||||||
|
|
||||||
|
model_name: str = Field(description="Name of the ControlNet model")
|
||||||
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
image: ImageField = Field(default=None, description="The control image")
|
image: ImageField = Field(default=None, description="The control image")
|
||||||
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
|
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
|
||||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||||
@ -118,15 +125,15 @@ class ControlField(BaseModel):
|
|||||||
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||||
|
|
||||||
@validator("control_weight")
|
@validator("control_weight")
|
||||||
def abs_le_one(cls, v):
|
def validate_control_weight(cls, v):
|
||||||
"""validate that all abs(values) are <=1"""
|
"""Validate that all control weights in the valid range"""
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
for i in v:
|
for i in v:
|
||||||
if abs(i) > 1:
|
if i < -1 or i > 2:
|
||||||
raise ValueError('all abs(control_weight) must be <= 1')
|
raise ValueError('Control weights must be within -1 to 2 range')
|
||||||
else:
|
else:
|
||||||
if abs(v) > 1:
|
if v < -1 or v > 2:
|
||||||
raise ValueError('abs(control_weight) must be <= 1')
|
raise ValueError('Control weights must be within -1 to 2 range')
|
||||||
return v
|
return v
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
@ -134,6 +141,7 @@ class ControlField(BaseModel):
|
|||||||
"ui": {
|
"ui": {
|
||||||
"type_hints": {
|
"type_hints": {
|
||||||
"control_weight": "float",
|
"control_weight": "float",
|
||||||
|
"control_model": "controlnet_model",
|
||||||
# "control_weight": "number",
|
# "control_weight": "number",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -154,10 +162,10 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
type: Literal["controlnet"] = "controlnet"
|
type: Literal["controlnet"] = "controlnet"
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The control image")
|
image: ImageField = Field(default=None, description="The control image")
|
||||||
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
|
||||||
description="control model used")
|
description="control model used")
|
||||||
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
||||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
begin_step_percent: float = Field(default=0, ge=-1, le=2,
|
||||||
description="When the ControlNet is first applied (% of total steps)")
|
description="When the ControlNet is first applied (% of total steps)")
|
||||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||||
description="When the ControlNet is last applied (% of total steps)")
|
description="When the ControlNet is last applied (% of total steps)")
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from contextlib import ExitStack
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
@ -11,6 +12,7 @@ from pydantic import BaseModel, Field, validator
|
|||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
|
from invokeai.backend.model_management.models.base import ModelType
|
||||||
|
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
@ -71,16 +73,21 @@ def get_scheduler(
|
|||||||
scheduler_name: str,
|
scheduler_name: str,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
|
||||||
scheduler_name, SCHEDULER_MAP['ddim'])
|
scheduler_name, SCHEDULER_MAP['ddim']
|
||||||
|
)
|
||||||
orig_scheduler_info = context.services.model_manager.get_model(
|
orig_scheduler_info = context.services.model_manager.get_model(
|
||||||
**scheduler_info.dict())
|
**scheduler_info.dict()
|
||||||
|
)
|
||||||
with orig_scheduler_info as orig_scheduler:
|
with orig_scheduler_info as orig_scheduler:
|
||||||
scheduler_config = orig_scheduler.config
|
scheduler_config = orig_scheduler.config
|
||||||
|
|
||||||
if "_backup" in scheduler_config:
|
if "_backup" in scheduler_config:
|
||||||
scheduler_config = scheduler_config["_backup"]
|
scheduler_config = scheduler_config["_backup"]
|
||||||
scheduler_config = {**scheduler_config, **
|
scheduler_config = {
|
||||||
scheduler_extra_config, "_backup": scheduler_config}
|
**scheduler_config,
|
||||||
|
**scheduler_extra_config,
|
||||||
|
"_backup": scheduler_config,
|
||||||
|
}
|
||||||
scheduler = scheduler_class.from_config(scheduler_config)
|
scheduler = scheduler_class.from_config(scheduler_config)
|
||||||
|
|
||||||
# hack copied over from generate.py
|
# hack copied over from generate.py
|
||||||
@ -137,8 +144,11 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self, context: InvocationContext, source_node_id: str,
|
self,
|
||||||
intermediate_state: PipelineIntermediateState) -> None:
|
context: InvocationContext,
|
||||||
|
source_node_id: str,
|
||||||
|
intermediate_state: PipelineIntermediateState,
|
||||||
|
) -> None:
|
||||||
stable_diffusion_step_callback(
|
stable_diffusion_step_callback(
|
||||||
context=context,
|
context=context,
|
||||||
intermediate_state=intermediate_state,
|
intermediate_state=intermediate_state,
|
||||||
@ -147,11 +157,16 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_conditioning_data(
|
def get_conditioning_data(
|
||||||
self, context: InvocationContext, scheduler) -> ConditioningData:
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
scheduler,
|
||||||
|
) -> ConditioningData:
|
||||||
c, extra_conditioning_info = context.services.latents.get(
|
c, extra_conditioning_info = context.services.latents.get(
|
||||||
self.positive_conditioning.conditioning_name)
|
self.positive_conditioning.conditioning_name
|
||||||
|
)
|
||||||
uc, _ = context.services.latents.get(
|
uc, _ = context.services.latents.get(
|
||||||
self.negative_conditioning.conditioning_name)
|
self.negative_conditioning.conditioning_name
|
||||||
|
)
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
unconditioned_embeddings=uc,
|
unconditioned_embeddings=uc,
|
||||||
@ -178,7 +193,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
def create_pipeline(
|
def create_pipeline(
|
||||||
self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
self,
|
||||||
|
unet,
|
||||||
|
scheduler,
|
||||||
|
) -> StableDiffusionGeneratorPipeline:
|
||||||
# TODO:
|
# TODO:
|
||||||
# configure_model_padding(
|
# configure_model_padding(
|
||||||
# unet,
|
# unet,
|
||||||
@ -213,6 +231,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
model: StableDiffusionGeneratorPipeline,
|
model: StableDiffusionGeneratorPipeline,
|
||||||
control_input: List[ControlField],
|
control_input: List[ControlField],
|
||||||
latents_shape: List[int],
|
latents_shape: List[int],
|
||||||
|
exit_stack: ExitStack,
|
||||||
do_classifier_free_guidance: bool = True,
|
do_classifier_free_guidance: bool = True,
|
||||||
) -> List[ControlNetData]:
|
) -> List[ControlNetData]:
|
||||||
|
|
||||||
@ -238,25 +257,19 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
control_data = []
|
control_data = []
|
||||||
control_models = []
|
control_models = []
|
||||||
for control_info in control_list:
|
for control_info in control_list:
|
||||||
# handle control models
|
control_model = exit_stack.enter_context(
|
||||||
if ("," in control_info.control_model):
|
context.services.model_manager.get_model(
|
||||||
control_model_split = control_info.control_model.split(",")
|
model_name=control_info.control_model.model_name,
|
||||||
control_name = control_model_split[0]
|
model_type=ModelType.ControlNet,
|
||||||
control_subfolder = control_model_split[1]
|
base_model=control_info.control_model.base_model,
|
||||||
print("Using HF model subfolders")
|
)
|
||||||
print(" control_name: ", control_name)
|
)
|
||||||
print(" control_subfolder: ", control_subfolder)
|
|
||||||
control_model = ControlNetModel.from_pretrained(
|
|
||||||
control_name, subfolder=control_subfolder,
|
|
||||||
torch_dtype=model.unet.dtype).to(
|
|
||||||
model.device)
|
|
||||||
else:
|
|
||||||
control_model = ControlNetModel.from_pretrained(
|
|
||||||
control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
|
|
||||||
control_models.append(control_model)
|
control_models.append(control_model)
|
||||||
control_image_field = control_info.image
|
control_image_field = control_info.image
|
||||||
input_image = context.services.images.get_pil_image(
|
input_image = context.services.images.get_pil_image(
|
||||||
control_image_field.image_name)
|
control_image_field.image_name
|
||||||
|
)
|
||||||
# self.image.image_type, self.image.image_name
|
# self.image.image_type, self.image.image_name
|
||||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
# and add in batch_size, num_images_per_prompt?
|
# and add in batch_size, num_images_per_prompt?
|
||||||
@ -278,7 +291,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
weight=control_info.control_weight,
|
weight=control_info.control_weight,
|
||||||
begin_step_percent=control_info.begin_step_percent,
|
begin_step_percent=control_info.begin_step_percent,
|
||||||
end_step_percent=control_info.end_step_percent,
|
end_step_percent=control_info.end_step_percent,
|
||||||
control_mode=control_info.control_mode,)
|
control_mode=control_info.control_mode,
|
||||||
|
)
|
||||||
control_data.append(control_item)
|
control_data.append(control_item)
|
||||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||||
return control_data
|
return control_data
|
||||||
@ -289,7 +303,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
context.graph_execution_state_id)
|
context.graph_execution_state_id
|
||||||
|
)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
@ -298,14 +313,17 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}))
|
**lora.dict(exclude={"weight"})
|
||||||
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict())
|
**self.unet.unet.dict()
|
||||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
)
|
||||||
|
with ExitStack() as exit_stack,\
|
||||||
|
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||||
unet_info as unet:
|
unet_info as unet:
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
@ -322,6 +340,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
latents_shape=noise.shape,
|
latents_shape=noise.shape,
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
|
exit_stack=exit_stack,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
@ -374,7 +393,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
context.graph_execution_state_id)
|
context.graph_execution_state_id
|
||||||
|
)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
@ -383,14 +403,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}))
|
**lora.dict(exclude={"weight"})
|
||||||
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict())
|
**self.unet.unet.dict()
|
||||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
)
|
||||||
|
with ExitStack() as exit_stack,\
|
||||||
|
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||||
unet_info as unet:
|
unet_info as unet:
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
@ -407,11 +430,13 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
latents_shape=noise.shape,
|
latents_shape=noise.shape,
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
|
exit_stack=exit_stack,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||||
latent, device=unet.device, dtype=latent.dtype)
|
latent, device=unet.device, dtype=latent.dtype
|
||||||
|
)
|
||||||
|
|
||||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||||
self.steps,
|
self.steps,
|
||||||
@ -535,7 +560,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents, size=(self.height // 8, self.width // 8),
|
latents, size=(self.height // 8, self.width // 8),
|
||||||
mode=self.mode, antialias=self.antialias
|
mode=self.mode, antialias=self.antialias
|
||||||
if self.mode in ["bilinear", "bicubic"] else False,)
|
if self.mode in ["bilinear", "bicubic"] else False,
|
||||||
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -569,7 +595,8 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents, scale_factor=self.scale_factor, mode=self.mode,
|
latents, scale_factor=self.scale_factor, mode=self.mode,
|
||||||
antialias=self.antialias
|
antialias=self.antialias
|
||||||
if self.mode in ["bilinear", "bicubic"] else False,)
|
if self.mode in ["bilinear", "bicubic"] else False,
|
||||||
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -19,7 +19,7 @@ from invokeai.backend.model_management import (
|
|||||||
ModelMerger,
|
ModelMerger,
|
||||||
MergeInterpolationMethod,
|
MergeInterpolationMethod,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_management.model_search import FindModels
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
@ -167,6 +167,27 @@ class ModelManagerServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def rename_model(self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
new_name: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rename the indicated model.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_checkpoint_configs(
|
||||||
|
self
|
||||||
|
)->List[Path]:
|
||||||
|
"""
|
||||||
|
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert_model(
|
def convert_model(
|
||||||
self,
|
self,
|
||||||
@ -220,6 +241,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
alpha: Optional[float] = 0.5,
|
alpha: Optional[float] = 0.5,
|
||||||
interp: Optional[MergeInterpolationMethod] = None,
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
force: Optional[bool] = False,
|
force: Optional[bool] = False,
|
||||||
|
merge_dest_directory: Optional[Path] = None
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Merge two to three diffusrs pipeline models and save as a new model.
|
Merge two to three diffusrs pipeline models and save as a new model.
|
||||||
@ -228,6 +250,23 @@ class ModelManagerServiceBase(ABC):
|
|||||||
:param merged_model_name: Name of destination merged model
|
:param merged_model_name: Name of destination merged model
|
||||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||||
:param interp: Interpolation method. None (default)
|
:param interp: Interpolation method. None (default)
|
||||||
|
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_for_models(self, directory: Path)->List[Path]:
|
||||||
|
"""
|
||||||
|
Return list of all models found in the designated directory.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def sync_to_config(self):
|
||||||
|
"""
|
||||||
|
Re-read models.yaml, rescan the models directory, and reimport models
|
||||||
|
in the autoimport directories. Call after making changes outside the
|
||||||
|
model manager API.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -431,16 +470,18 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
"""
|
"""
|
||||||
Delete the named model from configuration. If delete_files is true,
|
Delete the named model from configuration. If delete_files is true,
|
||||||
then the underlying weight file or diffusers directory will be deleted
|
then the underlying weight file or diffusers directory will be deleted
|
||||||
as well. Call commit() to write to disk.
|
as well.
|
||||||
"""
|
"""
|
||||||
self.logger.debug(f'delete model {model_name}')
|
self.logger.debug(f'delete model {model_name}')
|
||||||
self.mgr.del_model(model_name, base_model, model_type)
|
self.mgr.del_model(model_name, base_model, model_type)
|
||||||
|
self.mgr.commit()
|
||||||
|
|
||||||
def convert_model(
|
def convert_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||||
|
convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||||
@ -449,13 +490,14 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
:param model_name: Name of the model to convert
|
:param model_name: Name of the model to convert
|
||||||
:param base_model: Base model type
|
:param base_model: Base model type
|
||||||
:param model_type: Type of model ['vae' or 'main']
|
:param model_type: Type of model ['vae' or 'main']
|
||||||
|
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
||||||
|
|
||||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||||
directory already in place.
|
directory already in place.
|
||||||
"""
|
"""
|
||||||
self.logger.debug(f'convert model {model_name}')
|
self.logger.debug(f'convert model {model_name}')
|
||||||
return self.mgr.convert_model(model_name, base_model, model_type)
|
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
||||||
|
|
||||||
def commit(self, conf_file: Optional[Path]=None):
|
def commit(self, conf_file: Optional[Path]=None):
|
||||||
"""
|
"""
|
||||||
@ -536,6 +578,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
alpha: Optional[float] = 0.5,
|
alpha: Optional[float] = 0.5,
|
||||||
interp: Optional[MergeInterpolationMethod] = None,
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
force: Optional[bool] = False,
|
force: Optional[bool] = False,
|
||||||
|
merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Merge two to three diffusrs pipeline models and save as a new model.
|
Merge two to three diffusrs pipeline models and save as a new model.
|
||||||
@ -544,6 +587,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
:param merged_model_name: Name of destination merged model
|
:param merged_model_name: Name of destination merged model
|
||||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||||
:param interp: Interpolation method. None (default)
|
:param interp: Interpolation method. None (default)
|
||||||
|
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||||
"""
|
"""
|
||||||
merger = ModelMerger(self.mgr)
|
merger = ModelMerger(self.mgr)
|
||||||
try:
|
try:
|
||||||
@ -554,7 +598,55 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
alpha = alpha,
|
alpha = alpha,
|
||||||
interp = interp,
|
interp = interp,
|
||||||
force = force,
|
force = force,
|
||||||
|
merge_dest_directory=merge_dest_directory,
|
||||||
)
|
)
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
raise ValueError(e)
|
raise ValueError(e)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def search_for_models(self, directory: Path)->List[Path]:
|
||||||
|
"""
|
||||||
|
Return list of all models found in the designated directory.
|
||||||
|
"""
|
||||||
|
search = FindModels(directory,self.logger)
|
||||||
|
return search.list_models()
|
||||||
|
|
||||||
|
def sync_to_config(self):
|
||||||
|
"""
|
||||||
|
Re-read models.yaml, rescan the models directory, and reimport models
|
||||||
|
in the autoimport directories. Call after making changes outside the
|
||||||
|
model manager API.
|
||||||
|
"""
|
||||||
|
return self.mgr.sync_to_config()
|
||||||
|
|
||||||
|
def list_checkpoint_configs(self)->List[Path]:
|
||||||
|
"""
|
||||||
|
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||||
|
"""
|
||||||
|
config = self.mgr.app_config
|
||||||
|
conf_path = config.legacy_conf_path
|
||||||
|
root_path = config.root_path
|
||||||
|
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')]
|
||||||
|
|
||||||
|
def rename_model(self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
new_name: str = None,
|
||||||
|
new_base: BaseModelType = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rename the indicated model. Can provide a new name and/or a new base.
|
||||||
|
:param model_name: Current name of the model
|
||||||
|
:param base_model: Current base of the model
|
||||||
|
:param model_type: Model type (can't be changed)
|
||||||
|
:param new_name: New name for the model
|
||||||
|
:param new_base: New base for the model
|
||||||
|
"""
|
||||||
|
self.mgr.rename_model(base_model = base_model,
|
||||||
|
model_type = model_type,
|
||||||
|
model_name = model_name,
|
||||||
|
new_name = new_name,
|
||||||
|
new_base = new_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
@ -71,8 +71,6 @@ class ModelInstallList:
|
|||||||
class InstallSelections():
|
class InstallSelections():
|
||||||
install_models: List[str]= field(default_factory=list)
|
install_models: List[str]= field(default_factory=list)
|
||||||
remove_models: List[str]=field(default_factory=list)
|
remove_models: List[str]=field(default_factory=list)
|
||||||
# scan_directory: Path = None
|
|
||||||
# autoscan_on_startup: bool=False
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelLoadInfo():
|
class ModelLoadInfo():
|
||||||
|
@ -247,6 +247,7 @@ import invokeai.backend.util.logging as logger
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util import CUDA_DEVICE, Chdir
|
from invokeai.backend.util import CUDA_DEVICE, Chdir
|
||||||
from .model_cache import ModelCache, ModelLocker
|
from .model_cache import ModelCache, ModelLocker
|
||||||
|
from .model_search import ModelSearch
|
||||||
from .models import (
|
from .models import (
|
||||||
BaseModelType, ModelType, SubModelType,
|
BaseModelType, ModelType, SubModelType,
|
||||||
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
||||||
@ -323,15 +324,6 @@ class ModelManager(object):
|
|||||||
# TODO: metadata not found
|
# TODO: metadata not found
|
||||||
# TODO: version check
|
# TODO: version check
|
||||||
|
|
||||||
self.models = dict()
|
|
||||||
for model_key, model_config in config.items():
|
|
||||||
model_name, base_model, model_type = self.parse_key(model_key)
|
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
|
||||||
# alias for config file
|
|
||||||
model_config["model_format"] = model_config.pop("format")
|
|
||||||
self.models[model_key] = model_class.create_config(**model_config)
|
|
||||||
|
|
||||||
# check config version number and update on disk/RAM if necessary
|
|
||||||
self.app_config = InvokeAIAppConfig.get_config()
|
self.app_config = InvokeAIAppConfig.get_config()
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.cache = ModelCache(
|
self.cache = ModelCache(
|
||||||
@ -342,11 +334,41 @@ class ModelManager(object):
|
|||||||
sequential_offload = sequential_offload,
|
sequential_offload = sequential_offload,
|
||||||
logger = logger,
|
logger = logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._read_models(config)
|
||||||
|
|
||||||
|
def _read_models(self, config: Optional[DictConfig] = None):
|
||||||
|
if not config:
|
||||||
|
if self.config_path:
|
||||||
|
config = OmegaConf.load(self.config_path)
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.models = dict()
|
||||||
|
for model_key, model_config in config.items():
|
||||||
|
if model_key.startswith('_'):
|
||||||
|
continue
|
||||||
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
|
# alias for config file
|
||||||
|
model_config["model_format"] = model_config.pop("format")
|
||||||
|
self.models[model_key] = model_class.create_config(**model_config)
|
||||||
|
|
||||||
|
# check config version number and update on disk/RAM if necessary
|
||||||
self.cache_keys = dict()
|
self.cache_keys = dict()
|
||||||
|
|
||||||
# add controlnet, lora and textual_inversion models from disk
|
# add controlnet, lora and textual_inversion models from disk
|
||||||
self.scan_models_directory()
|
self.scan_models_directory()
|
||||||
|
|
||||||
|
def sync_to_config(self):
|
||||||
|
"""
|
||||||
|
Call this when `models.yaml` has been changed externally.
|
||||||
|
This will reinitialize internal data structures
|
||||||
|
"""
|
||||||
|
# Reread models directory; note that this will reinitialize the cache,
|
||||||
|
# causing otherwise unreferenced models to be removed from memory
|
||||||
|
self._read_models()
|
||||||
|
|
||||||
def model_exists(
|
def model_exists(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -527,7 +549,10 @@ class ModelManager(object):
|
|||||||
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
|
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
|
||||||
models = []
|
models = []
|
||||||
for model_key in model_keys:
|
for model_key in model_keys:
|
||||||
model_config = self.models[model_key]
|
model_config = self.models.get(model_key)
|
||||||
|
if not model_config:
|
||||||
|
self.logger.error(f'Unknown model {model_name}')
|
||||||
|
raise KeyError(f'Unknown model {model_name}')
|
||||||
|
|
||||||
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||||
if base_model is not None and cur_base_model != base_model:
|
if base_model is not None and cur_base_model != base_model:
|
||||||
@ -646,11 +671,61 @@ class ModelManager(object):
|
|||||||
config = model_config,
|
config = model_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def rename_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
new_name: str = None,
|
||||||
|
new_base: BaseModelType = None,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Rename or rebase a model.
|
||||||
|
'''
|
||||||
|
if new_name is None and new_base is None:
|
||||||
|
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
|
||||||
|
return
|
||||||
|
|
||||||
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
model_cfg = self.models.get(model_key, None)
|
||||||
|
if not model_cfg:
|
||||||
|
raise KeyError(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
|
old_path = self.app_config.root_path / model_cfg.path
|
||||||
|
new_name = new_name or model_name
|
||||||
|
new_base = new_base or base_model
|
||||||
|
new_key = self.create_key(new_name, new_base, model_type)
|
||||||
|
if new_key in self.models:
|
||||||
|
raise ValueError(f'Attempt to overwrite existing model definition "{new_key}"')
|
||||||
|
|
||||||
|
# if this is a model file/directory that we manage ourselves, we need to move it
|
||||||
|
if old_path.is_relative_to(self.app_config.models_path):
|
||||||
|
new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name
|
||||||
|
move(old_path, new_path)
|
||||||
|
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
|
||||||
|
|
||||||
|
# clean up caches
|
||||||
|
old_model_cache = self._get_model_cache_path(old_path)
|
||||||
|
if old_model_cache.exists():
|
||||||
|
if old_model_cache.is_dir():
|
||||||
|
rmtree(str(old_model_cache))
|
||||||
|
else:
|
||||||
|
old_model_cache.unlink()
|
||||||
|
|
||||||
|
cache_ids = self.cache_keys.pop(model_key, [])
|
||||||
|
for cache_id in cache_ids:
|
||||||
|
self.cache.uncache_model(cache_id)
|
||||||
|
|
||||||
|
self.models.pop(model_key, None) # delete
|
||||||
|
self.models[new_key] = model_cfg
|
||||||
|
self.commit()
|
||||||
|
|
||||||
def convert_model (
|
def convert_model (
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: Union[ModelType.Main,ModelType.Vae],
|
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||||
|
dest_directory: Optional[Path]=None,
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
'''
|
'''
|
||||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||||
@ -677,14 +752,14 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
checkpoint_path = self.app_config.root_path / info["path"]
|
checkpoint_path = self.app_config.root_path / info["path"]
|
||||||
old_diffusers_path = self.app_config.models_path / model.location
|
old_diffusers_path = self.app_config.models_path / model.location
|
||||||
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
|
new_diffusers_path = (dest_directory or self.app_config.models_path / base_model.value / model_type.value) / model_name
|
||||||
if new_diffusers_path.exists():
|
if new_diffusers_path.exists():
|
||||||
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
move(old_diffusers_path,new_diffusers_path)
|
move(old_diffusers_path,new_diffusers_path)
|
||||||
info["model_format"] = "diffusers"
|
info["model_format"] = "diffusers"
|
||||||
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path))
|
info["path"] = str(new_diffusers_path) if dest_directory else str(new_diffusers_path.relative_to(self.app_config.root_path))
|
||||||
info.pop('config')
|
info.pop('config')
|
||||||
|
|
||||||
result = self.add_model(model_name, base_model, model_type,
|
result = self.add_model(model_name, base_model, model_type,
|
||||||
@ -824,6 +899,7 @@ class ModelManager(object):
|
|||||||
if (new_models_found or imported_models) and self.config_path:
|
if (new_models_found or imported_models) and self.config_path:
|
||||||
self.commit()
|
self.commit()
|
||||||
|
|
||||||
|
|
||||||
def autoimport(self)->Dict[str, AddModelResult]:
|
def autoimport(self)->Dict[str, AddModelResult]:
|
||||||
'''
|
'''
|
||||||
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
||||||
@ -832,61 +908,40 @@ class ModelManager(object):
|
|||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
||||||
|
|
||||||
|
class ScanAndImport(ModelSearch):
|
||||||
|
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
|
||||||
|
super().__init__(directories, logger)
|
||||||
|
self.installer = installer
|
||||||
|
self.ignore = ignore
|
||||||
|
|
||||||
|
def on_search_started(self):
|
||||||
|
self.new_models_found = dict()
|
||||||
|
|
||||||
|
def on_model_found(self, model: Path):
|
||||||
|
if model not in self.ignore:
|
||||||
|
self.new_models_found.update(self.installer.heuristic_import(model))
|
||||||
|
|
||||||
|
def on_search_completed(self):
|
||||||
|
self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models')
|
||||||
|
|
||||||
|
def models_found(self):
|
||||||
|
return self.new_models_found
|
||||||
|
|
||||||
|
|
||||||
installer = ModelInstall(config = self.app_config,
|
installer = ModelInstall(config = self.app_config,
|
||||||
model_manager = self,
|
model_manager = self,
|
||||||
prediction_type_helper = ask_user_for_prediction_type,
|
prediction_type_helper = ask_user_for_prediction_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
scanned_dirs = set()
|
|
||||||
|
|
||||||
config = self.app_config
|
config = self.app_config
|
||||||
known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()}
|
known_paths = {config.root_path / x['path'] for x in self.list_models()}
|
||||||
|
directories = {config.root_path / x for x in [config.autoimport_dir,
|
||||||
for autodir in [config.autoimport_dir,
|
config.lora_dir,
|
||||||
config.lora_dir,
|
config.embedding_dir,
|
||||||
config.embedding_dir,
|
config.controlnet_dir]
|
||||||
config.controlnet_dir]:
|
}
|
||||||
if autodir is None:
|
scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
|
||||||
continue
|
scanner.search()
|
||||||
|
return scanner.models_found()
|
||||||
installed = dict()
|
|
||||||
|
|
||||||
autodir = self.app_config.root_path / autodir
|
|
||||||
if not autodir.exists():
|
|
||||||
continue
|
|
||||||
|
|
||||||
items_scanned = 0
|
|
||||||
new_models_found = dict()
|
|
||||||
|
|
||||||
for root, dirs, files in os.walk(autodir):
|
|
||||||
items_scanned += len(dirs) + len(files)
|
|
||||||
for d in dirs:
|
|
||||||
path = Path(root) / d
|
|
||||||
if path in known_paths or path.parent in scanned_dirs:
|
|
||||||
scanned_dirs.add(path)
|
|
||||||
continue
|
|
||||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
|
||||||
try:
|
|
||||||
new_models_found.update(installer.heuristic_import(path))
|
|
||||||
scanned_dirs.add(path)
|
|
||||||
except ValueError as e:
|
|
||||||
self.logger.warning(str(e))
|
|
||||||
|
|
||||||
for f in files:
|
|
||||||
path = Path(root) / f
|
|
||||||
if path in known_paths or path.parent in scanned_dirs:
|
|
||||||
continue
|
|
||||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
|
||||||
try:
|
|
||||||
import_result = installer.heuristic_import(path)
|
|
||||||
new_models_found.update(import_result)
|
|
||||||
except ValueError as e:
|
|
||||||
self.logger.warning(str(e))
|
|
||||||
|
|
||||||
installed.update(new_models_found)
|
|
||||||
|
|
||||||
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
|
||||||
return installed
|
|
||||||
|
|
||||||
def heuristic_import(self,
|
def heuristic_import(self,
|
||||||
items_to_import: Set[str],
|
items_to_import: Set[str],
|
||||||
@ -924,3 +979,4 @@ class ModelManager(object):
|
|||||||
successfully_installed.update(installed)
|
successfully_installed.update(installed)
|
||||||
self.commit()
|
self.commit()
|
||||||
return successfully_installed
|
return successfully_installed
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ from enum import Enum
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers import logging as dlogging
|
from diffusers import logging as dlogging
|
||||||
from typing import List, Union
|
from typing import List, Union, Optional
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
@ -74,6 +74,7 @@ class ModelMerger(object):
|
|||||||
alpha: float = 0.5,
|
alpha: float = 0.5,
|
||||||
interp: MergeInterpolationMethod = None,
|
interp: MergeInterpolationMethod = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
|
merge_dest_directory: Optional[Path] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
@ -85,7 +86,7 @@ class ModelMerger(object):
|
|||||||
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||||
|
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||||
"""
|
"""
|
||||||
@ -111,7 +112,7 @@ class ModelMerger(object):
|
|||||||
merged_pipe = self.merge_diffusion_models(
|
merged_pipe = self.merge_diffusion_models(
|
||||||
model_paths, alpha, merge_method, force, **kwargs
|
model_paths, alpha, merge_method, force, **kwargs
|
||||||
)
|
)
|
||||||
dump_path = config.models_path / base_model.value / ModelType.Main.value
|
dump_path = Path(merge_dest_directory) if merge_dest_directory else config.models_path / base_model.value / ModelType.Main.value
|
||||||
dump_path.mkdir(parents=True, exist_ok=True)
|
dump_path.mkdir(parents=True, exist_ok=True)
|
||||||
dump_path = dump_path / merged_model_name
|
dump_path = dump_path / merged_model_name
|
||||||
|
|
||||||
|
103
invokeai/backend/model_management/model_search.py
Normal file
103
invokeai/backend/model_management/model_search.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
||||||
|
"""
|
||||||
|
Abstract base class for recursive directory search for models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Set, types
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
|
class ModelSearch(ABC):
|
||||||
|
def __init__(self, directories: List[Path], logger: types.ModuleType=logger):
|
||||||
|
"""
|
||||||
|
Initialize a recursive model directory search.
|
||||||
|
:param directories: List of directory Paths to recurse through
|
||||||
|
:param logger: Logger to use
|
||||||
|
"""
|
||||||
|
self.directories = directories
|
||||||
|
self.logger = logger
|
||||||
|
self._items_scanned = 0
|
||||||
|
self._models_found = 0
|
||||||
|
self._scanned_dirs = set()
|
||||||
|
self._scanned_paths = set()
|
||||||
|
self._pruned_paths = set()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_search_started(self):
|
||||||
|
"""
|
||||||
|
Called before the scan starts.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_model_found(self, model: Path):
|
||||||
|
"""
|
||||||
|
Process a found model. Raise an exception if something goes wrong.
|
||||||
|
:param model: Model to process - could be a directory or checkpoint.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_search_completed(self):
|
||||||
|
"""
|
||||||
|
Perform some activity when the scan is completed. May use instance
|
||||||
|
variables, items_scanned and models_found
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def search(self):
|
||||||
|
self.on_search_started()
|
||||||
|
for dir in self.directories:
|
||||||
|
self.walk_directory(dir)
|
||||||
|
self.on_search_completed()
|
||||||
|
|
||||||
|
def walk_directory(self, path: Path):
|
||||||
|
for root, dirs, files in os.walk(path):
|
||||||
|
if str(Path(root).name).startswith('.'):
|
||||||
|
self._pruned_paths.add(root)
|
||||||
|
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._items_scanned += len(dirs) + len(files)
|
||||||
|
for d in dirs:
|
||||||
|
path = Path(root) / d
|
||||||
|
if path in self._scanned_paths or path.parent in self._scanned_dirs:
|
||||||
|
self._scanned_dirs.add(path)
|
||||||
|
continue
|
||||||
|
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]):
|
||||||
|
try:
|
||||||
|
self.on_model_found(path)
|
||||||
|
self._models_found += 1
|
||||||
|
self._scanned_dirs.add(path)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(str(e))
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
path = Path(root) / f
|
||||||
|
if path.parent in self._scanned_dirs:
|
||||||
|
continue
|
||||||
|
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||||
|
try:
|
||||||
|
self.on_model_found(path)
|
||||||
|
self._models_found += 1
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(str(e))
|
||||||
|
|
||||||
|
class FindModels(ModelSearch):
|
||||||
|
def on_search_started(self):
|
||||||
|
self.models_found: Set[Path] = set()
|
||||||
|
|
||||||
|
def on_model_found(self,model: Path):
|
||||||
|
self.models_found.add(model)
|
||||||
|
|
||||||
|
def on_search_completed(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def list_models(self) -> List[Path]:
|
||||||
|
self.search()
|
||||||
|
return self.models_found
|
||||||
|
|
||||||
|
|
@ -48,7 +48,9 @@ for base_model, models in MODEL_CLASSES.items():
|
|||||||
model_configs.discard(None)
|
model_configs.discard(None)
|
||||||
MODEL_CONFIGS.extend(model_configs)
|
MODEL_CONFIGS.extend(model_configs)
|
||||||
|
|
||||||
for cfg in model_configs:
|
# LS: sort to get the checkpoint configs first, which makes
|
||||||
|
# for a better template in the Swagger docs
|
||||||
|
for cfg in sorted(model_configs, key=lambda x: str(x)):
|
||||||
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
||||||
openapi_cfg_name = model_name + cfg_name
|
openapi_cfg_name = model_name + cfg_name
|
||||||
if openapi_cfg_name in vars():
|
if openapi_cfg_name in vars():
|
||||||
|
@ -59,7 +59,6 @@ class ModelConfigBase(BaseModel):
|
|||||||
path: str # or Path
|
path: str # or Path
|
||||||
description: Optional[str] = Field(None)
|
description: Optional[str] = Field(None)
|
||||||
model_format: Optional[str] = Field(None)
|
model_format: Optional[str] = Field(None)
|
||||||
# do not save to config
|
|
||||||
error: Optional[ModelError] = Field(None)
|
error: Optional[ModelError] = Field(None)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
@ -38,7 +38,6 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
config: str
|
config: str
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert base_model == BaseModelType.StableDiffusion1
|
assert base_model == BaseModelType.StableDiffusion1
|
||||||
assert model_type == ModelType.Main
|
assert model_type == ModelType.Main
|
||||||
|
@ -241,11 +241,45 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
# fast batched path
|
# fast batched path
|
||||||
|
|
||||||
|
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||||
|
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
||||||
|
|
||||||
|
if cond.shape[1] < max_len:
|
||||||
|
conditioning_attention_mask = torch.cat([
|
||||||
|
conditioning_attention_mask,
|
||||||
|
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
|
cond = torch.cat([
|
||||||
|
cond,
|
||||||
|
torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype),
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
|
if encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = conditioning_attention_mask
|
||||||
|
else:
|
||||||
|
encoder_attention_mask = torch.cat([
|
||||||
|
encoder_attention_mask,
|
||||||
|
conditioning_attention_mask,
|
||||||
|
])
|
||||||
|
|
||||||
|
return cond, encoder_attention_mask
|
||||||
|
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
|
encoder_attention_mask = None
|
||||||
|
if unconditioning.shape[1] != conditioning.shape[1]:
|
||||||
|
max_len = max(unconditioning.shape[1], conditioning.shape[1])
|
||||||
|
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
||||||
|
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
||||||
|
|
||||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice, sigma_twice, both_conditionings, **kwargs,
|
x_twice, sigma_twice, both_conditionings,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
@ -13,7 +13,11 @@ import { RootState } from 'app/store/store';
|
|||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'controlNet' });
|
const moduleLog = log.child({ namespace: 'controlNet' });
|
||||||
|
|
||||||
const predicate: AnyListenerPredicate<RootState> = (action, state) => {
|
const predicate: AnyListenerPredicate<RootState> = (
|
||||||
|
action,
|
||||||
|
state,
|
||||||
|
prevState
|
||||||
|
) => {
|
||||||
const isActionMatched =
|
const isActionMatched =
|
||||||
controlNetProcessorParamsChanged.match(action) ||
|
controlNetProcessorParamsChanged.match(action) ||
|
||||||
controlNetModelChanged.match(action) ||
|
controlNetModelChanged.match(action) ||
|
||||||
@ -25,6 +29,16 @@ const predicate: AnyListenerPredicate<RootState> = (action, state) => {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (controlNetAutoConfigToggled.match(action)) {
|
||||||
|
// do not process if the user just disabled auto-config
|
||||||
|
if (
|
||||||
|
prevState.controlNet.controlNets[action.payload.controlNetId]
|
||||||
|
.shouldAutoConfig === true
|
||||||
|
) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const { controlImage, processorType, shouldAutoConfig } =
|
const { controlImage, processorType, shouldAutoConfig } =
|
||||||
state.controlNet.controlNets[action.payload.controlNetId];
|
state.controlNet.controlNets[action.payload.controlNetId];
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ import { zMainModel } from 'features/parameters/types/parameterSchemas';
|
|||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
|
||||||
const moduleLog = log.child({ module: 'models' });
|
const moduleLog = log.child({ module: 'models' });
|
||||||
|
|
||||||
@ -51,7 +52,14 @@ export const addModelSelectedListener = () => {
|
|||||||
modelsCleared += 1;
|
modelsCleared += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: handle incompatible controlnet; pending model manager support
|
const { controlNets } = state.controlNet;
|
||||||
|
forEach(controlNets, (controlNet, controlNetId) => {
|
||||||
|
if (controlNet.model?.base_model !== base_model) {
|
||||||
|
dispatch(controlNetRemoved({ controlNetId }));
|
||||||
|
modelsCleared += 1;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
if (modelsCleared > 0) {
|
if (modelsCleared > 0) {
|
||||||
dispatch(
|
dispatch(
|
||||||
addToast(
|
addToast(
|
||||||
|
@ -11,6 +11,7 @@ import {
|
|||||||
import { forEach, some } from 'lodash-es';
|
import { forEach, some } from 'lodash-es';
|
||||||
import { modelsApi } from 'services/api/endpoints/models';
|
import { modelsApi } from 'services/api/endpoints/models';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
|
||||||
const moduleLog = log.child({ module: 'models' });
|
const moduleLog = log.child({ module: 'models' });
|
||||||
|
|
||||||
@ -127,7 +128,22 @@ export const addModelsLoadedListener = () => {
|
|||||||
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
|
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
// ControlNet models loaded - need to remove missing ControlNets from state
|
// ControlNet models loaded - need to remove missing ControlNets from state
|
||||||
// TODO: pending model manager controlnet support
|
const controlNets = getState().controlNet.controlNets;
|
||||||
|
|
||||||
|
forEach(controlNets, (controlNet, controlNetId) => {
|
||||||
|
const isControlNetAvailable = some(
|
||||||
|
action.payload.entities,
|
||||||
|
(m) =>
|
||||||
|
m?.model_name === controlNet?.model?.model_name &&
|
||||||
|
m?.base_model === controlNet?.model?.base_model
|
||||||
|
);
|
||||||
|
|
||||||
|
if (isControlNetAvailable) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(controlNetRemoved({ controlNetId }));
|
||||||
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
CONTROLNET_MODELS,
|
// CONTROLNET_MODELS,
|
||||||
CONTROLNET_PROCESSORS,
|
CONTROLNET_PROCESSORS,
|
||||||
} from 'features/controlNet/store/constants';
|
} from 'features/controlNet/store/constants';
|
||||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
@ -128,7 +128,7 @@ export type AppConfig = {
|
|||||||
canRestoreDeletedImagesFromBin: boolean;
|
canRestoreDeletedImagesFromBin: boolean;
|
||||||
sd: {
|
sd: {
|
||||||
defaultModel?: string;
|
defaultModel?: string;
|
||||||
disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[];
|
disabledControlNetModels: string[];
|
||||||
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
|
disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[];
|
||||||
iterations: {
|
iterations: {
|
||||||
initial: number;
|
initial: number;
|
||||||
|
@ -170,12 +170,14 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
{!imageDTO && isUploadDisabled && noContentFallback}
|
{!imageDTO && isUploadDisabled && noContentFallback}
|
||||||
<IAIDroppable
|
{!isDropDisabled && (
|
||||||
data={droppableData}
|
<IAIDroppable
|
||||||
disabled={isDropDisabled}
|
data={droppableData}
|
||||||
dropLabel={dropLabel}
|
disabled={isDropDisabled}
|
||||||
/>
|
dropLabel={dropLabel}
|
||||||
{imageDTO && (
|
/>
|
||||||
|
)}
|
||||||
|
{imageDTO && !isDragDisabled && (
|
||||||
<IAIDraggable
|
<IAIDraggable
|
||||||
data={draggableData}
|
data={draggableData}
|
||||||
disabled={isDragDisabled || !imageDTO}
|
disabled={isDragDisabled || !imageDTO}
|
||||||
|
@ -1,17 +1,25 @@
|
|||||||
import { Tooltip } from '@chakra-ui/react';
|
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
|
||||||
import { MultiSelect, MultiSelectProps } from '@mantine/core';
|
import { MultiSelect, MultiSelectProps } from '@mantine/core';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||||
import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles';
|
import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles';
|
||||||
import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
|
import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
|
||||||
|
|
||||||
type IAIMultiSelectProps = MultiSelectProps & {
|
type IAIMultiSelectProps = Omit<MultiSelectProps, 'label'> & {
|
||||||
tooltip?: string;
|
tooltip?: string;
|
||||||
inputRef?: RefObject<HTMLInputElement>;
|
inputRef?: RefObject<HTMLInputElement>;
|
||||||
|
label?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
||||||
const { searchable = true, tooltip, inputRef, ...rest } = props;
|
const {
|
||||||
|
searchable = true,
|
||||||
|
tooltip,
|
||||||
|
inputRef,
|
||||||
|
label,
|
||||||
|
disabled,
|
||||||
|
...rest
|
||||||
|
} = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const handleKeyDown = useCallback(
|
const handleKeyDown = useCallback(
|
||||||
@ -37,7 +45,15 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
|||||||
return (
|
return (
|
||||||
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
|
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
|
||||||
<MultiSelect
|
<MultiSelect
|
||||||
|
label={
|
||||||
|
label ? (
|
||||||
|
<FormControl isDisabled={disabled}>
|
||||||
|
<FormLabel>{label}</FormLabel>
|
||||||
|
</FormControl>
|
||||||
|
) : undefined
|
||||||
|
}
|
||||||
ref={inputRef}
|
ref={inputRef}
|
||||||
|
disabled={disabled}
|
||||||
onKeyDown={handleKeyDown}
|
onKeyDown={handleKeyDown}
|
||||||
onKeyUp={handleKeyUp}
|
onKeyUp={handleKeyUp}
|
||||||
searchable={searchable}
|
searchable={searchable}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Tooltip } from '@chakra-ui/react';
|
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
|
||||||
import { Select, SelectProps } from '@mantine/core';
|
import { Select, SelectProps } from '@mantine/core';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||||
@ -11,13 +11,22 @@ export type IAISelectDataType = {
|
|||||||
tooltip?: string;
|
tooltip?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
type IAISelectProps = SelectProps & {
|
type IAISelectProps = Omit<SelectProps, 'label'> & {
|
||||||
tooltip?: string;
|
tooltip?: string;
|
||||||
|
label?: string;
|
||||||
inputRef?: RefObject<HTMLInputElement>;
|
inputRef?: RefObject<HTMLInputElement>;
|
||||||
};
|
};
|
||||||
|
|
||||||
const IAIMantineSearchableSelect = (props: IAISelectProps) => {
|
const IAIMantineSearchableSelect = (props: IAISelectProps) => {
|
||||||
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props;
|
const {
|
||||||
|
searchable = true,
|
||||||
|
tooltip,
|
||||||
|
inputRef,
|
||||||
|
onChange,
|
||||||
|
label,
|
||||||
|
disabled,
|
||||||
|
...rest
|
||||||
|
} = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const [searchValue, setSearchValue] = useState('');
|
const [searchValue, setSearchValue] = useState('');
|
||||||
@ -61,6 +70,14 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => {
|
|||||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||||
<Select
|
<Select
|
||||||
ref={inputRef}
|
ref={inputRef}
|
||||||
|
label={
|
||||||
|
label ? (
|
||||||
|
<FormControl isDisabled={disabled}>
|
||||||
|
<FormLabel>{label}</FormLabel>
|
||||||
|
</FormControl>
|
||||||
|
) : undefined
|
||||||
|
}
|
||||||
|
disabled={disabled}
|
||||||
searchValue={searchValue}
|
searchValue={searchValue}
|
||||||
onSearchChange={setSearchValue}
|
onSearchChange={setSearchValue}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Tooltip } from '@chakra-ui/react';
|
import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react';
|
||||||
import { Select, SelectProps } from '@mantine/core';
|
import { Select, SelectProps } from '@mantine/core';
|
||||||
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
|
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
|
||||||
import { RefObject, memo } from 'react';
|
import { RefObject, memo } from 'react';
|
||||||
@ -9,19 +9,32 @@ export type IAISelectDataType = {
|
|||||||
tooltip?: string;
|
tooltip?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
type IAISelectProps = SelectProps & {
|
type IAISelectProps = Omit<SelectProps, 'label'> & {
|
||||||
tooltip?: string;
|
tooltip?: string;
|
||||||
inputRef?: RefObject<HTMLInputElement>;
|
inputRef?: RefObject<HTMLInputElement>;
|
||||||
|
label?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
const IAIMantineSelect = (props: IAISelectProps) => {
|
const IAIMantineSelect = (props: IAISelectProps) => {
|
||||||
const { tooltip, inputRef, ...rest } = props;
|
const { tooltip, inputRef, label, disabled, ...rest } = props;
|
||||||
|
|
||||||
const styles = useMantineSelectStyles();
|
const styles = useMantineSelectStyles();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Tooltip label={tooltip} placement="top" hasArrow>
|
<Tooltip label={tooltip} placement="top" hasArrow>
|
||||||
<Select ref={inputRef} styles={styles} {...rest} />
|
<Select
|
||||||
|
label={
|
||||||
|
label ? (
|
||||||
|
<FormControl isDisabled={disabled}>
|
||||||
|
<FormLabel>{label}</FormLabel>
|
||||||
|
</FormControl>
|
||||||
|
) : undefined
|
||||||
|
}
|
||||||
|
disabled={disabled}
|
||||||
|
ref={inputRef}
|
||||||
|
styles={styles}
|
||||||
|
{...rest}
|
||||||
|
/>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -28,7 +28,7 @@ import {
|
|||||||
useState,
|
useState,
|
||||||
} from 'react';
|
} from 'react';
|
||||||
|
|
||||||
const numberStringRegex = /^-?(0\.)?\.?$/;
|
export const numberStringRegex = /^-?(0\.)?\.?$/;
|
||||||
|
|
||||||
interface Props extends Omit<NumberInputProps, 'onChange'> {
|
interface Props extends Omit<NumberInputProps, 'onChange'> {
|
||||||
label?: string;
|
label?: string;
|
||||||
|
@ -43,11 +43,6 @@ import { useTranslation } from 'react-i18next';
|
|||||||
import { BiReset } from 'react-icons/bi';
|
import { BiReset } from 'react-icons/bi';
|
||||||
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
|
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
|
||||||
|
|
||||||
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
|
|
||||||
mt: 1.5,
|
|
||||||
fontSize: '2xs',
|
|
||||||
};
|
|
||||||
|
|
||||||
export type IAIFullSliderProps = {
|
export type IAIFullSliderProps = {
|
||||||
label?: string;
|
label?: string;
|
||||||
value: number;
|
value: number;
|
||||||
@ -207,7 +202,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
{...sliderFormControlProps}
|
{...sliderFormControlProps}
|
||||||
>
|
>
|
||||||
{label && (
|
{label && (
|
||||||
<FormLabel {...sliderFormLabelProps} mb={-1}>
|
<FormLabel sx={withInput ? { mb: -1.5 } : {}} {...sliderFormLabelProps}>
|
||||||
{label}
|
{label}
|
||||||
</FormLabel>
|
</FormLabel>
|
||||||
)}
|
)}
|
||||||
@ -233,7 +228,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
sx={{
|
sx={{
|
||||||
insetInlineStart: '0 !important',
|
insetInlineStart: '0 !important',
|
||||||
insetInlineEnd: 'unset !important',
|
insetInlineEnd: 'unset !important',
|
||||||
...SLIDER_MARK_STYLES,
|
|
||||||
}}
|
}}
|
||||||
{...sliderMarkProps}
|
{...sliderMarkProps}
|
||||||
>
|
>
|
||||||
@ -244,7 +238,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
sx={{
|
sx={{
|
||||||
insetInlineStart: 'unset !important',
|
insetInlineStart: 'unset !important',
|
||||||
insetInlineEnd: '0 !important',
|
insetInlineEnd: '0 !important',
|
||||||
...SLIDER_MARK_STYLES,
|
|
||||||
}}
|
}}
|
||||||
{...sliderMarkProps}
|
{...sliderMarkProps}
|
||||||
>
|
>
|
||||||
@ -263,7 +256,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
sx={{
|
sx={{
|
||||||
insetInlineStart: '0 !important',
|
insetInlineStart: '0 !important',
|
||||||
insetInlineEnd: 'unset !important',
|
insetInlineEnd: 'unset !important',
|
||||||
...SLIDER_MARK_STYLES,
|
|
||||||
}}
|
}}
|
||||||
{...sliderMarkProps}
|
{...sliderMarkProps}
|
||||||
>
|
>
|
||||||
@ -278,7 +270,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
sx={{
|
sx={{
|
||||||
insetInlineStart: 'unset !important',
|
insetInlineStart: 'unset !important',
|
||||||
insetInlineEnd: '0 !important',
|
insetInlineEnd: '0 !important',
|
||||||
...SLIDER_MARK_STYLES,
|
|
||||||
}}
|
}}
|
||||||
{...sliderMarkProps}
|
{...sliderMarkProps}
|
||||||
>
|
>
|
||||||
@ -291,7 +282,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
key={m}
|
key={m}
|
||||||
value={m}
|
value={m}
|
||||||
sx={{
|
sx={{
|
||||||
...SLIDER_MARK_STYLES,
|
transform: 'translateX(-50%)',
|
||||||
}}
|
}}
|
||||||
{...sliderMarkProps}
|
{...sliderMarkProps}
|
||||||
>
|
>
|
||||||
|
@ -5,6 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|||||||
import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { modelsApi } from '../../services/api/endpoints/models';
|
import { modelsApi } from '../../services/api/endpoints/models';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
|
||||||
const readinessSelector = createSelector(
|
const readinessSelector = createSelector(
|
||||||
[stateSelector, activeTabNameSelector],
|
[stateSelector, activeTabNameSelector],
|
||||||
@ -52,6 +53,13 @@ const readinessSelector = createSelector(
|
|||||||
reasonsWhyNotReady.push('Seed-Weights badly formatted.');
|
reasonsWhyNotReady.push('Seed-Weights badly formatted.');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
forEach(state.controlNet.controlNets, (controlNet, id) => {
|
||||||
|
if (!controlNet.model) {
|
||||||
|
isReady = false;
|
||||||
|
reasonsWhyNotReady.push('ControlNet ${id} has no model selected.');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
// All good
|
// All good
|
||||||
return { isReady, reasonsWhyNotReady };
|
return { isReady, reasonsWhyNotReady };
|
||||||
},
|
},
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import { Box, ChakraProps, Flex, useColorMode } from '@chakra-ui/react';
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { FaCopy, FaTrash } from 'react-icons/fa';
|
import { FaCopy, FaTrash } from 'react-icons/fa';
|
||||||
import {
|
import {
|
||||||
ControlNetConfig,
|
controlNetDuplicated,
|
||||||
controlNetAdded,
|
|
||||||
controlNetRemoved,
|
controlNetRemoved,
|
||||||
controlNetToggled,
|
controlNetToggled,
|
||||||
} from '../store/controlNetSlice';
|
} from '../store/controlNetSlice';
|
||||||
@ -12,6 +11,9 @@ import ParamControlNetModel from './parameters/ParamControlNetModel';
|
|||||||
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
|
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
|
||||||
|
|
||||||
import { ChevronUpIcon } from '@chakra-ui/icons';
|
import { ChevronUpIcon } from '@chakra-ui/icons';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
import { useToggle } from 'react-use';
|
import { useToggle } from 'react-use';
|
||||||
@ -22,41 +24,41 @@ import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
|
|||||||
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
|
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
|
||||||
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
|
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
|
||||||
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
|
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
|
||||||
import { mode } from 'theme/util/mode';
|
|
||||||
|
|
||||||
const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 };
|
|
||||||
|
|
||||||
type ControlNetProps = {
|
type ControlNetProps = {
|
||||||
controlNet: ControlNetConfig;
|
controlNetId: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ControlNet = (props: ControlNetProps) => {
|
const ControlNet = (props: ControlNetProps) => {
|
||||||
const {
|
const { controlNetId } = props;
|
||||||
controlNetId,
|
|
||||||
isEnabled,
|
|
||||||
model,
|
|
||||||
weight,
|
|
||||||
beginStepPct,
|
|
||||||
endStepPct,
|
|
||||||
controlMode,
|
|
||||||
controlImage,
|
|
||||||
processedControlImage,
|
|
||||||
processorNode,
|
|
||||||
processorType,
|
|
||||||
shouldAutoConfig,
|
|
||||||
} = props.controlNet;
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ controlNet }) => {
|
||||||
|
const { isEnabled, shouldAutoConfig } =
|
||||||
|
controlNet.controlNets[controlNetId];
|
||||||
|
|
||||||
|
return { isEnabled, shouldAutoConfig };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
|
||||||
const [isExpanded, toggleIsExpanded] = useToggle(false);
|
const [isExpanded, toggleIsExpanded] = useToggle(false);
|
||||||
const { colorMode } = useColorMode();
|
|
||||||
const handleDelete = useCallback(() => {
|
const handleDelete = useCallback(() => {
|
||||||
dispatch(controlNetRemoved({ controlNetId }));
|
dispatch(controlNetRemoved({ controlNetId }));
|
||||||
}, [controlNetId, dispatch]);
|
}, [controlNetId, dispatch]);
|
||||||
|
|
||||||
const handleDuplicate = useCallback(() => {
|
const handleDuplicate = useCallback(() => {
|
||||||
dispatch(
|
dispatch(
|
||||||
controlNetAdded({ controlNetId: uuidv4(), controlNet: props.controlNet })
|
controlNetDuplicated({
|
||||||
|
sourceControlNetId: controlNetId,
|
||||||
|
newControlNetId: uuidv4(),
|
||||||
|
})
|
||||||
);
|
);
|
||||||
}, [dispatch, props.controlNet]);
|
}, [controlNetId, dispatch]);
|
||||||
|
|
||||||
const handleToggleIsEnabled = useCallback(() => {
|
const handleToggleIsEnabled = useCallback(() => {
|
||||||
dispatch(controlNetToggled({ controlNetId }));
|
dispatch(controlNetToggled({ controlNetId }));
|
||||||
@ -68,15 +70,18 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
flexDir: 'column',
|
flexDir: 'column',
|
||||||
gap: 2,
|
gap: 2,
|
||||||
p: 3,
|
p: 3,
|
||||||
bg: mode('base.200', 'base.850')(colorMode),
|
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
position: 'relative',
|
position: 'relative',
|
||||||
|
bg: 'base.200',
|
||||||
|
_dark: {
|
||||||
|
bg: 'base.850',
|
||||||
|
},
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Flex sx={{ gap: 2 }}>
|
<Flex sx={{ gap: 2, alignItems: 'center' }}>
|
||||||
<IAISwitch
|
<IAISwitch
|
||||||
tooltip="Toggle"
|
tooltip={'Toggle this ControlNet'}
|
||||||
aria-label="Toggle"
|
aria-label={'Toggle this ControlNet'}
|
||||||
isChecked={isEnabled}
|
isChecked={isEnabled}
|
||||||
onChange={handleToggleIsEnabled}
|
onChange={handleToggleIsEnabled}
|
||||||
/>
|
/>
|
||||||
@ -90,7 +95,7 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
transitionDuration: '0.1s',
|
transitionDuration: '0.1s',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<ParamControlNetModel controlNetId={controlNetId} model={model} />
|
<ParamControlNetModel controlNetId={controlNetId} />
|
||||||
</Box>
|
</Box>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
size="sm"
|
size="sm"
|
||||||
@ -109,21 +114,26 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
/>
|
/>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
size="sm"
|
size="sm"
|
||||||
aria-label="Show All Options"
|
tooltip={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
|
||||||
|
aria-label={isExpanded ? 'Hide Advanced' : 'Show Advanced'}
|
||||||
onClick={toggleIsExpanded}
|
onClick={toggleIsExpanded}
|
||||||
variant="link"
|
variant="link"
|
||||||
icon={
|
icon={
|
||||||
<ChevronUpIcon
|
<ChevronUpIcon
|
||||||
sx={{
|
sx={{
|
||||||
boxSize: 4,
|
boxSize: 4,
|
||||||
color: mode('base.700', 'base.300')(colorMode),
|
color: 'base.700',
|
||||||
transform: isExpanded ? 'rotate(0deg)' : 'rotate(180deg)',
|
transform: isExpanded ? 'rotate(0deg)' : 'rotate(180deg)',
|
||||||
transitionProperty: 'common',
|
transitionProperty: 'common',
|
||||||
transitionDuration: 'normal',
|
transitionDuration: 'normal',
|
||||||
|
_dark: {
|
||||||
|
color: 'base.300',
|
||||||
|
},
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{!shouldAutoConfig && (
|
{!shouldAutoConfig && (
|
||||||
<Box
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
@ -131,85 +141,59 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
w: 1.5,
|
w: 1.5,
|
||||||
h: 1.5,
|
h: 1.5,
|
||||||
borderRadius: 'full',
|
borderRadius: 'full',
|
||||||
bg: mode('error.700', 'error.200')(colorMode),
|
|
||||||
top: 4,
|
top: 4,
|
||||||
insetInlineEnd: 4,
|
insetInlineEnd: 4,
|
||||||
|
bg: 'accent.700',
|
||||||
|
_dark: {
|
||||||
|
bg: 'accent.400',
|
||||||
|
},
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
{isEnabled && (
|
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
|
||||||
<>
|
<Flex sx={{ gap: 4, w: 'full', alignItems: 'center' }}>
|
||||||
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
|
<Flex
|
||||||
<Flex sx={{ gap: 4, w: 'full' }}>
|
sx={{
|
||||||
<Flex
|
flexDir: 'column',
|
||||||
sx={{
|
gap: 3,
|
||||||
flexDir: 'column',
|
h: 28,
|
||||||
gap: 3,
|
w: 'full',
|
||||||
w: 'full',
|
paddingInlineStart: 1,
|
||||||
paddingInlineStart: 1,
|
paddingInlineEnd: isExpanded ? 1 : 0,
|
||||||
paddingInlineEnd: isExpanded ? 1 : 0,
|
pb: 2,
|
||||||
pb: 2,
|
justifyContent: 'space-between',
|
||||||
justifyContent: 'space-between',
|
}}
|
||||||
}}
|
>
|
||||||
>
|
<ParamControlNetWeight controlNetId={controlNetId} />
|
||||||
<ParamControlNetWeight
|
<ParamControlNetBeginEnd controlNetId={controlNetId} />
|
||||||
controlNetId={controlNetId}
|
|
||||||
weight={weight}
|
|
||||||
mini={!isExpanded}
|
|
||||||
/>
|
|
||||||
<ParamControlNetBeginEnd
|
|
||||||
controlNetId={controlNetId}
|
|
||||||
beginStepPct={beginStepPct}
|
|
||||||
endStepPct={endStepPct}
|
|
||||||
mini={!isExpanded}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
{!isExpanded && (
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
h: 24,
|
|
||||||
w: 24,
|
|
||||||
aspectRatio: '1/1',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<ControlNetImagePreview
|
|
||||||
controlNet={props.controlNet}
|
|
||||||
height={24}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
<ParamControlNetControlMode
|
|
||||||
controlNetId={controlNetId}
|
|
||||||
controlMode={controlMode}
|
|
||||||
/>
|
|
||||||
</Flex>
|
</Flex>
|
||||||
|
{!isExpanded && (
|
||||||
{isExpanded && (
|
<Flex
|
||||||
<>
|
sx={{
|
||||||
<Box mt={2}>
|
alignItems: 'center',
|
||||||
<ControlNetImagePreview
|
justifyContent: 'center',
|
||||||
controlNet={props.controlNet}
|
h: 28,
|
||||||
height={96}
|
w: 28,
|
||||||
/>
|
aspectRatio: '1/1',
|
||||||
</Box>
|
mt: 3,
|
||||||
<ParamControlNetProcessorSelect
|
}}
|
||||||
controlNetId={controlNetId}
|
>
|
||||||
processorNode={processorNode}
|
<ControlNetImagePreview controlNetId={controlNetId} height={28} />
|
||||||
/>
|
</Flex>
|
||||||
<ControlNetProcessorComponent
|
|
||||||
controlNetId={controlNetId}
|
|
||||||
processorNode={processorNode}
|
|
||||||
/>
|
|
||||||
<ParamControlNetShouldAutoConfig
|
|
||||||
controlNetId={controlNetId}
|
|
||||||
shouldAutoConfig={shouldAutoConfig}
|
|
||||||
/>
|
|
||||||
</>
|
|
||||||
)}
|
)}
|
||||||
|
</Flex>
|
||||||
|
<Box mt={2}>
|
||||||
|
<ParamControlNetControlMode controlNetId={controlNetId} />
|
||||||
|
</Box>
|
||||||
|
<ParamControlNetProcessorSelect controlNetId={controlNetId} />
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
{isExpanded && (
|
||||||
|
<>
|
||||||
|
<ControlNetImagePreview controlNetId={controlNetId} height="392px" />
|
||||||
|
<ParamControlNetShouldAutoConfig controlNetId={controlNetId} />
|
||||||
|
<ControlNetProcessorComponent controlNetId={controlNetId} />
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -5,42 +5,57 @@ import {
|
|||||||
TypesafeDraggableData,
|
TypesafeDraggableData,
|
||||||
TypesafeDroppableData,
|
TypesafeDroppableData,
|
||||||
} from 'app/components/ImageDnd/typesafeDnd';
|
} from 'app/components/ImageDnd/typesafeDnd';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
import { memo, useCallback, useMemo, useState } from 'react';
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
import { PostUploadAction } from 'services/api/thunks/image';
|
import { PostUploadAction } from 'services/api/thunks/image';
|
||||||
import {
|
import { controlNetImageChanged } from '../store/controlNetSlice';
|
||||||
ControlNetConfig,
|
|
||||||
controlNetImageChanged,
|
|
||||||
controlNetSelector,
|
|
||||||
} from '../store/controlNetSlice';
|
|
||||||
|
|
||||||
const selector = createSelector(
|
|
||||||
controlNetSelector,
|
|
||||||
(controlNet) => {
|
|
||||||
const { pendingControlImages } = controlNet;
|
|
||||||
return { pendingControlImages };
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
);
|
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNet: ControlNetConfig;
|
controlNetId: string;
|
||||||
height: SystemStyleObject['h'];
|
height: SystemStyleObject['h'];
|
||||||
};
|
};
|
||||||
|
|
||||||
const ControlNetImagePreview = (props: Props) => {
|
const ControlNetImagePreview = (props: Props) => {
|
||||||
const { height } = props;
|
const { height, controlNetId } = props;
|
||||||
const {
|
|
||||||
controlNetId,
|
|
||||||
controlImage: controlImageName,
|
|
||||||
processedControlImage: processedControlImageName,
|
|
||||||
processorType,
|
|
||||||
} = props.controlNet;
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { pendingControlImages } = useAppSelector(selector);
|
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ controlNet }) => {
|
||||||
|
const { pendingControlImages } = controlNet;
|
||||||
|
const {
|
||||||
|
controlImage,
|
||||||
|
processedControlImage,
|
||||||
|
processorType,
|
||||||
|
isEnabled,
|
||||||
|
} = controlNet.controlNets[controlNetId];
|
||||||
|
|
||||||
|
return {
|
||||||
|
controlImageName: controlImage,
|
||||||
|
processedControlImageName: processedControlImage,
|
||||||
|
processorType,
|
||||||
|
isEnabled,
|
||||||
|
pendingControlImages,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
|
[controlNetId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const {
|
||||||
|
controlImageName,
|
||||||
|
processedControlImageName,
|
||||||
|
processorType,
|
||||||
|
pendingControlImages,
|
||||||
|
isEnabled,
|
||||||
|
} = useAppSelector(selector);
|
||||||
|
|
||||||
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
||||||
|
|
||||||
@ -110,13 +125,15 @@ const ControlNetImagePreview = (props: Props) => {
|
|||||||
h: height,
|
h: height,
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
justifyContent: 'center',
|
justifyContent: 'center',
|
||||||
|
pointerEvents: isEnabled ? 'auto' : 'none',
|
||||||
|
opacity: isEnabled ? 1 : 0.5,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<IAIDndImage
|
<IAIDndImage
|
||||||
draggableData={draggableData}
|
draggableData={draggableData}
|
||||||
droppableData={droppableData}
|
droppableData={droppableData}
|
||||||
imageDTO={controlImage}
|
imageDTO={controlImage}
|
||||||
isDropDisabled={shouldShowProcessedImage}
|
isDropDisabled={shouldShowProcessedImage || !isEnabled}
|
||||||
onClickReset={handleResetControlImage}
|
onClickReset={handleResetControlImage}
|
||||||
postUploadAction={postUploadAction}
|
postUploadAction={postUploadAction}
|
||||||
resetTooltip="Reset Control Image"
|
resetTooltip="Reset Control Image"
|
||||||
@ -140,6 +157,7 @@ const ControlNetImagePreview = (props: Props) => {
|
|||||||
droppableData={droppableData}
|
droppableData={droppableData}
|
||||||
imageDTO={processedControlImage}
|
imageDTO={processedControlImage}
|
||||||
isUploadDisabled={true}
|
isUploadDisabled={true}
|
||||||
|
isDropDisabled={!isEnabled}
|
||||||
onClickReset={handleResetControlImage}
|
onClickReset={handleResetControlImage}
|
||||||
resetTooltip="Reset Control Image"
|
resetTooltip="Reset Control Image"
|
||||||
withResetIcon={Boolean(controlImage)}
|
withResetIcon={Boolean(controlImage)}
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
import { memo } from 'react';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RequiredControlNetProcessorNode } from '../store/types';
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { memo, useMemo } from 'react';
|
||||||
import CannyProcessor from './processors/CannyProcessor';
|
import CannyProcessor from './processors/CannyProcessor';
|
||||||
import HedProcessor from './processors/HedProcessor';
|
|
||||||
import LineartProcessor from './processors/LineartProcessor';
|
|
||||||
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
|
|
||||||
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
|
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
|
||||||
|
import HedProcessor from './processors/HedProcessor';
|
||||||
|
import LineartAnimeProcessor from './processors/LineartAnimeProcessor';
|
||||||
|
import LineartProcessor from './processors/LineartProcessor';
|
||||||
import MediapipeFaceProcessor from './processors/MediapipeFaceProcessor';
|
import MediapipeFaceProcessor from './processors/MediapipeFaceProcessor';
|
||||||
import MidasDepthProcessor from './processors/MidasDepthProcessor';
|
import MidasDepthProcessor from './processors/MidasDepthProcessor';
|
||||||
import MlsdImageProcessor from './processors/MlsdImageProcessor';
|
import MlsdImageProcessor from './processors/MlsdImageProcessor';
|
||||||
@ -15,23 +18,45 @@ import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
|
|||||||
|
|
||||||
export type ControlNetProcessorProps = {
|
export type ControlNetProcessorProps = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
processorNode: RequiredControlNetProcessorNode;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId } = props;
|
||||||
|
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ controlNet }) => {
|
||||||
|
const { isEnabled, processorNode } =
|
||||||
|
controlNet.controlNets[controlNetId];
|
||||||
|
|
||||||
|
return { isEnabled, processorNode };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
|
[controlNetId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { isEnabled, processorNode } = useAppSelector(selector);
|
||||||
|
|
||||||
if (processorNode.type === 'canny_image_processor') {
|
if (processorNode.type === 'canny_image_processor') {
|
||||||
return (
|
return (
|
||||||
<CannyProcessor
|
<CannyProcessor
|
||||||
controlNetId={controlNetId}
|
controlNetId={controlNetId}
|
||||||
processorNode={processorNode}
|
processorNode={processorNode}
|
||||||
|
isEnabled={isEnabled}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (processorNode.type === 'hed_image_processor') {
|
if (processorNode.type === 'hed_image_processor') {
|
||||||
return (
|
return (
|
||||||
<HedProcessor controlNetId={controlNetId} processorNode={processorNode} />
|
<HedProcessor
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
processorNode={processorNode}
|
||||||
|
isEnabled={isEnabled}
|
||||||
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,6 +65,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
|||||||
<LineartProcessor
|
<LineartProcessor
|
||||||
controlNetId={controlNetId}
|
controlNetId={controlNetId}
|
||||||
processorNode={processorNode}
|
processorNode={processorNode}
|
||||||
|
isEnabled={isEnabled}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -49,6 +75,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
|||||||
<ContentShuffleProcessor
|
<ContentShuffleProcessor
|
||||||
controlNetId={controlNetId}
|
controlNetId={controlNetId}
|
||||||
processorNode={processorNode}
|
processorNode={processorNode}
|
||||||
|
isEnabled={isEnabled}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -58,6 +85,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
|||||||
<LineartAnimeProcessor
|
<LineartAnimeProcessor
|
||||||
controlNetId={controlNetId}
|
controlNetId={controlNetId}
|
||||||
processorNode={processorNode}
|
processorNode={processorNode}
|
||||||
|
isEnabled={isEnabled}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -67,6 +95,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
|||||||
<MediapipeFaceProcessor
|
<MediapipeFaceProcessor
|
||||||
controlNetId={controlNetId}
|
controlNetId={controlNetId}
|
||||||
processorNode={processorNode}
|
processorNode={processorNode}
|
||||||
|
isEnabled={isEnabled}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -76,6 +105,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
|||||||
<MidasDepthProcessor
|
<MidasDepthProcessor
|
||||||
controlNetId={controlNetId}
|
controlNetId={controlNetId}
|
||||||
processorNode={processorNode}
|
processorNode={processorNode}
|
||||||
|
isEnabled={isEnabled}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -85,6 +115,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
|||||||
<MlsdImageProcessor
|
<MlsdImageProcessor
|
||||||
controlNetId={controlNetId}
|
controlNetId={controlNetId}
|
||||||
processorNode={processorNode}
|
processorNode={processorNode}
|
||||||
|
isEnabled={isEnabled}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -94,6 +125,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
|||||||
<NormalBaeProcessor
|
<NormalBaeProcessor
|
||||||
controlNetId={controlNetId}
|
controlNetId={controlNetId}
|
||||||
processorNode={processorNode}
|
processorNode={processorNode}
|
||||||
|
isEnabled={isEnabled}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -103,6 +135,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
|||||||
<OpenposeProcessor
|
<OpenposeProcessor
|
||||||
controlNetId={controlNetId}
|
controlNetId={controlNetId}
|
||||||
processorNode={processorNode}
|
processorNode={processorNode}
|
||||||
|
isEnabled={isEnabled}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -112,6 +145,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
|||||||
<PidiProcessor
|
<PidiProcessor
|
||||||
controlNetId={controlNetId}
|
controlNetId={controlNetId}
|
||||||
processorNode={processorNode}
|
processorNode={processorNode}
|
||||||
|
isEnabled={isEnabled}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -121,6 +155,7 @@ const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
|||||||
<ZoeDepthProcessor
|
<ZoeDepthProcessor
|
||||||
controlNetId={controlNetId}
|
controlNetId={controlNetId}
|
||||||
processorNode={processorNode}
|
processorNode={processorNode}
|
||||||
|
isEnabled={isEnabled}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -1,18 +1,36 @@
|
|||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
|
||||||
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
shouldAutoConfig: boolean;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const ParamControlNetShouldAutoConfig = (props: Props) => {
|
const ParamControlNetShouldAutoConfig = (props: Props) => {
|
||||||
const { controlNetId, shouldAutoConfig } = props;
|
const { controlNetId } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const isReady = useIsReadyToInvoke();
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ controlNet }) => {
|
||||||
|
const { isEnabled, shouldAutoConfig } =
|
||||||
|
controlNet.controlNets[controlNetId];
|
||||||
|
return { isEnabled, shouldAutoConfig };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
|
[controlNetId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
|
||||||
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
|
|
||||||
const handleShouldAutoConfigChanged = useCallback(() => {
|
const handleShouldAutoConfigChanged = useCallback(() => {
|
||||||
dispatch(controlNetAutoConfigToggled({ controlNetId }));
|
dispatch(controlNetAutoConfigToggled({ controlNetId }));
|
||||||
}, [controlNetId, dispatch]);
|
}, [controlNetId, dispatch]);
|
||||||
@ -23,7 +41,7 @@ const ParamControlNetShouldAutoConfig = (props: Props) => {
|
|||||||
aria-label="Auto configure processor"
|
aria-label="Auto configure processor"
|
||||||
isChecked={shouldAutoConfig}
|
isChecked={shouldAutoConfig}
|
||||||
onChange={handleShouldAutoConfigChanged}
|
onChange={handleShouldAutoConfigChanged}
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import {
|
import {
|
||||||
ChakraProps,
|
|
||||||
FormControl,
|
FormControl,
|
||||||
FormLabel,
|
FormLabel,
|
||||||
HStack,
|
HStack,
|
||||||
@ -10,34 +9,41 @@ import {
|
|||||||
RangeSliderTrack,
|
RangeSliderTrack,
|
||||||
Tooltip,
|
Tooltip,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import {
|
import {
|
||||||
controlNetBeginStepPctChanged,
|
controlNetBeginStepPctChanged,
|
||||||
controlNetEndStepPctChanged,
|
controlNetEndStepPctChanged,
|
||||||
} from 'features/controlNet/store/controlNetSlice';
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
const SLIDER_MARK_STYLES: ChakraProps['sx'] = {
|
|
||||||
mt: 1.5,
|
|
||||||
fontSize: '2xs',
|
|
||||||
fontWeight: '500',
|
|
||||||
color: 'base.400',
|
|
||||||
};
|
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
beginStepPct: number;
|
|
||||||
endStepPct: number;
|
|
||||||
mini?: boolean;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
|
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
|
||||||
|
|
||||||
const ParamControlNetBeginEnd = (props: Props) => {
|
const ParamControlNetBeginEnd = (props: Props) => {
|
||||||
const { controlNetId, beginStepPct, mini = false, endStepPct } = props;
|
const { controlNetId } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ controlNet }) => {
|
||||||
|
const { beginStepPct, endStepPct, isEnabled } =
|
||||||
|
controlNet.controlNets[controlNetId];
|
||||||
|
return { beginStepPct, endStepPct, isEnabled };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
|
[controlNetId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { beginStepPct, endStepPct, isEnabled } = useAppSelector(selector);
|
||||||
|
|
||||||
const handleStepPctChanged = useCallback(
|
const handleStepPctChanged = useCallback(
|
||||||
(v: number[]) => {
|
(v: number[]) => {
|
||||||
@ -55,7 +61,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
|
|||||||
}, [controlNetId, dispatch]);
|
}, [controlNetId, dispatch]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<FormControl>
|
<FormControl isDisabled={!isEnabled}>
|
||||||
<FormLabel>Begin / End Step Percentage</FormLabel>
|
<FormLabel>Begin / End Step Percentage</FormLabel>
|
||||||
<HStack w="100%" gap={2} alignItems="center">
|
<HStack w="100%" gap={2} alignItems="center">
|
||||||
<RangeSlider
|
<RangeSlider
|
||||||
@ -66,6 +72,7 @@ const ParamControlNetBeginEnd = (props: Props) => {
|
|||||||
max={1}
|
max={1}
|
||||||
step={0.01}
|
step={0.01}
|
||||||
minStepsBetweenThumbs={5}
|
minStepsBetweenThumbs={5}
|
||||||
|
isDisabled={!isEnabled}
|
||||||
>
|
>
|
||||||
<RangeSliderTrack>
|
<RangeSliderTrack>
|
||||||
<RangeSliderFilledTrack />
|
<RangeSliderFilledTrack />
|
||||||
@ -76,38 +83,33 @@ const ParamControlNetBeginEnd = (props: Props) => {
|
|||||||
<Tooltip label={formatPct(endStepPct)} placement="top" hasArrow>
|
<Tooltip label={formatPct(endStepPct)} placement="top" hasArrow>
|
||||||
<RangeSliderThumb index={1} />
|
<RangeSliderThumb index={1} />
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
{!mini && (
|
<RangeSliderMark
|
||||||
<>
|
value={0}
|
||||||
<RangeSliderMark
|
sx={{
|
||||||
value={0}
|
insetInlineStart: '0 !important',
|
||||||
sx={{
|
insetInlineEnd: 'unset !important',
|
||||||
insetInlineStart: '0 !important',
|
}}
|
||||||
insetInlineEnd: 'unset !important',
|
>
|
||||||
...SLIDER_MARK_STYLES,
|
0%
|
||||||
}}
|
</RangeSliderMark>
|
||||||
>
|
<RangeSliderMark
|
||||||
0%
|
value={0.5}
|
||||||
</RangeSliderMark>
|
sx={{
|
||||||
<RangeSliderMark
|
insetInlineStart: '50% !important',
|
||||||
value={0.5}
|
transform: 'translateX(-50%)',
|
||||||
sx={{
|
}}
|
||||||
...SLIDER_MARK_STYLES,
|
>
|
||||||
}}
|
50%
|
||||||
>
|
</RangeSliderMark>
|
||||||
50%
|
<RangeSliderMark
|
||||||
</RangeSliderMark>
|
value={1}
|
||||||
<RangeSliderMark
|
sx={{
|
||||||
value={1}
|
insetInlineStart: 'unset !important',
|
||||||
sx={{
|
insetInlineEnd: '0 !important',
|
||||||
insetInlineStart: 'unset !important',
|
}}
|
||||||
insetInlineEnd: '0 !important',
|
>
|
||||||
...SLIDER_MARK_STYLES,
|
100%
|
||||||
}}
|
</RangeSliderMark>
|
||||||
>
|
|
||||||
100%
|
|
||||||
</RangeSliderMark>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
</RangeSlider>
|
</RangeSlider>
|
||||||
</HStack>
|
</HStack>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
import {
|
import {
|
||||||
ControlModes,
|
ControlModes,
|
||||||
controlNetControlModeChanged,
|
controlNetControlModeChanged,
|
||||||
} from 'features/controlNet/store/controlNetSlice';
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { useCallback } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
type ParamControlNetControlModeProps = {
|
type ParamControlNetControlModeProps = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
controlMode: string;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const CONTROL_MODE_DATA = [
|
const CONTROL_MODE_DATA = [
|
||||||
@ -22,8 +24,23 @@ const CONTROL_MODE_DATA = [
|
|||||||
export default function ParamControlNetControlMode(
|
export default function ParamControlNetControlMode(
|
||||||
props: ParamControlNetControlModeProps
|
props: ParamControlNetControlModeProps
|
||||||
) {
|
) {
|
||||||
const { controlNetId, controlMode = false } = props;
|
const { controlNetId } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ controlNet }) => {
|
||||||
|
const { controlMode, isEnabled } =
|
||||||
|
controlNet.controlNets[controlNetId];
|
||||||
|
return { controlMode, isEnabled };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
|
[controlNetId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { controlMode, isEnabled } = useAppSelector(selector);
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
@ -36,7 +53,8 @@ export default function ParamControlNetControlMode(
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
label={t('parameters.controlNetControlMode')}
|
disabled={!isEnabled}
|
||||||
|
label="Control Mode"
|
||||||
data={CONTROL_MODE_DATA}
|
data={CONTROL_MODE_DATA}
|
||||||
value={String(controlMode)}
|
value={String(controlMode)}
|
||||||
onChange={handleControlModeChange}
|
onChange={handleControlModeChange}
|
||||||
|
@ -29,6 +29,9 @@ const ParamControlNetFeatureToggle = () => {
|
|||||||
label="Enable ControlNet"
|
label="Enable ControlNet"
|
||||||
isChecked={isEnabled}
|
isChecked={isEnabled}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
|
formControlProps={{
|
||||||
|
width: '100%',
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,28 +0,0 @@
|
|||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
|
||||||
import { controlNetToggled } from 'features/controlNet/store/controlNetSlice';
|
|
||||||
import { memo, useCallback } from 'react';
|
|
||||||
|
|
||||||
type ParamControlNetIsEnabledProps = {
|
|
||||||
controlNetId: string;
|
|
||||||
isEnabled: boolean;
|
|
||||||
};
|
|
||||||
|
|
||||||
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
|
|
||||||
const { controlNetId, isEnabled } = props;
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const handleIsEnabledChanged = useCallback(() => {
|
|
||||||
dispatch(controlNetToggled({ controlNetId }));
|
|
||||||
}, [dispatch, controlNetId]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<IAISwitch
|
|
||||||
label="Enabled"
|
|
||||||
isChecked={isEnabled}
|
|
||||||
onChange={handleIsEnabledChanged}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(ParamControlNetIsEnabled);
|
|
@ -1,36 +0,0 @@
|
|||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import IAIFullCheckbox from 'common/components/IAIFullCheckbox';
|
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
|
||||||
import {
|
|
||||||
controlNetToggled,
|
|
||||||
isControlNetImagePreprocessedToggled,
|
|
||||||
} from 'features/controlNet/store/controlNetSlice';
|
|
||||||
import { memo, useCallback } from 'react';
|
|
||||||
|
|
||||||
type ParamControlNetIsEnabledProps = {
|
|
||||||
controlNetId: string;
|
|
||||||
isControlImageProcessed: boolean;
|
|
||||||
};
|
|
||||||
|
|
||||||
const ParamControlNetIsEnabled = (props: ParamControlNetIsEnabledProps) => {
|
|
||||||
const { controlNetId, isControlImageProcessed } = props;
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const handleIsControlImageProcessedToggled = useCallback(() => {
|
|
||||||
dispatch(
|
|
||||||
isControlNetImagePreprocessedToggled({
|
|
||||||
controlNetId,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}, [controlNetId, dispatch]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<IAISwitch
|
|
||||||
label="Preprocess"
|
|
||||||
isChecked={isControlImageProcessed}
|
|
||||||
onChange={handleIsControlImageProcessedToggled}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(ParamControlNetIsEnabled);
|
|
@ -1,59 +1,118 @@
|
|||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIMantineSearchableSelect, {
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
IAISelectDataType,
|
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||||
} from 'common/components/IAIMantineSearchableSelect';
|
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
|
||||||
import {
|
|
||||||
CONTROLNET_MODELS,
|
|
||||||
ControlNetModelName,
|
|
||||||
} from 'features/controlNet/store/constants';
|
|
||||||
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { configSelector } from 'features/system/store/configSelectors';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { map } from 'lodash-es';
|
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
|
||||||
import { memo, useCallback } from 'react';
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
type ParamControlNetModelProps = {
|
type ParamControlNetModelProps = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
model: ControlNetModelName;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const selector = createSelector(configSelector, (config) => {
|
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||||
const controlNetModels: IAISelectDataType[] = map(CONTROLNET_MODELS, (m) => ({
|
const { controlNetId } = props;
|
||||||
label: m.label,
|
const dispatch = useAppDispatch();
|
||||||
value: m.type,
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
})).filter(
|
|
||||||
(d) =>
|
const selector = useMemo(
|
||||||
!config.sd.disabledControlNetModels.includes(
|
() =>
|
||||||
d.value as ControlNetModelName
|
createSelector(
|
||||||
)
|
stateSelector,
|
||||||
|
({ generation, controlNet }) => {
|
||||||
|
const { model } = generation;
|
||||||
|
const controlNetModel = controlNet.controlNets[controlNetId]?.model;
|
||||||
|
const isEnabled = controlNet.controlNets[controlNetId]?.isEnabled;
|
||||||
|
return { mainModel: model, controlNetModel, isEnabled };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
|
[controlNetId]
|
||||||
);
|
);
|
||||||
|
|
||||||
return controlNetModels;
|
const { mainModel, controlNetModel, isEnabled } = useAppSelector(selector);
|
||||||
});
|
|
||||||
|
|
||||||
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
||||||
const { controlNetId, model } = props;
|
|
||||||
const controlNetModels = useAppSelector(selector);
|
const data = useMemo(() => {
|
||||||
const dispatch = useAppDispatch();
|
if (!controlNetModels) {
|
||||||
const isReady = useIsReadyToInvoke();
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(controlNetModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const disabled = model?.base_model !== mainModel?.base_model;
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.model_name,
|
||||||
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
|
disabled,
|
||||||
|
tooltip: disabled
|
||||||
|
? `Incompatible base model: ${model.base_model}`
|
||||||
|
: undefined,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [controlNetModels, mainModel?.base_model]);
|
||||||
|
|
||||||
|
// grab the full model entity from the RTK Query cache
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() =>
|
||||||
|
controlNetModels?.entities[
|
||||||
|
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
|
||||||
|
] ?? null,
|
||||||
|
[
|
||||||
|
controlNetModel?.base_model,
|
||||||
|
controlNetModel?.model_name,
|
||||||
|
controlNetModels?.entities,
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
const handleModelChanged = useCallback(
|
const handleModelChanged = useCallback(
|
||||||
(val: string | null) => {
|
(v: string | null) => {
|
||||||
// TODO: do not cast
|
if (!v) {
|
||||||
const model = val as ControlNetModelName;
|
return;
|
||||||
dispatch(controlNetModelChanged({ controlNetId, model }));
|
}
|
||||||
|
|
||||||
|
const newControlNetModel = modelIdToControlNetModelParam(v);
|
||||||
|
|
||||||
|
if (!newControlNetModel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
controlNetModelChanged({ controlNetId, model: newControlNetModel })
|
||||||
|
);
|
||||||
},
|
},
|
||||||
[controlNetId, dispatch]
|
[controlNetId, dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIMantineSearchableSelect
|
<IAIMantineSearchableSelect
|
||||||
data={controlNetModels}
|
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||||
value={model}
|
data={data}
|
||||||
|
error={
|
||||||
|
!selectedModel || mainModel?.base_model !== selectedModel.base_model
|
||||||
|
}
|
||||||
|
placeholder="Select a model"
|
||||||
|
value={selectedModel?.id ?? null}
|
||||||
onChange={handleModelChanged}
|
onChange={handleModelChanged}
|
||||||
disabled={!isReady}
|
disabled={isBusy || !isEnabled}
|
||||||
tooltip={model}
|
tooltip={selectedModel?.description}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,24 +1,22 @@
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIMantineSearchableSelect, {
|
import IAIMantineSearchableSelect, {
|
||||||
IAISelectDataType,
|
IAISelectDataType,
|
||||||
} from 'common/components/IAIMantineSearchableSelect';
|
} from 'common/components/IAIMantineSearchableSelect';
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
|
||||||
import { configSelector } from 'features/system/store/configSelectors';
|
import { configSelector } from 'features/system/store/configSelectors';
|
||||||
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { CONTROLNET_PROCESSORS } from '../../store/constants';
|
import { CONTROLNET_PROCESSORS } from '../../store/constants';
|
||||||
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
|
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
|
||||||
import {
|
import { ControlNetProcessorType } from '../../store/types';
|
||||||
ControlNetProcessorNode,
|
import { FormControl, FormLabel } from '@chakra-ui/react';
|
||||||
ControlNetProcessorType,
|
|
||||||
} from '../../store/types';
|
|
||||||
|
|
||||||
type ParamControlNetProcessorSelectProps = {
|
type ParamControlNetProcessorSelectProps = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
processorNode: ControlNetProcessorNode;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
@ -54,10 +52,24 @@ const selector = createSelector(
|
|||||||
const ParamControlNetProcessorSelect = (
|
const ParamControlNetProcessorSelect = (
|
||||||
props: ParamControlNetProcessorSelectProps
|
props: ParamControlNetProcessorSelectProps
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, processorNode } = props;
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const isReady = useIsReadyToInvoke();
|
const { controlNetId } = props;
|
||||||
|
const processorNodeSelector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ controlNet }) => {
|
||||||
|
const { isEnabled, processorNode } =
|
||||||
|
controlNet.controlNets[controlNetId];
|
||||||
|
return { isEnabled, processorNode };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
|
[controlNetId]
|
||||||
|
);
|
||||||
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
const controlNetProcessors = useAppSelector(selector);
|
const controlNetProcessors = useAppSelector(selector);
|
||||||
|
const { isEnabled, processorNode } = useAppSelector(processorNodeSelector);
|
||||||
|
|
||||||
const handleProcessorTypeChanged = useCallback(
|
const handleProcessorTypeChanged = useCallback(
|
||||||
(v: string | null) => {
|
(v: string | null) => {
|
||||||
@ -77,7 +89,7 @@ const ParamControlNetProcessorSelect = (
|
|||||||
value={processorNode.type ?? 'canny_image_processor'}
|
value={processorNode.type ?? 'canny_image_processor'}
|
||||||
data={controlNetProcessors}
|
data={controlNetProcessors}
|
||||||
onChange={handleProcessorTypeChanged}
|
onChange={handleProcessorTypeChanged}
|
||||||
disabled={!isReady}
|
disabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,18 +1,32 @@
|
|||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
|
||||||
type ParamControlNetWeightProps = {
|
type ParamControlNetWeightProps = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
weight: number;
|
|
||||||
mini?: boolean;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
|
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
|
||||||
const { controlNetId, weight, mini = false } = props;
|
const { controlNetId } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ controlNet }) => {
|
||||||
|
const { weight, isEnabled } = controlNet.controlNets[controlNetId];
|
||||||
|
return { weight, isEnabled };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
),
|
||||||
|
[controlNetId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const { weight, isEnabled } = useAppSelector(selector);
|
||||||
const handleWeightChanged = useCallback(
|
const handleWeightChanged = useCallback(
|
||||||
(weight: number) => {
|
(weight: number) => {
|
||||||
dispatch(controlNetWeightChanged({ controlNetId, weight }));
|
dispatch(controlNetWeightChanged({ controlNetId, weight }));
|
||||||
@ -22,15 +36,15 @@ const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<IAISlider
|
<IAISlider
|
||||||
|
isDisabled={!isEnabled}
|
||||||
label={'Weight'}
|
label={'Weight'}
|
||||||
sliderFormLabelProps={{ pb: 2 }}
|
|
||||||
value={weight}
|
value={weight}
|
||||||
onChange={handleWeightChanged}
|
onChange={handleWeightChanged}
|
||||||
min={-1}
|
min={0}
|
||||||
max={1}
|
max={2}
|
||||||
step={0.01}
|
step={0.01}
|
||||||
withSliderMarks={!mini}
|
withSliderMarks
|
||||||
sliderMarks={[-1, 0, 1]}
|
sliderMarks={[0, 1, 2]}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,22 +1,25 @@
|
|||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types';
|
import { RequiredCannyImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.canny_image_processor
|
||||||
|
.default as RequiredCannyImageProcessorInvocation;
|
||||||
|
|
||||||
type CannyProcessorProps = {
|
type CannyProcessorProps = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
processorNode: RequiredCannyImageProcessorInvocation;
|
processorNode: RequiredCannyImageProcessorInvocation;
|
||||||
|
isEnabled: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
const CannyProcessor = (props: CannyProcessorProps) => {
|
const CannyProcessor = (props: CannyProcessorProps) => {
|
||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode, isEnabled } = props;
|
||||||
const { low_threshold, high_threshold } = processorNode;
|
const { low_threshold, high_threshold } = processorNode;
|
||||||
const isReady = useIsReadyToInvoke();
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
|
||||||
const handleLowThresholdChanged = useCallback(
|
const handleLowThresholdChanged = useCallback(
|
||||||
@ -48,7 +51,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
|
|||||||
return (
|
return (
|
||||||
<ProcessorWrapper>
|
<ProcessorWrapper>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
label="Low Threshold"
|
label="Low Threshold"
|
||||||
value={low_threshold}
|
value={low_threshold}
|
||||||
onChange={handleLowThresholdChanged}
|
onChange={handleLowThresholdChanged}
|
||||||
@ -60,7 +63,7 @@ const CannyProcessor = (props: CannyProcessorProps) => {
|
|||||||
withSliderMarks
|
withSliderMarks
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
label="High Threshold"
|
label="High Threshold"
|
||||||
value={high_threshold}
|
value={high_threshold}
|
||||||
onChange={handleHighThresholdChanged}
|
onChange={handleHighThresholdChanged}
|
||||||
|
@ -4,20 +4,23 @@ import { RequiredContentShuffleImageProcessorInvocation } from 'features/control
|
|||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.content_shuffle_image_processor
|
||||||
|
.default as RequiredContentShuffleImageProcessorInvocation;
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
processorNode: RequiredContentShuffleImageProcessorInvocation;
|
processorNode: RequiredContentShuffleImageProcessorInvocation;
|
||||||
|
isEnabled: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ContentShuffleProcessor = (props: Props) => {
|
const ContentShuffleProcessor = (props: Props) => {
|
||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode, isEnabled } = props;
|
||||||
const { image_resolution, detect_resolution, w, h, f } = processorNode;
|
const { image_resolution, detect_resolution, w, h, f } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
const isReady = useIsReadyToInvoke();
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
|
|
||||||
const handleDetectResolutionChanged = useCallback(
|
const handleDetectResolutionChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -96,7 +99,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Image Resolution"
|
label="Image Resolution"
|
||||||
@ -108,7 +111,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="W"
|
label="W"
|
||||||
@ -120,7 +123,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="H"
|
label="H"
|
||||||
@ -132,7 +135,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="F"
|
label="F"
|
||||||
@ -144,7 +147,7 @@ const ContentShuffleProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -1,25 +1,29 @@
|
|||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types';
|
import { RequiredHedImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { ChangeEvent, memo, useCallback } from 'react';
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.hed_image_processor
|
||||||
|
.default as RequiredHedImageProcessorInvocation;
|
||||||
|
|
||||||
type HedProcessorProps = {
|
type HedProcessorProps = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
processorNode: RequiredHedImageProcessorInvocation;
|
processorNode: RequiredHedImageProcessorInvocation;
|
||||||
|
isEnabled: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
const HedPreprocessor = (props: HedProcessorProps) => {
|
const HedPreprocessor = (props: HedProcessorProps) => {
|
||||||
const {
|
const {
|
||||||
controlNetId,
|
controlNetId,
|
||||||
processorNode: { detect_resolution, image_resolution, scribble },
|
processorNode: { detect_resolution, image_resolution, scribble },
|
||||||
|
isEnabled,
|
||||||
} = props;
|
} = props;
|
||||||
const isReady = useIsReadyToInvoke();
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
|
||||||
const handleDetectResolutionChanged = useCallback(
|
const handleDetectResolutionChanged = useCallback(
|
||||||
@ -67,7 +71,7 @@ const HedPreprocessor = (props: HedProcessorProps) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Image Resolution"
|
label="Image Resolution"
|
||||||
@ -79,13 +83,13 @@ const HedPreprocessor = (props: HedProcessorProps) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISwitch
|
<IAISwitch
|
||||||
label="Scribble"
|
label="Scribble"
|
||||||
isChecked={scribble}
|
isChecked={scribble}
|
||||||
onChange={handleScribbleChanged}
|
onChange={handleScribbleChanged}
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -1,23 +1,26 @@
|
|||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types';
|
import { RequiredLineartAnimeImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_anime_image_processor
|
||||||
|
.default as RequiredLineartAnimeImageProcessorInvocation;
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
processorNode: RequiredLineartAnimeImageProcessorInvocation;
|
processorNode: RequiredLineartAnimeImageProcessorInvocation;
|
||||||
|
isEnabled: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
const LineartAnimeProcessor = (props: Props) => {
|
const LineartAnimeProcessor = (props: Props) => {
|
||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode, isEnabled } = props;
|
||||||
const { image_resolution, detect_resolution } = processorNode;
|
const { image_resolution, detect_resolution } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
const isReady = useIsReadyToInvoke();
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
|
|
||||||
const handleDetectResolutionChanged = useCallback(
|
const handleDetectResolutionChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -57,7 +60,7 @@ const LineartAnimeProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Image Resolution"
|
label="Image Resolution"
|
||||||
@ -69,7 +72,7 @@ const LineartAnimeProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -1,24 +1,27 @@
|
|||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types';
|
import { RequiredLineartImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { ChangeEvent, memo, useCallback } from 'react';
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.lineart_image_processor
|
||||||
|
.default as RequiredLineartImageProcessorInvocation;
|
||||||
|
|
||||||
type LineartProcessorProps = {
|
type LineartProcessorProps = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
processorNode: RequiredLineartImageProcessorInvocation;
|
processorNode: RequiredLineartImageProcessorInvocation;
|
||||||
|
isEnabled: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
const LineartProcessor = (props: LineartProcessorProps) => {
|
const LineartProcessor = (props: LineartProcessorProps) => {
|
||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode, isEnabled } = props;
|
||||||
const { image_resolution, detect_resolution, coarse } = processorNode;
|
const { image_resolution, detect_resolution, coarse } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
const isReady = useIsReadyToInvoke();
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
|
|
||||||
const handleDetectResolutionChanged = useCallback(
|
const handleDetectResolutionChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -65,7 +68,7 @@ const LineartProcessor = (props: LineartProcessorProps) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Image Resolution"
|
label="Image Resolution"
|
||||||
@ -77,13 +80,13 @@ const LineartProcessor = (props: LineartProcessorProps) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISwitch
|
<IAISwitch
|
||||||
label="Coarse"
|
label="Coarse"
|
||||||
isChecked={coarse}
|
isChecked={coarse}
|
||||||
onChange={handleCoarseChanged}
|
onChange={handleCoarseChanged}
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -1,23 +1,26 @@
|
|||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/store/types';
|
import { RequiredMediapipeFaceProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.mediapipe_face_processor
|
||||||
|
.default as RequiredMediapipeFaceProcessorInvocation;
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
processorNode: RequiredMediapipeFaceProcessorInvocation;
|
processorNode: RequiredMediapipeFaceProcessorInvocation;
|
||||||
|
isEnabled: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
const MediapipeFaceProcessor = (props: Props) => {
|
const MediapipeFaceProcessor = (props: Props) => {
|
||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode, isEnabled } = props;
|
||||||
const { max_faces, min_confidence } = processorNode;
|
const { max_faces, min_confidence } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
const isReady = useIsReadyToInvoke();
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
|
|
||||||
const handleMaxFacesChanged = useCallback(
|
const handleMaxFacesChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -53,7 +56,7 @@ const MediapipeFaceProcessor = (props: Props) => {
|
|||||||
max={20}
|
max={20}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Min Confidence"
|
label="Min Confidence"
|
||||||
@ -66,7 +69,7 @@ const MediapipeFaceProcessor = (props: Props) => {
|
|||||||
step={0.01}
|
step={0.01}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -1,23 +1,26 @@
|
|||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types';
|
import { RequiredMidasDepthImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.midas_depth_image_processor
|
||||||
|
.default as RequiredMidasDepthImageProcessorInvocation;
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
processorNode: RequiredMidasDepthImageProcessorInvocation;
|
processorNode: RequiredMidasDepthImageProcessorInvocation;
|
||||||
|
isEnabled: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
const MidasDepthProcessor = (props: Props) => {
|
const MidasDepthProcessor = (props: Props) => {
|
||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode, isEnabled } = props;
|
||||||
const { a_mult, bg_th } = processorNode;
|
const { a_mult, bg_th } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
const isReady = useIsReadyToInvoke();
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
|
|
||||||
const handleAMultChanged = useCallback(
|
const handleAMultChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -54,7 +57,7 @@ const MidasDepthProcessor = (props: Props) => {
|
|||||||
step={0.01}
|
step={0.01}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="bg_th"
|
label="bg_th"
|
||||||
@ -67,7 +70,7 @@ const MidasDepthProcessor = (props: Props) => {
|
|||||||
step={0.01}
|
step={0.01}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -1,23 +1,26 @@
|
|||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types';
|
import { RequiredMlsdImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.mlsd_image_processor
|
||||||
|
.default as RequiredMlsdImageProcessorInvocation;
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
processorNode: RequiredMlsdImageProcessorInvocation;
|
processorNode: RequiredMlsdImageProcessorInvocation;
|
||||||
|
isEnabled: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
const MlsdImageProcessor = (props: Props) => {
|
const MlsdImageProcessor = (props: Props) => {
|
||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode, isEnabled } = props;
|
||||||
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
|
const { image_resolution, detect_resolution, thr_d, thr_v } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
const isReady = useIsReadyToInvoke();
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
|
|
||||||
const handleDetectResolutionChanged = useCallback(
|
const handleDetectResolutionChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -79,7 +82,7 @@ const MlsdImageProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Image Resolution"
|
label="Image Resolution"
|
||||||
@ -91,7 +94,7 @@ const MlsdImageProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="W"
|
label="W"
|
||||||
@ -104,7 +107,7 @@ const MlsdImageProcessor = (props: Props) => {
|
|||||||
step={0.01}
|
step={0.01}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="H"
|
label="H"
|
||||||
@ -117,7 +120,7 @@ const MlsdImageProcessor = (props: Props) => {
|
|||||||
step={0.01}
|
step={0.01}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -1,23 +1,26 @@
|
|||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types';
|
import { RequiredNormalbaeImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.normalbae_image_processor
|
||||||
|
.default as RequiredNormalbaeImageProcessorInvocation;
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
processorNode: RequiredNormalbaeImageProcessorInvocation;
|
processorNode: RequiredNormalbaeImageProcessorInvocation;
|
||||||
|
isEnabled: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
const NormalBaeProcessor = (props: Props) => {
|
const NormalBaeProcessor = (props: Props) => {
|
||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode, isEnabled } = props;
|
||||||
const { image_resolution, detect_resolution } = processorNode;
|
const { image_resolution, detect_resolution } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
const isReady = useIsReadyToInvoke();
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
|
|
||||||
const handleDetectResolutionChanged = useCallback(
|
const handleDetectResolutionChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -57,7 +60,7 @@ const NormalBaeProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Image Resolution"
|
label="Image Resolution"
|
||||||
@ -69,7 +72,7 @@ const NormalBaeProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -1,24 +1,27 @@
|
|||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types';
|
import { RequiredOpenposeImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { ChangeEvent, memo, useCallback } from 'react';
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.openpose_image_processor
|
||||||
|
.default as RequiredOpenposeImageProcessorInvocation;
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
processorNode: RequiredOpenposeImageProcessorInvocation;
|
processorNode: RequiredOpenposeImageProcessorInvocation;
|
||||||
|
isEnabled: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
const OpenposeProcessor = (props: Props) => {
|
const OpenposeProcessor = (props: Props) => {
|
||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode, isEnabled } = props;
|
||||||
const { image_resolution, detect_resolution, hand_and_face } = processorNode;
|
const { image_resolution, detect_resolution, hand_and_face } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
const isReady = useIsReadyToInvoke();
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
|
|
||||||
const handleDetectResolutionChanged = useCallback(
|
const handleDetectResolutionChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -65,7 +68,7 @@ const OpenposeProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Image Resolution"
|
label="Image Resolution"
|
||||||
@ -77,13 +80,13 @@ const OpenposeProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISwitch
|
<IAISwitch
|
||||||
label="Hand and Face"
|
label="Hand and Face"
|
||||||
isChecked={hand_and_face}
|
isChecked={hand_and_face}
|
||||||
onChange={handleHandAndFaceChanged}
|
onChange={handleHandAndFaceChanged}
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -1,24 +1,27 @@
|
|||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
|
||||||
import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types';
|
import { RequiredPidiImageProcessorInvocation } from 'features/controlNet/store/types';
|
||||||
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { ChangeEvent, memo, useCallback } from 'react';
|
import { ChangeEvent, memo, useCallback } from 'react';
|
||||||
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from '../hooks/useProcessorNodeChanged';
|
||||||
import ProcessorWrapper from './common/ProcessorWrapper';
|
import ProcessorWrapper from './common/ProcessorWrapper';
|
||||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
|
||||||
|
|
||||||
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor.default;
|
const DEFAULTS = CONTROLNET_PROCESSORS.pidi_image_processor
|
||||||
|
.default as RequiredPidiImageProcessorInvocation;
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
processorNode: RequiredPidiImageProcessorInvocation;
|
processorNode: RequiredPidiImageProcessorInvocation;
|
||||||
|
isEnabled: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
const PidiProcessor = (props: Props) => {
|
const PidiProcessor = (props: Props) => {
|
||||||
const { controlNetId, processorNode } = props;
|
const { controlNetId, processorNode, isEnabled } = props;
|
||||||
const { image_resolution, detect_resolution, scribble, safe } = processorNode;
|
const { image_resolution, detect_resolution, scribble, safe } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
const isReady = useIsReadyToInvoke();
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
|
|
||||||
const handleDetectResolutionChanged = useCallback(
|
const handleDetectResolutionChanged = useCallback(
|
||||||
(v: number) => {
|
(v: number) => {
|
||||||
@ -72,7 +75,7 @@ const PidiProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISlider
|
<IAISlider
|
||||||
label="Image Resolution"
|
label="Image Resolution"
|
||||||
@ -84,7 +87,7 @@ const PidiProcessor = (props: Props) => {
|
|||||||
max={4096}
|
max={4096}
|
||||||
withInput
|
withInput
|
||||||
withSliderMarks
|
withSliderMarks
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
<IAISwitch
|
<IAISwitch
|
||||||
label="Scribble"
|
label="Scribble"
|
||||||
@ -95,7 +98,7 @@ const PidiProcessor = (props: Props) => {
|
|||||||
label="Safe"
|
label="Safe"
|
||||||
isChecked={safe}
|
isChecked={safe}
|
||||||
onChange={handleSafeChanged}
|
onChange={handleSafeChanged}
|
||||||
isDisabled={!isReady}
|
isDisabled={isBusy || !isEnabled}
|
||||||
/>
|
/>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
|
@ -4,6 +4,7 @@ import { memo } from 'react';
|
|||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
processorNode: RequiredZoeDepthImageProcessorInvocation;
|
processorNode: RequiredZoeDepthImageProcessorInvocation;
|
||||||
|
isEnabled: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ZoeDepthProcessor = (props: Props) => {
|
const ZoeDepthProcessor = (props: Props) => {
|
||||||
|
@ -173,91 +173,17 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
type ControlNetModelsDict = Record<string, ControlNetModel>;
|
export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: {
|
||||||
|
[key: string]: ControlNetProcessorType;
|
||||||
type ControlNetModel = {
|
} = {
|
||||||
type: string;
|
canny: 'canny_image_processor',
|
||||||
label: string;
|
mlsd: 'mlsd_image_processor',
|
||||||
description?: string;
|
depth: 'midas_depth_image_processor',
|
||||||
defaultProcessor?: ControlNetProcessorType;
|
bae: 'normalbae_image_processor',
|
||||||
|
lineart: 'lineart_image_processor',
|
||||||
|
lineart_anime: 'lineart_anime_image_processor',
|
||||||
|
softedge: 'hed_image_processor',
|
||||||
|
shuffle: 'content_shuffle_image_processor',
|
||||||
|
openpose: 'openpose_image_processor',
|
||||||
|
mediapipe: 'mediapipe_face_processor',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const CONTROLNET_MODELS: ControlNetModelsDict = {
|
|
||||||
'lllyasviel/control_v11p_sd15_canny': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_canny',
|
|
||||||
label: 'Canny',
|
|
||||||
defaultProcessor: 'canny_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15_inpaint': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_inpaint',
|
|
||||||
label: 'Inpaint',
|
|
||||||
defaultProcessor: 'none',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15_mlsd': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_mlsd',
|
|
||||||
label: 'M-LSD',
|
|
||||||
defaultProcessor: 'mlsd_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11f1p_sd15_depth': {
|
|
||||||
type: 'lllyasviel/control_v11f1p_sd15_depth',
|
|
||||||
label: 'Depth',
|
|
||||||
defaultProcessor: 'midas_depth_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15_normalbae': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_normalbae',
|
|
||||||
label: 'Normal Map (BAE)',
|
|
||||||
defaultProcessor: 'normalbae_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15_seg': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_seg',
|
|
||||||
label: 'Segmentation',
|
|
||||||
defaultProcessor: 'none',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15_lineart': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_lineart',
|
|
||||||
label: 'Lineart',
|
|
||||||
defaultProcessor: 'lineart_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15s2_lineart_anime': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15s2_lineart_anime',
|
|
||||||
label: 'Lineart Anime',
|
|
||||||
defaultProcessor: 'lineart_anime_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15_scribble': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_scribble',
|
|
||||||
label: 'Scribble',
|
|
||||||
defaultProcessor: 'none',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15_softedge': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_softedge',
|
|
||||||
label: 'Soft Edge',
|
|
||||||
defaultProcessor: 'hed_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11e_sd15_shuffle': {
|
|
||||||
type: 'lllyasviel/control_v11e_sd15_shuffle',
|
|
||||||
label: 'Content Shuffle',
|
|
||||||
defaultProcessor: 'content_shuffle_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11p_sd15_openpose': {
|
|
||||||
type: 'lllyasviel/control_v11p_sd15_openpose',
|
|
||||||
label: 'Openpose',
|
|
||||||
defaultProcessor: 'openpose_image_processor',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11f1e_sd15_tile': {
|
|
||||||
type: 'lllyasviel/control_v11f1e_sd15_tile',
|
|
||||||
label: 'Tile (experimental)',
|
|
||||||
defaultProcessor: 'none',
|
|
||||||
},
|
|
||||||
'lllyasviel/control_v11e_sd15_ip2p': {
|
|
||||||
type: 'lllyasviel/control_v11e_sd15_ip2p',
|
|
||||||
label: 'Pix2Pix (experimental)',
|
|
||||||
defaultProcessor: 'none',
|
|
||||||
},
|
|
||||||
'CrucibleAI/ControlNetMediaPipeFace': {
|
|
||||||
type: 'CrucibleAI/ControlNetMediaPipeFace',
|
|
||||||
label: 'Mediapipe Face',
|
|
||||||
defaultProcessor: 'mediapipe_face_processor',
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ControlNetModelName = keyof typeof CONTROLNET_MODELS;
|
|
||||||
|
@ -1,22 +1,20 @@
|
|||||||
import { PayloadAction } from '@reduxjs/toolkit';
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { cloneDeep, forEach } from 'lodash-es';
|
||||||
|
import { imageDeleted } from 'services/api/thunks/image';
|
||||||
|
import { isAnySessionRejected } from 'services/api/thunks/session';
|
||||||
|
import { appSocketInvocationError } from 'services/events/actions';
|
||||||
|
import { controlNetImageProcessed } from './actions';
|
||||||
|
import {
|
||||||
|
CONTROLNET_MODEL_DEFAULT_PROCESSORS,
|
||||||
|
CONTROLNET_PROCESSORS,
|
||||||
|
} from './constants';
|
||||||
import {
|
import {
|
||||||
ControlNetProcessorType,
|
ControlNetProcessorType,
|
||||||
RequiredCannyImageProcessorInvocation,
|
RequiredCannyImageProcessorInvocation,
|
||||||
RequiredControlNetProcessorNode,
|
RequiredControlNetProcessorNode,
|
||||||
} from './types';
|
} from './types';
|
||||||
import {
|
|
||||||
CONTROLNET_MODELS,
|
|
||||||
CONTROLNET_PROCESSORS,
|
|
||||||
ControlNetModelName,
|
|
||||||
} from './constants';
|
|
||||||
import { controlNetImageProcessed } from './actions';
|
|
||||||
import { imageDeleted, imageUrlsReceived } from 'services/api/thunks/image';
|
|
||||||
import { forEach } from 'lodash-es';
|
|
||||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
|
||||||
import { appSocketInvocationError } from 'services/events/actions';
|
|
||||||
|
|
||||||
export type ControlModes =
|
export type ControlModes =
|
||||||
| 'balanced'
|
| 'balanced'
|
||||||
@ -26,7 +24,7 @@ export type ControlModes =
|
|||||||
|
|
||||||
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
|
model: null,
|
||||||
weight: 1,
|
weight: 1,
|
||||||
beginStepPct: 0,
|
beginStepPct: 0,
|
||||||
endStepPct: 1,
|
endStepPct: 1,
|
||||||
@ -42,7 +40,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
|||||||
export type ControlNetConfig = {
|
export type ControlNetConfig = {
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
isEnabled: boolean;
|
isEnabled: boolean;
|
||||||
model: ControlNetModelName;
|
model: ControlNetModelParam | null;
|
||||||
weight: number;
|
weight: number;
|
||||||
beginStepPct: number;
|
beginStepPct: number;
|
||||||
endStepPct: number;
|
endStepPct: number;
|
||||||
@ -86,6 +84,19 @@ export const controlNetSlice = createSlice({
|
|||||||
controlNetId,
|
controlNetId,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
controlNetDuplicated: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<{
|
||||||
|
sourceControlNetId: string;
|
||||||
|
newControlNetId: string;
|
||||||
|
}>
|
||||||
|
) => {
|
||||||
|
const { sourceControlNetId, newControlNetId } = action.payload;
|
||||||
|
|
||||||
|
const newControlnet = cloneDeep(state.controlNets[sourceControlNetId]);
|
||||||
|
newControlnet.controlNetId = newControlNetId;
|
||||||
|
state.controlNets[newControlNetId] = newControlnet;
|
||||||
|
},
|
||||||
controlNetAddedFromImage: (
|
controlNetAddedFromImage: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{ controlNetId: string; controlImage: string }>
|
action: PayloadAction<{ controlNetId: string; controlImage: string }>
|
||||||
@ -147,7 +158,7 @@ export const controlNetSlice = createSlice({
|
|||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
model: ControlNetModelName;
|
model: ControlNetModelParam;
|
||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, model } = action.payload;
|
const { controlNetId, model } = action.payload;
|
||||||
@ -155,7 +166,15 @@ export const controlNetSlice = createSlice({
|
|||||||
state.controlNets[controlNetId].processedControlImage = null;
|
state.controlNets[controlNetId].processedControlImage = null;
|
||||||
|
|
||||||
if (state.controlNets[controlNetId].shouldAutoConfig) {
|
if (state.controlNets[controlNetId].shouldAutoConfig) {
|
||||||
const processorType = CONTROLNET_MODELS[model].defaultProcessor;
|
let processorType: ControlNetProcessorType | undefined = undefined;
|
||||||
|
|
||||||
|
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||||
|
if (model.model_name.includes(modelSubstring)) {
|
||||||
|
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (processorType) {
|
if (processorType) {
|
||||||
state.controlNets[controlNetId].processorType = processorType;
|
state.controlNets[controlNetId].processorType = processorType;
|
||||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
||||||
@ -241,9 +260,19 @@ export const controlNetSlice = createSlice({
|
|||||||
|
|
||||||
if (newShouldAutoConfig) {
|
if (newShouldAutoConfig) {
|
||||||
// manage the processor for the user
|
// manage the processor for the user
|
||||||
const processorType =
|
let processorType: ControlNetProcessorType | undefined = undefined;
|
||||||
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
|
|
||||||
.defaultProcessor;
|
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||||
|
if (
|
||||||
|
state.controlNets[controlNetId].model?.model_name.includes(
|
||||||
|
modelSubstring
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (processorType) {
|
if (processorType) {
|
||||||
state.controlNets[controlNetId].processorType = processorType;
|
state.controlNets[controlNetId].processorType = processorType;
|
||||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
||||||
@ -272,7 +301,8 @@ export const controlNetSlice = createSlice({
|
|||||||
});
|
});
|
||||||
|
|
||||||
builder.addCase(imageDeleted.pending, (state, action) => {
|
builder.addCase(imageDeleted.pending, (state, action) => {
|
||||||
// Preemptively remove the image from the gallery
|
// Preemptively remove the image from all controlnets
|
||||||
|
// TODO: doesn't the imageusage stuff do this for us?
|
||||||
const { image_name } = action.meta.arg;
|
const { image_name } = action.meta.arg;
|
||||||
forEach(state.controlNets, (c) => {
|
forEach(state.controlNets, (c) => {
|
||||||
if (c.controlImage === image_name) {
|
if (c.controlImage === image_name) {
|
||||||
@ -285,21 +315,6 @@ export const controlNetSlice = createSlice({
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
|
||||||
// const { image_name, image_url, thumbnail_url } = action.payload;
|
|
||||||
|
|
||||||
// forEach(state.controlNets, (c) => {
|
|
||||||
// if (c.controlImage?.image_name === image_name) {
|
|
||||||
// c.controlImage.image_url = image_url;
|
|
||||||
// c.controlImage.thumbnail_url = thumbnail_url;
|
|
||||||
// }
|
|
||||||
// if (c.processedControlImage?.image_name === image_name) {
|
|
||||||
// c.processedControlImage.image_url = image_url;
|
|
||||||
// c.processedControlImage.thumbnail_url = thumbnail_url;
|
|
||||||
// }
|
|
||||||
// });
|
|
||||||
// });
|
|
||||||
|
|
||||||
builder.addCase(appSocketInvocationError, (state, action) => {
|
builder.addCase(appSocketInvocationError, (state, action) => {
|
||||||
state.pendingControlImages = [];
|
state.pendingControlImages = [];
|
||||||
});
|
});
|
||||||
@ -313,6 +328,7 @@ export const controlNetSlice = createSlice({
|
|||||||
export const {
|
export const {
|
||||||
isControlNetEnabledToggled,
|
isControlNetEnabledToggled,
|
||||||
controlNetAdded,
|
controlNetAdded,
|
||||||
|
controlNetDuplicated,
|
||||||
controlNetAddedFromImage,
|
controlNetAddedFromImage,
|
||||||
controlNetRemoved,
|
controlNetRemoved,
|
||||||
controlNetImageChanged,
|
controlNetImageChanged,
|
||||||
|
@ -118,6 +118,20 @@ const GalleryImageGrid = () => {
|
|||||||
);
|
);
|
||||||
}, [dispatch, imageNames.length, galleryView]);
|
}, [dispatch, imageNames.length, galleryView]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
// Set up gallery scroler
|
||||||
|
const { current: root } = rootRef;
|
||||||
|
if (scroller && root) {
|
||||||
|
initialize({
|
||||||
|
target: root,
|
||||||
|
elements: {
|
||||||
|
viewport: scroller,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return () => osInstance()?.destroy();
|
||||||
|
}, [scroller, initialize, osInstance]);
|
||||||
|
|
||||||
const handleEndReached = useMemo(() => {
|
const handleEndReached = useMemo(() => {
|
||||||
if (areMoreAvailable) {
|
if (areMoreAvailable) {
|
||||||
return handleLoadMoreImages;
|
return handleLoadMoreImages;
|
||||||
|
@ -3,6 +3,7 @@ import { LoRAModelParam } from 'features/parameters/types/parameterSchemas';
|
|||||||
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
export type LoRA = LoRAModelParam & {
|
export type LoRA = LoRAModelParam & {
|
||||||
|
id: string;
|
||||||
weight: number;
|
weight: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -24,7 +25,7 @@ export const loraSlice = createSlice({
|
|||||||
reducers: {
|
reducers: {
|
||||||
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
|
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
|
||||||
const { model_name, id, base_model } = action.payload;
|
const { model_name, id, base_model } = action.payload;
|
||||||
state.loras[id] = { model_name, base_model, ...defaultLoRAConfig };
|
state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
|
||||||
},
|
},
|
||||||
loraRemoved: (state, action: PayloadAction<string>) => {
|
loraRemoved: (state, action: PayloadAction<string>) => {
|
||||||
const id = action.payload;
|
const id = action.payload;
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { Flex, Heading, Icon, Tooltip } from '@chakra-ui/react';
|
import { Flex, Heading, Icon, Tooltip } from '@chakra-ui/react';
|
||||||
|
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/hooks/useBuildInvocation';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { FaInfoCircle } from 'react-icons/fa';
|
import { FaInfoCircle } from 'react-icons/fa';
|
||||||
|
|
||||||
@ -12,6 +13,7 @@ const IAINodeHeader = (props: IAINodeHeaderProps) => {
|
|||||||
const { nodeId, title, description } = props;
|
const { nodeId, title, description } = props;
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
|
className={DRAG_HANDLE_CLASSNAME}
|
||||||
sx={{
|
sx={{
|
||||||
borderTopRadius: 'md',
|
borderTopRadius: 'md',
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
|
@ -1,25 +1,25 @@
|
|||||||
import {
|
|
||||||
InputFieldTemplate,
|
|
||||||
InputFieldValue,
|
|
||||||
InvocationTemplate,
|
|
||||||
} from 'features/nodes/types/types';
|
|
||||||
import { memo, ReactNode, useCallback } from 'react';
|
|
||||||
import { map } from 'lodash-es';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import {
|
import {
|
||||||
Box,
|
Box,
|
||||||
|
Divider,
|
||||||
Flex,
|
Flex,
|
||||||
FormControl,
|
FormControl,
|
||||||
FormLabel,
|
FormLabel,
|
||||||
HStack,
|
HStack,
|
||||||
Tooltip,
|
Tooltip,
|
||||||
Divider,
|
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import FieldHandle from '../FieldHandle';
|
import { RootState } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
||||||
import InputFieldComponent from '../InputFieldComponent';
|
|
||||||
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
|
||||||
|
import {
|
||||||
|
InputFieldTemplate,
|
||||||
|
InputFieldValue,
|
||||||
|
InvocationTemplate,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { map } from 'lodash-es';
|
||||||
|
import { ReactNode, memo, useCallback } from 'react';
|
||||||
|
import FieldHandle from '../FieldHandle';
|
||||||
|
import InputFieldComponent from '../InputFieldComponent';
|
||||||
|
|
||||||
interface IAINodeInputProps {
|
interface IAINodeInputProps {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -35,6 +35,7 @@ function IAINodeInput(props: IAINodeInputProps) {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Box
|
<Box
|
||||||
|
className="nopan"
|
||||||
position="relative"
|
position="relative"
|
||||||
borderColor={
|
borderColor={
|
||||||
!template
|
!template
|
||||||
@ -136,7 +137,7 @@ const IAINodeInputs = (props: IAINodeInputsProps) => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex flexDir="column" gap={2} p={2}>
|
<Flex className="nopan" flexDir="column" gap={2} p={2}>
|
||||||
{IAINodeInputsToRender}
|
{IAINodeInputsToRender}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
|
@ -7,6 +7,7 @@ import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
|||||||
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
||||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
||||||
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
||||||
|
import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent';
|
||||||
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
||||||
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
||||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||||
@ -174,6 +175,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (type === 'controlnet_model' && template.type === 'controlnet_model') {
|
||||||
|
return (
|
||||||
|
<ControlNetModelInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (type === 'array' && template.type === 'array') {
|
if (type === 'array' && template.type === 'array') {
|
||||||
return (
|
return (
|
||||||
<ArrayInputFieldComponent
|
<ArrayInputFieldComponent
|
||||||
|
@ -23,7 +23,14 @@ export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
|
|||||||
if (!template) {
|
if (!template) {
|
||||||
return (
|
return (
|
||||||
<NodeWrapper selected={selected}>
|
<NodeWrapper selected={selected}>
|
||||||
<Flex sx={{ alignItems: 'center', justifyContent: 'center' }}>
|
<Flex
|
||||||
|
className="nopan"
|
||||||
|
sx={{
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
cursor: 'auto',
|
||||||
|
}}
|
||||||
|
>
|
||||||
<Icon
|
<Icon
|
||||||
as={FaExclamationCircle}
|
as={FaExclamationCircle}
|
||||||
sx={{
|
sx={{
|
||||||
@ -46,7 +53,9 @@ export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
|
|||||||
description={template.description}
|
description={template.description}
|
||||||
/>
|
/>
|
||||||
<Flex
|
<Flex
|
||||||
|
className={'nopan'}
|
||||||
sx={{
|
sx={{
|
||||||
|
cursor: 'auto',
|
||||||
flexDirection: 'column',
|
flexDirection: 'column',
|
||||||
borderBottomRadius: 'md',
|
borderBottomRadius: 'md',
|
||||||
py: 2,
|
py: 2,
|
||||||
|
@ -2,6 +2,8 @@ import { Box, useToken } from '@chakra-ui/react';
|
|||||||
import { NODE_MIN_WIDTH } from 'app/constants';
|
import { NODE_MIN_WIDTH } from 'app/constants';
|
||||||
|
|
||||||
import { PropsWithChildren } from 'react';
|
import { PropsWithChildren } from 'react';
|
||||||
|
import { DRAG_HANDLE_CLASSNAME } from '../hooks/useBuildInvocation';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
|
||||||
type NodeWrapperProps = PropsWithChildren & {
|
type NodeWrapperProps = PropsWithChildren & {
|
||||||
selected: boolean;
|
selected: boolean;
|
||||||
@ -13,8 +15,11 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
|||||||
'dark-lg',
|
'dark-lg',
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
const shift = useAppSelector((state) => state.hotkeys.shift);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box
|
<Box
|
||||||
|
className={shift ? DRAG_HANDLE_CLASSNAME : 'nopan'}
|
||||||
sx={{
|
sx={{
|
||||||
position: 'relative',
|
position: 'relative',
|
||||||
borderRadius: 'md',
|
borderRadius: 'md',
|
||||||
|
@ -21,6 +21,7 @@ const ProgressImageNode = (props: NodeProps<InvocationValue>) => {
|
|||||||
/>
|
/>
|
||||||
|
|
||||||
<Flex
|
<Flex
|
||||||
|
className="nopan"
|
||||||
sx={{
|
sx={{
|
||||||
flexDirection: 'column',
|
flexDirection: 'column',
|
||||||
borderBottomRadius: 'md',
|
borderBottomRadius: 'md',
|
||||||
|
@ -0,0 +1,103 @@
|
|||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import {
|
||||||
|
ControlNetModelInputFieldTemplate,
|
||||||
|
ControlNetModelInputFieldValue,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
|
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
const ControlNetModelInputFieldComponent = (
|
||||||
|
props: FieldComponentProps<
|
||||||
|
ControlNetModelInputFieldValue,
|
||||||
|
ControlNetModelInputFieldTemplate
|
||||||
|
>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
const controlNetModel = field.value;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
||||||
|
|
||||||
|
// grab the full model entity from the RTK Query cache
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() =>
|
||||||
|
controlNetModels?.entities[
|
||||||
|
`${controlNetModel?.base_model}/controlnet/${controlNetModel?.model_name}`
|
||||||
|
] ?? null,
|
||||||
|
[
|
||||||
|
controlNetModel?.base_model,
|
||||||
|
controlNetModel?.model_name,
|
||||||
|
controlNetModels?.entities,
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!controlNetModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(controlNetModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.model_name,
|
||||||
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [controlNetModels]);
|
||||||
|
|
||||||
|
const handleValueChanged = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const newControlNetModel = modelIdToControlNetModelParam(v);
|
||||||
|
|
||||||
|
if (!newControlNetModel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
fieldValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value: newControlNetModel,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIMantineSelect
|
||||||
|
tooltip={selectedModel?.description}
|
||||||
|
label={
|
||||||
|
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||||
|
}
|
||||||
|
value={selectedModel?.id ?? null}
|
||||||
|
placeholder="Pick one"
|
||||||
|
error={!selectedModel}
|
||||||
|
data={data}
|
||||||
|
onChange={handleValueChanged}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ControlNetModelInputFieldComponent);
|
@ -6,6 +6,7 @@ import {
|
|||||||
NumberInputStepper,
|
NumberInputStepper,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { numberStringRegex } from 'common/components/IAINumberInput';
|
||||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import {
|
import {
|
||||||
FloatInputFieldTemplate,
|
FloatInputFieldTemplate,
|
||||||
@ -13,7 +14,7 @@ import {
|
|||||||
IntegerInputFieldTemplate,
|
IntegerInputFieldTemplate,
|
||||||
IntegerInputFieldValue,
|
IntegerInputFieldValue,
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { memo } from 'react';
|
import { memo, useEffect, useState } from 'react';
|
||||||
import { FieldComponentProps } from './types';
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
const NumberInputFieldComponent = (
|
const NumberInputFieldComponent = (
|
||||||
@ -23,17 +24,42 @@ const NumberInputFieldComponent = (
|
|||||||
>
|
>
|
||||||
) => {
|
) => {
|
||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const [valueAsString, setValueAsString] = useState<string>(
|
||||||
|
String(field.value)
|
||||||
|
);
|
||||||
|
|
||||||
const handleValueChanged = (_: string, value: number) => {
|
const handleValueChanged = (v: string) => {
|
||||||
dispatch(fieldValueChanged({ nodeId, fieldName: field.name, value }));
|
setValueAsString(v);
|
||||||
|
// This allows negatives and decimals e.g. '-123', `.5`, `-0.2`, etc.
|
||||||
|
if (!v.match(numberStringRegex)) {
|
||||||
|
// Cast the value to number. Floor it if it should be an integer.
|
||||||
|
dispatch(
|
||||||
|
fieldValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value:
|
||||||
|
props.template.type === 'integer'
|
||||||
|
? Math.floor(Number(v))
|
||||||
|
: Number(v),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (
|
||||||
|
!valueAsString.match(numberStringRegex) &&
|
||||||
|
field.value !== Number(valueAsString)
|
||||||
|
) {
|
||||||
|
setValueAsString(String(field.value));
|
||||||
|
}
|
||||||
|
}, [field.value, valueAsString]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<NumberInput
|
<NumberInput
|
||||||
onChange={handleValueChanged}
|
onChange={handleValueChanged}
|
||||||
value={field.value}
|
value={valueAsString}
|
||||||
step={props.template.type === 'integer' ? 1 : 0.1}
|
step={props.template.type === 'integer' ? 1 : 0.1}
|
||||||
precision={props.template.type === 'integer' ? 0 : 3}
|
precision={props.template.type === 'integer' ? 0 : 3}
|
||||||
>
|
>
|
||||||
|
@ -18,6 +18,12 @@ const templatesSelector = createSelector(
|
|||||||
(nodes) => nodes.invocationTemplates
|
(nodes) => nodes.invocationTemplates
|
||||||
);
|
);
|
||||||
|
|
||||||
|
export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle';
|
||||||
|
|
||||||
|
export const SHARED_NODE_PROPERTIES: Partial<Node> = {
|
||||||
|
dragHandle: `.${DRAG_HANDLE_CLASSNAME}`,
|
||||||
|
};
|
||||||
|
|
||||||
export const useBuildInvocation = () => {
|
export const useBuildInvocation = () => {
|
||||||
const invocationTemplates = useAppSelector(templatesSelector);
|
const invocationTemplates = useAppSelector(templatesSelector);
|
||||||
|
|
||||||
@ -32,6 +38,7 @@ export const useBuildInvocation = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
const node: Node = {
|
const node: Node = {
|
||||||
|
...SHARED_NODE_PROPERTIES,
|
||||||
id: 'progress_image',
|
id: 'progress_image',
|
||||||
type: 'progress_image',
|
type: 'progress_image',
|
||||||
position: { x: x, y: y },
|
position: { x: x, y: y },
|
||||||
@ -91,6 +98,7 @@ export const useBuildInvocation = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
const invocation: Node<InvocationValue> = {
|
const invocation: Node<InvocationValue> = {
|
||||||
|
...SHARED_NODE_PROPERTIES,
|
||||||
id: nodeId,
|
id: nodeId,
|
||||||
type: 'invocation',
|
type: 'invocation',
|
||||||
position: { x: x, y: y },
|
position: { x: x, y: y },
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import {
|
import {
|
||||||
|
ControlNetModelParam,
|
||||||
LoRAModelParam,
|
LoRAModelParam,
|
||||||
MainModelParam,
|
MainModelParam,
|
||||||
VaeModelParam,
|
VaeModelParam,
|
||||||
@ -81,7 +82,8 @@ const nodesSlice = createSlice({
|
|||||||
| ImageField[]
|
| ImageField[]
|
||||||
| MainModelParam
|
| MainModelParam
|
||||||
| VaeModelParam
|
| VaeModelParam
|
||||||
| LoRAModelParam;
|
| LoRAModelParam
|
||||||
|
| ControlNetModelParam;
|
||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { nodeId, fieldName, value } = action.payload;
|
const { nodeId, fieldName, value } = action.payload;
|
||||||
|
@ -19,6 +19,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
|||||||
model: 'model',
|
model: 'model',
|
||||||
vae_model: 'vae_model',
|
vae_model: 'vae_model',
|
||||||
lora_model: 'lora_model',
|
lora_model: 'lora_model',
|
||||||
|
controlnet_model: 'controlnet_model',
|
||||||
|
ControlNetModelField: 'controlnet_model',
|
||||||
array: 'array',
|
array: 'array',
|
||||||
item: 'item',
|
item: 'item',
|
||||||
ColorField: 'color',
|
ColorField: 'color',
|
||||||
@ -130,6 +132,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
|||||||
title: 'LoRA',
|
title: 'LoRA',
|
||||||
description: 'Models are models.',
|
description: 'Models are models.',
|
||||||
},
|
},
|
||||||
|
controlnet_model: {
|
||||||
|
color: 'teal',
|
||||||
|
colorCssVar: getColorTokenCssVariable('teal'),
|
||||||
|
title: 'ControlNet',
|
||||||
|
description: 'Models are models.',
|
||||||
|
},
|
||||||
array: {
|
array: {
|
||||||
color: 'gray',
|
color: 'gray',
|
||||||
colorCssVar: getColorTokenCssVariable('gray'),
|
colorCssVar: getColorTokenCssVariable('gray'),
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
|
ControlNetModelParam,
|
||||||
LoRAModelParam,
|
LoRAModelParam,
|
||||||
MainModelParam,
|
MainModelParam,
|
||||||
VaeModelParam,
|
VaeModelParam,
|
||||||
@ -71,6 +72,7 @@ export type FieldType =
|
|||||||
| 'model'
|
| 'model'
|
||||||
| 'vae_model'
|
| 'vae_model'
|
||||||
| 'lora_model'
|
| 'lora_model'
|
||||||
|
| 'controlnet_model'
|
||||||
| 'array'
|
| 'array'
|
||||||
| 'item'
|
| 'item'
|
||||||
| 'color'
|
| 'color'
|
||||||
@ -100,6 +102,7 @@ export type InputFieldValue =
|
|||||||
| MainModelInputFieldValue
|
| MainModelInputFieldValue
|
||||||
| VaeModelInputFieldValue
|
| VaeModelInputFieldValue
|
||||||
| LoRAModelInputFieldValue
|
| LoRAModelInputFieldValue
|
||||||
|
| ControlNetModelInputFieldValue
|
||||||
| ArrayInputFieldValue
|
| ArrayInputFieldValue
|
||||||
| ItemInputFieldValue
|
| ItemInputFieldValue
|
||||||
| ColorInputFieldValue
|
| ColorInputFieldValue
|
||||||
@ -127,6 +130,7 @@ export type InputFieldTemplate =
|
|||||||
| ModelInputFieldTemplate
|
| ModelInputFieldTemplate
|
||||||
| VaeModelInputFieldTemplate
|
| VaeModelInputFieldTemplate
|
||||||
| LoRAModelInputFieldTemplate
|
| LoRAModelInputFieldTemplate
|
||||||
|
| ControlNetModelInputFieldTemplate
|
||||||
| ArrayInputFieldTemplate
|
| ArrayInputFieldTemplate
|
||||||
| ItemInputFieldTemplate
|
| ItemInputFieldTemplate
|
||||||
| ColorInputFieldTemplate
|
| ColorInputFieldTemplate
|
||||||
@ -249,6 +253,11 @@ export type LoRAModelInputFieldValue = FieldValueBase & {
|
|||||||
value?: LoRAModelParam;
|
value?: LoRAModelParam;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type ControlNetModelInputFieldValue = FieldValueBase & {
|
||||||
|
type: 'controlnet_model';
|
||||||
|
value?: ControlNetModelParam;
|
||||||
|
};
|
||||||
|
|
||||||
export type ArrayInputFieldValue = FieldValueBase & {
|
export type ArrayInputFieldValue = FieldValueBase & {
|
||||||
type: 'array';
|
type: 'array';
|
||||||
value?: (string | number)[];
|
value?: (string | number)[];
|
||||||
@ -368,6 +377,11 @@ export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'lora_model';
|
type: 'lora_model';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: string;
|
||||||
|
type: 'controlnet_model';
|
||||||
|
};
|
||||||
|
|
||||||
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: [];
|
default: [];
|
||||||
type: 'array';
|
type: 'array';
|
||||||
|
@ -9,6 +9,7 @@ import {
|
|||||||
ColorInputFieldTemplate,
|
ColorInputFieldTemplate,
|
||||||
ConditioningInputFieldTemplate,
|
ConditioningInputFieldTemplate,
|
||||||
ControlInputFieldTemplate,
|
ControlInputFieldTemplate,
|
||||||
|
ControlNetModelInputFieldTemplate,
|
||||||
EnumInputFieldTemplate,
|
EnumInputFieldTemplate,
|
||||||
FieldType,
|
FieldType,
|
||||||
FloatInputFieldTemplate,
|
FloatInputFieldTemplate,
|
||||||
@ -207,6 +208,21 @@ const buildLoRAModelInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildControlNetModelInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ControlNetModelInputFieldTemplate => {
|
||||||
|
const template: ControlNetModelInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'controlnet_model',
|
||||||
|
inputRequirement: 'always',
|
||||||
|
inputKind: 'direct',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildImageInputFieldTemplate = ({
|
const buildImageInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -479,6 +495,9 @@ export const buildInputFieldTemplate = (
|
|||||||
if (['lora_model'].includes(fieldType)) {
|
if (['lora_model'].includes(fieldType)) {
|
||||||
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
|
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
if (['controlnet_model'].includes(fieldType)) {
|
||||||
|
return buildControlNetModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
|
}
|
||||||
if (['enum'].includes(fieldType)) {
|
if (['enum'].includes(fieldType)) {
|
||||||
return buildEnumInputFieldTemplate({ schemaObject, baseField });
|
return buildEnumInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
@ -83,6 +83,10 @@ export const buildInputFieldValue = (
|
|||||||
if (template.type === 'lora_model') {
|
if (template.type === 'lora_model') {
|
||||||
fieldValue.value = undefined;
|
fieldValue.value = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (template.type === 'controlnet_model') {
|
||||||
|
fieldValue.value = undefined;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return fieldValue;
|
return fieldValue;
|
||||||
|
@ -60,7 +60,7 @@ export const addLoRAsToGraph = (
|
|||||||
const loraLoaderNode: LoraLoaderInvocation = {
|
const loraLoaderNode: LoraLoaderInvocation = {
|
||||||
type: 'lora_loader',
|
type: 'lora_loader',
|
||||||
id: currentLoraNodeId,
|
id: currentLoraNodeId,
|
||||||
lora,
|
lora: { model_name, base_model },
|
||||||
weight,
|
weight,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2,12 +2,13 @@ import { Divider, Flex } from '@chakra-ui/react';
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
|
||||||
import IAICollapse from 'common/components/IAICollapse';
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import ControlNet from 'features/controlNet/components/ControlNet';
|
import ControlNet from 'features/controlNet/components/ControlNet';
|
||||||
import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle';
|
import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle';
|
||||||
import {
|
import {
|
||||||
controlNetAdded,
|
controlNetAdded,
|
||||||
|
controlNetModelChanged,
|
||||||
controlNetSelector,
|
controlNetSelector,
|
||||||
} from 'features/controlNet/store/controlNetSlice';
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
|
import { getValidControlNets } from 'features/controlNet/util/getValidControlNets';
|
||||||
@ -15,6 +16,8 @@ import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
|||||||
import { map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
import { Fragment, memo, useCallback } from 'react';
|
import { Fragment, memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { FaPlus } from 'react-icons/fa';
|
||||||
|
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
@ -39,10 +42,23 @@ const ParamControlNetCollapse = () => {
|
|||||||
const { controlNetsArray, activeLabel } = useAppSelector(selector);
|
const { controlNetsArray, activeLabel } = useAppSelector(selector);
|
||||||
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
|
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const { firstModel } = useGetControlNetModelsQuery(undefined, {
|
||||||
|
selectFromResult: (result) => {
|
||||||
|
const firstModel = result.data?.entities[result.data?.ids[0]];
|
||||||
|
return {
|
||||||
|
firstModel,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
const handleClickedAddControlNet = useCallback(() => {
|
const handleClickedAddControlNet = useCallback(() => {
|
||||||
dispatch(controlNetAdded({ controlNetId: uuidv4() }));
|
if (!firstModel) {
|
||||||
}, [dispatch]);
|
return;
|
||||||
|
}
|
||||||
|
const controlNetId = uuidv4();
|
||||||
|
dispatch(controlNetAdded({ controlNetId }));
|
||||||
|
dispatch(controlNetModelChanged({ controlNetId, model: firstModel }));
|
||||||
|
}, [dispatch, firstModel]);
|
||||||
|
|
||||||
if (isControlNetDisabled) {
|
if (isControlNetDisabled) {
|
||||||
return null;
|
return null;
|
||||||
@ -51,16 +67,39 @@ const ParamControlNetCollapse = () => {
|
|||||||
return (
|
return (
|
||||||
<IAICollapse label="ControlNet" activeLabel={activeLabel}>
|
<IAICollapse label="ControlNet" activeLabel={activeLabel}>
|
||||||
<Flex sx={{ flexDir: 'column', gap: 3 }}>
|
<Flex sx={{ flexDir: 'column', gap: 3 }}>
|
||||||
<ParamControlNetFeatureToggle />
|
<Flex gap={2} alignItems="center">
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
flexDirection: 'column',
|
||||||
|
w: '100%',
|
||||||
|
gap: 2,
|
||||||
|
px: 4,
|
||||||
|
py: 2,
|
||||||
|
borderRadius: 4,
|
||||||
|
bg: 'base.200',
|
||||||
|
_dark: {
|
||||||
|
bg: 'base.850',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<ParamControlNetFeatureToggle />
|
||||||
|
</Flex>
|
||||||
|
<IAIIconButton
|
||||||
|
tooltip="Add ControlNet"
|
||||||
|
aria-label="Add ControlNet"
|
||||||
|
icon={<FaPlus />}
|
||||||
|
isDisabled={!firstModel}
|
||||||
|
flexGrow={1}
|
||||||
|
size="md"
|
||||||
|
onClick={handleClickedAddControlNet}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
{controlNetsArray.map((c, i) => (
|
{controlNetsArray.map((c, i) => (
|
||||||
<Fragment key={c.controlNetId}>
|
<Fragment key={c.controlNetId}>
|
||||||
{i > 0 && <Divider />}
|
{i > 0 && <Divider />}
|
||||||
<ControlNet controlNet={c} />
|
<ControlNet controlNetId={c.controlNetId} />
|
||||||
</Fragment>
|
</Fragment>
|
||||||
))}
|
))}
|
||||||
<IAIButton flexGrow={1} onClick={handleClickedAddControlNet}>
|
|
||||||
Add ControlNet
|
|
||||||
</IAIButton>
|
|
||||||
</Flex>
|
</Flex>
|
||||||
</IAICollapse>
|
</IAICollapse>
|
||||||
);
|
);
|
||||||
|
@ -37,6 +37,7 @@ const ParamVAEModelSelect = () => {
|
|||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add a "default" option, this means use the main model's included VAE
|
||||||
const data: SelectItem[] = [
|
const data: SelectItem[] = [
|
||||||
{
|
{
|
||||||
value: 'default',
|
value: 'default',
|
||||||
|
@ -17,8 +17,10 @@ import { FaPlay } from 'react-icons/fa';
|
|||||||
const IN_PROGRESS_STYLES: ChakraProps['sx'] = {
|
const IN_PROGRESS_STYLES: ChakraProps['sx'] = {
|
||||||
_disabled: {
|
_disabled: {
|
||||||
bg: 'none',
|
bg: 'none',
|
||||||
|
color: 'base.600',
|
||||||
cursor: 'not-allowed',
|
cursor: 'not-allowed',
|
||||||
_hover: {
|
_hover: {
|
||||||
|
color: 'base.600',
|
||||||
bg: 'none',
|
bg: 'none',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -180,6 +180,23 @@ export type LoRAModelParam = z.infer<typeof zLoRAModel>;
|
|||||||
*/
|
*/
|
||||||
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
|
export const isValidLoRAModel = (val: unknown): val is LoRAModelParam =>
|
||||||
zLoRAModel.safeParse(val).success;
|
zLoRAModel.safeParse(val).success;
|
||||||
|
/**
|
||||||
|
* Zod schema for ControlNet models
|
||||||
|
*/
|
||||||
|
export const zControlNetModel = z.object({
|
||||||
|
model_name: z.string().min(1),
|
||||||
|
base_model: zBaseModel,
|
||||||
|
});
|
||||||
|
/**
|
||||||
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
|
*/
|
||||||
|
export type ControlNetModelParam = z.infer<typeof zLoRAModel>;
|
||||||
|
/**
|
||||||
|
* Validates/type-guards a value as a model parameter
|
||||||
|
*/
|
||||||
|
export const isValidControlNetModel = (
|
||||||
|
val: unknown
|
||||||
|
): val is ControlNetModelParam => zControlNetModel.safeParse(val).success;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Zod schema for l2l strength parameter
|
* Zod schema for l2l strength parameter
|
||||||
|
@ -0,0 +1,30 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { zControlNetModel } from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { ControlNetModelField } from 'services/api/types';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ module: 'models' });
|
||||||
|
|
||||||
|
export const modelIdToControlNetModelParam = (
|
||||||
|
controlNetModelId: string
|
||||||
|
): ControlNetModelField | undefined => {
|
||||||
|
const [base_model, model_type, model_name] = controlNetModelId.split('/');
|
||||||
|
|
||||||
|
const result = zControlNetModel.safeParse({
|
||||||
|
base_model,
|
||||||
|
model_name,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!result.success) {
|
||||||
|
moduleLog.error(
|
||||||
|
{
|
||||||
|
controlNetModelId,
|
||||||
|
errors: result.error.format(),
|
||||||
|
},
|
||||||
|
'Failed to parse ControlNet model id'
|
||||||
|
);
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.data;
|
||||||
|
};
|
@ -1,9 +1,12 @@
|
|||||||
import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
|
import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ module: 'models' });
|
||||||
|
|
||||||
export const modelIdToLoRAModelParam = (
|
export const modelIdToLoRAModelParam = (
|
||||||
loraId: string
|
loraModelId: string
|
||||||
): LoRAModelParam | undefined => {
|
): LoRAModelParam | undefined => {
|
||||||
const [base_model, model_type, model_name] = loraId.split('/');
|
const [base_model, model_type, model_name] = loraModelId.split('/');
|
||||||
|
|
||||||
const result = zLoRAModel.safeParse({
|
const result = zLoRAModel.safeParse({
|
||||||
base_model,
|
base_model,
|
||||||
@ -11,6 +14,13 @@ export const modelIdToLoRAModelParam = (
|
|||||||
});
|
});
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
|
moduleLog.error(
|
||||||
|
{
|
||||||
|
loraModelId,
|
||||||
|
errors: result.error.format(),
|
||||||
|
},
|
||||||
|
'Failed to parse LoRA model id'
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,11 +2,14 @@ import {
|
|||||||
MainModelParam,
|
MainModelParam,
|
||||||
zMainModel,
|
zMainModel,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ module: 'models' });
|
||||||
|
|
||||||
export const modelIdToMainModelParam = (
|
export const modelIdToMainModelParam = (
|
||||||
modelId: string
|
mainModelId: string
|
||||||
): MainModelParam | undefined => {
|
): MainModelParam | undefined => {
|
||||||
const [base_model, model_type, model_name] = modelId.split('/');
|
const [base_model, model_type, model_name] = mainModelId.split('/');
|
||||||
|
|
||||||
const result = zMainModel.safeParse({
|
const result = zMainModel.safeParse({
|
||||||
base_model,
|
base_model,
|
||||||
@ -14,6 +17,13 @@ export const modelIdToMainModelParam = (
|
|||||||
});
|
});
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
|
moduleLog.error(
|
||||||
|
{
|
||||||
|
mainModelId,
|
||||||
|
errors: result.error.format(),
|
||||||
|
},
|
||||||
|
'Failed to parse main model id'
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
|
import { VaeModelParam, zVaeModel } from '../types/parameterSchemas';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ module: 'models' });
|
||||||
|
|
||||||
export const modelIdToVAEModelParam = (
|
export const modelIdToVAEModelParam = (
|
||||||
modelId: string
|
vaeModelId: string
|
||||||
): VaeModelParam | undefined => {
|
): VaeModelParam | undefined => {
|
||||||
const [base_model, model_type, model_name] = modelId.split('/');
|
const [base_model, model_type, model_name] = vaeModelId.split('/');
|
||||||
|
|
||||||
const result = zVaeModel.safeParse({
|
const result = zVaeModel.safeParse({
|
||||||
base_model,
|
base_model,
|
||||||
@ -11,6 +14,13 @@ export const modelIdToVAEModelParam = (
|
|||||||
});
|
});
|
||||||
|
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
|
moduleLog.error(
|
||||||
|
{
|
||||||
|
vaeModelId,
|
||||||
|
errors: result.error.format(),
|
||||||
|
},
|
||||||
|
'Failed to parse VAE model id'
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,9 +19,9 @@ const ImageToImageTabParameters = () => {
|
|||||||
<ParamNegativeConditioning />
|
<ParamNegativeConditioning />
|
||||||
<ProcessButtons />
|
<ProcessButtons />
|
||||||
<ImageToImageTabCoreParameters />
|
<ImageToImageTabCoreParameters />
|
||||||
|
<ParamControlNetCollapse />
|
||||||
<ParamLoraCollapse />
|
<ParamLoraCollapse />
|
||||||
<ParamDynamicPromptsCollapse />
|
<ParamDynamicPromptsCollapse />
|
||||||
<ParamControlNetCollapse />
|
|
||||||
<ParamVariationCollapse />
|
<ParamVariationCollapse />
|
||||||
<ParamNoiseCollapse />
|
<ParamNoiseCollapse />
|
||||||
<ParamSymmetryCollapse />
|
<ParamSymmetryCollapse />
|
||||||
|
@ -20,9 +20,9 @@ const TextToImageTabParameters = () => {
|
|||||||
<ParamNegativeConditioning />
|
<ParamNegativeConditioning />
|
||||||
<ProcessButtons />
|
<ProcessButtons />
|
||||||
<TextToImageTabCoreParameters />
|
<TextToImageTabCoreParameters />
|
||||||
|
<ParamControlNetCollapse />
|
||||||
<ParamLoraCollapse />
|
<ParamLoraCollapse />
|
||||||
<ParamDynamicPromptsCollapse />
|
<ParamDynamicPromptsCollapse />
|
||||||
<ParamControlNetCollapse />
|
|
||||||
<ParamVariationCollapse />
|
<ParamVariationCollapse />
|
||||||
<ParamNoiseCollapse />
|
<ParamNoiseCollapse />
|
||||||
<ParamSymmetryCollapse />
|
<ParamSymmetryCollapse />
|
||||||
|
@ -19,9 +19,9 @@ const UnifiedCanvasParameters = () => {
|
|||||||
<ParamNegativeConditioning />
|
<ParamNegativeConditioning />
|
||||||
<ProcessButtons />
|
<ProcessButtons />
|
||||||
<UnifiedCanvasCoreParameters />
|
<UnifiedCanvasCoreParameters />
|
||||||
|
<ParamControlNetCollapse />
|
||||||
<ParamLoraCollapse />
|
<ParamLoraCollapse />
|
||||||
<ParamDynamicPromptsCollapse />
|
<ParamDynamicPromptsCollapse />
|
||||||
<ParamControlNetCollapse />
|
|
||||||
<ParamVariationCollapse />
|
<ParamVariationCollapse />
|
||||||
<ParamSymmetryCollapse />
|
<ParamSymmetryCollapse />
|
||||||
<ParamSeamCorrectionCollapse />
|
<ParamSeamCorrectionCollapse />
|
||||||
|
@ -734,7 +734,7 @@ export type components = {
|
|||||||
* Control Model
|
* Control Model
|
||||||
* @description The ControlNet model to use
|
* @description The ControlNet model to use
|
||||||
*/
|
*/
|
||||||
control_model: string;
|
control_model: components["schemas"]["ControlNetModelField"];
|
||||||
/**
|
/**
|
||||||
* Control Weight
|
* Control Weight
|
||||||
* @description The weight given to the ControlNet
|
* @description The weight given to the ControlNet
|
||||||
@ -792,9 +792,8 @@ export type components = {
|
|||||||
* Control Model
|
* Control Model
|
||||||
* @description control model used
|
* @description control model used
|
||||||
* @default lllyasviel/sd-controlnet-canny
|
* @default lllyasviel/sd-controlnet-canny
|
||||||
* @enum {string}
|
|
||||||
*/
|
*/
|
||||||
control_model?: "lllyasviel/sd-controlnet-canny" | "lllyasviel/sd-controlnet-depth" | "lllyasviel/sd-controlnet-hed" | "lllyasviel/sd-controlnet-seg" | "lllyasviel/sd-controlnet-openpose" | "lllyasviel/sd-controlnet-scribble" | "lllyasviel/sd-controlnet-normal" | "lllyasviel/sd-controlnet-mlsd" | "lllyasviel/control_v11p_sd15_canny" | "lllyasviel/control_v11p_sd15_openpose" | "lllyasviel/control_v11p_sd15_seg" | "lllyasviel/control_v11f1p_sd15_depth" | "lllyasviel/control_v11p_sd15_normalbae" | "lllyasviel/control_v11p_sd15_scribble" | "lllyasviel/control_v11p_sd15_mlsd" | "lllyasviel/control_v11p_sd15_softedge" | "lllyasviel/control_v11p_sd15s2_lineart_anime" | "lllyasviel/control_v11p_sd15_lineart" | "lllyasviel/control_v11p_sd15_inpaint" | "lllyasviel/control_v11e_sd15_shuffle" | "lllyasviel/control_v11e_sd15_ip2p" | "lllyasviel/control_v11f1e_sd15_tile" | "thibaud/controlnet-sd21-openpose-diffusers" | "thibaud/controlnet-sd21-canny-diffusers" | "thibaud/controlnet-sd21-depth-diffusers" | "thibaud/controlnet-sd21-scribble-diffusers" | "thibaud/controlnet-sd21-hed-diffusers" | "thibaud/controlnet-sd21-zoedepth-diffusers" | "thibaud/controlnet-sd21-color-diffusers" | "thibaud/controlnet-sd21-openposev2-diffusers" | "thibaud/controlnet-sd21-lineart-diffusers" | "thibaud/controlnet-sd21-normalbae-diffusers" | "thibaud/controlnet-sd21-ade20k-diffusers" | "CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15" | "CrucibleAI/ControlNetMediaPipeFace";
|
control_model?: components["schemas"]["ControlNetModelField"];
|
||||||
/**
|
/**
|
||||||
* Control Weight
|
* Control Weight
|
||||||
* @description The weight given to the ControlNet
|
* @description The weight given to the ControlNet
|
||||||
@ -838,6 +837,19 @@ export type components = {
|
|||||||
model_format: components["schemas"]["ControlNetModelFormat"];
|
model_format: components["schemas"]["ControlNetModelFormat"];
|
||||||
error?: components["schemas"]["ModelError"];
|
error?: components["schemas"]["ModelError"];
|
||||||
};
|
};
|
||||||
|
/**
|
||||||
|
* ControlNetModelField
|
||||||
|
* @description ControlNet model field
|
||||||
|
*/
|
||||||
|
ControlNetModelField: {
|
||||||
|
/**
|
||||||
|
* Model Name
|
||||||
|
* @description Name of the ControlNet model
|
||||||
|
*/
|
||||||
|
model_name: string;
|
||||||
|
/** @description Base model */
|
||||||
|
base_model: components["schemas"]["BaseModelType"];
|
||||||
|
};
|
||||||
/**
|
/**
|
||||||
* ControlNetModelFormat
|
* ControlNetModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
@ -1923,12 +1935,12 @@ export type components = {
|
|||||||
* Width
|
* Width
|
||||||
* @description The width to resize to (px)
|
* @description The width to resize to (px)
|
||||||
*/
|
*/
|
||||||
width: number;
|
width?: number;
|
||||||
/**
|
/**
|
||||||
* Height
|
* Height
|
||||||
* @description The height to resize to (px)
|
* @description The height to resize to (px)
|
||||||
*/
|
*/
|
||||||
height: number;
|
height?: number;
|
||||||
/**
|
/**
|
||||||
* Resample Mode
|
* Resample Mode
|
||||||
* @description The resampling mode
|
* @description The resampling mode
|
||||||
@ -3911,13 +3923,15 @@ export type components = {
|
|||||||
/**
|
/**
|
||||||
* Width
|
* Width
|
||||||
* @description The width to resize to (px)
|
* @description The width to resize to (px)
|
||||||
|
* @default 512
|
||||||
*/
|
*/
|
||||||
width: number;
|
width?: number;
|
||||||
/**
|
/**
|
||||||
* Height
|
* Height
|
||||||
* @description The height to resize to (px)
|
* @description The height to resize to (px)
|
||||||
|
* @default 512
|
||||||
*/
|
*/
|
||||||
height: number;
|
height?: number;
|
||||||
/**
|
/**
|
||||||
* Mode
|
* Mode
|
||||||
* @description The interpolation mode
|
* @description The interpolation mode
|
||||||
@ -4605,18 +4619,18 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
image?: components["schemas"]["ImageField"];
|
image?: components["schemas"]["ImageField"];
|
||||||
};
|
};
|
||||||
/**
|
|
||||||
* StableDiffusion2ModelFormat
|
|
||||||
* @description An enumeration.
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
|
||||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
|
||||||
/**
|
/**
|
||||||
* StableDiffusion1ModelFormat
|
* StableDiffusion1ModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||||
|
/**
|
||||||
|
* StableDiffusion2ModelFormat
|
||||||
|
* @description An enumeration.
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||||
};
|
};
|
||||||
responses: never;
|
responses: never;
|
||||||
parameters: never;
|
parameters: never;
|
||||||
|
@ -32,6 +32,8 @@ export type BaseModelType = components['schemas']['BaseModelType'];
|
|||||||
export type MainModelField = components['schemas']['MainModelField'];
|
export type MainModelField = components['schemas']['MainModelField'];
|
||||||
export type VAEModelField = components['schemas']['VAEModelField'];
|
export type VAEModelField = components['schemas']['VAEModelField'];
|
||||||
export type LoRAModelField = components['schemas']['LoRAModelField'];
|
export type LoRAModelField = components['schemas']['LoRAModelField'];
|
||||||
|
export type ControlNetModelField =
|
||||||
|
components['schemas']['ControlNetModelField'];
|
||||||
export type ModelsList = components['schemas']['ModelsList'];
|
export type ModelsList = components['schemas']['ModelsList'];
|
||||||
export type ControlField = components['schemas']['ControlField'];
|
export type ControlField = components['schemas']['ControlField'];
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ const invokeAIThumb = defineStyle((props) => {
|
|||||||
|
|
||||||
const invokeAIMark = defineStyle((props) => {
|
const invokeAIMark = defineStyle((props) => {
|
||||||
return {
|
return {
|
||||||
fontSize: 'xs',
|
fontSize: '2xs',
|
||||||
fontWeight: '500',
|
fontWeight: '500',
|
||||||
color: mode('base.700', 'base.400')(props),
|
color: mode('base.700', 'base.400')(props),
|
||||||
mt: 2,
|
mt: 2,
|
||||||
|
Loading…
Reference in New Issue
Block a user