mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into depth_anything_v2
This commit is contained in:
commit
f170697ebe
@ -55,6 +55,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
|||||||
FROM node:20-slim AS web-builder
|
FROM node:20-slim AS web-builder
|
||||||
ENV PNPM_HOME="/pnpm"
|
ENV PNPM_HOME="/pnpm"
|
||||||
ENV PATH="$PNPM_HOME:$PATH"
|
ENV PATH="$PNPM_HOME:$PATH"
|
||||||
|
RUN corepack use pnpm@8.x
|
||||||
RUN corepack enable
|
RUN corepack enable
|
||||||
|
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
|
@ -37,9 +37,9 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
|||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager import BaseModelType
|
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ControlNetData,
|
ControlNetData,
|
||||||
@ -60,8 +60,12 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionB
|
|||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt
|
||||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||||
@ -498,6 +502,33 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_t2i_adapter_field(
|
||||||
|
exit_stack: ExitStack,
|
||||||
|
context: InvocationContext,
|
||||||
|
t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
||||||
|
ext_manager: ExtensionsManager,
|
||||||
|
) -> None:
|
||||||
|
if t2i_adapters is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField.
|
||||||
|
if isinstance(t2i_adapters, T2IAdapterField):
|
||||||
|
t2i_adapters = [t2i_adapters]
|
||||||
|
|
||||||
|
for t2i_adapter_field in t2i_adapters:
|
||||||
|
ext_manager.add_extension(
|
||||||
|
T2IAdapterExt(
|
||||||
|
node_context=context,
|
||||||
|
model_id=t2i_adapter_field.t2i_adapter_model,
|
||||||
|
image=context.images.get_pil(t2i_adapter_field.image.image_name),
|
||||||
|
weight=t2i_adapter_field.weight,
|
||||||
|
begin_step_percent=t2i_adapter_field.begin_step_percent,
|
||||||
|
end_step_percent=t2i_adapter_field.end_step_percent,
|
||||||
|
resize_mode=t2i_adapter_field.resize_mode,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def prep_ip_adapter_image_prompts(
|
def prep_ip_adapter_image_prompts(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
@ -707,7 +738,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
else:
|
else:
|
||||||
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
||||||
|
|
||||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
return mask, masked_latents, self.denoise_mask.gradient
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_noise_and_latents(
|
def prepare_noise_and_latents(
|
||||||
@ -765,10 +796,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
dtype = TorchDevice.choose_torch_dtype()
|
dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
|
||||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
latents = latents.to(device=device, dtype=dtype)
|
|
||||||
if noise is not None:
|
|
||||||
noise = noise.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
_, _, latent_height, latent_width = latents.shape
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
|
||||||
conditioning_data = self.get_conditioning_data(
|
conditioning_data = self.get_conditioning_data(
|
||||||
@ -801,21 +828,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
denoising_end=self.denoising_end,
|
denoising_end=self.denoising_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
denoise_ctx = DenoiseContext(
|
|
||||||
inputs=DenoiseInputs(
|
|
||||||
orig_latents=latents,
|
|
||||||
timesteps=timesteps,
|
|
||||||
init_timestep=init_timestep,
|
|
||||||
noise=noise,
|
|
||||||
seed=seed,
|
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
attention_processor_cls=CustomAttnProcessor2_0,
|
|
||||||
),
|
|
||||||
unet=None,
|
|
||||||
scheduler=scheduler,
|
|
||||||
)
|
|
||||||
|
|
||||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||||
unet_config = context.models.get_config(self.unet.unet.key)
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
|
|
||||||
@ -833,6 +845,40 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if self.unet.freeu_config:
|
if self.unet.freeu_config:
|
||||||
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
||||||
|
|
||||||
|
### seamless
|
||||||
|
if self.unet.seamless_axes:
|
||||||
|
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
|
||||||
|
|
||||||
|
### inpaint
|
||||||
|
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||||
|
# NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
|
||||||
|
# use the ModelVariantType config. During testing, there was a report of a user with models that had an
|
||||||
|
# incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be
|
||||||
|
# prevalent, we will have to revisit how we initialize the inpainting extensions.
|
||||||
|
if unet_config.variant == ModelVariantType.Inpaint:
|
||||||
|
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
|
||||||
|
elif mask is not None:
|
||||||
|
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
|
||||||
|
|
||||||
|
# Initialize context for modular denoise
|
||||||
|
latents = latents.to(device=device, dtype=dtype)
|
||||||
|
if noise is not None:
|
||||||
|
noise = noise.to(device=device, dtype=dtype)
|
||||||
|
denoise_ctx = DenoiseContext(
|
||||||
|
inputs=DenoiseInputs(
|
||||||
|
orig_latents=latents,
|
||||||
|
timesteps=timesteps,
|
||||||
|
init_timestep=init_timestep,
|
||||||
|
noise=noise,
|
||||||
|
seed=seed,
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
attention_processor_cls=CustomAttnProcessor2_0,
|
||||||
|
),
|
||||||
|
unet=None,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
# context for loading additional models
|
# context for loading additional models
|
||||||
with ExitStack() as exit_stack:
|
with ExitStack() as exit_stack:
|
||||||
# later should be smth like:
|
# later should be smth like:
|
||||||
@ -840,6 +886,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
|
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
|
||||||
# ext_manager.add_extension(ext)
|
# ext_manager.add_extension(ext)
|
||||||
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
|
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
|
||||||
|
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
|
||||||
|
|
||||||
# ext: t2i/ip adapter
|
# ext: t2i/ip adapter
|
||||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||||
@ -871,6 +918,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
|
|
||||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||||
|
# At this point, the mask ranges from 0 (leave unchanged) to 1 (inpaint).
|
||||||
|
# We invert the mask here for compatibility with the old backend implementation.
|
||||||
|
if mask is not None:
|
||||||
|
mask = 1 - mask
|
||||||
|
|
||||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||||
# below. Investigate whether this is appropriate.
|
# below. Investigate whether this is appropriate.
|
||||||
@ -915,7 +966,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
unet_info.model_on_device() as (model_state_dict, unet),
|
unet_info.model_on_device() as (model_state_dict, unet),
|
||||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||||
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora_unet(
|
ModelPatcher.apply_lora_unet(
|
||||||
unet,
|
unet,
|
||||||
|
@ -24,7 +24,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
from invokeai.app.invocations.model import VAEField
|
from invokeai.app.invocations.model import VAEField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.stable_diffusion import set_seamless
|
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
@ -59,7 +59,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
vae_info = context.models.load(self.vae.vae)
|
vae_info = context.models.load(self.vae.vae)
|
||||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||||
latents = latents.to(vae.device)
|
latents = latents.to(vae.device)
|
||||||
if self.fp32:
|
if self.fp32:
|
||||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -98,6 +98,9 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
|||||||
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
||||||
ModelVariantType.Inpaint: StableDiffusionXLInpaintPipeline,
|
ModelVariantType.Inpaint: StableDiffusionXLInpaintPipeline,
|
||||||
},
|
},
|
||||||
|
BaseModelType.StableDiffusionXLRefiner: {
|
||||||
|
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
assert isinstance(config, MainCheckpointConfig)
|
assert isinstance(config, MainCheckpointConfig)
|
||||||
try:
|
try:
|
||||||
|
@ -187,164 +187,171 @@ STARTER_MODELS: list[StarterModel] = [
|
|||||||
# endregion
|
# endregion
|
||||||
# region ControlNet
|
# region ControlNet
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="QRCode Monster",
|
name="QRCode Monster v2 (SD1.5)",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="monster-labs/control_v1p_sd15_qrcode_monster",
|
source="monster-labs/control_v1p_sd15_qrcode_monster::v2",
|
||||||
description="Controlnet model that generates scannable creative QR codes",
|
description="ControlNet model that generates scannable creative QR codes",
|
||||||
|
type=ModelType.ControlNet,
|
||||||
|
),
|
||||||
|
StarterModel(
|
||||||
|
name="QRCode Monster (SDXL)",
|
||||||
|
base=BaseModelType.StableDiffusionXL,
|
||||||
|
source="monster-labs/control_v1p_sdxl_qrcode_monster",
|
||||||
|
description="ControlNet model that generates scannable creative QR codes",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="canny",
|
name="canny",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_canny",
|
source="lllyasviel/control_v11p_sd15_canny",
|
||||||
description="Controlnet weights trained on sd-1.5 with canny conditioning.",
|
description="ControlNet weights trained on sd-1.5 with canny conditioning.",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="inpaint",
|
name="inpaint",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_inpaint",
|
source="lllyasviel/control_v11p_sd15_inpaint",
|
||||||
description="Controlnet weights trained on sd-1.5 with canny conditioning, inpaint version",
|
description="ControlNet weights trained on sd-1.5 with canny conditioning, inpaint version",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="mlsd",
|
name="mlsd",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_mlsd",
|
source="lllyasviel/control_v11p_sd15_mlsd",
|
||||||
description="Controlnet weights trained on sd-1.5 with canny conditioning, MLSD version",
|
description="ControlNet weights trained on sd-1.5 with canny conditioning, MLSD version",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="depth",
|
name="depth",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11f1p_sd15_depth",
|
source="lllyasviel/control_v11f1p_sd15_depth",
|
||||||
description="Controlnet weights trained on sd-1.5 with depth conditioning",
|
description="ControlNet weights trained on sd-1.5 with depth conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="normal_bae",
|
name="normal_bae",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_normalbae",
|
source="lllyasviel/control_v11p_sd15_normalbae",
|
||||||
description="Controlnet weights trained on sd-1.5 with normalbae image conditioning",
|
description="ControlNet weights trained on sd-1.5 with normalbae image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="seg",
|
name="seg",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_seg",
|
source="lllyasviel/control_v11p_sd15_seg",
|
||||||
description="Controlnet weights trained on sd-1.5 with seg image conditioning",
|
description="ControlNet weights trained on sd-1.5 with seg image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="lineart",
|
name="lineart",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_lineart",
|
source="lllyasviel/control_v11p_sd15_lineart",
|
||||||
description="Controlnet weights trained on sd-1.5 with lineart image conditioning",
|
description="ControlNet weights trained on sd-1.5 with lineart image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="lineart_anime",
|
name="lineart_anime",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15s2_lineart_anime",
|
source="lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||||
description="Controlnet weights trained on sd-1.5 with anime image conditioning",
|
description="ControlNet weights trained on sd-1.5 with anime image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="openpose",
|
name="openpose",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_openpose",
|
source="lllyasviel/control_v11p_sd15_openpose",
|
||||||
description="Controlnet weights trained on sd-1.5 with openpose image conditioning",
|
description="ControlNet weights trained on sd-1.5 with openpose image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="scribble",
|
name="scribble",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_scribble",
|
source="lllyasviel/control_v11p_sd15_scribble",
|
||||||
description="Controlnet weights trained on sd-1.5 with scribble image conditioning",
|
description="ControlNet weights trained on sd-1.5 with scribble image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="softedge",
|
name="softedge",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_softedge",
|
source="lllyasviel/control_v11p_sd15_softedge",
|
||||||
description="Controlnet weights trained on sd-1.5 with soft edge conditioning",
|
description="ControlNet weights trained on sd-1.5 with soft edge conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="shuffle",
|
name="shuffle",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11e_sd15_shuffle",
|
source="lllyasviel/control_v11e_sd15_shuffle",
|
||||||
description="Controlnet weights trained on sd-1.5 with shuffle image conditioning",
|
description="ControlNet weights trained on sd-1.5 with shuffle image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="tile",
|
name="tile",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11f1e_sd15_tile",
|
source="lllyasviel/control_v11f1e_sd15_tile",
|
||||||
description="Controlnet weights trained on sd-1.5 with tiled image conditioning",
|
description="ControlNet weights trained on sd-1.5 with tiled image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="ip2p",
|
name="ip2p",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11e_sd15_ip2p",
|
source="lllyasviel/control_v11e_sd15_ip2p",
|
||||||
description="Controlnet weights trained on sd-1.5 with ip2p conditioning.",
|
description="ControlNet weights trained on sd-1.5 with ip2p conditioning.",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="canny-sdxl",
|
name="canny-sdxl",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
source="xinsir/controlnet-canny-sdxl-1.0",
|
source="xinsir/controlNet-canny-sdxl-1.0",
|
||||||
description="Controlnet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
|
description="ControlNet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="depth-sdxl",
|
name="depth-sdxl",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
source="diffusers/controlnet-depth-sdxl-1.0",
|
source="diffusers/controlNet-depth-sdxl-1.0",
|
||||||
description="Controlnet weights trained on sdxl-1.0 with depth conditioning.",
|
description="ControlNet weights trained on sdxl-1.0 with depth conditioning.",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="softedge-dexined-sdxl",
|
name="softedge-dexined-sdxl",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
source="SargeZT/controlnet-sd-xl-1.0-softedge-dexined",
|
source="SargeZT/controlNet-sd-xl-1.0-softedge-dexined",
|
||||||
description="Controlnet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
|
description="ControlNet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="depth-16bit-zoe-sdxl",
|
name="depth-16bit-zoe-sdxl",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
source="SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe",
|
source="SargeZT/controlNet-sd-xl-1.0-depth-16bit-zoe",
|
||||||
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
|
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="depth-zoe-sdxl",
|
name="depth-zoe-sdxl",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
source="diffusers/controlnet-zoe-depth-sdxl-1.0",
|
source="diffusers/controlNet-zoe-depth-sdxl-1.0",
|
||||||
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="openpose-sdxl",
|
name="openpose-sdxl",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
source="xinsir/controlnet-openpose-sdxl-1.0",
|
source="xinsir/controlNet-openpose-sdxl-1.0",
|
||||||
description="Controlnet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
|
description="ControlNet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="scribble-sdxl",
|
name="scribble-sdxl",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
source="xinsir/controlnet-scribble-sdxl-1.0",
|
source="xinsir/controlNet-scribble-sdxl-1.0",
|
||||||
description="Controlnet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
description="ControlNet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="tile-sdxl",
|
name="tile-sdxl",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
source="xinsir/controlnet-tile-sdxl-1.0",
|
source="xinsir/controlNet-tile-sdxl-1.0",
|
||||||
description="Controlnet weights trained on sdxl-1.0 with tiled image conditioning",
|
description="ControlNet weights trained on sdxl-1.0 with tiled image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
# endregion
|
# endregion
|
||||||
|
@ -7,11 +7,9 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import ( # noqa: F401
|
|||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
)
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401
|
from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||||
from invokeai.backend.stable_diffusion.seamless import set_seamless # noqa: F401
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PipelineIntermediateState",
|
"PipelineIntermediateState",
|
||||||
"StableDiffusionGeneratorPipeline",
|
"StableDiffusionGeneratorPipeline",
|
||||||
"InvokeAIDiffuserComponent",
|
"InvokeAIDiffuserComponent",
|
||||||
"set_seamless",
|
|
||||||
]
|
]
|
||||||
|
120
invokeai/backend/stable_diffusion/extensions/inpaint.py
Normal file
120
invokeai/backend/stable_diffusion/extensions/inpaint.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
|
||||||
|
|
||||||
|
class InpaintExt(ExtensionBase):
|
||||||
|
"""An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
is_gradient_mask: bool,
|
||||||
|
):
|
||||||
|
"""Initialize InpaintExt.
|
||||||
|
Args:
|
||||||
|
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
||||||
|
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
|
||||||
|
inpainted.
|
||||||
|
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
|
||||||
|
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
|
||||||
|
1.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._mask = mask
|
||||||
|
self._is_gradient_mask = is_gradient_mask
|
||||||
|
|
||||||
|
# Noise, which used to noisify unmasked part of image
|
||||||
|
# if noise provided to context, then it will be used
|
||||||
|
# if no noise provided, then noise will be generated based on seed
|
||||||
|
self._noise: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_normal_model(unet: UNet2DConditionModel):
|
||||||
|
"""Checks if the provided UNet belongs to a regular model.
|
||||||
|
The `in_channels` of a UNet vary depending on model type:
|
||||||
|
- normal - 4
|
||||||
|
- depth - 5
|
||||||
|
- inpaint - 9
|
||||||
|
"""
|
||||||
|
return unet.conv_in.in_channels == 4
|
||||||
|
|
||||||
|
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size = latents.size(0)
|
||||||
|
mask = einops.repeat(self._mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
|
if t.dim() == 0:
|
||||||
|
# some schedulers expect t to be one-dimensional.
|
||||||
|
# TODO: file diffusers bug about inconsistency?
|
||||||
|
t = einops.repeat(t, "-> batch", batch=batch_size)
|
||||||
|
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
||||||
|
# get very confused about what is happening from step to step when we do that.
|
||||||
|
mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self._noise, t)
|
||||||
|
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||||
|
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||||
|
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||||
|
if self._is_gradient_mask:
|
||||||
|
threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
|
||||||
|
mask_bool = mask < 1 - threshold
|
||||||
|
masked_input = torch.where(mask_bool, latents, mask_latents)
|
||||||
|
else:
|
||||||
|
masked_input = torch.lerp(latents, mask_latents.to(dtype=latents.dtype), mask.to(dtype=latents.dtype))
|
||||||
|
return masked_input
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
|
def init_tensors(self, ctx: DenoiseContext):
|
||||||
|
if not self._is_normal_model(ctx.unet):
|
||||||
|
raise ValueError(
|
||||||
|
"InpaintExt should be used only on normal (non-inpainting) models. This could be caused by an "
|
||||||
|
"inpainting model that was incorrectly marked as a non-inpainting model. In some cases, this can be "
|
||||||
|
"fixed by removing and re-adding the model (so that it gets re-probed)."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
|
self._noise = ctx.inputs.noise
|
||||||
|
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||||
|
# We still need noise for inpainting, so we generate it from the seed here.
|
||||||
|
if self._noise is None:
|
||||||
|
self._noise = torch.randn(
|
||||||
|
ctx.latents.shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cpu",
|
||||||
|
generator=torch.Generator(device="cpu").manual_seed(ctx.seed),
|
||||||
|
).to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
|
# Use negative order to make extensions with default order work with patched latents
|
||||||
|
@callback(ExtensionCallbackType.PRE_STEP, order=-100)
|
||||||
|
def apply_mask_to_initial_latents(self, ctx: DenoiseContext):
|
||||||
|
ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep)
|
||||||
|
|
||||||
|
# TODO: redo this with preview events rewrite
|
||||||
|
# Use negative order to make extensions with default order work with patched latents
|
||||||
|
@callback(ExtensionCallbackType.POST_STEP, order=-100)
|
||||||
|
def apply_mask_to_step_output(self, ctx: DenoiseContext):
|
||||||
|
timestep = ctx.scheduler.timesteps[-1]
|
||||||
|
if hasattr(ctx.step_output, "denoised"):
|
||||||
|
ctx.step_output.denoised = self._apply_mask(ctx, ctx.step_output.denoised, timestep)
|
||||||
|
elif hasattr(ctx.step_output, "pred_original_sample"):
|
||||||
|
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.pred_original_sample, timestep)
|
||||||
|
else:
|
||||||
|
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep)
|
||||||
|
|
||||||
|
# Restore unmasked part after the last step is completed
|
||||||
|
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
||||||
|
def restore_unmasked(self, ctx: DenoiseContext):
|
||||||
|
if self._is_gradient_mask:
|
||||||
|
ctx.latents = torch.where(self._mask < 1, ctx.latents, ctx.inputs.orig_latents)
|
||||||
|
else:
|
||||||
|
ctx.latents = torch.lerp(ctx.latents, ctx.inputs.orig_latents, self._mask)
|
@ -0,0 +1,88 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
|
||||||
|
|
||||||
|
class InpaintModelExt(ExtensionBase):
|
||||||
|
"""An extension for inpainting with inpainting models. See `InpaintExt` for inpainting with non-inpainting
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mask: Optional[torch.Tensor],
|
||||||
|
masked_latents: Optional[torch.Tensor],
|
||||||
|
is_gradient_mask: bool,
|
||||||
|
):
|
||||||
|
"""Initialize InpaintModelExt.
|
||||||
|
Args:
|
||||||
|
mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
||||||
|
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
|
||||||
|
inpainted.
|
||||||
|
masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area.
|
||||||
|
If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width)
|
||||||
|
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
|
||||||
|
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
|
||||||
|
1.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if mask is not None and masked_latents is None:
|
||||||
|
raise ValueError("Source image required for inpaint mask when inpaint model used!")
|
||||||
|
|
||||||
|
# Inverse mask, because inpaint models treat mask as: 0 - remain same, 1 - inpaint
|
||||||
|
self._mask = None
|
||||||
|
if mask is not None:
|
||||||
|
self._mask = 1 - mask
|
||||||
|
self._masked_latents = masked_latents
|
||||||
|
self._is_gradient_mask = is_gradient_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_inpaint_model(unet: UNet2DConditionModel):
|
||||||
|
"""Checks if the provided UNet belongs to a regular model.
|
||||||
|
The `in_channels` of a UNet vary depending on model type:
|
||||||
|
- normal - 4
|
||||||
|
- depth - 5
|
||||||
|
- inpaint - 9
|
||||||
|
"""
|
||||||
|
return unet.conv_in.in_channels == 9
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
|
def init_tensors(self, ctx: DenoiseContext):
|
||||||
|
if not self._is_inpaint_model(ctx.unet):
|
||||||
|
raise ValueError("InpaintModelExt should be used only on inpaint models!")
|
||||||
|
|
||||||
|
if self._mask is None:
|
||||||
|
self._mask = torch.ones_like(ctx.latents[:1, :1])
|
||||||
|
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
|
if self._masked_latents is None:
|
||||||
|
self._masked_latents = torch.zeros_like(ctx.latents[:1])
|
||||||
|
self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
||||||
|
|
||||||
|
# Do last so that other extensions works with normal latents
|
||||||
|
@callback(ExtensionCallbackType.PRE_UNET, order=1000)
|
||||||
|
def append_inpaint_layers(self, ctx: DenoiseContext):
|
||||||
|
batch_size = ctx.unet_kwargs.sample.shape[0]
|
||||||
|
b_mask = torch.cat([self._mask] * batch_size)
|
||||||
|
b_masked_latents = torch.cat([self._masked_latents] * batch_size)
|
||||||
|
ctx.unet_kwargs.sample = torch.cat(
|
||||||
|
[ctx.unet_kwargs.sample, b_mask, b_masked_latents],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Restore unmasked part as inpaint model can change unmasked part slightly
|
||||||
|
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
||||||
|
def restore_unmasked(self, ctx: DenoiseContext):
|
||||||
|
if self._is_gradient_mask:
|
||||||
|
ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents)
|
||||||
|
else:
|
||||||
|
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask)
|
71
invokeai/backend/stable_diffusion/extensions/seamless.py
Normal file
71
invokeai/backend/stable_diffusion/extensions/seamless.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
from diffusers.models.lora import LoRACompatibleConv
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||||
|
|
||||||
|
|
||||||
|
class SeamlessExt(ExtensionBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
seamless_axes: List[str],
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._seamless_axes = seamless_axes
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||||
|
with self.static_patch_model(
|
||||||
|
model=unet,
|
||||||
|
seamless_axes=self._seamless_axes,
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@contextmanager
|
||||||
|
def static_patch_model(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
seamless_axes: List[str],
|
||||||
|
):
|
||||||
|
if not seamless_axes:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
x_mode = "circular" if "x" in seamless_axes else "constant"
|
||||||
|
y_mode = "circular" if "y" in seamless_axes else "constant"
|
||||||
|
|
||||||
|
# override conv_forward
|
||||||
|
# https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
|
||||||
|
def _conv_forward_asymmetric(
|
||||||
|
self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||||
|
):
|
||||||
|
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
|
||||||
|
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
|
||||||
|
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
|
||||||
|
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
|
||||||
|
return torch.nn.functional.conv2d(
|
||||||
|
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
|
||||||
|
)
|
||||||
|
|
||||||
|
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
|
||||||
|
try:
|
||||||
|
for layer in model.modules():
|
||||||
|
if not isinstance(layer, torch.nn.Conv2d):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
|
||||||
|
layer.lora_layer = lambda *x: 0
|
||||||
|
original_layers.append((layer, layer._conv_forward))
|
||||||
|
layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
for layer, orig_conv_forward in original_layers:
|
||||||
|
layer._conv_forward = orig_conv_forward
|
120
invokeai/backend/stable_diffusion/extensions/t2i_adapter.py
Normal file
120
invokeai/backend/stable_diffusion/extensions/t2i_adapter.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import T2IAdapter
|
||||||
|
from PIL.Image import Image
|
||||||
|
|
||||||
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
|
from invokeai.backend.model_manager import BaseModelType
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
|
||||||
|
|
||||||
|
class T2IAdapterExt(ExtensionBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
node_context: InvocationContext,
|
||||||
|
model_id: ModelIdentifierField,
|
||||||
|
image: Image,
|
||||||
|
weight: Union[float, List[float]],
|
||||||
|
begin_step_percent: float,
|
||||||
|
end_step_percent: float,
|
||||||
|
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._node_context = node_context
|
||||||
|
self._model_id = model_id
|
||||||
|
self._image = image
|
||||||
|
self._weight = weight
|
||||||
|
self._resize_mode = resize_mode
|
||||||
|
self._begin_step_percent = begin_step_percent
|
||||||
|
self._end_step_percent = end_step_percent
|
||||||
|
|
||||||
|
self._adapter_state: Optional[List[torch.Tensor]] = None
|
||||||
|
|
||||||
|
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||||
|
model_config = self._node_context.models.get_config(self._model_id.key)
|
||||||
|
if model_config.base == BaseModelType.StableDiffusion1:
|
||||||
|
self._max_unet_downscale = 8
|
||||||
|
elif model_config.base == BaseModelType.StableDiffusionXL:
|
||||||
|
self._max_unet_downscale = 4
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected T2I-Adapter base model type: '{model_config.base}'.")
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.SETUP)
|
||||||
|
def setup(self, ctx: DenoiseContext):
|
||||||
|
t2i_model: T2IAdapter
|
||||||
|
with self._node_context.models.load(self._model_id) as t2i_model:
|
||||||
|
_, _, latents_height, latents_width = ctx.inputs.orig_latents.shape
|
||||||
|
|
||||||
|
self._adapter_state = self._run_model(
|
||||||
|
model=t2i_model,
|
||||||
|
image=self._image,
|
||||||
|
latents_height=latents_height,
|
||||||
|
latents_width=latents_width,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_model(
|
||||||
|
self,
|
||||||
|
model: T2IAdapter,
|
||||||
|
image: Image,
|
||||||
|
latents_height: int,
|
||||||
|
latents_width: int,
|
||||||
|
):
|
||||||
|
# Resize the T2I-Adapter input image.
|
||||||
|
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
||||||
|
# result will match the latent image's dimensions after max_unet_downscale is applied.
|
||||||
|
input_height = latents_height // self._max_unet_downscale * model.total_downscale_factor
|
||||||
|
input_width = latents_width // self._max_unet_downscale * model.total_downscale_factor
|
||||||
|
|
||||||
|
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
|
||||||
|
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
|
||||||
|
# T2I-Adapter model.
|
||||||
|
#
|
||||||
|
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
|
||||||
|
# of the same requirements (e.g. preserving binary masks during resize).
|
||||||
|
t2i_image = prepare_control_image(
|
||||||
|
image=image,
|
||||||
|
do_classifier_free_guidance=False,
|
||||||
|
width=input_width,
|
||||||
|
height=input_height,
|
||||||
|
num_channels=model.config["in_channels"],
|
||||||
|
device=model.device,
|
||||||
|
dtype=model.dtype,
|
||||||
|
resize_mode=self._resize_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model(t2i_image)
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_UNET)
|
||||||
|
def pre_unet_step(self, ctx: DenoiseContext):
|
||||||
|
# skip if model not active in current step
|
||||||
|
total_steps = len(ctx.inputs.timesteps)
|
||||||
|
first_step = math.floor(self._begin_step_percent * total_steps)
|
||||||
|
last_step = math.ceil(self._end_step_percent * total_steps)
|
||||||
|
if ctx.step_index < first_step or ctx.step_index > last_step:
|
||||||
|
return
|
||||||
|
|
||||||
|
weight = self._weight
|
||||||
|
if isinstance(weight, list):
|
||||||
|
weight = weight[ctx.step_index]
|
||||||
|
|
||||||
|
adapter_state = self._adapter_state
|
||||||
|
if ctx.conditioning_mode == ConditioningMode.Both:
|
||||||
|
adapter_state = [torch.cat([v] * 2) for v in adapter_state]
|
||||||
|
|
||||||
|
if ctx.unet_kwargs.down_intrablock_additional_residuals is None:
|
||||||
|
ctx.unet_kwargs.down_intrablock_additional_residuals = [v * weight for v in adapter_state]
|
||||||
|
else:
|
||||||
|
for i, value in enumerate(adapter_state):
|
||||||
|
ctx.unet_kwargs.down_intrablock_additional_residuals[i] += value * weight
|
@ -1,51 +0,0 @@
|
|||||||
from contextlib import contextmanager
|
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
|
||||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
|
||||||
from diffusers.models.lora import LoRACompatibleConv
|
|
||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
|
|
||||||
if not seamless_axes:
|
|
||||||
yield
|
|
||||||
return
|
|
||||||
|
|
||||||
# override conv_forward
|
|
||||||
# https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
|
|
||||||
def _conv_forward_asymmetric(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
|
|
||||||
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
|
|
||||||
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
|
|
||||||
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
|
|
||||||
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
|
|
||||||
return torch.nn.functional.conv2d(
|
|
||||||
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
|
|
||||||
)
|
|
||||||
|
|
||||||
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
x_mode = "circular" if "x" in seamless_axes else "constant"
|
|
||||||
y_mode = "circular" if "y" in seamless_axes else "constant"
|
|
||||||
|
|
||||||
conv_layers: List[torch.nn.Conv2d] = []
|
|
||||||
|
|
||||||
for module in model.modules():
|
|
||||||
if isinstance(module, torch.nn.Conv2d):
|
|
||||||
conv_layers.append(module)
|
|
||||||
|
|
||||||
for layer in conv_layers:
|
|
||||||
if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
|
|
||||||
layer.lora_layer = lambda *x: 0
|
|
||||||
original_layers.append((layer, layer._conv_forward))
|
|
||||||
layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
finally:
|
|
||||||
for layer, orig_conv_forward in original_layers:
|
|
||||||
layer._conv_forward = orig_conv_forward
|
|
@ -31,7 +31,8 @@
|
|||||||
"deleteBoard": "Delete Board",
|
"deleteBoard": "Delete Board",
|
||||||
"deleteBoardAndImages": "Delete Board and Images",
|
"deleteBoardAndImages": "Delete Board and Images",
|
||||||
"deleteBoardOnly": "Delete Board Only",
|
"deleteBoardOnly": "Delete Board Only",
|
||||||
"deletedBoardsCannotbeRestored": "Deleted boards cannot be restored",
|
"deletedBoardsCannotbeRestored": "Deleted boards cannot be restored. Selecting 'Delete Board Only' will move images to an uncategorized state.",
|
||||||
|
"deletedPrivateBoardsCannotbeRestored": "Deleted boards cannot be restored. Selecting 'Delete Board Only' will move images to a private uncategorized state for the image's creator.",
|
||||||
"hideBoards": "Hide Boards",
|
"hideBoards": "Hide Boards",
|
||||||
"loading": "Loading...",
|
"loading": "Loading...",
|
||||||
"menuItemAutoAdd": "Auto-add to this Board",
|
"menuItemAutoAdd": "Auto-add to this Board",
|
||||||
@ -105,6 +106,7 @@
|
|||||||
"negativePrompt": "Negative Prompt",
|
"negativePrompt": "Negative Prompt",
|
||||||
"discordLabel": "Discord",
|
"discordLabel": "Discord",
|
||||||
"dontAskMeAgain": "Don't ask me again",
|
"dontAskMeAgain": "Don't ask me again",
|
||||||
|
"dontShowMeThese": "Don't show me these",
|
||||||
"editor": "Editor",
|
"editor": "Editor",
|
||||||
"error": "Error",
|
"error": "Error",
|
||||||
"file": "File",
|
"file": "File",
|
||||||
@ -1099,6 +1101,8 @@
|
|||||||
"displayInProgress": "Display Progress Images",
|
"displayInProgress": "Display Progress Images",
|
||||||
"enableImageDebugging": "Enable Image Debugging",
|
"enableImageDebugging": "Enable Image Debugging",
|
||||||
"enableInformationalPopovers": "Enable Informational Popovers",
|
"enableInformationalPopovers": "Enable Informational Popovers",
|
||||||
|
"informationalPopoversDisabled": "Informational Popovers Disabled",
|
||||||
|
"informationalPopoversDisabledDesc": "Informational popovers have been disabled. Enable them in Settings.",
|
||||||
"enableInvisibleWatermark": "Enable Invisible Watermark",
|
"enableInvisibleWatermark": "Enable Invisible Watermark",
|
||||||
"enableNSFWChecker": "Enable NSFW Checker",
|
"enableNSFWChecker": "Enable NSFW Checker",
|
||||||
"general": "General",
|
"general": "General",
|
||||||
@ -1506,6 +1510,30 @@
|
|||||||
"seamlessTilingYAxis": {
|
"seamlessTilingYAxis": {
|
||||||
"heading": "Seamless Tiling Y Axis",
|
"heading": "Seamless Tiling Y Axis",
|
||||||
"paragraphs": ["Seamlessly tile an image along the vertical axis."]
|
"paragraphs": ["Seamlessly tile an image along the vertical axis."]
|
||||||
|
},
|
||||||
|
"upscaleModel": {
|
||||||
|
"heading": "Upscale Model",
|
||||||
|
"paragraphs": [
|
||||||
|
"The upscale model scales the image to the output size before details are added. Any supported upscale model may be used, but some are specialized for different kinds of images, like photos or line drawings."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"scale": {
|
||||||
|
"heading": "Scale",
|
||||||
|
"paragraphs": [
|
||||||
|
"Scale controls the output image size, and is based on a multiple of the input image resolution. For example a 2x upscale on a 1024x1024 image would produce a 2048 x 2048 output."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"creativity": {
|
||||||
|
"heading": "Creativity",
|
||||||
|
"paragraphs": [
|
||||||
|
"Creativity controls the amount of freedom granted to the model when adding details. Low creativity stays close to the original image, while high creativity allows for more change. When using a prompt, high creativity increases the influence of the prompt."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"structure": {
|
||||||
|
"heading": "Structure",
|
||||||
|
"paragraphs": [
|
||||||
|
"Structure controls how closely the output image will keep to the layout of the original. Low structure allows major changes, while high structure strictly maintains the original composition and layout."
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"unifiedCanvas": {
|
"unifiedCanvas": {
|
||||||
|
@ -10,32 +10,32 @@ import {
|
|||||||
import { boardsApi } from 'services/api/endpoints/boards';
|
import { boardsApi } from 'services/api/endpoints/boards';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
|
||||||
|
// Type inference doesn't work for this if you inline it in the listener for some reason
|
||||||
|
const matchAnyBoardDeleted = isAnyOf(
|
||||||
|
imagesApi.endpoints.deleteBoard.matchFulfilled,
|
||||||
|
imagesApi.endpoints.deleteBoardAndImages.matchFulfilled
|
||||||
|
);
|
||||||
|
|
||||||
export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartListening) => {
|
export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartListening) => {
|
||||||
/**
|
/**
|
||||||
* The auto-add board shouldn't be set to an archived board or deleted board. When we archive a board, delete
|
* The auto-add board shouldn't be set to an archived board or deleted board. When we archive a board, delete
|
||||||
* a board, or change a the archived board visibility flag, we may need to reset the auto-add board.
|
* a board, or change a the archived board visibility flag, we may need to reset the auto-add board.
|
||||||
*/
|
*/
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: isAnyOf(
|
matcher: matchAnyBoardDeleted,
|
||||||
// If a board is deleted, we'll need to reset the auto-add board
|
|
||||||
imagesApi.endpoints.deleteBoard.matchFulfilled,
|
|
||||||
imagesApi.endpoints.deleteBoardAndImages.matchFulfilled
|
|
||||||
),
|
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const queryArgs = selectListBoardsQueryArgs(state);
|
const deletedBoardId = action.meta.arg.originalArgs;
|
||||||
const queryResult = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
|
|
||||||
const { autoAddBoardId, selectedBoardId } = state.gallery;
|
const { autoAddBoardId, selectedBoardId } = state.gallery;
|
||||||
|
|
||||||
if (!queryResult.data) {
|
// If the deleted board was currently selected, we should reset the selected board to uncategorized
|
||||||
return;
|
if (deletedBoardId === selectedBoardId) {
|
||||||
}
|
|
||||||
|
|
||||||
if (!queryResult.data.find((board) => board.board_id === selectedBoardId)) {
|
|
||||||
dispatch(boardIdSelected({ boardId: 'none' }));
|
dispatch(boardIdSelected({ boardId: 'none' }));
|
||||||
dispatch(galleryViewChanged('images'));
|
dispatch(galleryViewChanged('images'));
|
||||||
}
|
}
|
||||||
if (!queryResult.data.find((board) => board.board_id === autoAddBoardId)) {
|
|
||||||
|
// If the deleted board was selected for auto-add, we should reset the auto-add board to uncategorized
|
||||||
|
if (deletedBoardId === autoAddBoardId) {
|
||||||
dispatch(autoAddBoardIdChanged('none'));
|
dispatch(autoAddBoardIdChanged('none'));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -46,14 +46,8 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
|||||||
matcher: boardsApi.endpoints.updateBoard.matchFulfilled,
|
matcher: boardsApi.endpoints.updateBoard.matchFulfilled,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const queryArgs = selectListBoardsQueryArgs(state);
|
|
||||||
const queryResult = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
|
|
||||||
const { shouldShowArchivedBoards } = state.gallery;
|
const { shouldShowArchivedBoards } = state.gallery;
|
||||||
|
|
||||||
if (!queryResult.data) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const wasArchived = action.meta.arg.originalArgs.changes.archived === true;
|
const wasArchived = action.meta.arg.originalArgs.changes.archived === true;
|
||||||
|
|
||||||
if (wasArchived && !shouldShowArchivedBoards) {
|
if (wasArchived && !shouldShowArchivedBoards) {
|
||||||
@ -71,7 +65,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
|||||||
const shouldShowArchivedBoards = action.payload;
|
const shouldShowArchivedBoards = action.payload;
|
||||||
|
|
||||||
// We only need to take action if we have just hidden archived boards.
|
// We only need to take action if we have just hidden archived boards.
|
||||||
if (!shouldShowArchivedBoards) {
|
if (shouldShowArchivedBoards) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,14 +80,16 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
|||||||
|
|
||||||
// Handle the case where selected board is archived
|
// Handle the case where selected board is archived
|
||||||
const selectedBoard = queryResult.data.find((b) => b.board_id === selectedBoardId);
|
const selectedBoard = queryResult.data.find((b) => b.board_id === selectedBoardId);
|
||||||
if (selectedBoard && selectedBoard.archived) {
|
if (!selectedBoard || selectedBoard.archived) {
|
||||||
|
// If we can't find the selected board or it's archived, we should reset the selected board to uncategorized
|
||||||
dispatch(boardIdSelected({ boardId: 'none' }));
|
dispatch(boardIdSelected({ boardId: 'none' }));
|
||||||
dispatch(galleryViewChanged('images'));
|
dispatch(galleryViewChanged('images'));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle the case where auto-add board is archived
|
// Handle the case where auto-add board is archived
|
||||||
const autoAddBoard = queryResult.data.find((b) => b.board_id === autoAddBoardId);
|
const autoAddBoard = queryResult.data.find((b) => b.board_id === autoAddBoardId);
|
||||||
if (autoAddBoard && autoAddBoard.archived) {
|
if (!autoAddBoard || autoAddBoard.archived) {
|
||||||
|
// If we can't find the auto-add board or it's archived, we should reset the selected board to uncategorized
|
||||||
dispatch(autoAddBoardIdChanged('none'));
|
dispatch(autoAddBoardIdChanged('none'));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -10,9 +10,12 @@ import {
|
|||||||
PopoverContent,
|
PopoverContent,
|
||||||
PopoverTrigger,
|
PopoverTrigger,
|
||||||
Portal,
|
Portal,
|
||||||
|
Spacer,
|
||||||
Text,
|
Text,
|
||||||
} from '@invoke-ai/ui-library';
|
} from '@invoke-ai/ui-library';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { setShouldEnableInformationalPopovers } from 'features/system/store/systemSlice';
|
||||||
|
import { toast } from 'features/toast/toast';
|
||||||
import { merge, omit } from 'lodash-es';
|
import { merge, omit } from 'lodash-es';
|
||||||
import type { ReactElement } from 'react';
|
import type { ReactElement } from 'react';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
@ -71,7 +74,7 @@ type ContentProps = {
|
|||||||
|
|
||||||
const Content = ({ data, feature }: ContentProps) => {
|
const Content = ({ data, feature }: ContentProps) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
const heading = useMemo<string | undefined>(() => t(`popovers.${feature}.heading`), [feature, t]);
|
const heading = useMemo<string | undefined>(() => t(`popovers.${feature}.heading`), [feature, t]);
|
||||||
|
|
||||||
const paragraphs = useMemo<string[]>(
|
const paragraphs = useMemo<string[]>(
|
||||||
@ -82,16 +85,25 @@ const Content = ({ data, feature }: ContentProps) => {
|
|||||||
[feature, t]
|
[feature, t]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleClick = useCallback(() => {
|
const onClickLearnMore = useCallback(() => {
|
||||||
if (!data?.href) {
|
if (!data?.href) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
window.open(data.href);
|
window.open(data.href);
|
||||||
}, [data?.href]);
|
}, [data?.href]);
|
||||||
|
|
||||||
|
const onClickDontShowMeThese = useCallback(() => {
|
||||||
|
dispatch(setShouldEnableInformationalPopovers(false));
|
||||||
|
toast({
|
||||||
|
title: t('settings.informationalPopoversDisabled'),
|
||||||
|
description: t('settings.informationalPopoversDisabledDesc'),
|
||||||
|
status: 'info',
|
||||||
|
});
|
||||||
|
}, [dispatch, t]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<PopoverContent w={96}>
|
<PopoverContent maxW={300}>
|
||||||
<PopoverCloseButton />
|
<PopoverCloseButton top={2} />
|
||||||
<PopoverBody>
|
<PopoverBody>
|
||||||
<Flex gap={2} flexDirection="column" alignItems="flex-start">
|
<Flex gap={2} flexDirection="column" alignItems="flex-start">
|
||||||
{heading && (
|
{heading && (
|
||||||
@ -116,21 +128,20 @@ const Content = ({ data, feature }: ContentProps) => {
|
|||||||
{paragraphs.map((p) => (
|
{paragraphs.map((p) => (
|
||||||
<Text key={p}>{p}</Text>
|
<Text key={p}>{p}</Text>
|
||||||
))}
|
))}
|
||||||
{data?.href && (
|
|
||||||
<>
|
|
||||||
<Divider />
|
<Divider />
|
||||||
<Button
|
<Flex alignItems="center" justifyContent="space-between" w="full">
|
||||||
pt={1}
|
<Button onClick={onClickDontShowMeThese} variant="link" size="sm">
|
||||||
onClick={handleClick}
|
{t('common.dontShowMeThese')}
|
||||||
leftIcon={<PiArrowSquareOutBold />}
|
</Button>
|
||||||
alignSelf="flex-end"
|
<Spacer />
|
||||||
variant="link"
|
{data?.href && (
|
||||||
>
|
<Button onClick={onClickLearnMore} leftIcon={<PiArrowSquareOutBold />} variant="link" size="sm">
|
||||||
{t('common.learnMore') ?? heading}
|
{t('common.learnMore') ?? heading}
|
||||||
</Button>
|
</Button>
|
||||||
</>
|
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
|
</Flex>
|
||||||
</PopoverBody>
|
</PopoverBody>
|
||||||
</PopoverContent>
|
</PopoverContent>
|
||||||
);
|
);
|
||||||
|
@ -53,7 +53,11 @@ export type Feature =
|
|||||||
| 'refinerCfgScale'
|
| 'refinerCfgScale'
|
||||||
| 'scaleBeforeProcessing'
|
| 'scaleBeforeProcessing'
|
||||||
| 'seamlessTilingXAxis'
|
| 'seamlessTilingXAxis'
|
||||||
| 'seamlessTilingYAxis';
|
| 'seamlessTilingYAxis'
|
||||||
|
| 'upscaleModel'
|
||||||
|
| 'scale'
|
||||||
|
| 'creativity'
|
||||||
|
| 'structure';
|
||||||
|
|
||||||
export type PopoverData = PopoverProps & {
|
export type PopoverData = PopoverProps & {
|
||||||
image?: string;
|
image?: string;
|
||||||
|
@ -120,7 +120,11 @@ const DeleteBoardModal = (props: Props) => {
|
|||||||
bottomMessage={t('boards.bottomMessage')}
|
bottomMessage={t('boards.bottomMessage')}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
<Text>{t('boards.deletedBoardsCannotbeRestored')}</Text>
|
<Text>
|
||||||
|
{boardToDelete.is_private
|
||||||
|
? t('boards.deletedPrivateBoardsCannotbeRestored')
|
||||||
|
: t('boards.deletedBoardsCannotbeRestored')}
|
||||||
|
</Text>
|
||||||
<Text>
|
<Text>
|
||||||
{canRestoreDeletedImagesFromBin ? t('gallery.deleteImageBin') : t('gallery.deleteImagePermanent')}
|
{canRestoreDeletedImagesFromBin ? t('gallery.deleteImageBin') : t('gallery.deleteImagePermanent')}
|
||||||
</Text>
|
</Text>
|
||||||
|
@ -1,15 +1,10 @@
|
|||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
|
||||||
import { isNil } from 'lodash-es';
|
import { isNil } from 'lodash-es';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
|
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||||
import { isControlNetOrT2IAdapterModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
export const useControlNetOrT2IAdapterDefaultSettings = (modelKey?: string | null) => {
|
|
||||||
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(
|
|
||||||
modelKey ?? skipToken,
|
|
||||||
isControlNetOrT2IAdapterModelConfig
|
|
||||||
);
|
|
||||||
|
|
||||||
|
export const useControlNetOrT2IAdapterDefaultSettings = (
|
||||||
|
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig
|
||||||
|
) => {
|
||||||
const defaultSettingsDefaults = useMemo(() => {
|
const defaultSettingsDefaults = useMemo(() => {
|
||||||
return {
|
return {
|
||||||
preprocessor: {
|
preprocessor: {
|
||||||
@ -19,5 +14,5 @@ export const useControlNetOrT2IAdapterDefaultSettings = (modelKey?: string | nul
|
|||||||
};
|
};
|
||||||
}, [modelConfig?.default_settings]);
|
}, [modelConfig?.default_settings]);
|
||||||
|
|
||||||
return { defaultSettingsDefaults, isLoading };
|
return defaultSettingsDefaults;
|
||||||
};
|
};
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
|
||||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||||
import { isNil } from 'lodash-es';
|
import { isNil } from 'lodash-es';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
|
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
|
||||||
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision, width, height } = config.sd;
|
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision, width, height } = config.sd;
|
||||||
@ -22,9 +19,7 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
|
|||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
export const useMainModelDefaultSettings = (modelKey?: string | null) => {
|
export const useMainModelDefaultSettings = (modelConfig: MainModelConfig) => {
|
||||||
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(modelKey ?? skipToken, isNonRefinerMainModelConfig);
|
|
||||||
|
|
||||||
const {
|
const {
|
||||||
initialSteps,
|
initialSteps,
|
||||||
initialCfg,
|
initialCfg,
|
||||||
@ -81,5 +76,5 @@ export const useMainModelDefaultSettings = (modelKey?: string | null) => {
|
|||||||
initialHeight,
|
initialHeight,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
return { defaultSettingsDefaults, isLoading, optimalDimension: getOptimalDimension(modelConfig) };
|
return defaultSettingsDefaults;
|
||||||
};
|
};
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import type { PersistConfig } from 'app/store/store';
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
import type { ModelType } from 'services/api/types';
|
import type { ModelType } from 'services/api/types';
|
||||||
|
|
||||||
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'> | 'refiner';
|
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'> | 'refiner';
|
||||||
@ -50,6 +50,8 @@ export const modelManagerV2Slice = createSlice({
|
|||||||
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType, setSelectedModelMode, setScanPath } =
|
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType, setSelectedModelMode, setScanPath } =
|
||||||
modelManagerV2Slice.actions;
|
modelManagerV2Slice.actions;
|
||||||
|
|
||||||
|
export const selectModelManagerV2Slice = (state: RootState) => state.modelmanagerV2;
|
||||||
|
|
||||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||||
const migrateModelManagerState = (state: any): any => {
|
const migrateModelManagerState = (state: any): any => {
|
||||||
if (!('_version' in state)) {
|
if (!('_version' in state)) {
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
|
import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||||
import type { ChangeEventHandler } from 'react';
|
import type { ChangeEventHandler } from 'react';
|
||||||
import { useCallback, useState } from 'react';
|
import { memo, useCallback, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
|
import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import { HuggingFaceResults } from './HuggingFaceResults';
|
import { HuggingFaceResults } from './HuggingFaceResults';
|
||||||
|
|
||||||
export const HuggingFaceForm = () => {
|
export const HuggingFaceForm = memo(() => {
|
||||||
const [huggingFaceRepo, setHuggingFaceRepo] = useState('');
|
const [huggingFaceRepo, setHuggingFaceRepo] = useState('');
|
||||||
const [displayResults, setDisplayResults] = useState(false);
|
const [displayResults, setDisplayResults] = useState(false);
|
||||||
const [errorMessage, setErrorMessage] = useState('');
|
const [errorMessage, setErrorMessage] = useState('');
|
||||||
@ -66,4 +66,6 @@ export const HuggingFaceForm = () => {
|
|||||||
{data && data.urls && displayResults && <HuggingFaceResults results={data.urls} />}
|
{data && data.urls && displayResults && <HuggingFaceResults results={data.urls} />}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
HuggingFaceForm.displayName = 'HuggingFaceForm';
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
import { Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
import { Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||||
import { useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiPlusBold } from 'react-icons/pi';
|
import { PiPlusBold } from 'react-icons/pi';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
result: string;
|
result: string;
|
||||||
};
|
};
|
||||||
export const HuggingFaceResultItem = ({ result }: Props) => {
|
export const HuggingFaceResultItem = memo(({ result }: Props) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const [installModel] = useInstallModel();
|
const [installModel] = useInstallModel();
|
||||||
@ -27,4 +27,6 @@ export const HuggingFaceResultItem = ({ result }: Props) => {
|
|||||||
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={onClick} size="sm" />
|
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={onClick} size="sm" />
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
HuggingFaceResultItem.displayName = 'HuggingFaceResultItem';
|
||||||
|
@ -11,7 +11,7 @@ import {
|
|||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||||
import type { ChangeEventHandler } from 'react';
|
import type { ChangeEventHandler } from 'react';
|
||||||
import { useCallback, useMemo, useState } from 'react';
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiXBold } from 'react-icons/pi';
|
import { PiXBold } from 'react-icons/pi';
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ type HuggingFaceResultsProps = {
|
|||||||
results: string[];
|
results: string[];
|
||||||
};
|
};
|
||||||
|
|
||||||
export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
|
export const HuggingFaceResults = memo(({ results }: HuggingFaceResultsProps) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const [searchTerm, setSearchTerm] = useState('');
|
const [searchTerm, setSearchTerm] = useState('');
|
||||||
|
|
||||||
@ -93,4 +93,6 @@ export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
HuggingFaceResults.displayName = 'HuggingFaceResults';
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { Button, Checkbox, Flex, FormControl, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
|
import { Button, Checkbox, Flex, FormControl, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import type { SubmitHandler } from 'react-hook-form';
|
import type { SubmitHandler } from 'react-hook-form';
|
||||||
import { useForm } from 'react-hook-form';
|
import { useForm } from 'react-hook-form';
|
||||||
|
|
||||||
@ -10,7 +10,7 @@ type SimpleImportModelConfig = {
|
|||||||
inplace: boolean;
|
inplace: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const InstallModelForm = () => {
|
export const InstallModelForm = memo(() => {
|
||||||
const [installModel, { isLoading }] = useInstallModel();
|
const [installModel, { isLoading }] = useInstallModel();
|
||||||
|
|
||||||
const { register, handleSubmit, formState, reset } = useForm<SimpleImportModelConfig>({
|
const { register, handleSubmit, formState, reset } = useForm<SimpleImportModelConfig>({
|
||||||
@ -74,4 +74,6 @@ export const InstallModelForm = () => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</form>
|
</form>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
InstallModelForm.displayName = 'InstallModelForm';
|
||||||
|
@ -2,12 +2,12 @@ import { Box, Button, Flex, Heading } from '@invoke-ai/ui-library';
|
|||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models';
|
import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import { ModelInstallQueueItem } from './ModelInstallQueueItem';
|
import { ModelInstallQueueItem } from './ModelInstallQueueItem';
|
||||||
|
|
||||||
export const ModelInstallQueue = () => {
|
export const ModelInstallQueue = memo(() => {
|
||||||
const { data } = useListModelInstallsQuery();
|
const { data } = useListModelInstallsQuery();
|
||||||
|
|
||||||
const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation();
|
const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation();
|
||||||
@ -61,4 +61,6 @@ export const ModelInstallQueue = () => {
|
|||||||
</Box>
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
ModelInstallQueue.displayName = 'ModelInstallQueue';
|
||||||
|
@ -2,7 +2,7 @@ import { Flex, IconButton, Progress, Text, Tooltip } from '@invoke-ai/ui-library
|
|||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { isNil } from 'lodash-es';
|
import { isNil } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { PiXBold } from 'react-icons/pi';
|
import { PiXBold } from 'react-icons/pi';
|
||||||
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
|
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
|
||||||
import type { ModelInstallJob } from 'services/api/types';
|
import type { ModelInstallJob } from 'services/api/types';
|
||||||
@ -25,7 +25,7 @@ const formatBytes = (bytes: number) => {
|
|||||||
return `${bytes.toFixed(2)} ${units[i]}`;
|
return `${bytes.toFixed(2)} ${units[i]}`;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
export const ModelInstallQueueItem = memo((props: ModelListItemProps) => {
|
||||||
const { installJob } = props;
|
const { installJob } = props;
|
||||||
|
|
||||||
const [deleteImportModel] = useCancelModelInstallMutation();
|
const [deleteImportModel] = useCancelModelInstallMutation();
|
||||||
@ -124,7 +124,9 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
|||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
ModelInstallQueueItem.displayName = 'ModelInstallQueueItem';
|
||||||
|
|
||||||
type TooltipLabelProps = {
|
type TooltipLabelProps = {
|
||||||
installJob: ModelInstallJob;
|
installJob: ModelInstallJob;
|
||||||
@ -132,7 +134,7 @@ type TooltipLabelProps = {
|
|||||||
source: string;
|
source: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
const TooltipLabel = ({ name, source, installJob }: TooltipLabelProps) => {
|
const TooltipLabel = memo(({ name, source, installJob }: TooltipLabelProps) => {
|
||||||
const progressString = useMemo(() => {
|
const progressString = useMemo(() => {
|
||||||
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
|
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
|
||||||
return '';
|
return '';
|
||||||
@ -156,4 +158,6 @@ const TooltipLabel = ({ name, source, installJob }: TooltipLabelProps) => {
|
|||||||
)}
|
)}
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
TooltipLabel.displayName = 'TooltipLabel';
|
||||||
|
@ -2,13 +2,13 @@ import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel,
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { setScanPath } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
import { setScanPath } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import type { ChangeEventHandler } from 'react';
|
import type { ChangeEventHandler } from 'react';
|
||||||
import { useCallback, useState } from 'react';
|
import { memo, useCallback, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useLazyScanFolderQuery } from 'services/api/endpoints/models';
|
import { useLazyScanFolderQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import { ScanModelsResults } from './ScanFolderResults';
|
import { ScanModelsResults } from './ScanFolderResults';
|
||||||
|
|
||||||
export const ScanModelsForm = () => {
|
export const ScanModelsForm = memo(() => {
|
||||||
const scanPath = useAppSelector((state) => state.modelmanagerV2.scanPath);
|
const scanPath = useAppSelector((state) => state.modelmanagerV2.scanPath);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const [errorMessage, setErrorMessage] = useState('');
|
const [errorMessage, setErrorMessage] = useState('');
|
||||||
@ -56,4 +56,6 @@ export const ScanModelsForm = () => {
|
|||||||
{data && <ScanModelsResults results={data} />}
|
{data && <ScanModelsResults results={data} />}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
ScanModelsForm.displayName = 'ScanModelsForm';
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||||
import { useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiPlusBold } from 'react-icons/pi';
|
import { PiPlusBold } from 'react-icons/pi';
|
||||||
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
||||||
@ -8,7 +8,7 @@ type Props = {
|
|||||||
result: ScanFolderResponse[number];
|
result: ScanFolderResponse[number];
|
||||||
installModel: (source: string) => void;
|
installModel: (source: string) => void;
|
||||||
};
|
};
|
||||||
export const ScanModelResultItem = ({ result, installModel }: Props) => {
|
export const ScanModelResultItem = memo(({ result, installModel }: Props) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const handleInstall = useCallback(() => {
|
const handleInstall = useCallback(() => {
|
||||||
@ -30,4 +30,6 @@ export const ScanModelResultItem = ({ result, installModel }: Props) => {
|
|||||||
</Box>
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
ScanModelResultItem.displayName = 'ScanModelResultItem';
|
||||||
|
@ -14,7 +14,7 @@ import {
|
|||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||||
import type { ChangeEvent, ChangeEventHandler } from 'react';
|
import type { ChangeEvent, ChangeEventHandler } from 'react';
|
||||||
import { useCallback, useMemo, useState } from 'react';
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiXBold } from 'react-icons/pi';
|
import { PiXBold } from 'react-icons/pi';
|
||||||
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
||||||
@ -25,7 +25,7 @@ type ScanModelResultsProps = {
|
|||||||
results: ScanFolderResponse;
|
results: ScanFolderResponse;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
export const ScanModelsResults = memo(({ results }: ScanModelResultsProps) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const [searchTerm, setSearchTerm] = useState('');
|
const [searchTerm, setSearchTerm] = useState('');
|
||||||
const [inplace, setInplace] = useState(true);
|
const [inplace, setInplace] = useState(true);
|
||||||
@ -116,4 +116,6 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
ScanModelsResults.displayName = 'ScanModelsResults';
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||||
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiPlusBold } from 'react-icons/pi';
|
import { PiPlusBold } from 'react-icons/pi';
|
||||||
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
||||||
@ -9,7 +9,7 @@ import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
|||||||
type Props = {
|
type Props = {
|
||||||
result: GetStarterModelsResponse[number];
|
result: GetStarterModelsResponse[number];
|
||||||
};
|
};
|
||||||
export const StarterModelsResultItem = ({ result }: Props) => {
|
export const StarterModelsResultItem = memo(({ result }: Props) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const allSources = useMemo(() => {
|
const allSources = useMemo(() => {
|
||||||
const _allSources = [{ source: result.source, config: { name: result.name, description: result.description } }];
|
const _allSources = [{ source: result.source, config: { name: result.name, description: result.description } }];
|
||||||
@ -47,4 +47,6 @@ export const StarterModelsResultItem = ({ result }: Props) => {
|
|||||||
</Box>
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
StarterModelsResultItem.displayName = 'StarterModelsResultItem';
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import { Flex } from '@invoke-ai/ui-library';
|
import { Flex } from '@invoke-ai/ui-library';
|
||||||
import { FetchingModelsLoader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/FetchingModelsLoader';
|
import { FetchingModelsLoader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/FetchingModelsLoader';
|
||||||
|
import { memo } from 'react';
|
||||||
import { useGetStarterModelsQuery } from 'services/api/endpoints/models';
|
import { useGetStarterModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import { StarterModelsResults } from './StarterModelsResults';
|
import { StarterModelsResults } from './StarterModelsResults';
|
||||||
|
|
||||||
export const StarterModelsForm = () => {
|
export const StarterModelsForm = memo(() => {
|
||||||
const { isLoading, data } = useGetStarterModelsQuery();
|
const { isLoading, data } = useGetStarterModelsQuery();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -13,4 +14,6 @@ export const StarterModelsForm = () => {
|
|||||||
{data && <StarterModelsResults results={data} />}
|
{data && <StarterModelsResults results={data} />}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
StarterModelsForm.displayName = 'StarterModelsForm';
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { Flex, IconButton, Input, InputGroup, InputRightElement } from '@invoke-ai/ui-library';
|
import { Flex, IconButton, Input, InputGroup, InputRightElement } from '@invoke-ai/ui-library';
|
||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
import type { ChangeEventHandler } from 'react';
|
import type { ChangeEventHandler } from 'react';
|
||||||
import { useCallback, useMemo, useState } from 'react';
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiXBold } from 'react-icons/pi';
|
import { PiXBold } from 'react-icons/pi';
|
||||||
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
||||||
@ -12,7 +12,7 @@ type StarterModelsResultsProps = {
|
|||||||
results: NonNullable<GetStarterModelsResponse>;
|
results: NonNullable<GetStarterModelsResponse>;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const StarterModelsResults = ({ results }: StarterModelsResultsProps) => {
|
export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const [searchTerm, setSearchTerm] = useState('');
|
const [searchTerm, setSearchTerm] = useState('');
|
||||||
|
|
||||||
@ -79,4 +79,6 @@ export const StarterModelsResults = ({ results }: StarterModelsResultsProps) =>
|
|||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
StarterModelsResults.displayName = 'StarterModelsResults';
|
||||||
|
@ -2,7 +2,7 @@ import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@in
|
|||||||
import { useStore } from '@nanostores/react';
|
import { useStore } from '@nanostores/react';
|
||||||
import { StarterModelsForm } from 'features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm';
|
import { StarterModelsForm } from 'features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm';
|
||||||
import { atom } from 'nanostores';
|
import { atom } from 'nanostores';
|
||||||
import { useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
|
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
|
||||||
@ -12,7 +12,7 @@ import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
|
|||||||
|
|
||||||
export const $installModelsTab = atom(0);
|
export const $installModelsTab = atom(0);
|
||||||
|
|
||||||
export const InstallModels = () => {
|
export const InstallModels = memo(() => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const index = useStore($installModelsTab);
|
const index = useStore($installModelsTab);
|
||||||
const onChange = useCallback((index: number) => {
|
const onChange = useCallback((index: number) => {
|
||||||
@ -49,4 +49,6 @@ export const InstallModels = () => {
|
|||||||
</Box>
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
InstallModels.displayName = 'InstallModels';
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
|
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import { useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiPlusBold } from 'react-icons/pi';
|
import { PiPlusBold } from 'react-icons/pi';
|
||||||
|
|
||||||
import ModelList from './ModelManagerPanel/ModelList';
|
import ModelList from './ModelManagerPanel/ModelList';
|
||||||
import { ModelListNavigation } from './ModelManagerPanel/ModelListNavigation';
|
import { ModelListNavigation } from './ModelManagerPanel/ModelListNavigation';
|
||||||
|
|
||||||
export const ModelManager = () => {
|
export const ModelManager = memo(() => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const handleClickAddModel = useCallback(() => {
|
const handleClickAddModel = useCallback(() => {
|
||||||
@ -29,4 +29,6 @@ export const ModelManager = () => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
ModelManager.displayName = 'ModelManager';
|
||||||
|
@ -21,7 +21,8 @@ import { FetchingModelsLoader } from './FetchingModelsLoader';
|
|||||||
import { ModelListWrapper } from './ModelListWrapper';
|
import { ModelListWrapper } from './ModelListWrapper';
|
||||||
|
|
||||||
const ModelList = () => {
|
const ModelList = () => {
|
||||||
const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2);
|
const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType);
|
||||||
|
const searchTerm = useAppSelector((s) => s.modelmanagerV2.searchTerm);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const [mainModels, { isLoading: isLoadingMainModels }] = useMainModels();
|
const [mainModels, { isLoading: isLoadingMainModels }] = useMainModels();
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||||
import { ConfirmationAlertDialog, Flex, IconButton, Spacer, Text, useDisclosure } from '@invoke-ai/ui-library';
|
import { ConfirmationAlertDialog, Flex, IconButton, Spacer, Text, useDisclosure } from '@invoke-ai/ui-library';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
import { selectModelManagerV2Slice, setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
||||||
import ModelFormatBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge';
|
import ModelFormatBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge';
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
@ -23,15 +24,21 @@ const sx: SystemStyleObject = {
|
|||||||
"&[aria-selected='true']": { bg: 'base.700' },
|
"&[aria-selected='true']": { bg: 'base.700' },
|
||||||
};
|
};
|
||||||
|
|
||||||
const ModelListItem = (props: ModelListItemProps) => {
|
const ModelListItem = ({ model }: ModelListItemProps) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectIsSelected = useMemo(
|
||||||
|
() =>
|
||||||
|
createSelector(
|
||||||
|
selectModelManagerV2Slice,
|
||||||
|
(modelManagerV2Slice) => modelManagerV2Slice.selectedModelKey === model.key
|
||||||
|
),
|
||||||
|
[model.key]
|
||||||
|
);
|
||||||
|
const isSelected = useAppSelector(selectIsSelected);
|
||||||
const [deleteModel] = useDeleteModelsMutation();
|
const [deleteModel] = useDeleteModelsMutation();
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
|
|
||||||
const { model } = props;
|
|
||||||
|
|
||||||
const handleSelectModel = useCallback(() => {
|
const handleSelectModel = useCallback(() => {
|
||||||
dispatch(setSelectedModelKey(model.key));
|
dispatch(setSelectedModelKey(model.key));
|
||||||
}, [model.key, dispatch]);
|
}, [model.key, dispatch]);
|
||||||
@ -43,11 +50,6 @@ const ModelListItem = (props: ModelListItemProps) => {
|
|||||||
},
|
},
|
||||||
[onOpen]
|
[onOpen]
|
||||||
);
|
);
|
||||||
|
|
||||||
const isSelected = useMemo(() => {
|
|
||||||
return selectedModelKey === model.key;
|
|
||||||
}, [selectedModelKey, model.key]);
|
|
||||||
|
|
||||||
const handleModelDelete = useCallback(() => {
|
const handleModelDelete = useCallback(() => {
|
||||||
deleteModel({ key: model.key })
|
deleteModel({ key: model.key })
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -3,12 +3,12 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { setSearchTerm } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
import { setSearchTerm } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import type { ChangeEventHandler } from 'react';
|
import type { ChangeEventHandler } from 'react';
|
||||||
import { useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { PiXBold } from 'react-icons/pi';
|
import { PiXBold } from 'react-icons/pi';
|
||||||
|
|
||||||
import { ModelTypeFilter } from './ModelTypeFilter';
|
import { ModelTypeFilter } from './ModelTypeFilter';
|
||||||
|
|
||||||
export const ModelListNavigation = () => {
|
export const ModelListNavigation = memo(() => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const searchTerm = useAppSelector((s) => s.modelmanagerV2.searchTerm);
|
const searchTerm = useAppSelector((s) => s.modelmanagerV2.searchTerm);
|
||||||
|
|
||||||
@ -49,4 +49,6 @@ export const ModelListNavigation = () => {
|
|||||||
</InputGroup>
|
</InputGroup>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
ModelListNavigation.displayName = 'ModelListNavigation';
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { StickyScrollable } from 'features/system/components/StickyScrollable';
|
import { StickyScrollable } from 'features/system/components/StickyScrollable';
|
||||||
|
import { memo } from 'react';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import ModelListItem from './ModelListItem';
|
import ModelListItem from './ModelListItem';
|
||||||
@ -8,7 +9,7 @@ type ModelListWrapperProps = {
|
|||||||
modelList: AnyModelConfig[];
|
modelList: AnyModelConfig[];
|
||||||
};
|
};
|
||||||
|
|
||||||
export const ModelListWrapper = (props: ModelListWrapperProps) => {
|
export const ModelListWrapper = memo((props: ModelListWrapperProps) => {
|
||||||
const { title, modelList } = props;
|
const { title, modelList } = props;
|
||||||
return (
|
return (
|
||||||
<StickyScrollable title={title} contentSx={{ gap: 1, p: 2 }}>
|
<StickyScrollable title={title} contentSx={{ gap: 1, p: 2 }}>
|
||||||
@ -17,4 +18,6 @@ export const ModelListWrapper = (props: ModelListWrapperProps) => {
|
|||||||
))}
|
))}
|
||||||
</StickyScrollable>
|
</StickyScrollable>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
ModelListWrapper.displayName = 'ModelListWrapper';
|
||||||
|
@ -2,12 +2,12 @@ import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-libr
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiFunnelBold } from 'react-icons/pi';
|
import { PiFunnelBold } from 'react-icons/pi';
|
||||||
import { objectKeys } from 'tsafe';
|
import { objectKeys } from 'tsafe';
|
||||||
|
|
||||||
export const ModelTypeFilter = () => {
|
export const ModelTypeFilter = memo(() => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
|
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
|
||||||
@ -57,4 +57,6 @@ export const ModelTypeFilter = () => {
|
|||||||
</MenuList>
|
</MenuList>
|
||||||
</Menu>
|
</Menu>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
ModelTypeFilter.displayName = 'ModelTypeFilter';
|
||||||
|
@ -1,14 +1,17 @@
|
|||||||
import { Box } from '@invoke-ai/ui-library';
|
import { Box } from '@invoke-ai/ui-library';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
import { InstallModels } from './InstallModels';
|
import { InstallModels } from './InstallModels';
|
||||||
import { Model } from './ModelPanel/Model';
|
import { Model } from './ModelPanel/Model';
|
||||||
|
|
||||||
export const ModelPane = () => {
|
export const ModelPane = memo(() => {
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
return (
|
return (
|
||||||
<Box layerStyle="first" p={4} borderRadius="base" w="50%" h="full">
|
<Box layerStyle="first" p={4} borderRadius="base" w="50%" h="full">
|
||||||
{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
|
{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
ModelPane.displayName = 'ModelPane';
|
||||||
|
@ -1,26 +1,28 @@
|
|||||||
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
import { Button, Flex, Heading, SimpleGrid } from '@invoke-ai/ui-library';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings';
|
import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings';
|
||||||
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor';
|
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor';
|
||||||
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
|
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import type { SubmitHandler } from 'react-hook-form';
|
import type { SubmitHandler } from 'react-hook-form';
|
||||||
import { useForm } from 'react-hook-form';
|
import { useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiCheckBold } from 'react-icons/pi';
|
import { PiCheckBold } from 'react-icons/pi';
|
||||||
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||||
|
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
export type ControlNetOrT2IAdapterDefaultSettingsFormData = {
|
export type ControlNetOrT2IAdapterDefaultSettingsFormData = {
|
||||||
preprocessor: FormField<string>;
|
preprocessor: FormField<string>;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const ControlNetOrT2IAdapterDefaultSettings = () => {
|
type Props = {
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const ControlNetOrT2IAdapterDefaultSettings = memo(({ modelConfig }: Props) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { defaultSettingsDefaults, isLoading: isLoadingDefaultSettings } =
|
const defaultSettingsDefaults = useControlNetOrT2IAdapterDefaultSettings(modelConfig);
|
||||||
useControlNetOrT2IAdapterDefaultSettings(selectedModelKey);
|
|
||||||
|
|
||||||
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
|
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
|
||||||
|
|
||||||
@ -30,16 +32,12 @@ export const ControlNetOrT2IAdapterDefaultSettings = () => {
|
|||||||
|
|
||||||
const onSubmit = useCallback<SubmitHandler<ControlNetOrT2IAdapterDefaultSettingsFormData>>(
|
const onSubmit = useCallback<SubmitHandler<ControlNetOrT2IAdapterDefaultSettingsFormData>>(
|
||||||
(data) => {
|
(data) => {
|
||||||
if (!selectedModelKey) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const body = {
|
const body = {
|
||||||
preprocessor: data.preprocessor.isEnabled ? data.preprocessor.value : null,
|
preprocessor: data.preprocessor.isEnabled ? data.preprocessor.value : null,
|
||||||
};
|
};
|
||||||
|
|
||||||
updateModel({
|
updateModel({
|
||||||
key: selectedModelKey,
|
key: modelConfig.key,
|
||||||
body: { default_settings: body },
|
body: { default_settings: body },
|
||||||
})
|
})
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@ -61,13 +59,9 @@ export const ControlNetOrT2IAdapterDefaultSettings = () => {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[selectedModelKey, reset, updateModel, t]
|
[updateModel, modelConfig.key, t, reset]
|
||||||
);
|
);
|
||||||
|
|
||||||
if (isLoadingDefaultSettings) {
|
|
||||||
return <Text>{t('common.loading')}</Text>;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Flex gap="4" justifyContent="space-between" w="full" pb={4}>
|
<Flex gap="4" justifyContent="space-between" w="full" pb={4}>
|
||||||
@ -89,4 +83,6 @@ export const ControlNetOrT2IAdapterDefaultSettings = () => {
|
|||||||
</SimpleGrid>
|
</SimpleGrid>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
ControlNetOrT2IAdapterDefaultSettings.displayName = 'ControlNetOrT2IAdapterDefaultSettings';
|
||||||
|
@ -4,7 +4,7 @@ import { InformationalPopover } from 'common/components/InformationalPopover/Inf
|
|||||||
import type { ControlNetOrT2IAdapterDefaultSettingsFormData } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
|
import type { ControlNetOrT2IAdapterDefaultSettingsFormData } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
|
||||||
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
|
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
|
||||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -28,7 +28,7 @@ const OPTIONS = [
|
|||||||
|
|
||||||
type DefaultSchedulerType = ControlNetOrT2IAdapterDefaultSettingsFormData['preprocessor'];
|
type DefaultSchedulerType = ControlNetOrT2IAdapterDefaultSettingsFormData['preprocessor'];
|
||||||
|
|
||||||
export function DefaultPreprocessor(props: UseControllerProps<ControlNetOrT2IAdapterDefaultSettingsFormData>) {
|
export const DefaultPreprocessor = memo((props: UseControllerProps<ControlNetOrT2IAdapterDefaultSettingsFormData>) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { field } = useController(props);
|
const { field } = useController(props);
|
||||||
|
|
||||||
@ -63,4 +63,6 @@ export function DefaultPreprocessor(props: UseControllerProps<ControlNetOrT2IAda
|
|||||||
<Combobox isDisabled={isDisabled} value={value} options={OPTIONS} onChange={onChange} />
|
<Combobox isDisabled={isDisabled} value={value} options={OPTIONS} onChange={onChange} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
);
|
);
|
||||||
}
|
});
|
||||||
|
|
||||||
|
DefaultPreprocessor.displayName = 'DefaultPreprocessor';
|
||||||
|
@ -2,7 +2,7 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } f
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -11,7 +11,7 @@ import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSetting
|
|||||||
|
|
||||||
type DefaultCfgRescaleMultiplierType = MainModelDefaultSettingsFormData['cfgRescaleMultiplier'];
|
type DefaultCfgRescaleMultiplierType = MainModelDefaultSettingsFormData['cfgRescaleMultiplier'];
|
||||||
|
|
||||||
export function DefaultCfgRescaleMultiplier(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
export const DefaultCfgRescaleMultiplier = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||||
const { field } = useController(props);
|
const { field } = useController(props);
|
||||||
|
|
||||||
const sliderMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMin);
|
const sliderMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMin);
|
||||||
@ -74,4 +74,6 @@ export function DefaultCfgRescaleMultiplier(props: UseControllerProps<MainModelD
|
|||||||
</Flex>
|
</Flex>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
);
|
);
|
||||||
}
|
});
|
||||||
|
|
||||||
|
DefaultCfgRescaleMultiplier.displayName = 'DefaultCfgRescaleMultiplier';
|
||||||
|
@ -2,7 +2,7 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } f
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -11,7 +11,7 @@ import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSetting
|
|||||||
|
|
||||||
type DefaultCfgType = MainModelDefaultSettingsFormData['cfgScale'];
|
type DefaultCfgType = MainModelDefaultSettingsFormData['cfgScale'];
|
||||||
|
|
||||||
export function DefaultCfgScale(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
export const DefaultCfgScale = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||||
const { field } = useController(props);
|
const { field } = useController(props);
|
||||||
|
|
||||||
const sliderMin = useAppSelector((s) => s.config.sd.guidance.sliderMin);
|
const sliderMin = useAppSelector((s) => s.config.sd.guidance.sliderMin);
|
||||||
@ -74,4 +74,6 @@ export function DefaultCfgScale(props: UseControllerProps<MainModelDefaultSettin
|
|||||||
</Flex>
|
</Flex>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
);
|
);
|
||||||
}
|
});
|
||||||
|
|
||||||
|
DefaultCfgScale.displayName = 'DefaultCfgScale';
|
||||||
|
@ -2,7 +2,7 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } f
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -16,7 +16,7 @@ type Props = {
|
|||||||
optimalDimension: number;
|
optimalDimension: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
export function DefaultHeight({ control, optimalDimension }: Props) {
|
export const DefaultHeight = memo(({ control, optimalDimension }: Props) => {
|
||||||
const { field } = useController({ control, name: 'height' });
|
const { field } = useController({ control, name: 'height' });
|
||||||
const sliderMin = useAppSelector((s) => s.config.sd.height.sliderMin);
|
const sliderMin = useAppSelector((s) => s.config.sd.height.sliderMin);
|
||||||
const sliderMax = useAppSelector((s) => s.config.sd.height.sliderMax);
|
const sliderMax = useAppSelector((s) => s.config.sd.height.sliderMax);
|
||||||
@ -78,4 +78,6 @@ export function DefaultHeight({ control, optimalDimension }: Props) {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
);
|
);
|
||||||
}
|
});
|
||||||
|
|
||||||
|
DefaultHeight.displayName = 'DefaultHeight';
|
||||||
|
@ -4,7 +4,7 @@ import { InformationalPopover } from 'common/components/InformationalPopover/Inf
|
|||||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||||
import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants';
|
import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants';
|
||||||
import { isParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
import { isParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -13,7 +13,7 @@ import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSetting
|
|||||||
|
|
||||||
type DefaultSchedulerType = MainModelDefaultSettingsFormData['scheduler'];
|
type DefaultSchedulerType = MainModelDefaultSettingsFormData['scheduler'];
|
||||||
|
|
||||||
export function DefaultScheduler(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
export const DefaultScheduler = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { field } = useController(props);
|
const { field } = useController(props);
|
||||||
|
|
||||||
@ -51,4 +51,6 @@ export function DefaultScheduler(props: UseControllerProps<MainModelDefaultSetti
|
|||||||
<Combobox isDisabled={isDisabled} value={value} options={SCHEDULER_OPTIONS} onChange={onChange} />
|
<Combobox isDisabled={isDisabled} value={value} options={SCHEDULER_OPTIONS} onChange={onChange} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
);
|
);
|
||||||
}
|
});
|
||||||
|
|
||||||
|
DefaultScheduler.displayName = 'DefaultScheduler';
|
||||||
|
@ -2,7 +2,7 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } f
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -11,7 +11,7 @@ import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSetting
|
|||||||
|
|
||||||
type DefaultSteps = MainModelDefaultSettingsFormData['steps'];
|
type DefaultSteps = MainModelDefaultSettingsFormData['steps'];
|
||||||
|
|
||||||
export function DefaultSteps(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
export const DefaultSteps = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||||
const { field } = useController(props);
|
const { field } = useController(props);
|
||||||
|
|
||||||
const sliderMin = useAppSelector((s) => s.config.sd.steps.sliderMin);
|
const sliderMin = useAppSelector((s) => s.config.sd.steps.sliderMin);
|
||||||
@ -74,4 +74,6 @@ export function DefaultSteps(props: UseControllerProps<MainModelDefaultSettingsF
|
|||||||
</Flex>
|
</Flex>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
);
|
);
|
||||||
}
|
});
|
||||||
|
|
||||||
|
DefaultSteps.displayName = 'DefaultSteps';
|
||||||
|
@ -4,7 +4,7 @@ import { skipToken } from '@reduxjs/toolkit/query';
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -15,7 +15,7 @@ import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSetting
|
|||||||
|
|
||||||
type DefaultVaeType = MainModelDefaultSettingsFormData['vae'];
|
type DefaultVaeType = MainModelDefaultSettingsFormData['vae'];
|
||||||
|
|
||||||
export function DefaultVae(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
export const DefaultVae = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { field } = useController(props);
|
const { field } = useController(props);
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
@ -64,4 +64,6 @@ export function DefaultVae(props: UseControllerProps<MainModelDefaultSettingsFor
|
|||||||
<Combobox isDisabled={isDisabled} value={value} options={compatibleOptions} onChange={onChange} />
|
<Combobox isDisabled={isDisabled} value={value} options={compatibleOptions} onChange={onChange} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
);
|
);
|
||||||
}
|
});
|
||||||
|
|
||||||
|
DefaultVae.displayName = 'DefaultVae';
|
||||||
|
@ -3,7 +3,7 @@ import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
|||||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||||
import { isParameterPrecision } from 'features/parameters/types/parameterSchemas';
|
import { isParameterPrecision } from 'features/parameters/types/parameterSchemas';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -17,7 +17,7 @@ const options = [
|
|||||||
|
|
||||||
type DefaultVaePrecisionType = MainModelDefaultSettingsFormData['vaePrecision'];
|
type DefaultVaePrecisionType = MainModelDefaultSettingsFormData['vaePrecision'];
|
||||||
|
|
||||||
export function DefaultVaePrecision(props: UseControllerProps<MainModelDefaultSettingsFormData>) {
|
export const DefaultVaePrecision = memo((props: UseControllerProps<MainModelDefaultSettingsFormData>) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { field } = useController(props);
|
const { field } = useController(props);
|
||||||
|
|
||||||
@ -52,4 +52,6 @@ export function DefaultVaePrecision(props: UseControllerProps<MainModelDefaultSe
|
|||||||
<Combobox isDisabled={isDisabled} value={value} options={options} onChange={onChange} />
|
<Combobox isDisabled={isDisabled} value={value} options={options} onChange={onChange} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
);
|
);
|
||||||
}
|
});
|
||||||
|
|
||||||
|
DefaultVaePrecision.displayName = 'DefaultVaePrecision';
|
||||||
|
@ -2,7 +2,7 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } f
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -16,7 +16,7 @@ type Props = {
|
|||||||
optimalDimension: number;
|
optimalDimension: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
export function DefaultWidth({ control, optimalDimension }: Props) {
|
export const DefaultWidth = memo(({ control, optimalDimension }: Props) => {
|
||||||
const { field } = useController({ control, name: 'width' });
|
const { field } = useController({ control, name: 'width' });
|
||||||
const sliderMin = useAppSelector((s) => s.config.sd.width.sliderMin);
|
const sliderMin = useAppSelector((s) => s.config.sd.width.sliderMin);
|
||||||
const sliderMax = useAppSelector((s) => s.config.sd.width.sliderMax);
|
const sliderMax = useAppSelector((s) => s.config.sd.width.sliderMax);
|
||||||
@ -78,4 +78,6 @@ export function DefaultWidth({ control, optimalDimension }: Props) {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
);
|
);
|
||||||
}
|
});
|
||||||
|
|
||||||
|
DefaultWidth.displayName = 'DefaultWidth';
|
||||||
|
@ -1,16 +1,18 @@
|
|||||||
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
import { Button, Flex, Heading, SimpleGrid } from '@invoke-ai/ui-library';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings';
|
import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings';
|
||||||
import { DefaultHeight } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultHeight';
|
import { DefaultHeight } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultHeight';
|
||||||
import { DefaultWidth } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultWidth';
|
import { DefaultWidth } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultWidth';
|
||||||
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { useCallback } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import type { SubmitHandler } from 'react-hook-form';
|
import type { SubmitHandler } from 'react-hook-form';
|
||||||
import { useForm } from 'react-hook-form';
|
import { useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiCheckBold } from 'react-icons/pi';
|
import { PiCheckBold } from 'react-icons/pi';
|
||||||
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||||
|
import type { MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
||||||
import { DefaultCfgScale } from './DefaultCfgScale';
|
import { DefaultCfgScale } from './DefaultCfgScale';
|
||||||
@ -35,16 +37,16 @@ export type MainModelDefaultSettingsFormData = {
|
|||||||
height: FormField<number>;
|
height: FormField<number>;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const MainModelDefaultSettings = () => {
|
type Props = {
|
||||||
|
modelConfig: MainModelConfig;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const MainModelDefaultSettings = memo(({ modelConfig }: Props) => {
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const {
|
const defaultSettingsDefaults = useMainModelDefaultSettings(modelConfig);
|
||||||
defaultSettingsDefaults,
|
const optimalDimension = useMemo(() => getOptimalDimension(modelConfig), [modelConfig]);
|
||||||
isLoading: isLoadingDefaultSettings,
|
|
||||||
optimalDimension,
|
|
||||||
} = useMainModelDefaultSettings(selectedModelKey);
|
|
||||||
|
|
||||||
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
|
const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();
|
||||||
|
|
||||||
const { handleSubmit, control, formState, reset } = useForm<MainModelDefaultSettingsFormData>({
|
const { handleSubmit, control, formState, reset } = useForm<MainModelDefaultSettingsFormData>({
|
||||||
@ -94,10 +96,6 @@ export const MainModelDefaultSettings = () => {
|
|||||||
[selectedModelKey, reset, updateModel, t]
|
[selectedModelKey, reset, updateModel, t]
|
||||||
);
|
);
|
||||||
|
|
||||||
if (isLoadingDefaultSettings) {
|
|
||||||
return <Text>{t('common.loading')}</Text>;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Flex gap="4" justifyContent="space-between" w="full" pb={4}>
|
<Flex gap="4" justifyContent="space-between" w="full" pb={4}>
|
||||||
@ -126,4 +124,6 @@ export const MainModelDefaultSettings = () => {
|
|||||||
</SimpleGrid>
|
</SimpleGrid>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
MainModelDefaultSettings.displayName = 'MainModelDefaultSettings';
|
||||||
|
@ -1,120 +1,47 @@
|
|||||||
import { Button, Flex, Heading, Spacer, Text } from '@invoke-ai/ui-library';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
import { IAINoContentFallback, IAINoContentFallbackWithSpinner } from 'common/components/IAIImageFallback';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { memo, useMemo } from 'react';
|
||||||
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
|
||||||
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
|
|
||||||
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
|
|
||||||
import { toast } from 'features/toast/toast';
|
|
||||||
import { useCallback } from 'react';
|
|
||||||
import type { SubmitHandler } from 'react-hook-form';
|
|
||||||
import { useForm } from 'react-hook-form';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiCheckBold, PiXBold } from 'react-icons/pi';
|
import { PiExclamationMarkBold } from 'react-icons/pi';
|
||||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models';
|
||||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
import ModelImageUpload from './Fields/ModelImageUpload';
|
|
||||||
import { ModelEdit } from './ModelEdit';
|
import { ModelEdit } from './ModelEdit';
|
||||||
import { ModelView } from './ModelView';
|
import { ModelView } from './ModelView';
|
||||||
|
|
||||||
export const Model = () => {
|
export const Model = memo(() => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
|
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
const { data: modelConfigs, isLoading } = useGetModelConfigsQuery();
|
||||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
const modelConfig = useMemo(() => {
|
||||||
const dispatch = useAppDispatch();
|
if (!modelConfigs) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (selectedModelKey === null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
const modelConfig = modelConfigsAdapterSelectors.selectById(modelConfigs, selectedModelKey);
|
||||||
|
|
||||||
const form = useForm<UpdateModelArg['body']>({
|
if (!modelConfig) {
|
||||||
defaultValues: data,
|
return null;
|
||||||
mode: 'onChange',
|
|
||||||
});
|
|
||||||
|
|
||||||
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
|
|
||||||
(values) => {
|
|
||||||
if (!data?.key) {
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const responseBody: UpdateModelArg = {
|
return modelConfig;
|
||||||
key: data.key,
|
}, [modelConfigs, selectedModelKey]);
|
||||||
body: values,
|
|
||||||
};
|
|
||||||
|
|
||||||
updateModel(responseBody)
|
|
||||||
.unwrap()
|
|
||||||
.then((payload) => {
|
|
||||||
form.reset(payload, { keepDefaultValues: true });
|
|
||||||
dispatch(setSelectedModelMode('view'));
|
|
||||||
toast({
|
|
||||||
id: 'MODEL_UPDATED',
|
|
||||||
title: t('modelManager.modelUpdated'),
|
|
||||||
status: 'success',
|
|
||||||
});
|
|
||||||
})
|
|
||||||
.catch((_) => {
|
|
||||||
form.reset();
|
|
||||||
toast({
|
|
||||||
id: 'MODEL_UPDATE_FAILED',
|
|
||||||
title: t('modelManager.modelUpdateFailed'),
|
|
||||||
status: 'error',
|
|
||||||
});
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[dispatch, data?.key, form, t, updateModel]
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleClickCancel = useCallback(() => {
|
|
||||||
dispatch(setSelectedModelMode('view'));
|
|
||||||
}, [dispatch]);
|
|
||||||
|
|
||||||
if (isLoading) {
|
if (isLoading) {
|
||||||
return <Text>{t('common.loading')}</Text>;
|
return <IAINoContentFallbackWithSpinner label={t('common.loading')} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!data) {
|
if (!modelConfig) {
|
||||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
return <IAINoContentFallback label={t('common.somethingWentWrong')} icon={PiExclamationMarkBold} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
if (selectedModelMode === 'view') {
|
||||||
<Flex flexDir="column" gap={4}>
|
return <ModelView modelConfig={modelConfig} />;
|
||||||
<Flex alignItems="flex-start" gap={4}>
|
}
|
||||||
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
|
|
||||||
<Flex flexDir="column" gap={1} flexGrow={1} minW={0}>
|
return <ModelEdit modelConfig={modelConfig} />;
|
||||||
<Flex gap={2}>
|
});
|
||||||
<Heading as="h2" fontSize="lg" noOfLines={1} wordBreak="break-all">
|
|
||||||
{data.name}
|
Model.displayName = 'Model';
|
||||||
</Heading>
|
|
||||||
<Spacer />
|
|
||||||
{selectedModelMode === 'view' && <ModelConvertButton modelKey={selectedModelKey} />}
|
|
||||||
{selectedModelMode === 'view' && <ModelEditButton />}
|
|
||||||
{selectedModelMode === 'edit' && (
|
|
||||||
<Button size="sm" onClick={handleClickCancel} leftIcon={<PiXBold />}>
|
|
||||||
{t('common.cancel')}
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
{selectedModelMode === 'edit' && (
|
|
||||||
<Button
|
|
||||||
size="sm"
|
|
||||||
colorScheme="invokeYellow"
|
|
||||||
leftIcon={<PiCheckBold />}
|
|
||||||
onClick={form.handleSubmit(onSubmit)}
|
|
||||||
isLoading={isSubmitting}
|
|
||||||
isDisabled={Boolean(Object.keys(form.formState.errors).length)}
|
|
||||||
>
|
|
||||||
{t('common.save')}
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
{data.source && (
|
|
||||||
<Text variant="subtext" noOfLines={1} wordBreak="break-all">
|
|
||||||
{t('modelManager.source')}: {data?.source}
|
|
||||||
</Text>
|
|
||||||
)}
|
|
||||||
<Text noOfLines={3}>{data.description}</Text>
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
{selectedModelMode === 'view' ? <ModelView /> : <ModelEdit form={form} onSubmit={onSubmit} />}
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import { FormControl, FormLabel, Text } from '@invoke-ai/ui-library';
|
import { FormControl, FormLabel, Text } from '@invoke-ai/ui-library';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
label: string;
|
label: string;
|
||||||
value: string | null | undefined;
|
value: string | null | undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const ModelAttrView = ({ label, value }: Props) => {
|
export const ModelAttrView = memo(({ label, value }: Props) => {
|
||||||
return (
|
return (
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={0}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={0}>
|
||||||
<FormLabel>{label}</FormLabel>
|
<FormLabel>{label}</FormLabel>
|
||||||
@ -14,4 +15,6 @@ export const ModelAttrView = ({ label, value }: Props) => {
|
|||||||
</Text>
|
</Text>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
ModelAttrView.displayName = 'ModelAttrView';
|
||||||
|
@ -8,52 +8,46 @@ import {
|
|||||||
UnorderedList,
|
UnorderedList,
|
||||||
useDisclosure,
|
useDisclosure,
|
||||||
} from '@invoke-ai/ui-library';
|
} from '@invoke-ai/ui-library';
|
||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
|
||||||
import { toast } from 'features/toast/toast';
|
import { toast } from 'features/toast/toast';
|
||||||
import { useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useConvertModelMutation, useGetModelConfigQuery } from 'services/api/endpoints/models';
|
import { useConvertModelMutation } from 'services/api/endpoints/models';
|
||||||
|
import type { CheckpointModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
interface ModelConvertProps {
|
interface ModelConvertProps {
|
||||||
modelKey: string | null;
|
modelConfig: CheckpointModelConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const ModelConvertButton = (props: ModelConvertProps) => {
|
export const ModelConvertButton = memo(({ modelConfig }: ModelConvertProps) => {
|
||||||
const { modelKey } = props;
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { data } = useGetModelConfigQuery(modelKey ?? skipToken);
|
|
||||||
const [convertModel, { isLoading }] = useConvertModelMutation();
|
const [convertModel, { isLoading }] = useConvertModelMutation();
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
|
|
||||||
const modelConvertHandler = useCallback(() => {
|
const modelConvertHandler = useCallback(() => {
|
||||||
if (!data || isLoading) {
|
if (!modelConfig || isLoading) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const toastId = `CONVERTING_MODEL_${data.key}`;
|
const toastId = `CONVERTING_MODEL_${modelConfig.key}`;
|
||||||
toast({
|
toast({
|
||||||
id: toastId,
|
id: toastId,
|
||||||
title: `${t('modelManager.convertingModelBegin')}: ${data?.name}`,
|
title: `${t('modelManager.convertingModelBegin')}: ${modelConfig.name}`,
|
||||||
status: 'info',
|
status: 'info',
|
||||||
});
|
});
|
||||||
|
|
||||||
convertModel(data?.key)
|
convertModel(modelConfig.key)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then(() => {
|
.then(() => {
|
||||||
toast({ id: toastId, title: `${t('modelManager.modelConverted')}: ${data?.name}`, status: 'success' });
|
toast({ id: toastId, title: `${t('modelManager.modelConverted')}: ${modelConfig.name}`, status: 'success' });
|
||||||
})
|
})
|
||||||
.catch(() => {
|
.catch(() => {
|
||||||
toast({
|
toast({
|
||||||
id: toastId,
|
id: toastId,
|
||||||
title: `${t('modelManager.modelConversionFailed')}: ${data?.name}`,
|
title: `${t('modelManager.modelConversionFailed')}: ${modelConfig.name}`,
|
||||||
status: 'error',
|
status: 'error',
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}, [data, isLoading, t, convertModel]);
|
}, [modelConfig, isLoading, t, convertModel]);
|
||||||
|
|
||||||
if (data?.format !== 'checkpoint') {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
@ -68,7 +62,7 @@ export const ModelConvertButton = (props: ModelConvertProps) => {
|
|||||||
🧨 {t('modelManager.convert')}
|
🧨 {t('modelManager.convert')}
|
||||||
</Button>
|
</Button>
|
||||||
<ConfirmationAlertDialog
|
<ConfirmationAlertDialog
|
||||||
title={`${t('modelManager.convert')} ${data?.name}`}
|
title={`${t('modelManager.convert')} ${modelConfig.name}`}
|
||||||
acceptCallback={modelConvertHandler}
|
acceptCallback={modelConvertHandler}
|
||||||
acceptButtonText={`${t('modelManager.convert')}`}
|
acceptButtonText={`${t('modelManager.convert')}`}
|
||||||
isOpen={isOpen}
|
isOpen={isOpen}
|
||||||
@ -96,4 +90,6 @@ export const ModelConvertButton = (props: ModelConvertProps) => {
|
|||||||
</ConfirmationAlertDialog>
|
</ConfirmationAlertDialog>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
ModelConvertButton.displayName = 'ModelConvertButton';
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
|
Button,
|
||||||
Checkbox,
|
Checkbox,
|
||||||
Flex,
|
Flex,
|
||||||
FormControl,
|
FormControl,
|
||||||
@ -7,47 +8,102 @@ import {
|
|||||||
Heading,
|
Heading,
|
||||||
Input,
|
Input,
|
||||||
SimpleGrid,
|
SimpleGrid,
|
||||||
Text,
|
|
||||||
Textarea,
|
Textarea,
|
||||||
} from '@invoke-ai/ui-library';
|
} from '@invoke-ai/ui-library';
|
||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import type { SubmitHandler, UseFormReturn } from 'react-hook-form';
|
import { ModelHeader } from 'features/modelManagerV2/subpanels/ModelPanel/ModelHeader';
|
||||||
|
import { toast } from 'features/toast/toast';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { type SubmitHandler, useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
import { PiCheckBold, PiXBold } from 'react-icons/pi';
|
||||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
import { type UpdateModelArg, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import BaseModelSelect from './Fields/BaseModelSelect';
|
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||||
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
||||||
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
form: UseFormReturn<UpdateModelArg['body']>;
|
modelConfig: AnyModelConfig;
|
||||||
onSubmit: SubmitHandler<UpdateModelArg['body']>;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const stringFieldOptions = {
|
const stringFieldOptions = {
|
||||||
validate: (value?: string | null) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
|
validate: (value?: string | null) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const ModelEdit = ({ form }: Props) => {
|
export const ModelEdit = memo(({ modelConfig }: Props) => {
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
|
||||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
if (isLoading) {
|
const form = useForm<UpdateModelArg['body']>({
|
||||||
return <Text>{t('common.loading')}</Text>;
|
defaultValues: modelConfig,
|
||||||
}
|
mode: 'onChange',
|
||||||
|
});
|
||||||
|
|
||||||
if (!data) {
|
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
|
||||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
(values) => {
|
||||||
}
|
const responseBody: UpdateModelArg = {
|
||||||
|
key: modelConfig.key,
|
||||||
|
body: values,
|
||||||
|
};
|
||||||
|
|
||||||
|
updateModel(responseBody)
|
||||||
|
.unwrap()
|
||||||
|
.then((payload) => {
|
||||||
|
form.reset(payload, { keepDefaultValues: true });
|
||||||
|
dispatch(setSelectedModelMode('view'));
|
||||||
|
toast({
|
||||||
|
id: 'MODEL_UPDATED',
|
||||||
|
title: t('modelManager.modelUpdated'),
|
||||||
|
status: 'success',
|
||||||
|
});
|
||||||
|
})
|
||||||
|
.catch((_) => {
|
||||||
|
form.reset();
|
||||||
|
toast({
|
||||||
|
id: 'MODEL_UPDATE_FAILED',
|
||||||
|
title: t('modelManager.modelUpdateFailed'),
|
||||||
|
status: 'error',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[dispatch, modelConfig.key, form, t, updateModel]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleClickCancel = useCallback(() => {
|
||||||
|
dispatch(setSelectedModelMode('view'));
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
<Flex flexDir="column" gap={4}>
|
||||||
|
<ModelHeader modelConfig={modelConfig}>
|
||||||
|
<Button flexShrink={0} size="sm" onClick={handleClickCancel} leftIcon={<PiXBold />}>
|
||||||
|
{t('common.cancel')}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
flexShrink={0}
|
||||||
|
size="sm"
|
||||||
|
colorScheme="invokeYellow"
|
||||||
|
leftIcon={<PiCheckBold />}
|
||||||
|
onClick={form.handleSubmit(onSubmit)}
|
||||||
|
isLoading={isSubmitting}
|
||||||
|
isDisabled={Boolean(Object.keys(form.formState.errors).length)}
|
||||||
|
>
|
||||||
|
{t('common.save')}
|
||||||
|
</Button>
|
||||||
|
</ModelHeader>
|
||||||
<Flex flexDir="column" h="full">
|
<Flex flexDir="column" h="full">
|
||||||
<form>
|
<form>
|
||||||
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(form.formState.errors.name)}>
|
<FormControl
|
||||||
|
flexDir="column"
|
||||||
|
alignItems="flex-start"
|
||||||
|
gap={1}
|
||||||
|
isInvalid={Boolean(form.formState.errors.name)}
|
||||||
|
>
|
||||||
<FormLabel>{t('modelManager.modelName')}</FormLabel>
|
<FormLabel>{t('modelManager.modelName')}</FormLabel>
|
||||||
<Input {...form.register('name', stringFieldOptions)} size="md" />
|
<Input {...form.register('name', stringFieldOptions)} size="md" />
|
||||||
|
|
||||||
@ -72,13 +128,13 @@ export const ModelEdit = ({ form }: Props) => {
|
|||||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||||
<BaseModelSelect control={form.control} />
|
<BaseModelSelect control={form.control} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
{data.type === 'main' && (
|
{modelConfig.type === 'main' && (
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||||
<ModelVariantSelect control={form.control} />
|
<ModelVariantSelect control={form.control} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
)}
|
)}
|
||||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
{modelConfig.type === 'main' && modelConfig.format === 'checkpoint' && (
|
||||||
<>
|
<>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||||
@ -98,5 +154,8 @@ export const ModelEdit = ({ form }: Props) => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</form>
|
</form>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
ModelEdit.displayName = 'ModelEdit';
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
import { Button } from '@invoke-ai/ui-library';
|
import { Button } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import { useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { IoPencil } from 'react-icons/io5';
|
import { IoPencil } from 'react-icons/io5';
|
||||||
|
|
||||||
export const ModelEditButton = () => {
|
export const ModelEditButton = memo(() => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
@ -18,4 +18,6 @@ export const ModelEditButton = () => {
|
|||||||
{t('modelManager.edit')}
|
{t('modelManager.edit')}
|
||||||
</Button>
|
</Button>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
ModelEditButton.displayName = 'ModelEditButton';
|
||||||
|
@ -0,0 +1,36 @@
|
|||||||
|
import { Flex, Heading, Spacer, Text } from '@invoke-ai/ui-library';
|
||||||
|
import ModelImageUpload from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelImageUpload';
|
||||||
|
import type { PropsWithChildren } from 'react';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
type Props = PropsWithChildren<{
|
||||||
|
modelConfig: AnyModelConfig;
|
||||||
|
}>;
|
||||||
|
|
||||||
|
export const ModelHeader = memo(({ modelConfig, children }: Props) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
return (
|
||||||
|
<Flex alignItems="flex-start" gap={4}>
|
||||||
|
<ModelImageUpload model_key={modelConfig.key} model_image={modelConfig.cover_image} />
|
||||||
|
<Flex flexDir="column" gap={1} flexGrow={1} minW={0}>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<Heading as="h2" fontSize="lg" noOfLines={1} wordBreak="break-all">
|
||||||
|
{modelConfig.name}
|
||||||
|
</Heading>
|
||||||
|
<Spacer />
|
||||||
|
{children}
|
||||||
|
</Flex>
|
||||||
|
{modelConfig.source && (
|
||||||
|
<Text variant="subtext" noOfLines={1} wordBreak="break-all">
|
||||||
|
{t('modelManager.source')}: {modelConfig.source}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
<Text noOfLines={3}>{modelConfig.description}</Text>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
ModelHeader.displayName = 'ModelHeader';
|
@ -1,55 +1,67 @@
|
|||||||
import { Box, Flex, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
import { Box, Flex, SimpleGrid } from '@invoke-ai/ui-library';
|
||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { ControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
|
import { ControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
|
||||||
|
import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelConvertButton';
|
||||||
|
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
|
||||||
|
import { ModelHeader } from 'features/modelManagerV2/subpanels/ModelPanel/ModelHeader';
|
||||||
import { TriggerPhrases } from 'features/modelManagerV2/subpanels/ModelPanel/TriggerPhrases';
|
import { TriggerPhrases } from 'features/modelManagerV2/subpanels/ModelPanel/TriggerPhrases';
|
||||||
|
import { memo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import { MainModelDefaultSettings } from './MainModelDefaultSettings/MainModelDefaultSettings';
|
import { MainModelDefaultSettings } from './MainModelDefaultSettings/MainModelDefaultSettings';
|
||||||
import { ModelAttrView } from './ModelAttrView';
|
import { ModelAttrView } from './ModelAttrView';
|
||||||
|
|
||||||
export const ModelView = () => {
|
type Props = {
|
||||||
|
modelConfig: AnyModelConfig;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const ModelView = memo(({ modelConfig }: Props) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
|
||||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
|
||||||
|
|
||||||
if (isLoading) {
|
|
||||||
return <Text>{t('common.loading')}</Text>;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!data) {
|
|
||||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
|
||||||
}
|
|
||||||
return (
|
return (
|
||||||
|
<Flex flexDir="column" gap={4}>
|
||||||
|
<ModelHeader modelConfig={modelConfig}>
|
||||||
|
{modelConfig.format === 'checkpoint' && modelConfig.type === 'main' && (
|
||||||
|
<ModelConvertButton modelConfig={modelConfig} />
|
||||||
|
)}
|
||||||
|
<ModelEditButton />
|
||||||
|
</ModelHeader>
|
||||||
<Flex flexDir="column" h="full" gap={4}>
|
<Flex flexDir="column" h="full" gap={4}>
|
||||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||||
<SimpleGrid columns={2} gap={4}>
|
<SimpleGrid columns={2} gap={4}>
|
||||||
<ModelAttrView label={t('modelManager.baseModel')} value={data.base} />
|
<ModelAttrView label={t('modelManager.baseModel')} value={modelConfig.base} />
|
||||||
<ModelAttrView label={t('modelManager.modelType')} value={data.type} />
|
<ModelAttrView label={t('modelManager.modelType')} value={modelConfig.type} />
|
||||||
<ModelAttrView label={t('common.format')} value={data.format} />
|
<ModelAttrView label={t('common.format')} value={modelConfig.format} />
|
||||||
<ModelAttrView label={t('modelManager.path')} value={data.path} />
|
<ModelAttrView label={t('modelManager.path')} value={modelConfig.path} />
|
||||||
{data.type === 'main' && <ModelAttrView label={t('modelManager.variant')} value={data.variant} />}
|
{modelConfig.type === 'main' && (
|
||||||
{data.type === 'main' && data.format === 'diffusers' && data.repo_variant && (
|
<ModelAttrView label={t('modelManager.variant')} value={modelConfig.variant} />
|
||||||
<ModelAttrView label={t('modelManager.repoVariant')} value={data.repo_variant} />
|
|
||||||
)}
|
)}
|
||||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
{modelConfig.type === 'main' && modelConfig.format === 'diffusers' && modelConfig.repo_variant && (
|
||||||
|
<ModelAttrView label={t('modelManager.repoVariant')} value={modelConfig.repo_variant} />
|
||||||
|
)}
|
||||||
|
{modelConfig.type === 'main' && modelConfig.format === 'checkpoint' && (
|
||||||
<>
|
<>
|
||||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={data.config_path} />
|
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelConfig.config_path} />
|
||||||
<ModelAttrView label={t('modelManager.predictionType')} value={data.prediction_type} />
|
<ModelAttrView label={t('modelManager.predictionType')} value={modelConfig.prediction_type} />
|
||||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${data.upcast_attention}`} />
|
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelConfig.upcast_attention}`} />
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
{data.type === 'ip_adapter' && data.format === 'invokeai' && (
|
{modelConfig.type === 'ip_adapter' && modelConfig.format === 'invokeai' && (
|
||||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
|
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={modelConfig.image_encoder_model_id} />
|
||||||
)}
|
)}
|
||||||
</SimpleGrid>
|
</SimpleGrid>
|
||||||
</Box>
|
</Box>
|
||||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||||
{data.type === 'main' && data.base !== 'sdxl-refiner' && <MainModelDefaultSettings />}
|
{modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner' && (
|
||||||
{(data.type === 'controlnet' || data.type === 't2i_adapter') && <ControlNetOrT2IAdapterDefaultSettings />}
|
<MainModelDefaultSettings modelConfig={modelConfig} />
|
||||||
{(data.type === 'main' || data.type === 'lora') && <TriggerPhrases />}
|
)}
|
||||||
|
{(modelConfig.type === 'controlnet' || modelConfig.type === 't2i_adapter') && (
|
||||||
|
<ControlNetOrT2IAdapterDefaultSettings modelConfig={modelConfig} />
|
||||||
|
)}
|
||||||
|
{(modelConfig.type === 'main' || modelConfig.type === 'lora') && <TriggerPhrases modelConfig={modelConfig} />}
|
||||||
</Box>
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
ModelView.displayName = 'ModelView';
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Switch } from '@invoke-ai/ui-library';
|
import { Switch, typedMemo } from '@invoke-ai/ui-library';
|
||||||
import type { ChangeEvent } from 'react';
|
import type { ChangeEvent } from 'react';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
@ -6,7 +6,7 @@ import { useController } from 'react-hook-form';
|
|||||||
|
|
||||||
import type { FormField } from './MainModelDefaultSettings/MainModelDefaultSettings';
|
import type { FormField } from './MainModelDefaultSettings/MainModelDefaultSettings';
|
||||||
|
|
||||||
export function SettingToggle<T, F extends Record<string, FormField<T>>>(props: UseControllerProps<F>) {
|
export const SettingToggle = typedMemo(<T, F extends Record<string, FormField<T>>>(props: UseControllerProps<F>) => {
|
||||||
const { field } = useController(props);
|
const { field } = useController(props);
|
||||||
|
|
||||||
const value = useMemo(() => {
|
const value = useMemo(() => {
|
||||||
@ -25,4 +25,6 @@ export function SettingToggle<T, F extends Record<string, FormField<T>>>(props:
|
|||||||
);
|
);
|
||||||
|
|
||||||
return <Switch size="sm" isChecked={value} onChange={onChange} />;
|
return <Switch size="sm" isChecked={value} onChange={onChange} />;
|
||||||
}
|
});
|
||||||
|
|
||||||
|
SettingToggle.displayName = 'SettingToggle';
|
||||||
|
@ -9,19 +9,19 @@ import {
|
|||||||
TagCloseButton,
|
TagCloseButton,
|
||||||
TagLabel,
|
TagLabel,
|
||||||
} from '@invoke-ai/ui-library';
|
} from '@invoke-ai/ui-library';
|
||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import type { ChangeEvent } from 'react';
|
import type { ChangeEvent } from 'react';
|
||||||
import { useCallback, useMemo, useState } from 'react';
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiPlusBold } from 'react-icons/pi';
|
import { PiPlusBold } from 'react-icons/pi';
|
||||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||||
import { isLoRAModelConfig, isNonRefinerMainModelConfig } from 'services/api/types';
|
import type { LoRAModelConfig, MainModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
export const TriggerPhrases = () => {
|
type Props = {
|
||||||
|
modelConfig: MainModelConfig | LoRAModelConfig;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const TriggerPhrases = memo(({ modelConfig }: Props) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
|
||||||
const { currentData: modelConfig } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
|
||||||
const [phrase, setPhrase] = useState('');
|
const [phrase, setPhrase] = useState('');
|
||||||
|
|
||||||
const [updateModel, { isLoading }] = useUpdateModelMutation();
|
const [updateModel, { isLoading }] = useUpdateModelMutation();
|
||||||
@ -31,9 +31,6 @@ export const TriggerPhrases = () => {
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const triggerPhrases = useMemo(() => {
|
const triggerPhrases = useMemo(() => {
|
||||||
if (!modelConfig || (!isNonRefinerMainModelConfig(modelConfig) && !isLoRAModelConfig(modelConfig))) {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
return modelConfig?.trigger_phrases || [];
|
return modelConfig?.trigger_phrases || [];
|
||||||
}, [modelConfig]);
|
}, [modelConfig]);
|
||||||
|
|
||||||
@ -48,10 +45,6 @@ export const TriggerPhrases = () => {
|
|||||||
}, [phrase, triggerPhrases]);
|
}, [phrase, triggerPhrases]);
|
||||||
|
|
||||||
const addTriggerPhrase = useCallback(async () => {
|
const addTriggerPhrase = useCallback(async () => {
|
||||||
if (!selectedModelKey) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!phrase.length || triggerPhrases.includes(phrase)) {
|
if (!phrase.length || triggerPhrases.includes(phrase)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -59,22 +52,18 @@ export const TriggerPhrases = () => {
|
|||||||
setPhrase('');
|
setPhrase('');
|
||||||
|
|
||||||
await updateModel({
|
await updateModel({
|
||||||
key: selectedModelKey,
|
key: modelConfig.key,
|
||||||
body: { trigger_phrases: [...triggerPhrases, phrase] },
|
body: { trigger_phrases: [...triggerPhrases, phrase] },
|
||||||
}).unwrap();
|
}).unwrap();
|
||||||
}, [updateModel, selectedModelKey, phrase, triggerPhrases]);
|
}, [phrase, triggerPhrases, updateModel, modelConfig.key]);
|
||||||
|
|
||||||
const removeTriggerPhrase = useCallback(
|
const removeTriggerPhrase = useCallback(
|
||||||
async (phraseToRemove: string) => {
|
async (phraseToRemove: string) => {
|
||||||
if (!selectedModelKey) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const filteredPhrases = triggerPhrases.filter((p) => p !== phraseToRemove);
|
const filteredPhrases = triggerPhrases.filter((p) => p !== phraseToRemove);
|
||||||
|
|
||||||
await updateModel({ key: selectedModelKey, body: { trigger_phrases: filteredPhrases } }).unwrap();
|
await updateModel({ key: modelConfig.key, body: { trigger_phrases: filteredPhrases } }).unwrap();
|
||||||
},
|
},
|
||||||
[updateModel, selectedModelKey, triggerPhrases]
|
[triggerPhrases, updateModel, modelConfig]
|
||||||
);
|
);
|
||||||
|
|
||||||
const onTriggerPhraseAddFormSubmit = useCallback(
|
const onTriggerPhraseAddFormSubmit = useCallback(
|
||||||
@ -103,7 +92,9 @@ export const TriggerPhrases = () => {
|
|||||||
{t('common.add')}
|
{t('common.add')}
|
||||||
</Button>
|
</Button>
|
||||||
</Flex>
|
</Flex>
|
||||||
{!!errors.length && errors.map((error) => <FormErrorMessage key={error}>{error}</FormErrorMessage>)}
|
{errors.map((error) => (
|
||||||
|
<FormErrorMessage key={error}>{error}</FormErrorMessage>
|
||||||
|
))}
|
||||||
</Flex>
|
</Flex>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</form>
|
</form>
|
||||||
@ -118,4 +109,6 @@ export const TriggerPhrases = () => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
});
|
||||||
|
|
||||||
|
TriggerPhrases.displayName = 'TriggerPhrases';
|
||||||
|
@ -59,17 +59,19 @@ const pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
|
|||||||
for (const edge of copiedEdges) {
|
for (const edge of copiedEdges) {
|
||||||
if (edge.source === node.id) {
|
if (edge.source === node.id) {
|
||||||
edge.source = id;
|
edge.source = id;
|
||||||
edge.id = edge.id.replace(node.data.id, id);
|
} else if (edge.target === node.id) {
|
||||||
}
|
|
||||||
if (edge.target === node.id) {
|
|
||||||
edge.target = id;
|
edge.target = id;
|
||||||
edge.id = edge.id.replace(node.data.id, id);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
node.id = id;
|
node.id = id;
|
||||||
node.data.id = id;
|
node.data.id = id;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
copiedEdges.forEach((edge) => {
|
||||||
|
// Copied edges need a fresh id too
|
||||||
|
edge.id = uuidv4();
|
||||||
|
});
|
||||||
|
|
||||||
const nodeChanges: NodeChange[] = [];
|
const nodeChanges: NodeChange[] = [];
|
||||||
const edgeChanges: EdgeChange[] = [];
|
const edgeChanges: EdgeChange[] = [];
|
||||||
// Deselect existing nodes
|
// Deselect existing nodes
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { creativityChanged } from 'features/parameters/store/upscaleSlice';
|
import { creativityChanged } from 'features/parameters/store/upscaleSlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -25,7 +26,9 @@ const ParamCreativity = () => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<FormControl>
|
<FormControl>
|
||||||
|
<InformationalPopover feature="creativity">
|
||||||
<FormLabel>{t('upscaling.creativity')}</FormLabel>
|
<FormLabel>{t('upscaling.creativity')}</FormLabel>
|
||||||
|
</InformationalPopover>
|
||||||
<CompositeSlider
|
<CompositeSlider
|
||||||
value={creativity}
|
value={creativity}
|
||||||
defaultValue={initial}
|
defaultValue={initial}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { Box, Combobox, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
|
import { Box, Combobox, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { useModelCombobox } from 'common/hooks/useModelCombobox';
|
import { useModelCombobox } from 'common/hooks/useModelCombobox';
|
||||||
import { upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
|
import { upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
@ -37,7 +38,9 @@ const ParamSpandrelModel = () => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<FormControl orientation="vertical">
|
<FormControl orientation="vertical">
|
||||||
|
<InformationalPopover feature="upscaleModel">
|
||||||
<FormLabel>{t('upscaling.upscaleModel')}</FormLabel>
|
<FormLabel>{t('upscaling.upscaleModel')}</FormLabel>
|
||||||
|
</InformationalPopover>
|
||||||
<Tooltip label={tooltipLabel}>
|
<Tooltip label={tooltipLabel}>
|
||||||
<Box w="full">
|
<Box w="full">
|
||||||
<Combobox
|
<Combobox
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { structureChanged } from 'features/parameters/store/upscaleSlice';
|
import { structureChanged } from 'features/parameters/store/upscaleSlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -25,7 +26,9 @@ const ParamStructure = () => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<FormControl>
|
<FormControl>
|
||||||
|
<InformationalPopover feature="structure">
|
||||||
<FormLabel>{t('upscaling.structure')}</FormLabel>
|
<FormLabel>{t('upscaling.structure')}</FormLabel>
|
||||||
|
</InformationalPopover>
|
||||||
<CompositeSlider
|
<CompositeSlider
|
||||||
value={structure}
|
value={structure}
|
||||||
defaultValue={initial}
|
defaultValue={initial}
|
||||||
|
@ -64,7 +64,7 @@ export const AdvancedSettingsAccordion = memo(() => {
|
|||||||
const badges = useAppSelector(selectBadges);
|
const badges = useAppSelector(selectBadges);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { isOpen, onToggle } = useStandaloneAccordionToggle({
|
const { isOpen, onToggle } = useStandaloneAccordionToggle({
|
||||||
id: 'advanced-settings',
|
id: `'advanced-settings-${activeTabName}`,
|
||||||
defaultIsOpen: false,
|
defaultIsOpen: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ import ParamMainModelSelect from 'features/parameters/components/MainModel/Param
|
|||||||
import { UseDefaultSettingsButton } from 'features/parameters/components/MainModel/UseDefaultSettingsButton';
|
import { UseDefaultSettingsButton } from 'features/parameters/components/MainModel/UseDefaultSettingsButton';
|
||||||
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
|
||||||
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
|
||||||
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { filter } from 'lodash-es';
|
import { filter } from 'lodash-es';
|
||||||
import { memo, useMemo } from 'react';
|
import { memo, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -26,6 +27,7 @@ const formLabelProps: FormLabelProps = {
|
|||||||
export const GenerationSettingsAccordion = memo(() => {
|
export const GenerationSettingsAccordion = memo(() => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const modelConfig = useSelectedModelConfig();
|
const modelConfig = useSelectedModelConfig();
|
||||||
|
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||||
const selectBadges = useMemo(
|
const selectBadges = useMemo(
|
||||||
() =>
|
() =>
|
||||||
createMemoizedSelector(selectLoraSlice, (lora) => {
|
createMemoizedSelector(selectLoraSlice, (lora) => {
|
||||||
@ -42,8 +44,8 @@ export const GenerationSettingsAccordion = memo(() => {
|
|||||||
defaultIsOpen: false,
|
defaultIsOpen: false,
|
||||||
});
|
});
|
||||||
const { isOpen: isOpenAccordion, onToggle: onToggleAccordion } = useStandaloneAccordionToggle({
|
const { isOpen: isOpenAccordion, onToggle: onToggleAccordion } = useStandaloneAccordionToggle({
|
||||||
id: 'generation-settings',
|
id: `generation-settings-${activeTabName}`,
|
||||||
defaultIsOpen: true,
|
defaultIsOpen: activeTabName !== 'upscaling',
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { scaleChanged } from 'features/parameters/store/upscaleSlice';
|
import { scaleChanged } from 'features/parameters/store/upscaleSlice';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -22,7 +23,9 @@ export const UpscaleScaleSlider = memo(() => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<FormControl orientation="vertical" gap={0}>
|
<FormControl orientation="vertical" gap={0}>
|
||||||
|
<InformationalPopover feature="scale">
|
||||||
<FormLabel m={0}>{t('upscaling.scale')}</FormLabel>
|
<FormLabel m={0}>{t('upscaling.scale')}</FormLabel>
|
||||||
|
</InformationalPopover>
|
||||||
<Flex w="full" gap={4}>
|
<Flex w="full" gap={4}>
|
||||||
<CompositeSlider
|
<CompositeSlider
|
||||||
min={2}
|
min={2}
|
||||||
|
@ -18,5 +18,6 @@ export const useStandaloneAccordionToggle = (arg: UseStandaloneAccordionToggleAr
|
|||||||
const onToggle = useCallback(() => {
|
const onToggle = useCallback(() => {
|
||||||
dispatch(accordionStateChanged({ id: arg.id, isOpen: !isOpen }));
|
dispatch(accordionStateChanged({ id: arg.id, isOpen: !isOpen }));
|
||||||
}, [arg.id, dispatch, isOpen]);
|
}, [arg.id, dispatch, isOpen]);
|
||||||
|
|
||||||
return { isOpen, onToggle };
|
return { isOpen, onToggle };
|
||||||
};
|
};
|
||||||
|
@ -27,7 +27,7 @@ const initialSystemState: SystemState = {
|
|||||||
language: 'en',
|
language: 'en',
|
||||||
shouldUseNSFWChecker: false,
|
shouldUseNSFWChecker: false,
|
||||||
shouldUseWatermarker: false,
|
shouldUseWatermarker: false,
|
||||||
shouldEnableInformationalPopovers: false,
|
shouldEnableInformationalPopovers: true,
|
||||||
status: 'DISCONNECTED',
|
status: 'DISCONNECTED',
|
||||||
cancellations: [],
|
cancellations: [],
|
||||||
};
|
};
|
||||||
|
@ -242,7 +242,6 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
}
|
}
|
||||||
return tags;
|
return tags;
|
||||||
},
|
},
|
||||||
keepUnusedDataFor: 60 * 60 * 1000 * 24, // 1 day (infinite)
|
|
||||||
transformResponse: (response: GetModelConfigsResponse) => {
|
transformResponse: (response: GetModelConfigsResponse) => {
|
||||||
return modelConfigsAdapter.setAll(modelConfigsAdapter.getInitialState(), response.models);
|
return modelConfigsAdapter.setAll(modelConfigsAdapter.getInitialState(), response.models);
|
||||||
},
|
},
|
||||||
|
@ -54,7 +54,7 @@ export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
|||||||
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
|
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
|
||||||
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
||||||
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||||
type CheckpointModelConfig = S['MainCheckpointConfig'];
|
export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||||
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
||||||
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
||||||
export type AnyModelConfig =
|
export type AnyModelConfig =
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = "4.2.7rc1"
|
__version__ = "4.2.7"
|
||||||
|
Loading…
Reference in New Issue
Block a user