mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into stalker-modular_inpaint-2
This commit is contained in:
commit
693a3eaff5
@ -55,6 +55,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
FROM node:20-slim AS web-builder
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
ENV PATH="$PNPM_HOME:$PATH"
|
||||
RUN corepack use pnpm@8.x
|
||||
RUN corepack enable
|
||||
|
||||
WORKDIR /build
|
||||
|
@ -6,7 +6,7 @@ import pathlib
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse, HTMLResponse
|
||||
@ -430,13 +430,11 @@ async def delete_model_image(
|
||||
async def install_model(
|
||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
||||
# TODO(MM2): Can we type this?
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
|
||||
config: ModelRecordChanges = Body(
|
||||
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
example={"name": "string", "description": "string"},
|
||||
),
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""Install a model using a string identifier.
|
||||
|
||||
@ -451,8 +449,9 @@ async def install_model(
|
||||
- model/name:fp16:path/to/model.safetensors
|
||||
- model/name::path/to/model.safetensors
|
||||
|
||||
`config` is an optional dict containing model configuration values that will override
|
||||
the ones that are probed automatically.
|
||||
`config` is a ModelRecordChanges object. Fields in this object will override
|
||||
the ones that are probed automatically. Pass an empty object to accept
|
||||
all the defaults.
|
||||
|
||||
`access_token` is an optional access token for use with Urls that require
|
||||
authentication.
|
||||
@ -737,7 +736,7 @@ async def convert_model(
|
||||
# write the converted file to the convert path
|
||||
raw_model = converted_model.model
|
||||
assert hasattr(raw_model, "save_pretrained")
|
||||
raw_model.save_pretrained(convert_path)
|
||||
raw_model.save_pretrained(convert_path) # type: ignore
|
||||
assert convert_path.exists()
|
||||
|
||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||
@ -750,12 +749,12 @@ async def convert_model(
|
||||
try:
|
||||
new_key = installer.install_path(
|
||||
convert_path,
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"hash": model_config.hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
config=ModelRecordChanges(
|
||||
name=original_name,
|
||||
description=model_config.description,
|
||||
hash=model_config.hash,
|
||||
source=model_config.source,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
@ -39,7 +39,7 @@ from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
@ -58,9 +58,14 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
||||
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
||||
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
|
||||
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
|
||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
@ -465,6 +470,65 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return controlnet_data
|
||||
|
||||
@staticmethod
|
||||
def parse_controlnet_field(
|
||||
exit_stack: ExitStack,
|
||||
context: InvocationContext,
|
||||
control_input: ControlField | list[ControlField] | None,
|
||||
ext_manager: ExtensionsManager,
|
||||
) -> None:
|
||||
# Normalize control_input to a list.
|
||||
control_list: list[ControlField]
|
||||
if isinstance(control_input, ControlField):
|
||||
control_list = [control_input]
|
||||
elif isinstance(control_input, list):
|
||||
control_list = control_input
|
||||
elif control_input is None:
|
||||
control_list = []
|
||||
else:
|
||||
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
|
||||
|
||||
for control_info in control_list:
|
||||
model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||
ext_manager.add_extension(
|
||||
ControlNetExt(
|
||||
model=model,
|
||||
image=context.images.get_pil(control_info.image.image_name),
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent,
|
||||
control_mode=control_info.control_mode,
|
||||
resize_mode=control_info.resize_mode,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_t2i_adapter_field(
|
||||
exit_stack: ExitStack,
|
||||
context: InvocationContext,
|
||||
t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
||||
ext_manager: ExtensionsManager,
|
||||
) -> None:
|
||||
if t2i_adapters is None:
|
||||
return
|
||||
|
||||
# Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField.
|
||||
if isinstance(t2i_adapters, T2IAdapterField):
|
||||
t2i_adapters = [t2i_adapters]
|
||||
|
||||
for t2i_adapter_field in t2i_adapters:
|
||||
ext_manager.add_extension(
|
||||
T2IAdapterExt(
|
||||
node_context=context,
|
||||
model_id=t2i_adapter_field.t2i_adapter_model,
|
||||
image=context.images.get_pil(t2i_adapter_field.image.image_name),
|
||||
weight=t2i_adapter_field.weight,
|
||||
begin_step_percent=t2i_adapter_field.begin_step_percent,
|
||||
end_step_percent=t2i_adapter_field.end_step_percent,
|
||||
resize_mode=t2i_adapter_field.resize_mode,
|
||||
)
|
||||
)
|
||||
|
||||
def prep_ip_adapter_image_prompts(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@ -773,6 +837,18 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
ext_manager.add_extension(PreviewExt(step_callback))
|
||||
|
||||
### cfg rescale
|
||||
if self.cfg_rescale_multiplier > 0:
|
||||
ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier))
|
||||
|
||||
### freeu
|
||||
if self.unet.freeu_config:
|
||||
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
||||
|
||||
### seamless
|
||||
if self.unet.seamless_axes:
|
||||
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
|
||||
|
||||
### inpaint
|
||||
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
# NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
|
||||
@ -788,7 +864,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=device, dtype=dtype)
|
||||
|
||||
denoise_ctx = DenoiseContext(
|
||||
inputs=DenoiseInputs(
|
||||
orig_latents=latents,
|
||||
@ -804,18 +879,27 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
# context for loading additional models
|
||||
with ExitStack() as exit_stack:
|
||||
# later should be smth like:
|
||||
# for extension_field in self.extensions:
|
||||
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
|
||||
# ext_manager.add_extension(ext)
|
||||
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
|
||||
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
|
||||
|
||||
# ext: t2i/ip adapter
|
||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
unet_info.model_on_device() as (model_state_dict, unet),
|
||||
unet_info.model_on_device() as (cached_weights, unet),
|
||||
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
||||
# ext: controlnet
|
||||
ext_manager.patch_extensions(unet),
|
||||
ext_manager.patch_extensions(denoise_ctx),
|
||||
# ext: freeu, seamless, ip adapter, lora
|
||||
ext_manager.patch_unet(model_state_dict, unet),
|
||||
ext_manager.patch_unet(unet, cached_weights),
|
||||
):
|
||||
sd_backend = StableDiffusionBackend(unet, scheduler)
|
||||
denoise_ctx.unet = unet
|
||||
@ -882,7 +966,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ExitStack() as exit_stack,
|
||||
unet_info.model_on_device() as (model_state_dict, unet),
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_unet(
|
||||
unet,
|
||||
|
@ -24,7 +24,7 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion import set_seamless
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@ -59,7 +59,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
@ -21,7 +23,7 @@ from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
|
||||
from invokeai.backend.tiles.utils import TBLR, Tile
|
||||
|
||||
|
||||
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.1.0")
|
||||
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.3.0")
|
||||
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
|
||||
|
||||
@ -35,7 +37,8 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
|
||||
)
|
||||
|
||||
def _scale_tile(self, tile: Tile, scale: int) -> Tile:
|
||||
@classmethod
|
||||
def scale_tile(cls, tile: Tile, scale: int) -> Tile:
|
||||
return Tile(
|
||||
coords=TBLR(
|
||||
top=tile.coords.top * scale,
|
||||
@ -51,20 +54,22 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
),
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
||||
# revisit this.
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
@classmethod
|
||||
def upscale_image(
|
||||
cls,
|
||||
image: Image.Image,
|
||||
tile_size: int,
|
||||
spandrel_model: SpandrelImageToImageModel,
|
||||
is_canceled: Callable[[], bool],
|
||||
) -> Image.Image:
|
||||
# Compute the image tiles.
|
||||
if self.tile_size > 0:
|
||||
if tile_size > 0:
|
||||
min_overlap = 20
|
||||
tiles = calc_tiles_min_overlap(
|
||||
image_height=image.height,
|
||||
image_width=image.width,
|
||||
tile_height=self.tile_size,
|
||||
tile_width=self.tile_size,
|
||||
tile_height=tile_size,
|
||||
tile_width=tile_size,
|
||||
min_overlap=min_overlap,
|
||||
)
|
||||
else:
|
||||
@ -85,16 +90,9 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# Prepare input image for inference.
|
||||
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
|
||||
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
|
||||
# Run the model on each tile.
|
||||
with spandrel_model_info as spandrel_model:
|
||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||
|
||||
# Scale the tiles for re-assembling the final image.
|
||||
scale = spandrel_model.scale
|
||||
scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles]
|
||||
scaled_tiles = [cls.scale_tile(tile, scale=scale) for tile in tiles]
|
||||
|
||||
# Prepare the output tensor.
|
||||
_, channels, height, width = image_tensor.shape
|
||||
@ -104,9 +102,10 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||
|
||||
# Run the model on each tile.
|
||||
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
|
||||
# Exit early if the invocation has been canceled.
|
||||
if context.util.is_canceled():
|
||||
if is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# Extract the current tile from the input tensor.
|
||||
@ -140,5 +139,115 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# Convert the output tensor to a PIL image.
|
||||
np_image = output_tensor.detach().numpy().astype(np.uint8)
|
||||
pil_image = Image.fromarray(np_image)
|
||||
|
||||
return pil_image
|
||||
|
||||
@torch.inference_mode()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
||||
# revisit this.
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
|
||||
# Do the upscaling.
|
||||
with spandrel_model_info as spandrel_model:
|
||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||
|
||||
# Upscale the image
|
||||
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"spandrel_image_to_image_autoscale",
|
||||
title="Image-to-Image (Autoscale)",
|
||||
tags=["upscale"],
|
||||
category="upscale",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
|
||||
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel) until the target scale is reached."""
|
||||
|
||||
scale: float = InputField(
|
||||
default=4.0,
|
||||
gt=0.0,
|
||||
le=16.0,
|
||||
description="The final scale of the output image. If the model does not upscale the image, this will be ignored.",
|
||||
)
|
||||
fit_to_multiple_of_8: bool = InputField(
|
||||
default=False,
|
||||
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
||||
# revisit this.
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
|
||||
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
|
||||
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.
|
||||
target_width = int(image.width * self.scale)
|
||||
target_height = int(image.height * self.scale)
|
||||
|
||||
# Do the upscaling.
|
||||
with spandrel_model_info as spandrel_model:
|
||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||
|
||||
# First pass of upscaling. Note: `pil_image` will be mutated.
|
||||
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
|
||||
|
||||
# Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
|
||||
# upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions
|
||||
# to be considered an upscale model.
|
||||
is_upscale_model = pil_image.width > image.width and pil_image.height > image.height
|
||||
|
||||
if is_upscale_model:
|
||||
# This is an upscale model, so we should keep upscaling until we reach the target size.
|
||||
iterations = 1
|
||||
while pil_image.width < target_width or pil_image.height < target_height:
|
||||
pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled)
|
||||
iterations += 1
|
||||
|
||||
# Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.
|
||||
# Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations.
|
||||
# We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice,
|
||||
# we should never reach this limit.
|
||||
if iterations >= 5:
|
||||
context.logger.warning(
|
||||
"Upscale loop reached maximum iteration count of 5, stopping upscaling early."
|
||||
)
|
||||
break
|
||||
else:
|
||||
# This model doesn't upscale the image. We should ignore the scale parameter, modifying the output size
|
||||
# to be the same as the processed image size.
|
||||
|
||||
# The output size is now the size of the processed image.
|
||||
target_width = pil_image.width
|
||||
target_height = pil_image.height
|
||||
|
||||
# Warn the user if they requested a scale greater than 1.
|
||||
if self.scale > 1:
|
||||
context.logger.warning(
|
||||
"Model does not increase the size of the image, but a greater scale than 1 was requested. Image will not be scaled."
|
||||
)
|
||||
|
||||
# We may need to resize the image to a multiple of 8. Use floor division to ensure we don't scale the image up
|
||||
# in the final resize
|
||||
if self.fit_to_multiple_of_8:
|
||||
target_width = int(target_width // 8 * 8)
|
||||
target_height = int(target_height // 8 * 8)
|
||||
|
||||
# Final resize. Per PIL documentation, Lanczos provides the best quality for both upscale and downscale.
|
||||
# See: https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table
|
||||
pil_image = pil_image.resize((target_width, target_height), resample=Image.Resampling.LANCZOS)
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
@ -12,7 +12,7 @@ from invokeai.app.services.download import DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records import ModelRecordChanges, ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
|
||||
|
||||
@ -64,7 +64,7 @@ class ModelInstallServiceBase(ABC):
|
||||
def register_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Probe and register the model at model_path.
|
||||
@ -72,7 +72,7 @@ class ModelInstallServiceBase(ABC):
|
||||
This keeps the model in its current location.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param config: Dict of attributes that will override autoassigned values.
|
||||
:param config: ModelRecordChanges object that will override autoassigned model record values.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
|
||||
@ -92,7 +92,7 @@ class ModelInstallServiceBase(ABC):
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Probe, register and install the model in the models directory.
|
||||
@ -101,7 +101,7 @@ class ModelInstallServiceBase(ABC):
|
||||
the models directory handled by InvokeAI.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param config: Dict of attributes that will override autoassigned values.
|
||||
:param config: ModelRecordChanges object that will override autoassigned model record values.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
|
||||
@ -109,14 +109,14 @@ class ModelInstallServiceBase(ABC):
|
||||
def heuristic_import(
|
||||
self,
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
r"""Install the indicated model using heuristics to interpret user intentions.
|
||||
|
||||
:param source: String source
|
||||
:param config: Optional dict. Any fields in this dict
|
||||
:param config: Optional ModelRecordChanges object. Any fields in this object
|
||||
will override corresponding autoassigned probe fields in the
|
||||
model's config record as described in `import_model()`.
|
||||
:param access_token: Optional access token for remote sources.
|
||||
@ -147,7 +147,7 @@ class ModelInstallServiceBase(ABC):
|
||||
def import_model(
|
||||
self,
|
||||
source: ModelSource,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""Install the indicated model.
|
||||
|
||||
|
@ -2,13 +2,14 @@ import re
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Set, Union
|
||||
from typing import Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
|
||||
from invokeai.app.services.model_records import ModelRecordChanges
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
@ -133,8 +134,9 @@ class ModelInstallJob(BaseModel):
|
||||
id: int = Field(description="Unique ID for this job")
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
|
||||
config_in: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||
config_in: ModelRecordChanges = Field(
|
||||
default_factory=ModelRecordChanges,
|
||||
description="Configuration information (e.g. 'description') to apply to model.",
|
||||
)
|
||||
config_out: Optional[AnyModelConfig] = Field(
|
||||
default=None, description="After successful installation, this will hold the configuration object."
|
||||
|
@ -163,26 +163,27 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def register_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
config["source_type"] = ModelSourceType.Path
|
||||
config = config or ModelRecordChanges()
|
||||
if not config.source:
|
||||
config.source = model_path.resolve().as_posix()
|
||||
config.source_type = ModelSourceType.Path
|
||||
return self._register(model_path, config)
|
||||
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
config = config or ModelRecordChanges()
|
||||
info: AnyModelConfig = ModelProbe.probe(
|
||||
Path(model_path), config.model_dump(), hash_algo=self._app_config.hashing_algorithm
|
||||
) # type: ignore
|
||||
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm)
|
||||
|
||||
if preferred_name := config.get("name"):
|
||||
if preferred_name := config.name:
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
|
||||
dest_path = (
|
||||
@ -204,7 +205,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def heuristic_import(
|
||||
self,
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
@ -216,7 +217,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source_obj.access_token = access_token
|
||||
return self.import_model(source_obj, config)
|
||||
|
||||
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
||||
def import_model(self, source: ModelSource, config: Optional[ModelRecordChanges] = None) -> ModelInstallJob: # noqa D102
|
||||
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
|
||||
if similar_jobs:
|
||||
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
|
||||
@ -318,16 +319,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
model_path = self._app_config.models_path / model_path
|
||||
model_path = model_path.resolve()
|
||||
|
||||
config: dict[str, Any] = {}
|
||||
config["name"] = model_name
|
||||
config["description"] = stanza.get("description")
|
||||
config = ModelRecordChanges(
|
||||
name=model_name,
|
||||
description=stanza.get("description"),
|
||||
)
|
||||
legacy_config_path = stanza.get("config")
|
||||
if legacy_config_path:
|
||||
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
||||
legacy_config_path = self._app_config.root_path / legacy_config_path
|
||||
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
||||
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
||||
config["config_path"] = str(legacy_config_path)
|
||||
config.config_path = str(legacy_config_path)
|
||||
try:
|
||||
id = self.register_path(model_path=model_path, config=config)
|
||||
self._logger.info(f"Migrated {model_name} with id {id}")
|
||||
@ -500,11 +502,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.total_bytes = self._stat_size(job.local_path)
|
||||
job.bytes = job.total_bytes
|
||||
self._signal_job_running(job)
|
||||
job.config_in["source"] = str(job.source)
|
||||
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
job.config_in.source = str(job.source)
|
||||
job.config_in.source_type = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||
# enter the metadata, if there is any
|
||||
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
|
||||
job.config_in["source_api_response"] = job.source_metadata.api_response
|
||||
job.config_in.source_api_response = job.source_metadata.api_response
|
||||
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
@ -639,11 +641,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
return new_path
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
config = config or {}
|
||||
config = config or ModelRecordChanges()
|
||||
|
||||
info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm)
|
||||
info = info or ModelProbe.probe(model_path, config.model_dump(), hash_algo=self._app_config.hashing_algorithm) # type: ignore
|
||||
|
||||
model_path = model_path.resolve()
|
||||
|
||||
@ -674,11 +676,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
precision = TorchDevice.choose_torch_dtype()
|
||||
return ModelRepoVariant.FP16 if precision == torch.float16 else None
|
||||
|
||||
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
def _import_local_model(
|
||||
self, source: LocalModelSource, config: Optional[ModelRecordChanges] = None
|
||||
) -> ModelInstallJob:
|
||||
return ModelInstallJob(
|
||||
id=self._next_id(),
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
config_in=config or ModelRecordChanges(),
|
||||
local_path=Path(source.path),
|
||||
inplace=source.inplace or False,
|
||||
)
|
||||
@ -686,7 +690,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _import_from_hf(
|
||||
self,
|
||||
source: HFModelSource,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> ModelInstallJob:
|
||||
# Add user's cached access token to HuggingFace requests
|
||||
if source.access_token is None:
|
||||
@ -702,7 +706,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _import_from_url(
|
||||
self,
|
||||
source: URLModelSource,
|
||||
config: Optional[Dict[str, Any]],
|
||||
config: Optional[ModelRecordChanges] = None,
|
||||
) -> ModelInstallJob:
|
||||
remote_files, metadata = self._remote_files_from_source(source)
|
||||
return self._import_remote_model(
|
||||
@ -717,7 +721,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source: HFModelSource | URLModelSource,
|
||||
remote_files: List[RemoteModelFile],
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
config: Optional[ModelRecordChanges],
|
||||
) -> ModelInstallJob:
|
||||
if len(remote_files) == 0:
|
||||
raise ValueError(f"{source}: No downloadable files found")
|
||||
@ -730,7 +734,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
install_job = ModelInstallJob(
|
||||
id=self._next_id(),
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
config_in=config or ModelRecordChanges(),
|
||||
source_metadata=metadata,
|
||||
local_path=destdir, # local path may change once the download has started due to content-disposition handling
|
||||
bytes=0,
|
||||
|
@ -18,6 +18,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ControlAdapterDefaultSettings,
|
||||
MainModelDefaultSettings,
|
||||
ModelFormat,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
@ -66,10 +67,16 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
"""A set of changes to apply to a model."""
|
||||
|
||||
# Changes applicable to all models
|
||||
source: Optional[str] = Field(description="original source of the model", default=None)
|
||||
source_type: Optional[ModelSourceType] = Field(description="type of model source", default=None)
|
||||
source_api_response: Optional[str] = Field(description="metadata from remote source", default=None)
|
||||
name: Optional[str] = Field(description="Name of the model.", default=None)
|
||||
path: Optional[str] = Field(description="Path to the model.", default=None)
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
||||
type: Optional[ModelType] = Field(description="Type of model", default=None)
|
||||
key: Optional[str] = Field(description="Database ID for this model", default=None)
|
||||
hash: Optional[str] = Field(description="hash of model file", default=None)
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -354,7 +354,7 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for CLIPVision."""
|
||||
|
||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
@ -365,7 +365,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
||||
"""Model config for T2I."""
|
||||
|
||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
|
@ -98,6 +98,9 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
||||
ModelVariantType.Inpaint: StableDiffusionXLInpaintPipeline,
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
||||
},
|
||||
}
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
try:
|
||||
|
@ -187,164 +187,171 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
# endregion
|
||||
# region ControlNet
|
||||
StarterModel(
|
||||
name="QRCode Monster",
|
||||
name="QRCode Monster v2 (SD1.5)",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="monster-labs/control_v1p_sd15_qrcode_monster",
|
||||
description="Controlnet model that generates scannable creative QR codes",
|
||||
source="monster-labs/control_v1p_sd15_qrcode_monster::v2",
|
||||
description="ControlNet model that generates scannable creative QR codes",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="QRCode Monster (SDXL)",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="monster-labs/control_v1p_sdxl_qrcode_monster",
|
||||
description="ControlNet model that generates scannable creative QR codes",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="canny",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_canny",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning.",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="inpaint",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_inpaint",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning, inpaint version",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning, inpaint version",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="mlsd",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_mlsd",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning, MLSD version",
|
||||
description="ControlNet weights trained on sd-1.5 with canny conditioning, MLSD version",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11f1p_sd15_depth",
|
||||
description="Controlnet weights trained on sd-1.5 with depth conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with depth conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="normal_bae",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_normalbae",
|
||||
description="Controlnet weights trained on sd-1.5 with normalbae image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with normalbae image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="seg",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_seg",
|
||||
description="Controlnet weights trained on sd-1.5 with seg image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with seg image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="lineart",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_lineart",
|
||||
description="Controlnet weights trained on sd-1.5 with lineart image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with lineart image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="lineart_anime",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||
description="Controlnet weights trained on sd-1.5 with anime image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with anime image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="openpose",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_openpose",
|
||||
description="Controlnet weights trained on sd-1.5 with openpose image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with openpose image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="scribble",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_scribble",
|
||||
description="Controlnet weights trained on sd-1.5 with scribble image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with scribble image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="softedge",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_softedge",
|
||||
description="Controlnet weights trained on sd-1.5 with soft edge conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with soft edge conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="shuffle",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11e_sd15_shuffle",
|
||||
description="Controlnet weights trained on sd-1.5 with shuffle image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with shuffle image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="tile",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11f1e_sd15_tile",
|
||||
description="Controlnet weights trained on sd-1.5 with tiled image conditioning",
|
||||
description="ControlNet weights trained on sd-1.5 with tiled image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="ip2p",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11e_sd15_ip2p",
|
||||
description="Controlnet weights trained on sd-1.5 with ip2p conditioning.",
|
||||
description="ControlNet weights trained on sd-1.5 with ip2p conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="canny-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlnet-canny-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
|
||||
source="xinsir/controlNet-canny-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlnet-depth-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with depth conditioning.",
|
||||
source="diffusers/controlNet-depth-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with depth conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="softedge-dexined-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="SargeZT/controlnet-sd-xl-1.0-softedge-dexined",
|
||||
description="Controlnet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
|
||||
source="SargeZT/controlNet-sd-xl-1.0-softedge-dexined",
|
||||
description="ControlNet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-16bit-zoe-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe",
|
||||
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
|
||||
source="SargeZT/controlNet-sd-xl-1.0-depth-16bit-zoe",
|
||||
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-zoe-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlnet-zoe-depth-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
||||
source="diffusers/controlNet-zoe-depth-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="openpose-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlnet-openpose-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
|
||||
source="xinsir/controlNet-openpose-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="scribble-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlnet-scribble-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
||||
source="xinsir/controlNet-scribble-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="tile-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="xinsir/controlnet-tile-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with tiled image conditioning",
|
||||
source="xinsir/controlNet-tile-sdxl-1.0",
|
||||
description="ControlNet weights trained on sdxl-1.0 with tiled image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
# endregion
|
||||
|
@ -7,11 +7,9 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import ( # noqa: F401
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
from invokeai.backend.stable_diffusion.seamless import set_seamless # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"PipelineIntermediateState",
|
||||
"StableDiffusionGeneratorPipeline",
|
||||
"InvokeAIDiffuserComponent",
|
||||
"set_seamless",
|
||||
]
|
||||
|
@ -83,47 +83,47 @@ class DenoiseContext:
|
||||
unet: Optional[UNet2DConditionModel] = None
|
||||
|
||||
# Current state of latent-space image in denoising process.
|
||||
# None until `pre_denoise_loop` callback.
|
||||
# None until `PRE_DENOISE_LOOP` callback.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
latents: Optional[torch.Tensor] = None
|
||||
|
||||
# Current denoising step index.
|
||||
# None until `pre_step` callback.
|
||||
# None until `PRE_STEP` callback.
|
||||
step_index: Optional[int] = None
|
||||
|
||||
# Current denoising step timestep.
|
||||
# None until `pre_step` callback.
|
||||
# None until `PRE_STEP` callback.
|
||||
timestep: Optional[torch.Tensor] = None
|
||||
|
||||
# Arguments which will be passed to UNet model.
|
||||
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
||||
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
|
||||
unet_kwargs: Optional[UNetKwargs] = None
|
||||
|
||||
# SchedulerOutput class returned from step function(normally, generated by scheduler).
|
||||
# Supposed to be used only in `post_step` callback, otherwise can be None.
|
||||
# Supposed to be used only in `POST_STEP` callback, otherwise can be None.
|
||||
step_output: Optional[SchedulerOutput] = None
|
||||
|
||||
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
|
||||
# Available in events inside step(between `pre_step` and `post_stop`).
|
||||
# Available in events inside step(between `PRE_STEP` and `POST_STEP`).
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
latent_model_input: Optional[torch.Tensor] = None
|
||||
|
||||
# [TMP] Defines on which conditionings current unet call will be runned.
|
||||
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
||||
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
|
||||
conditioning_mode: Optional[ConditioningMode] = None
|
||||
|
||||
# [TMP] Noise predictions from negative conditioning.
|
||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
negative_noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
# [TMP] Noise predictions from positive conditioning.
|
||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
positive_noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
# Combined noise prediction from passed conditionings.
|
||||
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
||||
# Shape: [batch, channels, latent_height, latent_width]
|
||||
noise_pred: Optional[torch.Tensor] = None
|
||||
|
||||
|
@ -76,12 +76,12 @@ class StableDiffusionBackend:
|
||||
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
|
||||
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
|
||||
|
||||
# ext: override apply_cfg
|
||||
ctx.noise_pred = self.apply_cfg(ctx)
|
||||
# ext: override combine_noise_preds
|
||||
ctx.noise_pred = self.combine_noise_preds(ctx)
|
||||
|
||||
# ext: cfg_rescale [modify_noise_prediction]
|
||||
# TODO: rename
|
||||
ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx)
|
||||
ext_manager.run_callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS, ctx)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
|
||||
@ -95,13 +95,15 @@ class StableDiffusionBackend:
|
||||
return step_output
|
||||
|
||||
@staticmethod
|
||||
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
|
||||
def combine_noise_preds(ctx: DenoiseContext) -> torch.Tensor:
|
||||
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
|
||||
if isinstance(guidance_scale, list):
|
||||
guidance_scale = guidance_scale[ctx.step_index]
|
||||
|
||||
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
||||
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
||||
# Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result
|
||||
# in slightly different outputs. It is suspected that this is caused by small precision differences.
|
||||
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
||||
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
||||
|
||||
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
|
||||
sample = ctx.latent_model_input
|
||||
|
@ -9,4 +9,4 @@ class ExtensionCallbackType(Enum):
|
||||
POST_STEP = "post_step"
|
||||
PRE_UNET = "pre_unet"
|
||||
POST_UNET = "post_unet"
|
||||
POST_APPLY_CFG = "post_apply_cfg"
|
||||
POST_COMBINE_NOISE_PREDS = "post_combine_noise_preds"
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
@ -52,9 +52,9 @@ class ExtensionBase:
|
||||
return self._callbacks
|
||||
|
||||
@contextmanager
|
||||
def patch_extension(self, context: DenoiseContext):
|
||||
def patch_extension(self, ctx: DenoiseContext):
|
||||
yield None
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
yield None
|
||||
|
158
invokeai/backend/stable_diffusion/extensions/controlnet.py
Normal file
158
invokeai/backend/stable_diffusion/extensions/controlnet.py
Normal file
@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
|
||||
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||
|
||||
|
||||
class ControlNetExt(ExtensionBase):
|
||||
def __init__(
|
||||
self,
|
||||
model: ControlNetModel,
|
||||
image: Image,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
control_mode: CONTROLNET_MODE_VALUES,
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||
):
|
||||
super().__init__()
|
||||
self._model = model
|
||||
self._image = image
|
||||
self._weight = weight
|
||||
self._begin_step_percent = begin_step_percent
|
||||
self._end_step_percent = end_step_percent
|
||||
self._control_mode = control_mode
|
||||
self._resize_mode = resize_mode
|
||||
|
||||
self._image_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
@contextmanager
|
||||
def patch_extension(self, ctx: DenoiseContext):
|
||||
original_processors = self._model.attn_processors
|
||||
try:
|
||||
self._model.set_attn_processor(ctx.inputs.attention_processor_cls())
|
||||
|
||||
yield None
|
||||
finally:
|
||||
self._model.set_attn_processor(original_processors)
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||
def resize_image(self, ctx: DenoiseContext):
|
||||
_, _, latent_height, latent_width = ctx.latents.shape
|
||||
image_height = latent_height * LATENT_SCALE_FACTOR
|
||||
image_width = latent_width * LATENT_SCALE_FACTOR
|
||||
|
||||
self._image_tensor = prepare_control_image(
|
||||
image=self._image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=image_width,
|
||||
height=image_height,
|
||||
device=ctx.latents.device,
|
||||
dtype=ctx.latents.dtype,
|
||||
control_mode=self._control_mode,
|
||||
resize_mode=self._resize_mode,
|
||||
)
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_UNET)
|
||||
def pre_unet_step(self, ctx: DenoiseContext):
|
||||
# skip if model not active in current step
|
||||
total_steps = len(ctx.inputs.timesteps)
|
||||
first_step = math.floor(self._begin_step_percent * total_steps)
|
||||
last_step = math.ceil(self._end_step_percent * total_steps)
|
||||
if ctx.step_index < first_step or ctx.step_index > last_step:
|
||||
return
|
||||
|
||||
# convert mode to internal flags
|
||||
soft_injection = self._control_mode in ["more_prompt", "more_control"]
|
||||
cfg_injection = self._control_mode in ["more_control", "unbalanced"]
|
||||
|
||||
# no negative conditioning in cfg_injection mode
|
||||
if cfg_injection:
|
||||
if ctx.conditioning_mode == ConditioningMode.Negative:
|
||||
return
|
||||
down_samples, mid_sample = self._run(ctx, soft_injection, ConditioningMode.Positive)
|
||||
|
||||
if ctx.conditioning_mode == ConditioningMode.Both:
|
||||
# add zeros as samples for negative conditioning
|
||||
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
||||
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
||||
|
||||
else:
|
||||
down_samples, mid_sample = self._run(ctx, soft_injection, ctx.conditioning_mode)
|
||||
|
||||
if (
|
||||
ctx.unet_kwargs.down_block_additional_residuals is None
|
||||
and ctx.unet_kwargs.mid_block_additional_residual is None
|
||||
):
|
||||
ctx.unet_kwargs.down_block_additional_residuals = down_samples
|
||||
ctx.unet_kwargs.mid_block_additional_residual = mid_sample
|
||||
else:
|
||||
# add controlnet outputs together if have multiple controlnets
|
||||
ctx.unet_kwargs.down_block_additional_residuals = [
|
||||
samples_prev + samples_curr
|
||||
for samples_prev, samples_curr in zip(
|
||||
ctx.unet_kwargs.down_block_additional_residuals, down_samples, strict=True
|
||||
)
|
||||
]
|
||||
ctx.unet_kwargs.mid_block_additional_residual += mid_sample
|
||||
|
||||
def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode):
|
||||
total_steps = len(ctx.inputs.timesteps)
|
||||
|
||||
model_input = ctx.latent_model_input
|
||||
image_tensor = self._image_tensor
|
||||
if conditioning_mode == ConditioningMode.Both:
|
||||
model_input = torch.cat([model_input] * 2)
|
||||
image_tensor = torch.cat([image_tensor] * 2)
|
||||
|
||||
cn_unet_kwargs = UNetKwargs(
|
||||
sample=model_input,
|
||||
timestep=ctx.timestep,
|
||||
encoder_hidden_states=None, # set later by conditioning
|
||||
cross_attention_kwargs=dict( # noqa: C408
|
||||
percent_through=ctx.step_index / total_steps,
|
||||
),
|
||||
)
|
||||
|
||||
ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode)
|
||||
|
||||
# get static weight, or weight corresponding to current step
|
||||
weight = self._weight
|
||||
if isinstance(weight, list):
|
||||
weight = weight[ctx.step_index]
|
||||
|
||||
tmp_kwargs = vars(cn_unet_kwargs)
|
||||
|
||||
# Remove kwargs not related to ControlNet unet
|
||||
# ControlNet guidance fields
|
||||
del tmp_kwargs["down_block_additional_residuals"]
|
||||
del tmp_kwargs["mid_block_additional_residual"]
|
||||
|
||||
# T2i Adapter guidance fields
|
||||
del tmp_kwargs["down_intrablock_additional_residuals"]
|
||||
|
||||
# controlnet(s) inference
|
||||
down_samples, mid_sample = self._model(
|
||||
controlnet_cond=image_tensor,
|
||||
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
|
||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||
return_dict=False,
|
||||
**vars(cn_unet_kwargs),
|
||||
)
|
||||
|
||||
return down_samples, mid_sample
|
35
invokeai/backend/stable_diffusion/extensions/freeu.py
Normal file
35
invokeai/backend/stable_diffusion/extensions/freeu.py
Normal file
@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
|
||||
|
||||
class FreeUExt(ExtensionBase):
|
||||
def __init__(
|
||||
self,
|
||||
freeu_config: FreeUConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self._freeu_config = freeu_config
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
unet.enable_freeu(
|
||||
b1=self._freeu_config.b1,
|
||||
b2=self._freeu_config.b2,
|
||||
s1=self._freeu_config.s1,
|
||||
s2=self._freeu_config.s2,
|
||||
)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
unet.disable_freeu()
|
36
invokeai/backend/stable_diffusion/extensions/rescale_cfg.py
Normal file
36
invokeai/backend/stable_diffusion/extensions/rescale_cfg.py
Normal file
@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
|
||||
|
||||
class RescaleCFGExt(ExtensionBase):
|
||||
def __init__(self, rescale_multiplier: float):
|
||||
super().__init__()
|
||||
self._rescale_multiplier = rescale_multiplier
|
||||
|
||||
@staticmethod
|
||||
def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7):
|
||||
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
|
||||
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
|
||||
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)
|
||||
|
||||
x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
|
||||
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
|
||||
return x_final
|
||||
|
||||
@callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS)
|
||||
def rescale_noise_pred(self, ctx: DenoiseContext):
|
||||
if self._rescale_multiplier > 0:
|
||||
ctx.noise_pred = self._rescale_cfg(
|
||||
ctx.noise_pred,
|
||||
ctx.positive_noise_pred,
|
||||
self._rescale_multiplier,
|
||||
)
|
71
invokeai/backend/stable_diffusion/extensions/seamless.py
Normal file
71
invokeai/backend/stable_diffusion/extensions/seamless.py
Normal file
@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.lora import LoRACompatibleConv
|
||||
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
|
||||
|
||||
class SeamlessExt(ExtensionBase):
|
||||
def __init__(
|
||||
self,
|
||||
seamless_axes: List[str],
|
||||
):
|
||||
super().__init__()
|
||||
self._seamless_axes = seamless_axes
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
with self.static_patch_model(
|
||||
model=unet,
|
||||
seamless_axes=self._seamless_axes,
|
||||
):
|
||||
yield
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def static_patch_model(
|
||||
model: torch.nn.Module,
|
||||
seamless_axes: List[str],
|
||||
):
|
||||
if not seamless_axes:
|
||||
yield
|
||||
return
|
||||
|
||||
x_mode = "circular" if "x" in seamless_axes else "constant"
|
||||
y_mode = "circular" if "y" in seamless_axes else "constant"
|
||||
|
||||
# override conv_forward
|
||||
# https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
|
||||
def _conv_forward_asymmetric(
|
||||
self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||
):
|
||||
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
|
||||
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
|
||||
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
|
||||
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
|
||||
return torch.nn.functional.conv2d(
|
||||
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
|
||||
)
|
||||
|
||||
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
|
||||
try:
|
||||
for layer in model.modules():
|
||||
if not isinstance(layer, torch.nn.Conv2d):
|
||||
continue
|
||||
|
||||
if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
|
||||
layer.lora_layer = lambda *x: 0
|
||||
original_layers.append((layer, layer._conv_forward))
|
||||
layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
for layer, orig_conv_forward in original_layers:
|
||||
layer._conv_forward = orig_conv_forward
|
120
invokeai/backend/stable_diffusion/extensions/t2i_adapter.py
Normal file
120
invokeai/backend/stable_diffusion/extensions/t2i_adapter.py
Normal file
@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from diffusers import T2IAdapter
|
||||
from PIL.Image import Image
|
||||
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
|
||||
|
||||
class T2IAdapterExt(ExtensionBase):
|
||||
def __init__(
|
||||
self,
|
||||
node_context: InvocationContext,
|
||||
model_id: ModelIdentifierField,
|
||||
image: Image,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||
):
|
||||
super().__init__()
|
||||
self._node_context = node_context
|
||||
self._model_id = model_id
|
||||
self._image = image
|
||||
self._weight = weight
|
||||
self._resize_mode = resize_mode
|
||||
self._begin_step_percent = begin_step_percent
|
||||
self._end_step_percent = end_step_percent
|
||||
|
||||
self._adapter_state: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
model_config = self._node_context.models.get_config(self._model_id.key)
|
||||
if model_config.base == BaseModelType.StableDiffusion1:
|
||||
self._max_unet_downscale = 8
|
||||
elif model_config.base == BaseModelType.StableDiffusionXL:
|
||||
self._max_unet_downscale = 4
|
||||
else:
|
||||
raise ValueError(f"Unexpected T2I-Adapter base model type: '{model_config.base}'.")
|
||||
|
||||
@callback(ExtensionCallbackType.SETUP)
|
||||
def setup(self, ctx: DenoiseContext):
|
||||
t2i_model: T2IAdapter
|
||||
with self._node_context.models.load(self._model_id) as t2i_model:
|
||||
_, _, latents_height, latents_width = ctx.inputs.orig_latents.shape
|
||||
|
||||
self._adapter_state = self._run_model(
|
||||
model=t2i_model,
|
||||
image=self._image,
|
||||
latents_height=latents_height,
|
||||
latents_width=latents_width,
|
||||
)
|
||||
|
||||
def _run_model(
|
||||
self,
|
||||
model: T2IAdapter,
|
||||
image: Image,
|
||||
latents_height: int,
|
||||
latents_width: int,
|
||||
):
|
||||
# Resize the T2I-Adapter input image.
|
||||
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
||||
# result will match the latent image's dimensions after max_unet_downscale is applied.
|
||||
input_height = latents_height // self._max_unet_downscale * model.total_downscale_factor
|
||||
input_width = latents_width // self._max_unet_downscale * model.total_downscale_factor
|
||||
|
||||
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
|
||||
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
|
||||
# T2I-Adapter model.
|
||||
#
|
||||
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
|
||||
# of the same requirements (e.g. preserving binary masks during resize).
|
||||
t2i_image = prepare_control_image(
|
||||
image=image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=input_width,
|
||||
height=input_height,
|
||||
num_channels=model.config["in_channels"],
|
||||
device=model.device,
|
||||
dtype=model.dtype,
|
||||
resize_mode=self._resize_mode,
|
||||
)
|
||||
|
||||
return model(t2i_image)
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_UNET)
|
||||
def pre_unet_step(self, ctx: DenoiseContext):
|
||||
# skip if model not active in current step
|
||||
total_steps = len(ctx.inputs.timesteps)
|
||||
first_step = math.floor(self._begin_step_percent * total_steps)
|
||||
last_step = math.ceil(self._end_step_percent * total_steps)
|
||||
if ctx.step_index < first_step or ctx.step_index > last_step:
|
||||
return
|
||||
|
||||
weight = self._weight
|
||||
if isinstance(weight, list):
|
||||
weight = weight[ctx.step_index]
|
||||
|
||||
adapter_state = self._adapter_state
|
||||
if ctx.conditioning_mode == ConditioningMode.Both:
|
||||
adapter_state = [torch.cat([v] * 2) for v in adapter_state]
|
||||
|
||||
if ctx.unet_kwargs.down_intrablock_additional_residuals is None:
|
||||
ctx.unet_kwargs.down_intrablock_additional_residuals = [v * weight for v in adapter_state]
|
||||
else:
|
||||
for i, value in enumerate(adapter_state):
|
||||
ctx.unet_kwargs.down_intrablock_additional_residuals[i] += value * weight
|
@ -52,20 +52,24 @@ class ExtensionsManager:
|
||||
cb.function(ctx)
|
||||
|
||||
@contextmanager
|
||||
def patch_extensions(self, context: DenoiseContext):
|
||||
def patch_extensions(self, ctx: DenoiseContext):
|
||||
if self._is_canceled and self._is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
for ext in self._extensions:
|
||||
exit_stack.enter_context(ext.patch_extension(context))
|
||||
exit_stack.enter_context(ext.patch_extension(ctx))
|
||||
|
||||
yield None
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
if self._is_canceled and self._is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# TODO: create logic in PR with extension which uses it
|
||||
# TODO: create weight patch logic in PR with extension which uses it
|
||||
with ExitStack() as exit_stack:
|
||||
for ext in self._extensions:
|
||||
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
|
||||
|
||||
yield None
|
||||
|
@ -1,51 +0,0 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
from diffusers.models.lora import LoRACompatibleConv
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
|
||||
if not seamless_axes:
|
||||
yield
|
||||
return
|
||||
|
||||
# override conv_forward
|
||||
# https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
|
||||
def _conv_forward_asymmetric(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
|
||||
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
|
||||
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
|
||||
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
|
||||
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
|
||||
return torch.nn.functional.conv2d(
|
||||
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
|
||||
)
|
||||
|
||||
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
|
||||
|
||||
try:
|
||||
x_mode = "circular" if "x" in seamless_axes else "constant"
|
||||
y_mode = "circular" if "y" in seamless_axes else "constant"
|
||||
|
||||
conv_layers: List[torch.nn.Conv2d] = []
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Conv2d):
|
||||
conv_layers.append(module)
|
||||
|
||||
for layer in conv_layers:
|
||||
if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
|
||||
layer.lora_layer = lambda *x: 0
|
||||
original_layers.append((layer, layer._conv_forward))
|
||||
layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
for layer, orig_conv_forward in original_layers:
|
||||
layer._conv_forward = orig_conv_forward
|
@ -155,5 +155,8 @@
|
||||
"vite-plugin-eslint": "^1.8.1",
|
||||
"vite-tsconfig-paths": "^4.3.2",
|
||||
"vitest": "^1.6.0"
|
||||
},
|
||||
"engines": {
|
||||
"pnpm": "8"
|
||||
}
|
||||
}
|
||||
|
@ -77,10 +77,6 @@
|
||||
"title": "استعادة الوجوه",
|
||||
"desc": "استعادة الصورة الحالية"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "تحسين الحجم",
|
||||
"desc": "تحسين حجم الصورة الحالية"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "عرض المعلومات",
|
||||
"desc": "عرض معلومات البيانات الخاصة بالصورة الحالية"
|
||||
@ -255,8 +251,6 @@
|
||||
"type": "نوع",
|
||||
"strength": "قوة",
|
||||
"upscaling": "تصغير",
|
||||
"upscale": "تصغير",
|
||||
"upscaleImage": "تصغير الصورة",
|
||||
"scale": "مقياس",
|
||||
"imageFit": "ملائمة الصورة الأولية لحجم الخرج",
|
||||
"scaleBeforeProcessing": "تحجيم قبل المعالجة",
|
||||
|
@ -187,10 +187,6 @@
|
||||
"title": "Gesicht restaurieren",
|
||||
"desc": "Das aktuelle Bild restaurieren"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Hochskalieren",
|
||||
"desc": "Das aktuelle Bild hochskalieren"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Info anzeigen",
|
||||
"desc": "Metadaten des aktuellen Bildes anzeigen"
|
||||
@ -433,8 +429,6 @@
|
||||
"type": "Art",
|
||||
"strength": "Stärke",
|
||||
"upscaling": "Hochskalierung",
|
||||
"upscale": "Hochskalieren (Shift + U)",
|
||||
"upscaleImage": "Bild hochskalieren",
|
||||
"scale": "Maßstab",
|
||||
"imageFit": "Ausgangsbild an Ausgabegröße anpassen",
|
||||
"scaleBeforeProcessing": "Skalieren vor der Verarbeitung",
|
||||
|
@ -32,12 +32,14 @@
|
||||
"deleteBoardAndImages": "Delete Board and Images",
|
||||
"deleteBoardOnly": "Delete Board Only",
|
||||
"deletedBoardsCannotbeRestored": "Deleted boards cannot be restored",
|
||||
"hideBoards": "Hide Boards",
|
||||
"loading": "Loading...",
|
||||
"menuItemAutoAdd": "Auto-add to this Board",
|
||||
"move": "Move",
|
||||
"movingImagesToBoard_one": "Moving {{count}} image to board:",
|
||||
"movingImagesToBoard_other": "Moving {{count}} images to board:",
|
||||
"myBoard": "My Board",
|
||||
"noBoards": "No {{boardType}} Boards",
|
||||
"noMatching": "No matching Boards",
|
||||
"private": "Private Boards",
|
||||
"searchBoard": "Search Boards...",
|
||||
@ -46,6 +48,7 @@
|
||||
"topMessage": "This board contains images used in the following features:",
|
||||
"unarchiveBoard": "Unarchive Board",
|
||||
"uncategorized": "Uncategorized",
|
||||
"viewBoards": "View Boards",
|
||||
"downloadBoard": "Download Board",
|
||||
"imagesWithCount_one": "{{count}} image",
|
||||
"imagesWithCount_other": "{{count}} images",
|
||||
@ -102,6 +105,7 @@
|
||||
"negativePrompt": "Negative Prompt",
|
||||
"discordLabel": "Discord",
|
||||
"dontAskMeAgain": "Don't ask me again",
|
||||
"dontShowMeThese": "Don't show me these",
|
||||
"editor": "Editor",
|
||||
"error": "Error",
|
||||
"file": "File",
|
||||
@ -373,10 +377,14 @@
|
||||
"displayBoardSearch": "Display Board Search",
|
||||
"displaySearch": "Display Search",
|
||||
"download": "Download",
|
||||
"exitBoardSearch": "Exit Board Search",
|
||||
"exitSearch": "Exit Search",
|
||||
"featuresWillReset": "If you delete this image, those features will immediately be reset.",
|
||||
"galleryImageSize": "Image Size",
|
||||
"gallerySettings": "Gallery Settings",
|
||||
"go": "Go",
|
||||
"image": "image",
|
||||
"jump": "Jump",
|
||||
"loading": "Loading",
|
||||
"loadMore": "Load More",
|
||||
"newestFirst": "Newest First",
|
||||
@ -636,9 +644,9 @@
|
||||
"title": "Undo Stroke"
|
||||
},
|
||||
"unifiedCanvasHotkeys": "Unified Canvas",
|
||||
"upscale": {
|
||||
"desc": "Upscale the current image",
|
||||
"title": "Upscale"
|
||||
"postProcess": {
|
||||
"desc": "Process the current image using the selected post-processing model",
|
||||
"title": "Process Image"
|
||||
},
|
||||
"toggleViewer": {
|
||||
"desc": "Switches between the Image Viewer and workspace for the current tab.",
|
||||
@ -1027,6 +1035,7 @@
|
||||
"imageActions": "Image Actions",
|
||||
"sendToImg2Img": "Send to Image to Image",
|
||||
"sendToUnifiedCanvas": "Send To Unified Canvas",
|
||||
"sendToUpscale": "Send To Upscale",
|
||||
"showOptionsPanel": "Show Side Panel (O or T)",
|
||||
"shuffle": "Shuffle Seed",
|
||||
"steps": "Steps",
|
||||
@ -1034,8 +1043,8 @@
|
||||
"symmetry": "Symmetry",
|
||||
"tileSize": "Tile Size",
|
||||
"type": "Type",
|
||||
"upscale": "Upscale (Shift + U)",
|
||||
"upscaleImage": "Upscale Image",
|
||||
"postProcessing": "Post-Processing (Shift + U)",
|
||||
"processImage": "Process Image",
|
||||
"upscaling": "Upscaling",
|
||||
"useAll": "Use All",
|
||||
"useSize": "Use Size",
|
||||
@ -1091,6 +1100,8 @@
|
||||
"displayInProgress": "Display Progress Images",
|
||||
"enableImageDebugging": "Enable Image Debugging",
|
||||
"enableInformationalPopovers": "Enable Informational Popovers",
|
||||
"informationalPopoversDisabled": "Informational Popovers Disabled",
|
||||
"informationalPopoversDisabledDesc": "Informational popovers have been disabled. Enable them in Settings.",
|
||||
"enableInvisibleWatermark": "Enable Invisible Watermark",
|
||||
"enableNSFWChecker": "Enable NSFW Checker",
|
||||
"general": "General",
|
||||
@ -1498,6 +1509,30 @@
|
||||
"seamlessTilingYAxis": {
|
||||
"heading": "Seamless Tiling Y Axis",
|
||||
"paragraphs": ["Seamlessly tile an image along the vertical axis."]
|
||||
},
|
||||
"upscaleModel": {
|
||||
"heading": "Upscale Model",
|
||||
"paragraphs": [
|
||||
"The upscale model scales the image to the output size before details are added. Any supported upscale model may be used, but some are specialized for different kinds of images, like photos or line drawings."
|
||||
]
|
||||
},
|
||||
"scale": {
|
||||
"heading": "Scale",
|
||||
"paragraphs": [
|
||||
"Scale controls the output image size, and is based on a multiple of the input image resolution. For example a 2x upscale on a 1024x1024 image would produce a 2048 x 2048 output."
|
||||
]
|
||||
},
|
||||
"creativity": {
|
||||
"heading": "Creativity",
|
||||
"paragraphs": [
|
||||
"Creativity controls the amount of freedom granted to the model when adding details. Low creativity stays close to the original image, while high creativity allows for more change. When using a prompt, high creativity increases the influence of the prompt."
|
||||
]
|
||||
},
|
||||
"structure": {
|
||||
"heading": "Structure",
|
||||
"paragraphs": [
|
||||
"Structure controls how closely the output image will keep to the layout of the original. Low structure allows major changes, while high structure strictly maintains the original composition and layout."
|
||||
]
|
||||
}
|
||||
},
|
||||
"unifiedCanvas": {
|
||||
@ -1640,6 +1675,27 @@
|
||||
"layers_one": "Layer",
|
||||
"layers_other": "Layers"
|
||||
},
|
||||
"upscaling": {
|
||||
"creativity": "Creativity",
|
||||
"structure": "Structure",
|
||||
"upscaleModel": "Upscale Model",
|
||||
"postProcessingModel": "Post-Processing Model",
|
||||
"scale": "Scale",
|
||||
"postProcessingMissingModelWarning": "Visit the <LinkComponent>Model Manager</LinkComponent> to install a post-processing (image to image) model.",
|
||||
"missingModelsWarning": "Visit the <LinkComponent>Model Manager</LinkComponent> to install the required models:",
|
||||
"mainModelDesc": "Main model (SD1.5 or SDXL architecture)",
|
||||
"tileControlNetModelDesc": "Tile ControlNet model for the chosen main model architecture",
|
||||
"upscaleModelDesc": "Upscale (image to image) model",
|
||||
"missingUpscaleInitialImage": "Missing initial image for upscaling",
|
||||
"missingUpscaleModel": "Missing upscale model",
|
||||
"missingTileControlNetModel": "No valid tile ControlNet models installed"
|
||||
},
|
||||
"upsell": {
|
||||
"inviteTeammates": "Invite Teammates",
|
||||
"professional": "Professional",
|
||||
"professionalUpsell": "Available in Invoke’s Professional Edition. Click here or visit invoke.com/pricing for more details.",
|
||||
"shareAccess": "Share Access"
|
||||
},
|
||||
"ui": {
|
||||
"tabs": {
|
||||
"generation": "Generation",
|
||||
@ -1651,7 +1707,9 @@
|
||||
"models": "Models",
|
||||
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
|
||||
"queue": "Queue",
|
||||
"queueTab": "$t(ui.tabs.queue) $t(common.tab)"
|
||||
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
|
||||
"upscaling": "Upscaling",
|
||||
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -151,10 +151,6 @@
|
||||
"title": "Restaurar rostros",
|
||||
"desc": "Restaurar rostros en la imagen actual"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Aumentar resolución",
|
||||
"desc": "Aumentar la resolución de la imagen actual"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Mostrar información",
|
||||
"desc": "Mostar metadatos de la imagen actual"
|
||||
@ -360,8 +356,6 @@
|
||||
"type": "Tipo",
|
||||
"strength": "Fuerza",
|
||||
"upscaling": "Aumento de resolución",
|
||||
"upscale": "Aumentar resolución",
|
||||
"upscaleImage": "Aumentar la resolución de la imagen",
|
||||
"scale": "Escala",
|
||||
"imageFit": "Ajuste tamaño de imagen inicial al tamaño objetivo",
|
||||
"scaleBeforeProcessing": "Redimensionar antes de procesar",
|
||||
@ -408,7 +402,12 @@
|
||||
"showProgressInViewer": "Mostrar las imágenes del progreso en el visor",
|
||||
"ui": "Interfaz del usuario",
|
||||
"generation": "Generación",
|
||||
"beta": "Beta"
|
||||
"beta": "Beta",
|
||||
"reloadingIn": "Recargando en",
|
||||
"intermediatesClearedFailed": "Error limpiando los intermediarios",
|
||||
"intermediatesCleared_one": "Borrado {{count}} intermediario",
|
||||
"intermediatesCleared_many": "Borrados {{count}} intermediarios",
|
||||
"intermediatesCleared_other": "Borrados {{count}} intermediarios"
|
||||
},
|
||||
"toast": {
|
||||
"uploadFailed": "Error al subir archivo",
|
||||
@ -426,7 +425,12 @@
|
||||
"parameterSet": "Conjunto de parámetros",
|
||||
"parameterNotSet": "Parámetro no configurado",
|
||||
"problemCopyingImage": "No se puede copiar la imagen",
|
||||
"errorCopied": "Error al copiar"
|
||||
"errorCopied": "Error al copiar",
|
||||
"baseModelChanged": "Modelo base cambiado",
|
||||
"addedToBoard": "Añadido al tablero",
|
||||
"baseModelChangedCleared_one": "Borrado o desactivado {{count}} submodelo incompatible",
|
||||
"baseModelChangedCleared_many": "Borrados o desactivados {{count}} submodelos incompatibles",
|
||||
"baseModelChangedCleared_other": "Borrados o desactivados {{count}} submodelos incompatibles"
|
||||
},
|
||||
"tooltip": {
|
||||
"feature": {
|
||||
@ -540,7 +544,13 @@
|
||||
"downloadBoard": "Descargar panel",
|
||||
"deleteBoardOnly": "Borrar solo el panel",
|
||||
"myBoard": "Mi panel",
|
||||
"noMatching": "No hay paneles que coincidan"
|
||||
"noMatching": "No hay paneles que coincidan",
|
||||
"imagesWithCount_one": "{{count}} imagen",
|
||||
"imagesWithCount_many": "{{count}} imágenes",
|
||||
"imagesWithCount_other": "{{count}} imágenes",
|
||||
"assetsWithCount_one": "{{count}} activo",
|
||||
"assetsWithCount_many": "{{count}} activos",
|
||||
"assetsWithCount_other": "{{count}} activos"
|
||||
},
|
||||
"accordions": {
|
||||
"compositing": {
|
||||
@ -590,6 +600,27 @@
|
||||
"balanced": "Equilibrado",
|
||||
"beginEndStepPercent": "Inicio / Final Porcentaje de pasos",
|
||||
"detectResolution": "Detectar resolución",
|
||||
"beginEndStepPercentShort": "Inicio / Final %"
|
||||
"beginEndStepPercentShort": "Inicio / Final %",
|
||||
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))",
|
||||
"controlnet": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.controlNet))",
|
||||
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
|
||||
"addControlNet": "Añadir $t(common.controlNet)",
|
||||
"addIPAdapter": "Añadir $t(common.ipAdapter)",
|
||||
"controlAdapter_one": "Adaptador de control",
|
||||
"controlAdapter_many": "Adaptadores de control",
|
||||
"controlAdapter_other": "Adaptadores de control",
|
||||
"addT2IAdapter": "Añadir $t(common.t2iAdapter)"
|
||||
},
|
||||
"queue": {
|
||||
"back": "Atrás",
|
||||
"front": "Delante",
|
||||
"batchQueuedDesc_one": "Se agregó {{count}} sesión a {{direction}} la cola",
|
||||
"batchQueuedDesc_many": "Se agregaron {{count}} sesiones a {{direction}} la cola",
|
||||
"batchQueuedDesc_other": "Se agregaron {{count}} sesiones a {{direction}} la cola"
|
||||
},
|
||||
"upsell": {
|
||||
"inviteTeammates": "Invitar compañeros de equipo",
|
||||
"shareAccess": "Compartir acceso",
|
||||
"professionalUpsell": "Disponible en la edición profesional de Invoke. Haz clic aquí o visita invoke.com/pricing para obtener más detalles."
|
||||
}
|
||||
}
|
||||
|
@ -130,10 +130,6 @@
|
||||
"title": "Restaurer les visages",
|
||||
"desc": "Restaurer l'image actuelle"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Agrandir",
|
||||
"desc": "Agrandir l'image actuelle"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Afficher les informations",
|
||||
"desc": "Afficher les informations de métadonnées de l'image actuelle"
|
||||
@ -308,8 +304,6 @@
|
||||
"type": "Type",
|
||||
"strength": "Force",
|
||||
"upscaling": "Agrandissement",
|
||||
"upscale": "Agrandir",
|
||||
"upscaleImage": "Image en Agrandissement",
|
||||
"scale": "Echelle",
|
||||
"imageFit": "Ajuster Image Initiale à la Taille de Sortie",
|
||||
"scaleBeforeProcessing": "Echelle Avant Traitement",
|
||||
|
@ -90,10 +90,6 @@
|
||||
"desc": "שחזור התמונה הנוכחית",
|
||||
"title": "שחזור פרצופים"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "הגדלת קנה מידה",
|
||||
"desc": "הגדל את התמונה הנוכחית"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "הצג מידע",
|
||||
"desc": "הצגת פרטי מטא-נתונים של התמונה הנוכחית"
|
||||
@ -263,8 +259,6 @@
|
||||
"seed": "זרע",
|
||||
"type": "סוג",
|
||||
"strength": "חוזק",
|
||||
"upscale": "הגדלת קנה מידה",
|
||||
"upscaleImage": "הגדלת קנה מידת התמונה",
|
||||
"denoisingStrength": "חוזק מנטרל הרעש",
|
||||
"scaleBeforeProcessing": "שנה קנה מידה לפני עיבוד",
|
||||
"scaledWidth": "קנה מידה לאחר שינוי W",
|
||||
|
@ -150,7 +150,11 @@
|
||||
"showArchivedBoards": "Mostra le bacheche archiviate",
|
||||
"searchImages": "Ricerca per metadati",
|
||||
"displayBoardSearch": "Mostra la ricerca nelle Bacheche",
|
||||
"displaySearch": "Mostra la ricerca"
|
||||
"displaySearch": "Mostra la ricerca",
|
||||
"selectAllOnPage": "Seleziona tutto nella pagina",
|
||||
"selectAllOnBoard": "Seleziona tutto nella bacheca",
|
||||
"exitBoardSearch": "Esci da Ricerca bacheca",
|
||||
"exitSearch": "Esci dalla ricerca"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "Tasti di scelta rapida",
|
||||
@ -210,10 +214,6 @@
|
||||
"title": "Restaura volti",
|
||||
"desc": "Restaura l'immagine corrente"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Amplia",
|
||||
"desc": "Amplia l'immagine corrente"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Mostra informazioni",
|
||||
"desc": "Mostra le informazioni sui metadati dell'immagine corrente"
|
||||
@ -377,6 +377,10 @@
|
||||
"toggleViewer": {
|
||||
"title": "Attiva/disattiva il visualizzatore di immagini",
|
||||
"desc": "Passa dal visualizzatore immagini all'area di lavoro per la scheda corrente."
|
||||
},
|
||||
"postProcess": {
|
||||
"desc": "Elabora l'immagine corrente utilizzando il modello di post-elaborazione selezionato",
|
||||
"title": "Elabora immagine"
|
||||
}
|
||||
},
|
||||
"modelManager": {
|
||||
@ -505,8 +509,6 @@
|
||||
"type": "Tipo",
|
||||
"strength": "Forza",
|
||||
"upscaling": "Ampliamento",
|
||||
"upscale": "Amplia (Shift + U)",
|
||||
"upscaleImage": "Amplia Immagine",
|
||||
"scale": "Scala",
|
||||
"imageFit": "Adatta l'immagine iniziale alle dimensioni di output",
|
||||
"scaleBeforeProcessing": "Scala prima dell'elaborazione",
|
||||
@ -591,7 +593,10 @@
|
||||
"infillColorValue": "Colore di riempimento",
|
||||
"globalSettings": "Impostazioni globali",
|
||||
"globalPositivePromptPlaceholder": "Prompt positivo globale",
|
||||
"globalNegativePromptPlaceholder": "Prompt negativo globale"
|
||||
"globalNegativePromptPlaceholder": "Prompt negativo globale",
|
||||
"processImage": "Elabora Immagine",
|
||||
"sendToUpscale": "Invia a Ampliare",
|
||||
"postProcessing": "Post-elaborazione (Shift + U)"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Modelli",
|
||||
@ -964,7 +969,10 @@
|
||||
"boards": "Bacheche",
|
||||
"private": "Bacheche private",
|
||||
"shared": "Bacheche condivise",
|
||||
"addPrivateBoard": "Aggiungi una Bacheca Privata"
|
||||
"addPrivateBoard": "Aggiungi una Bacheca Privata",
|
||||
"noBoards": "Nessuna bacheca {{boardType}}",
|
||||
"hideBoards": "Nascondi bacheche",
|
||||
"viewBoards": "Visualizza bacheche"
|
||||
},
|
||||
"controlnet": {
|
||||
"contentShuffleDescription": "Rimescola il contenuto di un'immagine",
|
||||
@ -1684,7 +1692,30 @@
|
||||
"models": "Modelli",
|
||||
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
|
||||
"queue": "Coda",
|
||||
"queueTab": "$t(ui.tabs.queue) $t(common.tab)"
|
||||
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
|
||||
"upscaling": "Ampliamento",
|
||||
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
|
||||
}
|
||||
},
|
||||
"upscaling": {
|
||||
"creativity": "Creatività",
|
||||
"structure": "Struttura",
|
||||
"upscaleModel": "Modello di Ampliamento",
|
||||
"scale": "Scala",
|
||||
"missingModelsWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare i modelli richiesti:",
|
||||
"mainModelDesc": "Modello principale (architettura SD1.5 o SDXL)",
|
||||
"tileControlNetModelDesc": "Modello Tile ControlNet per l'architettura del modello principale scelto",
|
||||
"upscaleModelDesc": "Modello per l'ampliamento (da immagine a immagine)",
|
||||
"missingUpscaleInitialImage": "Immagine iniziale mancante per l'ampliamento",
|
||||
"missingUpscaleModel": "Modello per l’ampliamento mancante",
|
||||
"missingTileControlNetModel": "Nessun modello ControlNet Tile valido installato",
|
||||
"postProcessingModel": "Modello di post-elaborazione",
|
||||
"postProcessingMissingModelWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare un modello di post-elaborazione (da immagine a immagine)."
|
||||
},
|
||||
"upsell": {
|
||||
"inviteTeammates": "Invita collaboratori",
|
||||
"shareAccess": "Condividi l'accesso",
|
||||
"professional": "Professionale",
|
||||
"professionalUpsell": "Disponibile nell'edizione Professional di Invoke. Fai clic qui o visita invoke.com/pricing per ulteriori dettagli."
|
||||
}
|
||||
}
|
||||
|
@ -199,10 +199,6 @@
|
||||
"title": "顔の修復",
|
||||
"desc": "現在の画像を修復"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "アップスケール",
|
||||
"desc": "現在の画像をアップスケール"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "情報を見る",
|
||||
"desc": "現在の画像のメタデータ情報を表示"
|
||||
@ -427,8 +423,6 @@
|
||||
"shuffle": "シャッフル",
|
||||
"strength": "強度",
|
||||
"upscaling": "アップスケーリング",
|
||||
"upscale": "アップスケール",
|
||||
"upscaleImage": "画像をアップスケール",
|
||||
"scale": "Scale",
|
||||
"scaleBeforeProcessing": "処理前のスケール",
|
||||
"scaledWidth": "幅のスケール",
|
||||
|
@ -258,10 +258,6 @@
|
||||
"desc": "캔버스 브러시를 선택",
|
||||
"title": "브러시 선택"
|
||||
},
|
||||
"upscale": {
|
||||
"desc": "현재 이미지를 업스케일",
|
||||
"title": "업스케일"
|
||||
},
|
||||
"previousImage": {
|
||||
"title": "이전 이미지",
|
||||
"desc": "갤러리에 이전 이미지 표시"
|
||||
|
@ -168,10 +168,6 @@
|
||||
"title": "Herstel gezichten",
|
||||
"desc": "Herstelt de huidige afbeelding"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Schaal op",
|
||||
"desc": "Schaalt de huidige afbeelding op"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Toon info",
|
||||
"desc": "Toont de metagegevens van de huidige afbeelding"
|
||||
@ -412,8 +408,6 @@
|
||||
"type": "Soort",
|
||||
"strength": "Sterkte",
|
||||
"upscaling": "Opschalen",
|
||||
"upscale": "Vergroot (Shift + U)",
|
||||
"upscaleImage": "Schaal afbeelding op",
|
||||
"scale": "Schaal",
|
||||
"imageFit": "Pas initiële afbeelding in uitvoergrootte",
|
||||
"scaleBeforeProcessing": "Schalen voor verwerking",
|
||||
|
@ -78,10 +78,6 @@
|
||||
"title": "Popraw twarze",
|
||||
"desc": "Uruchamia proces poprawiania twarzy dla aktywnego obrazu"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Powiększ",
|
||||
"desc": "Uruchamia proces powiększania aktywnego obrazu"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Pokaż informacje",
|
||||
"desc": "Pokazuje metadane zapisane w aktywnym obrazie"
|
||||
@ -232,8 +228,6 @@
|
||||
"type": "Metoda",
|
||||
"strength": "Siła",
|
||||
"upscaling": "Powiększanie",
|
||||
"upscale": "Powiększ",
|
||||
"upscaleImage": "Powiększ obraz",
|
||||
"scale": "Skala",
|
||||
"imageFit": "Przeskaluj oryginalny obraz",
|
||||
"scaleBeforeProcessing": "Tryb skalowania",
|
||||
|
@ -160,10 +160,6 @@
|
||||
"title": "Restaurar Rostos",
|
||||
"desc": "Restaurar a imagem atual"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Redimensionar",
|
||||
"desc": "Redimensionar a imagem atual"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Mostrar Informações",
|
||||
"desc": "Mostrar metadados de informações da imagem atual"
|
||||
@ -275,8 +271,6 @@
|
||||
"showOptionsPanel": "Mostrar Painel de Opções",
|
||||
"strength": "Força",
|
||||
"upscaling": "Redimensionando",
|
||||
"upscale": "Redimensionar",
|
||||
"upscaleImage": "Redimensionar Imagem",
|
||||
"scaleBeforeProcessing": "Escala Antes do Processamento",
|
||||
"images": "Imagems",
|
||||
"steps": "Passos",
|
||||
|
@ -80,10 +80,6 @@
|
||||
"title": "Restaurar Rostos",
|
||||
"desc": "Restaurar a imagem atual"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Redimensionar",
|
||||
"desc": "Redimensionar a imagem atual"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Mostrar Informações",
|
||||
"desc": "Mostrar metadados de informações da imagem atual"
|
||||
@ -268,8 +264,6 @@
|
||||
"type": "Tipo",
|
||||
"strength": "Força",
|
||||
"upscaling": "Redimensionando",
|
||||
"upscale": "Redimensionar",
|
||||
"upscaleImage": "Redimensionar Imagem",
|
||||
"scale": "Escala",
|
||||
"imageFit": "Caber Imagem Inicial No Tamanho de Saída",
|
||||
"scaleBeforeProcessing": "Escala Antes do Processamento",
|
||||
|
@ -214,10 +214,6 @@
|
||||
"title": "Восстановить лица",
|
||||
"desc": "Восстановить лица на текущем изображении"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Увеличение",
|
||||
"desc": "Увеличить текущеее изображение"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Показать метаданные",
|
||||
"desc": "Показать метаданные из текущего изображения"
|
||||
@ -512,8 +508,6 @@
|
||||
"type": "Тип",
|
||||
"strength": "Сила",
|
||||
"upscaling": "Увеличение",
|
||||
"upscale": "Увеличить",
|
||||
"upscaleImage": "Увеличить изображение",
|
||||
"scale": "Масштаб",
|
||||
"imageFit": "Уместить изображение",
|
||||
"scaleBeforeProcessing": "Масштабировать",
|
||||
|
@ -90,10 +90,6 @@
|
||||
"title": "Återskapa ansikten",
|
||||
"desc": "Återskapa nuvarande bild"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Skala upp",
|
||||
"desc": "Skala upp nuvarande bild"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Visa info",
|
||||
"desc": "Visa metadata för nuvarande bild"
|
||||
|
@ -416,10 +416,6 @@
|
||||
"desc": "Maske/Taban katmanları arasında geçiş yapar",
|
||||
"title": "Katmanı Gizle-Göster"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Büyüt",
|
||||
"desc": "Seçili görseli büyüt"
|
||||
},
|
||||
"setSeed": {
|
||||
"title": "Tohumu Kullan",
|
||||
"desc": "Seçili görselin tohumunu kullan"
|
||||
@ -641,7 +637,6 @@
|
||||
"copyImage": "Görseli Kopyala",
|
||||
"height": "Boy",
|
||||
"width": "En",
|
||||
"upscale": "Büyüt (Shift + U)",
|
||||
"useSize": "Boyutu Kullan",
|
||||
"symmetry": "Bakışım",
|
||||
"tileSize": "Döşeme Boyutu",
|
||||
@ -657,7 +652,6 @@
|
||||
"showOptionsPanel": "Yan Paneli Göster (O ya da T)",
|
||||
"shuffle": "Kar",
|
||||
"usePrompt": "İstemi Kullan",
|
||||
"upscaleImage": "Görseli Büyüt",
|
||||
"setToOptimalSizeTooSmall": "$t(parameters.setToOptimalSize) (çok küçük olabilir)",
|
||||
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (çok büyük olabilir)",
|
||||
"cfgRescaleMultiplier": "CFG Rescale Çarpanı",
|
||||
|
@ -85,10 +85,6 @@
|
||||
"title": "Відновити обличчя",
|
||||
"desc": "Відновити обличчя на поточному зображенні"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "Збільшення",
|
||||
"desc": "Збільшити поточне зображення"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "Показати метадані",
|
||||
"desc": "Показати метадані з поточного зображення"
|
||||
@ -276,8 +272,6 @@
|
||||
"type": "Тип",
|
||||
"strength": "Сила",
|
||||
"upscaling": "Збільшення",
|
||||
"upscale": "Збільшити",
|
||||
"upscaleImage": "Збільшити зображення",
|
||||
"scale": "Масштаб",
|
||||
"imageFit": "Вмістити зображення",
|
||||
"scaleBeforeProcessing": "Масштабувати",
|
||||
|
@ -193,10 +193,6 @@
|
||||
"title": "面部修复",
|
||||
"desc": "对当前图像进行面部修复"
|
||||
},
|
||||
"upscale": {
|
||||
"title": "放大",
|
||||
"desc": "对当前图像进行放大"
|
||||
},
|
||||
"showInfo": {
|
||||
"title": "显示信息",
|
||||
"desc": "显示当前图像的元数据"
|
||||
@ -422,8 +418,6 @@
|
||||
"type": "种类",
|
||||
"strength": "强度",
|
||||
"upscaling": "放大",
|
||||
"upscale": "放大 (Shift + U)",
|
||||
"upscaleImage": "放大图像",
|
||||
"scale": "等级",
|
||||
"imageFit": "使生成图像长宽适配初始图像",
|
||||
"scaleBeforeProcessing": "处理前缩放",
|
||||
|
@ -1,5 +1,6 @@
|
||||
import type { TypedStartListening } from '@reduxjs/toolkit';
|
||||
import { createListenerMiddleware } from '@reduxjs/toolkit';
|
||||
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
|
||||
import { addCommitStagingAreaImageListener } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
|
||||
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
|
||||
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
|
||||
@ -47,11 +48,11 @@ import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddlewa
|
||||
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
|
||||
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
|
||||
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
|
||||
import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
|
||||
import { addEnqueueRequestedUpscale } from './listeners/enqueueRequestedUpscale';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -85,6 +86,7 @@ addGalleryOffsetChangedListener(startAppListening);
|
||||
addEnqueueRequestedCanvasListener(startAppListening);
|
||||
addEnqueueRequestedNodes(startAppListening);
|
||||
addEnqueueRequestedLinear(startAppListening);
|
||||
addEnqueueRequestedUpscale(startAppListening);
|
||||
addAnyEnqueuedListener(startAppListening);
|
||||
addBatchEnqueuedListener(startAppListening);
|
||||
|
||||
@ -140,7 +142,7 @@ addModelsLoadedListener(startAppListening);
|
||||
addAppConfigReceivedListener(startAppListening);
|
||||
|
||||
// Ad-hoc upscale workflwo
|
||||
addUpscaleRequestedListener(startAppListening);
|
||||
addAdHocPostProcessingRequestedListener(startAppListening);
|
||||
|
||||
// Prompts
|
||||
addDynamicPromptsListener(startAppListening);
|
||||
|
@ -2,46 +2,28 @@ import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { buildAdHocUpscaleGraph } from 'features/nodes/util/graph/buildAdHocUpscaleGraph';
|
||||
import { createIsAllowedToUpscaleSelector } from 'features/parameters/hooks/useIsAllowedToUpscale';
|
||||
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { BatchConfig, ImageDTO } from 'services/api/types';
|
||||
|
||||
export const upscaleRequested = createAction<{ imageDTO: ImageDTO }>(`upscale/upscaleRequested`);
|
||||
export const adHocPostProcessingRequested = createAction<{ imageDTO: ImageDTO }>(`upscaling/postProcessingRequested`);
|
||||
|
||||
export const addUpscaleRequestedListener = (startAppListening: AppStartListening) => {
|
||||
export const addAdHocPostProcessingRequestedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: upscaleRequested,
|
||||
actionCreator: adHocPostProcessingRequested,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const log = logger('session');
|
||||
|
||||
const { imageDTO } = action.payload;
|
||||
const { image_name } = imageDTO;
|
||||
const state = getState();
|
||||
|
||||
const { isAllowedToUpscale, detailTKey } = createIsAllowedToUpscaleSelector(imageDTO)(state);
|
||||
|
||||
// if we can't upscale, show a toast and return
|
||||
if (!isAllowedToUpscale) {
|
||||
log.error(
|
||||
{ imageDTO },
|
||||
t(detailTKey ?? 'parameters.isAllowedToUpscale.tooLarge') // should never coalesce
|
||||
);
|
||||
toast({
|
||||
id: 'NOT_ALLOWED_TO_UPSCALE',
|
||||
title: t(detailTKey ?? 'parameters.isAllowedToUpscale.tooLarge'), // should never coalesce
|
||||
status: 'error',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const enqueueBatchArg: BatchConfig = {
|
||||
prepend: true,
|
||||
batch: {
|
||||
graph: buildAdHocUpscaleGraph({
|
||||
image_name,
|
||||
graph: await buildAdHocPostProcessingGraph({
|
||||
image: imageDTO,
|
||||
state,
|
||||
}),
|
||||
runs: 1,
|
@ -10,32 +10,32 @@ import {
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
// Type inference doesn't work for this if you inline it in the listener for some reason
|
||||
const matchAnyBoardDeleted = isAnyOf(
|
||||
imagesApi.endpoints.deleteBoard.matchFulfilled,
|
||||
imagesApi.endpoints.deleteBoardAndImages.matchFulfilled
|
||||
);
|
||||
|
||||
export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartListening) => {
|
||||
/**
|
||||
* The auto-add board shouldn't be set to an archived board or deleted board. When we archive a board, delete
|
||||
* a board, or change a the archived board visibility flag, we may need to reset the auto-add board.
|
||||
*/
|
||||
startAppListening({
|
||||
matcher: isAnyOf(
|
||||
// If a board is deleted, we'll need to reset the auto-add board
|
||||
imagesApi.endpoints.deleteBoard.matchFulfilled,
|
||||
imagesApi.endpoints.deleteBoardAndImages.matchFulfilled
|
||||
),
|
||||
matcher: matchAnyBoardDeleted,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
const queryArgs = selectListBoardsQueryArgs(state);
|
||||
const queryResult = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
|
||||
const deletedBoardId = action.meta.arg.originalArgs;
|
||||
const { autoAddBoardId, selectedBoardId } = state.gallery;
|
||||
|
||||
if (!queryResult.data) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!queryResult.data.find((board) => board.board_id === selectedBoardId)) {
|
||||
// If the deleted board was currently selected, we should reset the selected board to uncategorized
|
||||
if (deletedBoardId === selectedBoardId) {
|
||||
dispatch(boardIdSelected({ boardId: 'none' }));
|
||||
dispatch(galleryViewChanged('images'));
|
||||
}
|
||||
if (!queryResult.data.find((board) => board.board_id === autoAddBoardId)) {
|
||||
|
||||
// If the deleted board was selected for auto-add, we should reset the auto-add board to uncategorized
|
||||
if (deletedBoardId === autoAddBoardId) {
|
||||
dispatch(autoAddBoardIdChanged('none'));
|
||||
}
|
||||
},
|
||||
@ -46,14 +46,8 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
||||
matcher: boardsApi.endpoints.updateBoard.matchFulfilled,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const state = getState();
|
||||
const queryArgs = selectListBoardsQueryArgs(state);
|
||||
const queryResult = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
|
||||
const { shouldShowArchivedBoards } = state.gallery;
|
||||
|
||||
if (!queryResult.data) {
|
||||
return;
|
||||
}
|
||||
|
||||
const wasArchived = action.meta.arg.originalArgs.changes.archived === true;
|
||||
|
||||
if (wasArchived && !shouldShowArchivedBoards) {
|
||||
@ -71,7 +65,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
||||
const shouldShowArchivedBoards = action.payload;
|
||||
|
||||
// We only need to take action if we have just hidden archived boards.
|
||||
if (!shouldShowArchivedBoards) {
|
||||
if (shouldShowArchivedBoards) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -86,14 +80,16 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
|
||||
|
||||
// Handle the case where selected board is archived
|
||||
const selectedBoard = queryResult.data.find((b) => b.board_id === selectedBoardId);
|
||||
if (selectedBoard && selectedBoard.archived) {
|
||||
if (!selectedBoard || selectedBoard.archived) {
|
||||
// If we can't find the selected board or it's archived, we should reset the selected board to uncategorized
|
||||
dispatch(boardIdSelected({ boardId: 'none' }));
|
||||
dispatch(galleryViewChanged('images'));
|
||||
}
|
||||
|
||||
// Handle the case where auto-add board is archived
|
||||
const autoAddBoard = queryResult.data.find((b) => b.board_id === autoAddBoardId);
|
||||
if (autoAddBoard && autoAddBoard.archived) {
|
||||
if (!autoAddBoard || autoAddBoard.archived) {
|
||||
// If we can't find the auto-add board or it's archived, we should reset the selected board to uncategorized
|
||||
dispatch(autoAddBoardIdChanged('none'));
|
||||
}
|
||||
},
|
||||
|
@ -0,0 +1,36 @@
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'upscaling',
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const { shouldShowProgressInViewer } = state.ui;
|
||||
const { prepend } = action.payload;
|
||||
|
||||
const graph = await buildMultidiffusionUpscaleGraph(state);
|
||||
|
||||
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
try {
|
||||
await req.unwrap();
|
||||
if (shouldShowProgressInViewer) {
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
}
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -23,6 +23,7 @@ import {
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
export const dndDropped = createAction<{
|
||||
@ -243,6 +244,20 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Image dropped on upscale initial image
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'SET_UPSCALE_INITIAL_IMAGE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { imageDTO } = activeData.payload;
|
||||
|
||||
dispatch(upscaleInitialImageChanged(imageDTO));
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Multiple images dropped on user board
|
||||
*/
|
||||
|
@ -14,6 +14,7 @@ import {
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { omit } from 'lodash-es';
|
||||
@ -89,6 +90,15 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
|
||||
return;
|
||||
}
|
||||
|
||||
if (postUploadAction?.type === 'SET_UPSCALE_INITIAL_IMAGE') {
|
||||
dispatch(upscaleInitialImageChanged(imageDTO));
|
||||
toast({
|
||||
...DEFAULT_UPLOADED_TOAST,
|
||||
description: 'set as upscale initial image',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (postUploadAction?.type === 'SET_CONTROL_ADAPTER_IMAGE') {
|
||||
const { id } = postUploadAction;
|
||||
dispatch(
|
||||
|
@ -10,6 +10,7 @@ import { heightChanged, widthChanged } from 'features/controlLayers/store/contro
|
||||
import { loraRemoved } from 'features/lora/store/loraSlice';
|
||||
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
|
||||
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||
@ -17,7 +18,12 @@ import { forEach } from 'lodash-es';
|
||||
import type { Logger } from 'roarr';
|
||||
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig, isRefinerMainModelModelConfig, isVAEModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isNonRefinerMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isSpandrelImageToImageModelConfig,
|
||||
isVAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
@ -36,6 +42,7 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
||||
handleVAEModels(models, state, dispatch, log);
|
||||
handleLoRAModels(models, state, dispatch, log);
|
||||
handleControlAdapterModels(models, state, dispatch, log);
|
||||
handleSpandrelImageToImageModels(models, state, dispatch, log);
|
||||
},
|
||||
});
|
||||
};
|
||||
@ -177,3 +184,25 @@ const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log)
|
||||
dispatch(controlAdapterModelCleared({ id: ca.id }));
|
||||
});
|
||||
};
|
||||
|
||||
const handleSpandrelImageToImageModels: ModelHandler = (models, state, dispatch, _log) => {
|
||||
const { upscaleModel: currentUpscaleModel, postProcessingModel: currentPostProcessingModel } = state.upscale;
|
||||
const upscaleModels = models.filter(isSpandrelImageToImageModelConfig);
|
||||
const firstModel = upscaleModels[0] || null;
|
||||
|
||||
const isCurrentUpscaleModelAvailable = currentUpscaleModel
|
||||
? upscaleModels.some((m) => m.key === currentUpscaleModel.key)
|
||||
: false;
|
||||
|
||||
if (!isCurrentUpscaleModelAvailable) {
|
||||
dispatch(upscaleModelChanged(firstModel));
|
||||
}
|
||||
|
||||
const isCurrentPostProcessingModelAvailable = currentPostProcessingModel
|
||||
? upscaleModels.some((m) => m.key === currentPostProcessingModel.key)
|
||||
: false;
|
||||
|
||||
if (!isCurrentPostProcessingModelAvailable) {
|
||||
dispatch(postProcessingModelChanged(firstModel));
|
||||
}
|
||||
};
|
||||
|
@ -25,7 +25,7 @@ import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/no
|
||||
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
|
||||
import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice';
|
||||
import { queueSlice } from 'features/queue/store/queueSlice';
|
||||
import { sdxlPersistConfig, sdxlSlice } from 'features/sdxl/store/sdxlSlice';
|
||||
import { configSlice } from 'features/system/store/configSlice';
|
||||
@ -52,7 +52,6 @@ const allReducers = {
|
||||
[gallerySlice.name]: gallerySlice.reducer,
|
||||
[generationSlice.name]: generationSlice.reducer,
|
||||
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
|
||||
[postprocessingSlice.name]: postprocessingSlice.reducer,
|
||||
[systemSlice.name]: systemSlice.reducer,
|
||||
[configSlice.name]: configSlice.reducer,
|
||||
[uiSlice.name]: uiSlice.reducer,
|
||||
@ -69,6 +68,7 @@ const allReducers = {
|
||||
[controlLayersSlice.name]: undoable(controlLayersSlice.reducer, controlLayersUndoableConfig),
|
||||
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
|
||||
[api.reducerPath]: api.reducer,
|
||||
[upscaleSlice.name]: upscaleSlice.reducer,
|
||||
};
|
||||
|
||||
const rootReducer = combineReducers(allReducers);
|
||||
@ -102,7 +102,6 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
||||
[galleryPersistConfig.name]: galleryPersistConfig,
|
||||
[generationPersistConfig.name]: generationPersistConfig,
|
||||
[nodesPersistConfig.name]: nodesPersistConfig,
|
||||
[postprocessingPersistConfig.name]: postprocessingPersistConfig,
|
||||
[systemPersistConfig.name]: systemPersistConfig,
|
||||
[workflowPersistConfig.name]: workflowPersistConfig,
|
||||
[uiPersistConfig.name]: uiPersistConfig,
|
||||
@ -114,6 +113,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
||||
[hrfPersistConfig.name]: hrfPersistConfig,
|
||||
[controlLayersPersistConfig.name]: controlLayersPersistConfig,
|
||||
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
|
||||
[upscalePersistConfig.name]: upscalePersistConfig,
|
||||
};
|
||||
|
||||
const unserialize: UnserializeFunction = (data, key) => {
|
||||
|
@ -72,7 +72,6 @@ export type AppConfig = {
|
||||
canRestoreDeletedImagesFromBin: boolean;
|
||||
nodesAllowlist: string[] | undefined;
|
||||
nodesDenylist: string[] | undefined;
|
||||
maxUpscalePixels?: number;
|
||||
metadataFetchDebounce?: number;
|
||||
workflowFetchDebounce?: number;
|
||||
isLocal?: boolean;
|
||||
|
@ -10,9 +10,12 @@ import {
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
Portal,
|
||||
Spacer,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setShouldEnableInformationalPopovers } from 'features/system/store/systemSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { merge, omit } from 'lodash-es';
|
||||
import type { ReactElement } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@ -71,7 +74,7 @@ type ContentProps = {
|
||||
|
||||
const Content = ({ data, feature }: ContentProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const heading = useMemo<string | undefined>(() => t(`popovers.${feature}.heading`), [feature, t]);
|
||||
|
||||
const paragraphs = useMemo<string[]>(
|
||||
@ -82,16 +85,25 @@ const Content = ({ data, feature }: ContentProps) => {
|
||||
[feature, t]
|
||||
);
|
||||
|
||||
const handleClick = useCallback(() => {
|
||||
const onClickLearnMore = useCallback(() => {
|
||||
if (!data?.href) {
|
||||
return;
|
||||
}
|
||||
window.open(data.href);
|
||||
}, [data?.href]);
|
||||
|
||||
const onClickDontShowMeThese = useCallback(() => {
|
||||
dispatch(setShouldEnableInformationalPopovers(false));
|
||||
toast({
|
||||
title: t('settings.informationalPopoversDisabled'),
|
||||
description: t('settings.informationalPopoversDisabledDesc'),
|
||||
status: 'info',
|
||||
});
|
||||
}, [dispatch, t]);
|
||||
|
||||
return (
|
||||
<PopoverContent w={96}>
|
||||
<PopoverCloseButton />
|
||||
<PopoverContent maxW={300}>
|
||||
<PopoverCloseButton top={2} />
|
||||
<PopoverBody>
|
||||
<Flex gap={2} flexDirection="column" alignItems="flex-start">
|
||||
{heading && (
|
||||
@ -116,21 +128,20 @@ const Content = ({ data, feature }: ContentProps) => {
|
||||
{paragraphs.map((p) => (
|
||||
<Text key={p}>{p}</Text>
|
||||
))}
|
||||
{data?.href && (
|
||||
<>
|
||||
|
||||
<Divider />
|
||||
<Button
|
||||
pt={1}
|
||||
onClick={handleClick}
|
||||
leftIcon={<PiArrowSquareOutBold />}
|
||||
alignSelf="flex-end"
|
||||
variant="link"
|
||||
>
|
||||
<Flex alignItems="center" justifyContent="space-between" w="full">
|
||||
<Button onClick={onClickDontShowMeThese} variant="link" size="sm">
|
||||
{t('common.dontShowMeThese')}
|
||||
</Button>
|
||||
<Spacer />
|
||||
{data?.href && (
|
||||
<Button onClick={onClickLearnMore} leftIcon={<PiArrowSquareOutBold />} variant="link" size="sm">
|
||||
{t('common.learnMore') ?? heading}
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
);
|
||||
|
@ -53,7 +53,11 @@ export type Feature =
|
||||
| 'refinerCfgScale'
|
||||
| 'scaleBeforeProcessing'
|
||||
| 'seamlessTilingXAxis'
|
||||
| 'seamlessTilingYAxis';
|
||||
| 'seamlessTilingYAxis'
|
||||
| 'upscaleModel'
|
||||
| 'scale'
|
||||
| 'creativity'
|
||||
| 'structure';
|
||||
|
||||
export type PopoverData = PopoverProps & {
|
||||
image?: string;
|
||||
|
18
invokeai/frontend/web/src/common/hooks/useAssertSingleton.ts
Normal file
18
invokeai/frontend/web/src/common/hooks/useAssertSingleton.ts
Normal file
@ -0,0 +1,18 @@
|
||||
import { useEffect } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const IDS = new Set<string>();
|
||||
|
||||
/**
|
||||
* Asserts that there is only one instance of a singleton entity. It can be a hook or a component.
|
||||
* @param id The ID of the singleton entity.
|
||||
*/
|
||||
export function useAssertSingleton(id: string) {
|
||||
useEffect(() => {
|
||||
assert(!IDS.has(id), `There should be only one instance of ${id}`);
|
||||
IDS.add(id);
|
||||
return () => {
|
||||
IDS.delete(id);
|
||||
};
|
||||
}, [id]);
|
||||
}
|
@ -21,6 +21,10 @@ const selectPostUploadAction = createMemoizedSelector(activeTabNameSelector, (ac
|
||||
postUploadAction = { type: 'SET_CANVAS_INITIAL_IMAGE' };
|
||||
}
|
||||
|
||||
if (activeTabName === 'upscaling') {
|
||||
postUploadAction = { type: 'SET_UPSCALE_INITIAL_IMAGE' };
|
||||
}
|
||||
|
||||
return postUploadAction;
|
||||
});
|
||||
|
||||
|
@ -15,6 +15,7 @@ import type { Templates } from 'features/nodes/store/types';
|
||||
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { selectUpscalelice } from 'features/parameters/store/upscaleSlice';
|
||||
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import i18n from 'i18next';
|
||||
@ -40,8 +41,19 @@ const createSelector = (templates: Templates) =>
|
||||
selectDynamicPromptsSlice,
|
||||
selectControlLayersSlice,
|
||||
activeTabNameSelector,
|
||||
selectUpscalelice,
|
||||
],
|
||||
(controlAdapters, generation, system, nodes, workflowSettings, dynamicPrompts, controlLayers, activeTabName) => {
|
||||
(
|
||||
controlAdapters,
|
||||
generation,
|
||||
system,
|
||||
nodes,
|
||||
workflowSettings,
|
||||
dynamicPrompts,
|
||||
controlLayers,
|
||||
activeTabName,
|
||||
upscale
|
||||
) => {
|
||||
const { model } = generation;
|
||||
const { size } = controlLayers.present;
|
||||
const { positivePrompt } = controlLayers.present;
|
||||
@ -194,6 +206,16 @@ const createSelector = (templates: Templates) =>
|
||||
reasons.push({ prefix, content });
|
||||
}
|
||||
});
|
||||
} else if (activeTabName === 'upscaling') {
|
||||
if (!upscale.upscaleInitialImage) {
|
||||
reasons.push({ content: i18n.t('upscaling.missingUpscaleInitialImage') });
|
||||
}
|
||||
if (!upscale.upscaleModel) {
|
||||
reasons.push({ content: i18n.t('upscaling.missingUpscaleModel') });
|
||||
}
|
||||
if (!upscale.tileControlnetModel) {
|
||||
reasons.push({ content: i18n.t('upscaling.missingTileControlNetModel') });
|
||||
}
|
||||
} else {
|
||||
// Handling for all other tabs
|
||||
selectControlAdapterAll(controlAdapters)
|
||||
|
@ -62,6 +62,10 @@ export type CanvasInitialImageDropData = BaseDropData & {
|
||||
actionType: 'SET_CANVAS_INITIAL_IMAGE';
|
||||
};
|
||||
|
||||
type UpscaleInitialImageDropData = BaseDropData & {
|
||||
actionType: 'SET_UPSCALE_INITIAL_IMAGE';
|
||||
};
|
||||
|
||||
type NodesImageDropData = BaseDropData & {
|
||||
actionType: 'SET_NODES_IMAGE';
|
||||
context: {
|
||||
@ -98,7 +102,8 @@ export type TypesafeDroppableData =
|
||||
| IPALayerImageDropData
|
||||
| RGLayerIPAdapterImageDropData
|
||||
| IILayerImageDropData
|
||||
| SelectForCompareDropData;
|
||||
| SelectForCompareDropData
|
||||
| UpscaleInitialImageDropData;
|
||||
|
||||
type BaseDragData = {
|
||||
id: string;
|
||||
|
@ -27,6 +27,8 @@ export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_CANVAS_INITIAL_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_UPSCALE_INITIAL_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_NODES_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SELECT_FOR_COMPARE':
|
||||
|
@ -0,0 +1,47 @@
|
||||
import { Flex, Image, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import type { BoardDTO } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
board: BoardDTO | null;
|
||||
};
|
||||
|
||||
export const BoardTooltip = ({ board }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const { imagesTotal } = useGetBoardImagesTotalQuery(board?.board_id || 'none', {
|
||||
selectFromResult: ({ data }) => {
|
||||
return { imagesTotal: data?.total ?? 0 };
|
||||
},
|
||||
});
|
||||
const { assetsTotal } = useGetBoardAssetsTotalQuery(board?.board_id || 'none', {
|
||||
selectFromResult: ({ data }) => {
|
||||
return { assetsTotal: data?.total ?? 0 };
|
||||
},
|
||||
});
|
||||
const { currentData: coverImage } = useGetImageDTOQuery(board?.cover_image_name ?? skipToken);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" alignItems="center" gap={1}>
|
||||
{coverImage && (
|
||||
<Image
|
||||
src={coverImage.thumbnail_url}
|
||||
draggable={false}
|
||||
objectFit="cover"
|
||||
maxW={150}
|
||||
aspectRatio="1/1"
|
||||
borderRadius="base"
|
||||
borderBottomRadius="lg"
|
||||
/>
|
||||
)}
|
||||
<Flex flexDir="column" alignItems="center">
|
||||
<Text noOfLines={1}>
|
||||
{t('boards.imagesWithCount', { count: imagesTotal })}, {t('boards.assetsWithCount', { count: assetsTotal })}
|
||||
</Text>
|
||||
{board?.archived && <Text>({t('boards.archived')})</Text>}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -1,22 +0,0 @@
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
|
||||
|
||||
type Props = {
|
||||
board_id: string;
|
||||
isArchived: boolean;
|
||||
};
|
||||
|
||||
export const BoardTotalsTooltip = ({ board_id, isArchived }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const { imagesTotal } = useGetBoardImagesTotalQuery(board_id, {
|
||||
selectFromResult: ({ data }) => {
|
||||
return { imagesTotal: data?.total ?? 0 };
|
||||
},
|
||||
});
|
||||
const { assetsTotal } = useGetBoardAssetsTotalQuery(board_id, {
|
||||
selectFromResult: ({ data }) => {
|
||||
return { assetsTotal: data?.total ?? 0 };
|
||||
},
|
||||
});
|
||||
return `${t('boards.imagesWithCount', { count: imagesTotal })}, ${t('boards.assetsWithCount', { count: assetsTotal })}${isArchived ? ` (${t('boards.archived')})` : ''}`;
|
||||
};
|
@ -1,13 +1,10 @@
|
||||
import { Box, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { Button, Collapse, Flex, Icon, Text, useDisclosure } from '@invoke-ai/ui-library';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
|
||||
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useMemo, useState } from 'react';
|
||||
import { useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretDownBold } from 'react-icons/pi';
|
||||
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
||||
import type { BoardDTO } from 'services/api/types';
|
||||
|
||||
@ -15,101 +12,111 @@ import AddBoardButton from './AddBoardButton';
|
||||
import GalleryBoard from './GalleryBoard';
|
||||
import NoBoardBoard from './NoBoardBoard';
|
||||
|
||||
const overlayScrollbarsStyles: CSSProperties = {
|
||||
height: '100%',
|
||||
width: '100%',
|
||||
type Props = {
|
||||
isPrivate: boolean;
|
||||
setBoardToDelete: (board?: BoardDTO) => void;
|
||||
};
|
||||
|
||||
const BoardsList = () => {
|
||||
export const BoardsList = ({ isPrivate, setBoardToDelete }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||
const boardSearchText = useAppSelector((s) => s.gallery.boardSearchText);
|
||||
const allowPrivateBoards = useAppSelector((s) => s.config.allowPrivateBoards);
|
||||
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
|
||||
const { data: boards } = useListAllBoardsQuery(queryArgs);
|
||||
const [boardToDelete, setBoardToDelete] = useState<BoardDTO>();
|
||||
const { t } = useTranslation();
|
||||
const allowPrivateBoards = useAppSelector((s) => s.config.allowPrivateBoards);
|
||||
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true });
|
||||
|
||||
const { filteredPrivateBoards, filteredSharedBoards } = useMemo(() => {
|
||||
const filteredBoards = boardSearchText
|
||||
? boards?.filter((board) => board.board_name.toLowerCase().includes(boardSearchText.toLowerCase()))
|
||||
: boards;
|
||||
const filteredPrivateBoards = filteredBoards?.filter((board) => board.is_private) ?? EMPTY_ARRAY;
|
||||
const filteredSharedBoards = filteredBoards?.filter((board) => !board.is_private) ?? EMPTY_ARRAY;
|
||||
return { filteredPrivateBoards, filteredSharedBoards };
|
||||
}, [boardSearchText, boards]);
|
||||
const filteredBoards = useMemo(() => {
|
||||
if (!boards) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
|
||||
return boards.filter((board) => {
|
||||
if (boardSearchText.length) {
|
||||
return board.is_private === isPrivate && board.board_name.toLowerCase().includes(boardSearchText.toLowerCase());
|
||||
} else {
|
||||
return board.is_private === isPrivate;
|
||||
}
|
||||
});
|
||||
}, [boardSearchText, boards, isPrivate]);
|
||||
|
||||
const boardElements = useMemo(() => {
|
||||
const elements = [];
|
||||
if (allowPrivateBoards && isPrivate && !boardSearchText.length) {
|
||||
elements.push(<NoBoardBoard key="none" isSelected={selectedBoardId === 'none'} />);
|
||||
}
|
||||
|
||||
if (!allowPrivateBoards && !boardSearchText.length) {
|
||||
elements.push(<NoBoardBoard key="none" isSelected={selectedBoardId === 'none'} />);
|
||||
}
|
||||
|
||||
filteredBoards.forEach((board) => {
|
||||
elements.push(
|
||||
<GalleryBoard
|
||||
board={board}
|
||||
isSelected={selectedBoardId === board.board_id}
|
||||
setBoardToDelete={setBoardToDelete}
|
||||
key={board.board_id}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
return elements;
|
||||
}, [allowPrivateBoards, isPrivate, boardSearchText.length, filteredBoards, selectedBoardId, setBoardToDelete]);
|
||||
|
||||
const boardListTitle = useMemo(() => {
|
||||
if (allowPrivateBoards) {
|
||||
return isPrivate ? t('boards.private') : t('boards.shared');
|
||||
} else {
|
||||
return t('boards.boards');
|
||||
}
|
||||
}, [isPrivate, allowPrivateBoards, t]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Box position="relative" w="full" h="full">
|
||||
<Box position="absolute" top={0} right={0} bottom={0} left={0}>
|
||||
<OverlayScrollbarsComponent defer style={overlayScrollbarsStyles} options={overlayScrollbarsParams.options}>
|
||||
{allowPrivateBoards && (
|
||||
<Flex direction="column" gap={1}>
|
||||
<Flex direction="column">
|
||||
<Flex
|
||||
position="sticky"
|
||||
w="full"
|
||||
justifyContent="space-between"
|
||||
alignItems="center"
|
||||
ps={2}
|
||||
pb={1}
|
||||
pt={2}
|
||||
py={1}
|
||||
zIndex={1}
|
||||
top={0}
|
||||
bg="base.900"
|
||||
>
|
||||
<Text fontSize="md" fontWeight="semibold" userSelect="none">
|
||||
{t('boards.private')}
|
||||
</Text>
|
||||
<AddBoardButton isPrivateBoard={true} />
|
||||
</Flex>
|
||||
<Flex direction="column" gap={1}>
|
||||
<NoBoardBoard isSelected={selectedBoardId === 'none'} />
|
||||
{filteredPrivateBoards.map((board) => (
|
||||
<GalleryBoard
|
||||
board={board}
|
||||
isSelected={selectedBoardId === board.board_id}
|
||||
setBoardToDelete={setBoardToDelete}
|
||||
key={board.board_id}
|
||||
{allowPrivateBoards ? (
|
||||
<Button variant="unstyled" onClick={onToggle}>
|
||||
<Flex gap="2" alignItems="center">
|
||||
<Icon
|
||||
boxSize={4}
|
||||
as={PiCaretDownBold}
|
||||
transform={isOpen ? undefined : 'rotate(-90deg)'}
|
||||
fill="base.500"
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
<Text fontSize="sm" fontWeight="semibold" userSelect="none" color="base.500">
|
||||
{boardListTitle}
|
||||
</Text>
|
||||
</Flex>
|
||||
</Button>
|
||||
) : (
|
||||
<Text fontSize="sm" fontWeight="semibold" userSelect="none" color="base.500">
|
||||
{boardListTitle}
|
||||
</Text>
|
||||
)}
|
||||
<AddBoardButton isPrivateBoard={isPrivate} />
|
||||
</Flex>
|
||||
<Collapse in={isOpen}>
|
||||
<Flex direction="column" gap={1}>
|
||||
<Flex
|
||||
position="sticky"
|
||||
w="full"
|
||||
justifyContent="space-between"
|
||||
alignItems="center"
|
||||
ps={2}
|
||||
pb={1}
|
||||
pt={2}
|
||||
zIndex={1}
|
||||
top={0}
|
||||
bg="base.900"
|
||||
>
|
||||
<Text fontSize="md" fontWeight="semibold" userSelect="none">
|
||||
{allowPrivateBoards ? t('boards.shared') : t('boards.boards')}
|
||||
{boardElements.length ? (
|
||||
boardElements
|
||||
) : (
|
||||
<Text variant="subtext" textAlign="center">
|
||||
{t('boards.noBoards', { boardType: boardSearchText.length ? 'Matching' : '' })}
|
||||
</Text>
|
||||
<AddBoardButton isPrivateBoard={false} />
|
||||
)}
|
||||
</Flex>
|
||||
<Flex direction="column" gap={1}>
|
||||
{!allowPrivateBoards && <NoBoardBoard isSelected={selectedBoardId === 'none'} />}
|
||||
{filteredSharedBoards.map((board) => (
|
||||
<GalleryBoard
|
||||
board={board}
|
||||
isSelected={selectedBoardId === board.board_id}
|
||||
setBoardToDelete={setBoardToDelete}
|
||||
key={board.board_id}
|
||||
/>
|
||||
))}
|
||||
</Collapse>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</OverlayScrollbarsComponent>
|
||||
</Box>
|
||||
</Box>
|
||||
<DeleteBoardModal boardToDelete={boardToDelete} setBoardToDelete={setBoardToDelete} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
export default memo(BoardsList);
|
||||
|
@ -0,0 +1,35 @@
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
|
||||
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useState } from 'react';
|
||||
import type { BoardDTO } from 'services/api/types';
|
||||
|
||||
import { BoardsList } from './BoardsList';
|
||||
|
||||
const overlayScrollbarsStyles: CSSProperties = {
|
||||
height: '100%',
|
||||
width: '100%',
|
||||
};
|
||||
|
||||
const BoardsListWrapper = () => {
|
||||
const allowPrivateBoards = useAppSelector((s) => s.config.allowPrivateBoards);
|
||||
const [boardToDelete, setBoardToDelete] = useState<BoardDTO>();
|
||||
|
||||
return (
|
||||
<>
|
||||
<Box position="relative" w="full" h="full">
|
||||
<Box position="absolute" top={0} right={0} bottom={0} left={0}>
|
||||
<OverlayScrollbarsComponent defer style={overlayScrollbarsStyles} options={overlayScrollbarsParams.options}>
|
||||
{allowPrivateBoards && <BoardsList isPrivate={true} setBoardToDelete={setBoardToDelete} />}
|
||||
<BoardsList isPrivate={false} setBoardToDelete={setBoardToDelete} />
|
||||
</OverlayScrollbarsComponent>
|
||||
</Box>
|
||||
</Box>
|
||||
<DeleteBoardModal boardToDelete={boardToDelete} setBoardToDelete={setBoardToDelete} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
export default memo(BoardsListWrapper);
|
@ -40,7 +40,7 @@ const BoardsSearch = () => {
|
||||
);
|
||||
|
||||
return (
|
||||
<InputGroup pt={2}>
|
||||
<InputGroup>
|
||||
<Input
|
||||
placeholder={t('boards.searchBoard')}
|
||||
value={boardSearchText}
|
||||
|
@ -17,7 +17,7 @@ import IAIDroppable from 'common/components/IAIDroppable';
|
||||
import type { AddToBoardDropData } from 'features/dnd/types';
|
||||
import { AutoAddBadge } from 'features/gallery/components/Boards/AutoAddBadge';
|
||||
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
|
||||
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
|
||||
import { BoardTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTooltip';
|
||||
import { autoAddBoardIdChanged, boardIdSelected } from 'features/gallery/store/gallerySlice';
|
||||
import type { MouseEvent, MouseEventHandler, MutableRefObject } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
@ -115,12 +115,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
return (
|
||||
<BoardContextMenu board={board} setBoardToDelete={setBoardToDelete}>
|
||||
{(ref) => (
|
||||
<Tooltip
|
||||
label={<BoardTotalsTooltip board_id={board.board_id} isArchived={Boolean(board.archived)} />}
|
||||
openDelay={1000}
|
||||
placement="left"
|
||||
closeOnScroll
|
||||
>
|
||||
<Tooltip label={<BoardTooltip board={board} />} openDelay={1000} placement="left" closeOnScroll p={2}>
|
||||
<Flex
|
||||
position="relative"
|
||||
ref={ref}
|
||||
@ -131,10 +126,12 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
borderRadius="base"
|
||||
cursor="pointer"
|
||||
py={1}
|
||||
px={2}
|
||||
gap={2}
|
||||
ps={1}
|
||||
pe={4}
|
||||
gap={4}
|
||||
bg={isSelected ? 'base.850' : undefined}
|
||||
_hover={_hover}
|
||||
h={12}
|
||||
>
|
||||
<CoverImage board={board} />
|
||||
<Editable
|
||||
@ -149,17 +146,17 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
onChange={onChange}
|
||||
onSubmit={onSubmit}
|
||||
isPreviewFocusable={false}
|
||||
fontSize="sm"
|
||||
>
|
||||
<EditablePreview
|
||||
cursor="pointer"
|
||||
p={0}
|
||||
fontSize="md"
|
||||
fontSize="sm"
|
||||
textOverflow="ellipsis"
|
||||
noOfLines={1}
|
||||
w="fit-content"
|
||||
wordBreak="break-all"
|
||||
color={isSelected ? 'base.100' : 'base.400'}
|
||||
fontWeight={isSelected ? 'semibold' : 'normal'}
|
||||
fontWeight={isSelected ? 'bold' : 'normal'}
|
||||
/>
|
||||
<EditableInput sx={editableInputStyles} />
|
||||
<JankEditableHijack onStartEditingRef={onStartEditingRef} />
|
||||
@ -168,7 +165,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
{board.archived && !editingDisclosure.isOpen && <Icon as={PiArchiveBold} fill="base.300" />}
|
||||
{!editingDisclosure.isOpen && <Text variant="subtext">{board.image_count}</Text>}
|
||||
|
||||
<IAIDroppable data={droppableData} dropLabel={<Text fontSize="md">{t('unifiedCanvas.move')}</Text>} />
|
||||
<IAIDroppable data={droppableData} dropLabel={<Text fontSize="lg">{t('unifiedCanvas.move')}</Text>} />
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
)}
|
||||
@ -197,8 +194,8 @@ const CoverImage = ({ board }: { board: BoardDTO }) => {
|
||||
src={coverImage.thumbnail_url}
|
||||
draggable={false}
|
||||
objectFit="cover"
|
||||
w={8}
|
||||
h={8}
|
||||
w={10}
|
||||
h={10}
|
||||
borderRadius="base"
|
||||
borderBottomRadius="lg"
|
||||
/>
|
||||
@ -206,8 +203,8 @@ const CoverImage = ({ board }: { board: BoardDTO }) => {
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex w={8} h={8} justifyContent="center" alignItems="center">
|
||||
<Icon boxSize={8} as={PiImageSquare} opacity={0.7} color="base.500" />
|
||||
<Flex w={10} h={10} justifyContent="center" alignItems="center">
|
||||
<Icon boxSize={10} as={PiImageSquare} opacity={0.7} color="base.500" />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -4,7 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIDroppable from 'common/components/IAIDroppable';
|
||||
import type { RemoveFromBoardDropData } from 'features/dnd/types';
|
||||
import { AutoAddBadge } from 'features/gallery/components/Boards/AutoAddBadge';
|
||||
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
|
||||
import { BoardTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTooltip';
|
||||
import NoBoardBoardContextMenu from 'features/gallery/components/Boards/NoBoardBoardContextMenu';
|
||||
import { autoAddBoardIdChanged, boardIdSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@ -46,25 +46,16 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
||||
[]
|
||||
);
|
||||
|
||||
const filteredOut = useMemo(() => {
|
||||
return boardSearchText ? !boardName.toLowerCase().includes(boardSearchText.toLowerCase()) : false;
|
||||
}, [boardName, boardSearchText]);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (filteredOut) {
|
||||
if (boardSearchText.length) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<NoBoardBoardContextMenu>
|
||||
{(ref) => (
|
||||
<Tooltip
|
||||
label={<BoardTotalsTooltip board_id="none" isArchived={false} />}
|
||||
openDelay={1000}
|
||||
placement="left"
|
||||
closeOnScroll
|
||||
>
|
||||
<Tooltip label={<BoardTooltip board={null} />} openDelay={1000} placement="left" closeOnScroll>
|
||||
<Flex
|
||||
position="relative"
|
||||
ref={ref}
|
||||
@ -73,15 +64,17 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
||||
alignItems="center"
|
||||
borderRadius="base"
|
||||
cursor="pointer"
|
||||
px={2}
|
||||
py={1}
|
||||
gap={2}
|
||||
ps={1}
|
||||
pe={4}
|
||||
gap={4}
|
||||
bg={isSelected ? 'base.850' : undefined}
|
||||
_hover={_hover}
|
||||
h={12}
|
||||
>
|
||||
<Flex w={8} h={8} justifyContent="center" alignItems="center">
|
||||
<Flex w="10" justifyContent="space-around">
|
||||
{/* iconified from public/assets/images/invoke-symbol-wht-lrg.svg */}
|
||||
<Icon boxSize={6} opacity={1} stroke="base.500" viewBox="0 0 66 66" fill="none">
|
||||
<Icon boxSize={8} opacity={1} stroke="base.500" viewBox="0 0 66 66" fill="none">
|
||||
<path
|
||||
d="M43.9137 16H63.1211V3H3.12109V16H22.3285L43.9137 50H63.1211V63H3.12109V50H22.3285"
|
||||
strokeWidth="5"
|
||||
@ -89,18 +82,12 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
||||
</Icon>
|
||||
</Flex>
|
||||
|
||||
<Text
|
||||
fontSize="md"
|
||||
color={isSelected ? 'base.100' : 'base.400'}
|
||||
fontWeight={isSelected ? 'semibold' : 'normal'}
|
||||
noOfLines={1}
|
||||
flexGrow={1}
|
||||
>
|
||||
<Text fontSize="sm" fontWeight={isSelected ? 'bold' : 'normal'} noOfLines={1} flexGrow={1}>
|
||||
{boardName}
|
||||
</Text>
|
||||
{autoAddBoardId === 'none' && <AutoAddBadge />}
|
||||
<Text variant="subtext">{imagesTotal}</Text>
|
||||
<IAIDroppable data={droppableData} dropLabel={<Text fontSize="md">{t('unifiedCanvas.move')}</Text>} />
|
||||
<IAIDroppable data={droppableData} dropLabel={<Text fontSize="lg">{t('unifiedCanvas.move')}</Text>} />
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
@ -0,0 +1,105 @@
|
||||
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Box,
|
||||
Collapse,
|
||||
Flex,
|
||||
IconButton,
|
||||
Spacer,
|
||||
Tab,
|
||||
TabList,
|
||||
Tabs,
|
||||
Text,
|
||||
useDisclosure,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGallerySearchTerm } from 'features/gallery/components/ImageGrid/useGallerySearchTerm';
|
||||
import { galleryViewChanged } from 'features/gallery/store/gallerySlice';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiMagnifyingGlassBold } from 'react-icons/pi';
|
||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
|
||||
import GalleryImageGrid from './ImageGrid/GalleryImageGrid';
|
||||
import { GalleryPagination } from './ImageGrid/GalleryPagination';
|
||||
import { GallerySearch } from './ImageGrid/GallerySearch';
|
||||
|
||||
const BASE_STYLES: ChakraProps['sx'] = {
|
||||
fontWeight: 'semibold',
|
||||
fontSize: 'sm',
|
||||
color: 'base.300',
|
||||
};
|
||||
|
||||
const SELECTED_STYLES: ChakraProps['sx'] = {
|
||||
borderColor: 'base.800',
|
||||
borderBottomColor: 'base.900',
|
||||
color: 'invokeBlue.300',
|
||||
};
|
||||
|
||||
const COLLAPSE_STYLES: CSSProperties = { flexShrink: 0, minHeight: 0 };
|
||||
|
||||
export const Gallery = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const galleryView = useAppSelector((s) => s.gallery.galleryView);
|
||||
const initialSearchTerm = useAppSelector((s) => s.gallery.searchTerm);
|
||||
const searchDisclosure = useDisclosure({ defaultIsOpen: initialSearchTerm.length > 0 });
|
||||
const [searchTerm, onChangeSearchTerm, onResetSearchTerm] = useGallerySearchTerm();
|
||||
|
||||
const handleClickImages = useCallback(() => {
|
||||
dispatch(galleryViewChanged('images'));
|
||||
}, [dispatch]);
|
||||
|
||||
const handleClickAssets = useCallback(() => {
|
||||
dispatch(galleryViewChanged('assets'));
|
||||
}, [dispatch]);
|
||||
|
||||
const handleClickSearch = useCallback(() => {
|
||||
searchDisclosure.onToggle();
|
||||
onResetSearchTerm();
|
||||
}, [onResetSearchTerm, searchDisclosure]);
|
||||
|
||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||
const boardName = useBoardName(selectedBoardId);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" alignItems="center" justifyContent="space-between" h="full" w="full" pt={1}>
|
||||
<Tabs index={galleryView === 'images' ? 0 : 1} variant="enclosed" display="flex" flexDir="column" w="full">
|
||||
<TabList gap={2} fontSize="sm" borderColor="base.800" alignItems="center" w="full">
|
||||
<Text fontSize="sm" fontWeight="semibold" noOfLines={1} px="2">
|
||||
{boardName}
|
||||
</Text>
|
||||
<Spacer />
|
||||
<Tab sx={BASE_STYLES} _selected={SELECTED_STYLES} onClick={handleClickImages} data-testid="images-tab">
|
||||
{t('parameters.images')}
|
||||
</Tab>
|
||||
<Tab sx={BASE_STYLES} _selected={SELECTED_STYLES} onClick={handleClickAssets} data-testid="assets-tab">
|
||||
{t('gallery.assets')}
|
||||
</Tab>
|
||||
<IconButton
|
||||
onClick={handleClickSearch}
|
||||
tooltip={searchDisclosure.isOpen ? `${t('gallery.exitSearch')}` : `${t('gallery.displaySearch')}`}
|
||||
aria-label={t('gallery.displaySearch')}
|
||||
icon={<PiMagnifyingGlassBold />}
|
||||
colorScheme={searchDisclosure.isOpen ? 'invokeBlue' : 'base'}
|
||||
variant="link"
|
||||
/>
|
||||
</TabList>
|
||||
</Tabs>
|
||||
|
||||
<Box w="full">
|
||||
<Collapse in={searchDisclosure.isOpen} style={COLLAPSE_STYLES}>
|
||||
<Box w="full" pt={2}>
|
||||
<GallerySearch
|
||||
searchTerm={searchTerm}
|
||||
onChangeSearchTerm={onChangeSearchTerm}
|
||||
onResetSearchTerm={onResetSearchTerm}
|
||||
/>
|
||||
</Box>
|
||||
</Collapse>
|
||||
</Box>
|
||||
<GalleryImageGrid />
|
||||
<GalleryPagination />
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -1,33 +0,0 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo } from 'react';
|
||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
|
||||
type Props = {
|
||||
onClick: () => void;
|
||||
};
|
||||
|
||||
const GalleryBoardName = (props: Props) => {
|
||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||
const boardName = useBoardName(selectedBoardId);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
onClick={props.onClick}
|
||||
as="button"
|
||||
h="full"
|
||||
w="full"
|
||||
borderWidth={1}
|
||||
borderRadius="base"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
px={2}
|
||||
>
|
||||
<Text fontWeight="semibold" fontSize="md" noOfLines={1} wordBreak="break-all" color="base.200">
|
||||
{boardName}
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(GalleryBoardName);
|
@ -3,32 +3,21 @@ import { useStore } from '@nanostores/react';
|
||||
import { $projectName, $projectUrl } from 'app/store/nanostores/projectId';
|
||||
import { memo } from 'react';
|
||||
|
||||
import GalleryBoardName from './GalleryBoardName';
|
||||
|
||||
type Props = {
|
||||
onClickBoardName: () => void;
|
||||
};
|
||||
|
||||
export const GalleryHeader = memo((props: Props) => {
|
||||
export const GalleryHeader = memo(() => {
|
||||
const projectName = useStore($projectName);
|
||||
const projectUrl = useStore($projectUrl);
|
||||
|
||||
if (projectName && projectUrl) {
|
||||
return (
|
||||
<Flex gap={2} w="full" alignItems="center" justifyContent="space-evenly" pe={2}>
|
||||
<Text fontSize="md" fontWeight="semibold" noOfLines={1} w="full" textAlign="center">
|
||||
<Text fontSize="md" fontWeight="semibold" noOfLines={1} wordBreak="break-all" w="full" textAlign="center">
|
||||
<Link href={projectUrl}>{projectName}</Link>
|
||||
</Text>
|
||||
<GalleryBoardName onClick={props.onClickBoardName} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex w="full" pe={2}>
|
||||
<GalleryBoardName onClick={props.onClickBoardName} />
|
||||
</Flex>
|
||||
);
|
||||
return null;
|
||||
});
|
||||
|
||||
GalleryHeader.displayName = 'GalleryHeader';
|
||||
|
@ -13,6 +13,7 @@ import { sentImageToCanvas, sentImageToImg2Img } from 'features/gallery/store/ac
|
||||
import { imageToCompareChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
@ -124,6 +125,11 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
dispatch(imageToCompareChanged(imageDTO));
|
||||
}, [dispatch, imageDTO]);
|
||||
|
||||
const handleSendToUpscale = useCallback(() => {
|
||||
dispatch(upscaleInitialImageChanged(imageDTO));
|
||||
dispatch(setActiveTab('upscaling'));
|
||||
}, [dispatch, imageDTO]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<MenuItem as="a" href={imageDTO.image_url} target="_blank" icon={<PiShareFatBold />}>
|
||||
@ -185,6 +191,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
{t('parameters.sendToUnifiedCanvas')}
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem icon={<PiShareFatBold />} onClickCapture={handleSendToUpscale} id="send-to-upscale">
|
||||
{t('parameters.sendToUpscale')}
|
||||
</MenuItem>
|
||||
<MenuDivider />
|
||||
<MenuItem icon={<PiFoldersBold />} onClickCapture={handleChangeBoard}>
|
||||
{t('boards.changeBoard')}
|
||||
|
@ -1,57 +1,28 @@
|
||||
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Box,
|
||||
Collapse,
|
||||
Divider,
|
||||
Flex,
|
||||
IconButton,
|
||||
Spacer,
|
||||
Tab,
|
||||
TabList,
|
||||
Tabs,
|
||||
useDisclosure,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { Box, Button, Collapse, Divider, Flex, IconButton, useDisclosure } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { GalleryHeader } from 'features/gallery/components/GalleryHeader';
|
||||
import { galleryViewChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { boardSearchTextChanged } from 'features/gallery/store/gallerySlice';
|
||||
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
|
||||
import { usePanel, type UsePanelOptions } from 'features/ui/hooks/usePanel';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiMagnifyingGlassBold } from 'react-icons/pi';
|
||||
import { PiCaretDownBold, PiCaretUpBold, PiMagnifyingGlassBold } from 'react-icons/pi';
|
||||
import type { ImperativePanelGroupHandle } from 'react-resizable-panels';
|
||||
import { Panel, PanelGroup } from 'react-resizable-panels';
|
||||
|
||||
import BoardsList from './Boards/BoardsList/BoardsList';
|
||||
import BoardsListWrapper from './Boards/BoardsList/BoardsListWrapper';
|
||||
import BoardsSearch from './Boards/BoardsList/BoardsSearch';
|
||||
import { Gallery } from './Gallery';
|
||||
import GallerySettingsPopover from './GallerySettingsPopover/GallerySettingsPopover';
|
||||
import GalleryImageGrid from './ImageGrid/GalleryImageGrid';
|
||||
import { GalleryPagination } from './ImageGrid/GalleryPagination';
|
||||
import { GallerySearch } from './ImageGrid/GallerySearch';
|
||||
|
||||
const COLLAPSE_STYLES: CSSProperties = { flexShrink: 0, minHeight: 0 };
|
||||
|
||||
const BASE_STYLES: ChakraProps['sx'] = {
|
||||
fontWeight: 'semibold',
|
||||
fontSize: 'sm',
|
||||
color: 'base.300',
|
||||
};
|
||||
|
||||
const SELECTED_STYLES: ChakraProps['sx'] = {
|
||||
borderColor: 'base.800',
|
||||
borderBottomColor: 'base.900',
|
||||
color: 'invokeBlue.300',
|
||||
};
|
||||
|
||||
const ImageGalleryContent = () => {
|
||||
const { t } = useTranslation();
|
||||
const galleryView = useAppSelector((s) => s.gallery.galleryView);
|
||||
const searchTerm = useAppSelector((s) => s.gallery.searchTerm);
|
||||
const boardSearchText = useAppSelector((s) => s.gallery.boardSearchText);
|
||||
const dispatch = useAppDispatch();
|
||||
const searchDisclosure = useDisclosure({ defaultIsOpen: false });
|
||||
const boardSearchDisclosure = useDisclosure({ defaultIsOpen: false });
|
||||
const boardSearchDisclosure = useDisclosure({ defaultIsOpen: !!boardSearchText.length });
|
||||
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
|
||||
|
||||
const boardsListPanelOptions = useMemo<UsePanelOptions>(
|
||||
@ -67,42 +38,58 @@ const ImageGalleryContent = () => {
|
||||
);
|
||||
const boardsListPanel = usePanel(boardsListPanelOptions);
|
||||
|
||||
const handleClickImages = useCallback(() => {
|
||||
dispatch(galleryViewChanged('images'));
|
||||
}, [dispatch]);
|
||||
const handleClickBoardSearch = useCallback(() => {
|
||||
if (boardSearchText.length) {
|
||||
dispatch(boardSearchTextChanged(''));
|
||||
}
|
||||
boardSearchDisclosure.onToggle();
|
||||
boardsListPanel.expand();
|
||||
}, [boardSearchText.length, boardSearchDisclosure, boardsListPanel, dispatch]);
|
||||
|
||||
const handleClickAssets = useCallback(() => {
|
||||
dispatch(galleryViewChanged('assets'));
|
||||
}, [dispatch]);
|
||||
const handleToggleBoardPanel = useCallback(() => {
|
||||
if (boardSearchText.length) {
|
||||
dispatch(boardSearchTextChanged(''));
|
||||
}
|
||||
boardSearchDisclosure.onClose();
|
||||
boardsListPanel.toggle();
|
||||
}, [boardSearchText.length, boardSearchDisclosure, boardsListPanel, dispatch]);
|
||||
|
||||
return (
|
||||
<Flex position="relative" flexDirection="column" h="full" w="full" pt={2}>
|
||||
<Flex alignItems="center" gap={2}>
|
||||
<GalleryHeader onClickBoardName={boardsListPanel.toggle} />
|
||||
<Flex alignItems="center" gap={0}>
|
||||
<GalleryHeader />
|
||||
<Flex alignItems="center" justifyContent="space-between" w="full">
|
||||
<Button
|
||||
w={112}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
onClick={handleToggleBoardPanel}
|
||||
rightIcon={boardsListPanel.isCollapsed ? <PiCaretDownBold /> : <PiCaretUpBold />}
|
||||
>
|
||||
{boardsListPanel.isCollapsed ? t('boards.viewBoards') : t('boards.hideBoards')}
|
||||
</Button>
|
||||
<Flex alignItems="center" justifyContent="space-between">
|
||||
<GallerySettingsPopover />
|
||||
<Box position="relative" h="full">
|
||||
<Flex>
|
||||
<IconButton
|
||||
w="full"
|
||||
h="full"
|
||||
onClick={boardSearchDisclosure.onToggle}
|
||||
tooltip={`${t('gallery.displayBoardSearch')}`}
|
||||
onClick={handleClickBoardSearch}
|
||||
tooltip={
|
||||
boardSearchDisclosure.isOpen
|
||||
? `${t('gallery.exitBoardSearch')}`
|
||||
: `${t('gallery.displayBoardSearch')}`
|
||||
}
|
||||
aria-label={t('gallery.displayBoardSearch')}
|
||||
icon={<PiMagnifyingGlassBold />}
|
||||
colorScheme={boardSearchDisclosure.isOpen ? 'invokeBlue' : 'base'}
|
||||
variant="link"
|
||||
/>
|
||||
{boardSearchText && (
|
||||
<Box
|
||||
position="absolute"
|
||||
w={2}
|
||||
h={2}
|
||||
bg="invokeBlue.300"
|
||||
borderRadius="full"
|
||||
insetBlockStart={2}
|
||||
insetInlineEnd={2}
|
||||
/>
|
||||
)}
|
||||
</Box>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
<PanelGroup ref={panelGroupRef} direction="vertical">
|
||||
<Panel
|
||||
id="boards-list-panel"
|
||||
@ -115,10 +102,12 @@ const ImageGalleryContent = () => {
|
||||
>
|
||||
<Flex flexDir="column" w="full" h="full">
|
||||
<Collapse in={boardSearchDisclosure.isOpen} style={COLLAPSE_STYLES}>
|
||||
<Box w="full" pt={2}>
|
||||
<BoardsSearch />
|
||||
</Box>
|
||||
</Collapse>
|
||||
<Divider pt={2} />
|
||||
<BoardsList />
|
||||
<BoardsListWrapper />
|
||||
</Flex>
|
||||
</Panel>
|
||||
<ResizeHandle
|
||||
@ -127,50 +116,7 @@ const ImageGalleryContent = () => {
|
||||
onDoubleClick={boardsListPanel.onDoubleClickHandle}
|
||||
/>
|
||||
<Panel id="gallery-wrapper-panel" minSize={20}>
|
||||
<Flex flexDirection="column" alignItems="center" justifyContent="space-between" h="full" w="full">
|
||||
<Tabs index={galleryView === 'images' ? 0 : 1} variant="enclosed" display="flex" flexDir="column" w="full">
|
||||
<TabList gap={2} fontSize="sm" borderColor="base.800">
|
||||
<Tab sx={BASE_STYLES} _selected={SELECTED_STYLES} onClick={handleClickImages} data-testid="images-tab">
|
||||
{t('parameters.images')}
|
||||
</Tab>
|
||||
<Tab sx={BASE_STYLES} _selected={SELECTED_STYLES} onClick={handleClickAssets} data-testid="assets-tab">
|
||||
{t('gallery.assets')}
|
||||
</Tab>
|
||||
<Spacer />
|
||||
<Box position="relative">
|
||||
<IconButton
|
||||
w="full"
|
||||
h="full"
|
||||
onClick={searchDisclosure.onToggle}
|
||||
tooltip={`${t('gallery.displaySearch')}`}
|
||||
aria-label={t('gallery.displaySearch')}
|
||||
icon={<PiMagnifyingGlassBold />}
|
||||
variant="link"
|
||||
/>
|
||||
{searchTerm && (
|
||||
<Box
|
||||
position="absolute"
|
||||
w={2}
|
||||
h={2}
|
||||
bg="invokeBlue.300"
|
||||
borderRadius="full"
|
||||
insetBlockStart={2}
|
||||
insetInlineEnd={2}
|
||||
/>
|
||||
)}
|
||||
</Box>
|
||||
</TabList>
|
||||
</Tabs>
|
||||
<Box w="full">
|
||||
<Collapse in={searchDisclosure.isOpen} style={COLLAPSE_STYLES}>
|
||||
<Box w="full" pt={2}>
|
||||
<GallerySearch />
|
||||
</Box>
|
||||
</Collapse>
|
||||
</Box>
|
||||
<GalleryImageGrid />
|
||||
<GalleryPagination />
|
||||
</Flex>
|
||||
<Gallery />
|
||||
</Panel>
|
||||
</PanelGroup>
|
||||
</Flex>
|
||||
|
@ -3,6 +3,8 @@ import { ELLIPSIS, useGalleryPagination } from 'features/gallery/hooks/useGaller
|
||||
import { useCallback } from 'react';
|
||||
import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
|
||||
|
||||
import { JumpTo } from './JumpTo';
|
||||
|
||||
export const GalleryPagination = () => {
|
||||
const { goPrev, goNext, isPrevEnabled, isNextEnabled, pageButtons, goToPage, currentPage, total } =
|
||||
useGalleryPagination();
|
||||
@ -20,7 +22,7 @@ export const GalleryPagination = () => {
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
<Flex justifyContent="center" alignItems="center" w="full" gap={1} pt={2}>
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label="prev"
|
||||
@ -30,25 +32,9 @@ export const GalleryPagination = () => {
|
||||
variant="ghost"
|
||||
/>
|
||||
<Spacer />
|
||||
{pageButtons.map((page, i) => {
|
||||
if (page === ELLIPSIS) {
|
||||
return (
|
||||
<Button size="sm" key={`ellipsis_${i}`} variant="link" isDisabled>
|
||||
...
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<Button
|
||||
size="sm"
|
||||
key={page}
|
||||
onClick={goToPage.bind(null, page - 1)}
|
||||
variant={currentPage === page - 1 ? 'solid' : 'outline'}
|
||||
>
|
||||
{page}
|
||||
</Button>
|
||||
);
|
||||
})}
|
||||
{pageButtons.map((page, i) => (
|
||||
<PageButton key={`${page}_${i}`} page={page} currentPage={currentPage} goToPage={goToPage} />
|
||||
))}
|
||||
<Spacer />
|
||||
<IconButton
|
||||
size="sm"
|
||||
@ -58,6 +44,28 @@ export const GalleryPagination = () => {
|
||||
isDisabled={!isNextEnabled}
|
||||
variant="ghost"
|
||||
/>
|
||||
<JumpTo />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
type PageButtonProps = {
|
||||
page: number | typeof ELLIPSIS;
|
||||
currentPage: number;
|
||||
goToPage: (page: number) => void;
|
||||
};
|
||||
|
||||
const PageButton = ({ page, currentPage, goToPage }: PageButtonProps) => {
|
||||
if (page === ELLIPSIS) {
|
||||
return (
|
||||
<Button size="sm" variant="link" isDisabled>
|
||||
...
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<Button size="sm" onClick={goToPage.bind(null, page - 1)} variant={currentPage === page - 1 ? 'solid' : 'outline'}>
|
||||
{page}
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
|
@ -1,59 +1,60 @@
|
||||
import { IconButton, Input, InputGroup, InputRightElement, Spinner } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { searchTermChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { debounce } from 'lodash-es';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import type { ChangeEvent, KeyboardEvent } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { useListImagesQuery } from 'services/api/endpoints/images';
|
||||
|
||||
export const GallerySearch = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const searchTerm = useAppSelector((s) => s.gallery.searchTerm);
|
||||
type Props = {
|
||||
searchTerm: string;
|
||||
onChangeSearchTerm: (value: string) => void;
|
||||
onResetSearchTerm: () => void;
|
||||
};
|
||||
|
||||
export const GallerySearch = ({ searchTerm, onChangeSearchTerm, onResetSearchTerm }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const [searchTermInput, setSearchTermInput] = useState(searchTerm);
|
||||
const queryArgs = useAppSelector(selectListImagesQueryArgs);
|
||||
const { isPending } = useListImagesQuery(queryArgs, {
|
||||
selectFromResult: ({ isLoading, isFetching }) => ({ isPending: isLoading || isFetching }),
|
||||
});
|
||||
const debouncedSetSearchTerm = useMemo(() => {
|
||||
return debounce((value: string) => {
|
||||
dispatch(searchTermChanged(value));
|
||||
}, 1000);
|
||||
}, [dispatch]);
|
||||
|
||||
const handleChangeInput = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
setSearchTermInput(e.target.value);
|
||||
debouncedSetSearchTerm(e.target.value);
|
||||
onChangeSearchTerm(e.target.value);
|
||||
},
|
||||
[debouncedSetSearchTerm]
|
||||
[onChangeSearchTerm]
|
||||
);
|
||||
|
||||
const handleClearInput = useCallback(() => {
|
||||
setSearchTermInput('');
|
||||
dispatch(searchTermChanged(''));
|
||||
}, [dispatch]);
|
||||
const handleKeydown = useCallback(
|
||||
(e: KeyboardEvent<HTMLInputElement>) => {
|
||||
// exit search mode on escape
|
||||
if (e.key === 'Escape') {
|
||||
onResetSearchTerm();
|
||||
}
|
||||
},
|
||||
[onResetSearchTerm]
|
||||
);
|
||||
|
||||
return (
|
||||
<InputGroup>
|
||||
<Input
|
||||
placeholder={t('gallery.searchImages')}
|
||||
value={searchTermInput}
|
||||
value={searchTerm}
|
||||
onChange={handleChangeInput}
|
||||
data-testid="image-search-input"
|
||||
onKeyDown={handleKeydown}
|
||||
/>
|
||||
{isPending && (
|
||||
<InputRightElement h="full" pe={2}>
|
||||
<Spinner size="sm" opacity={0.5} />
|
||||
</InputRightElement>
|
||||
)}
|
||||
{!isPending && searchTermInput.length && (
|
||||
{!isPending && searchTerm.length && (
|
||||
<InputRightElement h="full" pe={2}>
|
||||
<IconButton
|
||||
onClick={handleClearInput}
|
||||
onClick={onResetSearchTerm}
|
||||
size="sm"
|
||||
variant="link"
|
||||
aria-label={t('boards.clearSearch')}
|
||||
|
@ -0,0 +1,97 @@
|
||||
import {
|
||||
Button,
|
||||
CompositeNumberInput,
|
||||
Flex,
|
||||
FormControl,
|
||||
Popover,
|
||||
PopoverArrow,
|
||||
PopoverBody,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
useDisclosure,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const JumpTo = () => {
|
||||
const { t } = useTranslation();
|
||||
const { goToPage, currentPage, pages } = useGalleryPagination();
|
||||
const [newPage, setNewPage] = useState(currentPage);
|
||||
const { isOpen, onToggle, onClose } = useDisclosure();
|
||||
const ref = useRef<HTMLInputElement>(null);
|
||||
|
||||
const onOpen = useCallback(() => {
|
||||
setNewPage(currentPage);
|
||||
setTimeout(() => {
|
||||
const input = ref.current?.querySelector('input');
|
||||
input?.focus();
|
||||
input?.select();
|
||||
}, 0);
|
||||
}, [currentPage]);
|
||||
|
||||
const onChangeJumpTo = useCallback((v: number) => {
|
||||
setNewPage(v - 1);
|
||||
}, []);
|
||||
|
||||
const onClickGo = useCallback(() => {
|
||||
goToPage(newPage);
|
||||
onClose();
|
||||
}, [newPage, goToPage, onClose]);
|
||||
|
||||
useHotkeys(
|
||||
'enter',
|
||||
() => {
|
||||
onClickGo();
|
||||
},
|
||||
{ enabled: isOpen, enableOnFormTags: ['input'] },
|
||||
[isOpen, onClickGo]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
'esc',
|
||||
() => {
|
||||
setNewPage(currentPage);
|
||||
onClose();
|
||||
},
|
||||
{ enabled: isOpen, enableOnFormTags: ['input'] },
|
||||
[isOpen, onClose]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
setNewPage(currentPage);
|
||||
}, [currentPage]);
|
||||
|
||||
return (
|
||||
<Popover isOpen={isOpen} onClose={onClose} onOpen={onOpen}>
|
||||
<PopoverTrigger>
|
||||
<Button aria-label={t('gallery.jump')} size="sm" onClick={onToggle} variant="outline">
|
||||
{t('gallery.jump')}
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent>
|
||||
<PopoverArrow />
|
||||
<PopoverBody>
|
||||
<Flex gap={2} alignItems="center">
|
||||
<FormControl>
|
||||
<CompositeNumberInput
|
||||
ref={ref}
|
||||
size="sm"
|
||||
maxW="60px"
|
||||
value={newPage + 1}
|
||||
min={1}
|
||||
max={pages}
|
||||
step={1}
|
||||
onChange={onChangeJumpTo}
|
||||
/>
|
||||
</FormControl>
|
||||
<Button h="full" size="sm" onClick={onClickGo}>
|
||||
{t('gallery.go')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
};
|
@ -0,0 +1,37 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import { searchTermChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { debounce } from 'lodash-es';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
|
||||
export const useGallerySearchTerm = () => {
|
||||
// Highlander!
|
||||
useAssertSingleton('gallery-search-state');
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const searchTerm = useAppSelector((s) => s.gallery.searchTerm);
|
||||
|
||||
const [localSearchTerm, setLocalSearchTerm] = useState(searchTerm);
|
||||
|
||||
const debouncedSetSearchTerm = useMemo(() => {
|
||||
return debounce((val: string) => {
|
||||
dispatch(searchTermChanged(val));
|
||||
}, 1000);
|
||||
}, [dispatch]);
|
||||
|
||||
const onChange = useCallback(
|
||||
(val: string) => {
|
||||
setLocalSearchTerm(val);
|
||||
debouncedSetSearchTerm(val);
|
||||
},
|
||||
[debouncedSetSearchTerm]
|
||||
);
|
||||
|
||||
const onReset = useCallback(() => {
|
||||
debouncedSetSearchTerm.cancel();
|
||||
setLocalSearchTerm('');
|
||||
dispatch(searchTermChanged(''));
|
||||
}, [debouncedSetSearchTerm, dispatch]);
|
||||
|
||||
return [localSearchTerm, onChange, onReset] as const;
|
||||
};
|
@ -2,7 +2,7 @@ import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
|
||||
import { adHocPostProcessingRequested } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { iiLayerAdded } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
|
||||
@ -14,7 +14,7 @@ import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors
|
||||
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import { parseAndRecallImageDimensions } from 'features/metadata/util/handlers';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import ParamUpscalePopover from 'features/parameters/components/Upscale/ParamUpscaleSettings';
|
||||
import { PostProcessingPopover } from 'features/parameters/components/PostProcessing/PostProcessingPopover';
|
||||
import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { selectSystemSlice } from 'features/system/store/systemSlice';
|
||||
@ -97,7 +97,7 @@ const CurrentImageButtons = () => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
dispatch(upscaleRequested({ imageDTO }));
|
||||
dispatch(adHocPostProcessingRequested({ imageDTO }));
|
||||
}, [dispatch, imageDTO]);
|
||||
|
||||
const handleDelete = useCallback(() => {
|
||||
@ -193,7 +193,7 @@ const CurrentImageButtons = () => {
|
||||
|
||||
{isUpscalingEnabled && (
|
||||
<ButtonGroup isDisabled={isQueueMutationInProgress}>
|
||||
{isUpscalingEnabled && <ParamUpscalePopover imageDTO={imageDTO} />}
|
||||
{isUpscalingEnabled && <PostProcessingPopover imageDTO={imageDTO} />}
|
||||
</ButtonGroup>
|
||||
)}
|
||||
|
||||
|
@ -9,7 +9,13 @@ import CurrentImageButtons from './CurrentImageButtons';
|
||||
import { ViewerToggleMenu } from './ViewerToggleMenu';
|
||||
|
||||
export const ViewerToolbar = memo(() => {
|
||||
const tab = useAppSelector(activeTabNameSelector);
|
||||
const showToggle = useAppSelector((s) => {
|
||||
const tab = activeTabNameSelector(s);
|
||||
if (tab === 'upscaling' || tab === 'workflows') {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return (
|
||||
<Flex w="full" gap={2}>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
@ -23,7 +29,7 @@ export const ViewerToolbar = memo(() => {
|
||||
</Flex>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineStart="auto">
|
||||
{tab !== 'workflows' && <ViewerToggleMenu />}
|
||||
{showToggle && <ViewerToggleMenu />}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { offsetChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { throttle } from 'lodash-es';
|
||||
import { useCallback, useEffect, useMemo } from 'react';
|
||||
import { useListImagesQuery } from 'services/api/endpoints/images';
|
||||
|
||||
@ -80,32 +81,41 @@ export const useGalleryPagination = () => {
|
||||
return offset > 0;
|
||||
}, [count, offset]);
|
||||
|
||||
const onOffsetChanged = useCallback(
|
||||
(arg: Parameters<typeof offsetChanged>[0]) => {
|
||||
dispatch(offsetChanged(arg));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const throttledOnOffsetChanged = useMemo(() => throttle(onOffsetChanged, 500), [onOffsetChanged]);
|
||||
|
||||
const goNext = useCallback(
|
||||
(withHotkey?: 'arrow' | 'alt+arrow') => {
|
||||
dispatch(offsetChanged({ offset: offset + (limit || 0), withHotkey }));
|
||||
throttledOnOffsetChanged({ offset: offset + (limit || 0), withHotkey });
|
||||
},
|
||||
[dispatch, offset, limit]
|
||||
[throttledOnOffsetChanged, offset, limit]
|
||||
);
|
||||
|
||||
const goPrev = useCallback(
|
||||
(withHotkey?: 'arrow' | 'alt+arrow') => {
|
||||
dispatch(offsetChanged({ offset: Math.max(offset - (limit || 0), 0), withHotkey }));
|
||||
throttledOnOffsetChanged({ offset: Math.max(offset - (limit || 0), 0), withHotkey });
|
||||
},
|
||||
[dispatch, offset, limit]
|
||||
[throttledOnOffsetChanged, offset, limit]
|
||||
);
|
||||
|
||||
const goToPage = useCallback(
|
||||
(page: number) => {
|
||||
dispatch(offsetChanged({ offset: page * (limit || 0) }));
|
||||
throttledOnOffsetChanged({ offset: page * (limit || 0) });
|
||||
},
|
||||
[dispatch, limit]
|
||||
[throttledOnOffsetChanged, limit]
|
||||
);
|
||||
const goToFirst = useCallback(() => {
|
||||
dispatch(offsetChanged({ offset: 0 }));
|
||||
}, [dispatch]);
|
||||
throttledOnOffsetChanged({ offset: 0 });
|
||||
}, [throttledOnOffsetChanged]);
|
||||
const goToLast = useCallback(() => {
|
||||
dispatch(offsetChanged({ offset: (pages - 1) * (limit || 0) }));
|
||||
}, [dispatch, pages, limit]);
|
||||
throttledOnOffsetChanged({ offset: (pages - 1) * (limit || 0) });
|
||||
}, [throttledOnOffsetChanged, pages, limit]);
|
||||
|
||||
// handle when total/pages decrease and user is on high page number (ie bulk removing or deleting)
|
||||
useEffect(() => {
|
||||
|
@ -1,15 +1,10 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
|
||||
import { isControlNetOrT2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
export const useControlNetOrT2IAdapterDefaultSettings = (modelKey?: string | null) => {
|
||||
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(
|
||||
modelKey ?? skipToken,
|
||||
isControlNetOrT2IAdapterModelConfig
|
||||
);
|
||||
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
export const useControlNetOrT2IAdapterDefaultSettings = (
|
||||
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig
|
||||
) => {
|
||||
const defaultSettingsDefaults = useMemo(() => {
|
||||
return {
|
||||
preprocessor: {
|
||||
@ -19,5 +14,5 @@ export const useControlNetOrT2IAdapterDefaultSettings = (modelKey?: string | nul
|
||||
};
|
||||
}, [modelConfig?.default_settings]);
|
||||
|
||||
return { defaultSettingsDefaults, isLoading };
|
||||
return defaultSettingsDefaults;
|
||||
};
|
||||
|
@ -1,11 +1,9 @@
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
import { type InstallModelArg, useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
type InstallModelArg = {
|
||||
source: string;
|
||||
inplace?: boolean;
|
||||
type InstallModelArgWithCallbacks = InstallModelArg & {
|
||||
onSuccess?: () => void;
|
||||
onError?: (error: unknown) => void;
|
||||
};
|
||||
@ -15,8 +13,9 @@ export const useInstallModel = () => {
|
||||
const [_installModel, request] = useInstallModelMutation();
|
||||
|
||||
const installModel = useCallback(
|
||||
({ source, inplace, onSuccess, onError }: InstallModelArg) => {
|
||||
_installModel({ source, inplace })
|
||||
({ source, inplace, config, onSuccess, onError }: InstallModelArgWithCallbacks) => {
|
||||
config ||= {};
|
||||
_installModel({ source, inplace, config })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
if (onSuccess) {
|
||||
|
@ -1,12 +1,9 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => {
|
||||
const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision, width, height } = config.sd;
|
||||
@ -22,9 +19,7 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
|
||||
};
|
||||
});
|
||||
|
||||
export const useMainModelDefaultSettings = (modelKey?: string | null) => {
|
||||
const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(modelKey ?? skipToken, isNonRefinerMainModelConfig);
|
||||
|
||||
export const useMainModelDefaultSettings = (modelConfig: MainModelConfig) => {
|
||||
const {
|
||||
initialSteps,
|
||||
initialCfg,
|
||||
@ -81,5 +76,5 @@ export const useMainModelDefaultSettings = (modelKey?: string | null) => {
|
||||
initialHeight,
|
||||
]);
|
||||
|
||||
return { defaultSettingsDefaults, isLoading, optimalDimension: getOptimalDimension(modelConfig) };
|
||||
return defaultSettingsDefaults;
|
||||
};
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { Button, Text, useToast } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { $installModelsTab } from 'features/modelManagerV2/subpanels/InstallModels';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
@ -44,6 +45,7 @@ const ToastDescription = () => {
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(setActiveTab('models'));
|
||||
$installModelsTab.set(3);
|
||||
toast.close(TOAST_ID);
|
||||
}, [dispatch, toast]);
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig } from 'app/store/store';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { ModelType } from 'services/api/types';
|
||||
|
||||
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'> | 'refiner';
|
||||
@ -50,6 +50,8 @@ export const modelManagerV2Slice = createSlice({
|
||||
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType, setSelectedModelMode, setScanPath } =
|
||||
modelManagerV2Slice.actions;
|
||||
|
||||
export const selectModelManagerV2Slice = (state: RootState) => state.modelmanagerV2;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrateModelManagerState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
|
@ -1,13 +1,13 @@
|
||||
import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { HuggingFaceResults } from './HuggingFaceResults';
|
||||
|
||||
export const HuggingFaceForm = () => {
|
||||
export const HuggingFaceForm = memo(() => {
|
||||
const [huggingFaceRepo, setHuggingFaceRepo] = useState('');
|
||||
const [displayResults, setDisplayResults] = useState(false);
|
||||
const [errorMessage, setErrorMessage] = useState('');
|
||||
@ -66,4 +66,6 @@ export const HuggingFaceForm = () => {
|
||||
{data && data.urls && displayResults && <HuggingFaceResults results={data.urls} />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
HuggingFaceForm.displayName = 'HuggingFaceForm';
|
||||
|
@ -1,13 +1,13 @@
|
||||
import { Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import { useCallback } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
type Props = {
|
||||
result: string;
|
||||
};
|
||||
export const HuggingFaceResultItem = ({ result }: Props) => {
|
||||
export const HuggingFaceResultItem = memo(({ result }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [installModel] = useInstallModel();
|
||||
@ -27,4 +27,6 @@ export const HuggingFaceResultItem = ({ result }: Props) => {
|
||||
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={onClick} size="sm" />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
HuggingFaceResultItem.displayName = 'HuggingFaceResultItem';
|
||||
|
@ -11,7 +11,7 @@ import {
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
@ -21,7 +21,7 @@ type HuggingFaceResultsProps = {
|
||||
results: string[];
|
||||
};
|
||||
|
||||
export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
|
||||
export const HuggingFaceResults = memo(({ results }: HuggingFaceResultsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
|
||||
@ -93,4 +93,6 @@ export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
HuggingFaceResults.displayName = 'HuggingFaceResults';
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { Button, Checkbox, Flex, FormControl, FormHelperText, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import { t } from 'i18next';
|
||||
import { useCallback } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
|
||||
@ -10,7 +10,7 @@ type SimpleImportModelConfig = {
|
||||
inplace: boolean;
|
||||
};
|
||||
|
||||
export const InstallModelForm = () => {
|
||||
export const InstallModelForm = memo(() => {
|
||||
const [installModel, { isLoading }] = useInstallModel();
|
||||
|
||||
const { register, handleSubmit, formState, reset } = useForm<SimpleImportModelConfig>({
|
||||
@ -74,4 +74,6 @@ export const InstallModelForm = () => {
|
||||
</Flex>
|
||||
</form>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
InstallModelForm.displayName = 'InstallModelForm';
|
||||
|
@ -2,12 +2,12 @@ import { Box, Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { ModelInstallQueueItem } from './ModelInstallQueueItem';
|
||||
|
||||
export const ModelInstallQueue = () => {
|
||||
export const ModelInstallQueue = memo(() => {
|
||||
const { data } = useListModelInstallsQuery();
|
||||
|
||||
const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation();
|
||||
@ -61,4 +61,6 @@ export const ModelInstallQueue = () => {
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
ModelInstallQueue.displayName = 'ModelInstallQueue';
|
||||
|
@ -2,7 +2,7 @@ import { Flex, IconButton, Progress, Text, Tooltip } from '@invoke-ai/ui-library
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
|
||||
import type { ModelInstallJob } from 'services/api/types';
|
||||
@ -25,7 +25,7 @@ const formatBytes = (bytes: number) => {
|
||||
return `${bytes.toFixed(2)} ${units[i]}`;
|
||||
};
|
||||
|
||||
export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
||||
export const ModelInstallQueueItem = memo((props: ModelListItemProps) => {
|
||||
const { installJob } = props;
|
||||
|
||||
const [deleteImportModel] = useCancelModelInstallMutation();
|
||||
@ -124,7 +124,9 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
ModelInstallQueueItem.displayName = 'ModelInstallQueueItem';
|
||||
|
||||
type TooltipLabelProps = {
|
||||
installJob: ModelInstallJob;
|
||||
@ -132,7 +134,7 @@ type TooltipLabelProps = {
|
||||
source: string;
|
||||
};
|
||||
|
||||
const TooltipLabel = ({ name, source, installJob }: TooltipLabelProps) => {
|
||||
const TooltipLabel = memo(({ name, source, installJob }: TooltipLabelProps) => {
|
||||
const progressString = useMemo(() => {
|
||||
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
|
||||
return '';
|
||||
@ -156,4 +158,6 @@ const TooltipLabel = ({ name, source, installJob }: TooltipLabelProps) => {
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
TooltipLabel.displayName = 'TooltipLabel';
|
||||
|
@ -2,13 +2,13 @@ import { Button, Flex, FormControl, FormErrorMessage, FormHelperText, FormLabel,
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setScanPath } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useLazyScanFolderQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { ScanModelsResults } from './ScanFolderResults';
|
||||
|
||||
export const ScanModelsForm = () => {
|
||||
export const ScanModelsForm = memo(() => {
|
||||
const scanPath = useAppSelector((state) => state.modelmanagerV2.scanPath);
|
||||
const dispatch = useAppDispatch();
|
||||
const [errorMessage, setErrorMessage] = useState('');
|
||||
@ -56,4 +56,6 @@ export const ScanModelsForm = () => {
|
||||
{data && <ScanModelsResults results={data} />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
ScanModelsForm.displayName = 'ScanModelsForm';
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import { useCallback } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
||||
@ -8,7 +8,7 @@ type Props = {
|
||||
result: ScanFolderResponse[number];
|
||||
installModel: (source: string) => void;
|
||||
};
|
||||
export const ScanModelResultItem = ({ result, installModel }: Props) => {
|
||||
export const ScanModelResultItem = memo(({ result, installModel }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleInstall = useCallback(() => {
|
||||
@ -30,4 +30,6 @@ export const ScanModelResultItem = ({ result, installModel }: Props) => {
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
ScanModelResultItem.displayName = 'ScanModelResultItem';
|
||||
|
@ -14,7 +14,7 @@ import {
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import type { ChangeEvent, ChangeEventHandler } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import type { ScanFolderResponse } from 'services/api/endpoints/models';
|
||||
@ -25,7 +25,7 @@ type ScanModelResultsProps = {
|
||||
results: ScanFolderResponse;
|
||||
};
|
||||
|
||||
export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
||||
export const ScanModelsResults = memo(({ results }: ScanModelResultsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const [inplace, setInplace] = useState(true);
|
||||
@ -116,4 +116,6 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
ScanModelsResults.displayName = 'ScanModelsResults';
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import { useInstallModel } from 'features/modelManagerV2/hooks/useInstallModel';
|
||||
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
||||
@ -9,20 +9,22 @@ import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
||||
type Props = {
|
||||
result: GetStarterModelsResponse[number];
|
||||
};
|
||||
export const StarterModelsResultItem = ({ result }: Props) => {
|
||||
export const StarterModelsResultItem = memo(({ result }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const allSources = useMemo(() => {
|
||||
const _allSources = [result.source];
|
||||
const _allSources = [{ source: result.source, config: { name: result.name, description: result.description } }];
|
||||
if (result.dependencies) {
|
||||
_allSources.push(...result.dependencies.map((d) => d.source));
|
||||
for (const d of result.dependencies) {
|
||||
_allSources.push({ source: d.source, config: { name: d.name, description: d.description } });
|
||||
}
|
||||
}
|
||||
return _allSources;
|
||||
}, [result]);
|
||||
const [installModel] = useInstallModel();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
for (const source of allSources) {
|
||||
installModel({ source });
|
||||
for (const { config, source } of allSources) {
|
||||
installModel({ config, source });
|
||||
}
|
||||
}, [allSources, installModel]);
|
||||
|
||||
@ -30,7 +32,7 @@ export const StarterModelsResultItem = ({ result }: Props) => {
|
||||
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
|
||||
<Flex fontSize="sm" flexDir="column">
|
||||
<Flex gap={3}>
|
||||
<Badge h="min-content">{result.type.replace('_', ' ')}</Badge>
|
||||
<Badge h="min-content">{result.type.replaceAll('_', ' ')}</Badge>
|
||||
<ModelBaseBadge base={result.base} />
|
||||
<Text fontWeight="semibold">{result.name}</Text>
|
||||
</Flex>
|
||||
@ -45,4 +47,6 @@ export const StarterModelsResultItem = ({ result }: Props) => {
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
StarterModelsResultItem.displayName = 'StarterModelsResultItem';
|
||||
|
@ -1,10 +1,11 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { FetchingModelsLoader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/FetchingModelsLoader';
|
||||
import { memo } from 'react';
|
||||
import { useGetStarterModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { StarterModelsResults } from './StarterModelsResults';
|
||||
|
||||
export const StarterModelsForm = () => {
|
||||
export const StarterModelsForm = memo(() => {
|
||||
const { isLoading, data } = useGetStarterModelsQuery();
|
||||
|
||||
return (
|
||||
@ -13,4 +14,6 @@ export const StarterModelsForm = () => {
|
||||
{data && <StarterModelsResults results={data} />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
StarterModelsForm.displayName = 'StarterModelsForm';
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { Flex, IconButton, Input, InputGroup, InputRightElement } from '@invoke-ai/ui-library';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
||||
@ -12,20 +12,30 @@ type StarterModelsResultsProps = {
|
||||
results: NonNullable<GetStarterModelsResponse>;
|
||||
};
|
||||
|
||||
export const StarterModelsResults = ({ results }: StarterModelsResultsProps) => {
|
||||
export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
|
||||
const filteredResults = useMemo(() => {
|
||||
return results.filter((result) => {
|
||||
const name = result.name.toLowerCase();
|
||||
const type = result.type.toLowerCase();
|
||||
return name.includes(searchTerm.toLowerCase()) || type.includes(searchTerm.toLowerCase());
|
||||
const trimmedSearchTerm = searchTerm.trim().toLowerCase();
|
||||
const matchStrings = [
|
||||
result.name.toLowerCase(),
|
||||
result.type.toLowerCase().replaceAll('_', ' '),
|
||||
result.description.toLowerCase(),
|
||||
];
|
||||
if (result.type === 'spandrel_image_to_image') {
|
||||
matchStrings.push('upscale');
|
||||
matchStrings.push('post-processing');
|
||||
matchStrings.push('postprocessing');
|
||||
matchStrings.push('post processing');
|
||||
}
|
||||
return matchStrings.some((matchString) => matchString.includes(trimmedSearchTerm));
|
||||
});
|
||||
}, [results, searchTerm]);
|
||||
|
||||
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||
setSearchTerm(e.target.value.trim());
|
||||
setSearchTerm(e.target.value);
|
||||
}, []);
|
||||
|
||||
const clearSearch = useCallback(() => {
|
||||
@ -69,4 +79,6 @@ export const StarterModelsResults = ({ results }: StarterModelsResultsProps) =>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
StarterModelsResults.displayName = 'StarterModelsResults';
|
||||
|
@ -1,28 +1,28 @@
|
||||
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { StarterModelsForm } from 'features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm';
|
||||
import { useMemo } from 'react';
|
||||
import { atom } from 'nanostores';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useMainModels } from 'services/api/hooks/modelsByType';
|
||||
|
||||
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
|
||||
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
|
||||
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
|
||||
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
|
||||
|
||||
export const InstallModels = () => {
|
||||
export const $installModelsTab = atom(0);
|
||||
|
||||
export const InstallModels = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const [mainModels, { data }] = useMainModels();
|
||||
const defaultIndex = useMemo(() => {
|
||||
if (data && mainModels.length) {
|
||||
return 0;
|
||||
}
|
||||
return 3;
|
||||
}, [data, mainModels.length]);
|
||||
const index = useStore($installModelsTab);
|
||||
const onChange = useCallback((index: number) => {
|
||||
$installModelsTab.set(index);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Flex layerStyle="first" borderRadius="base" w="full" h="full" flexDir="column" gap={4}>
|
||||
<Heading fontSize="xl">{t('modelManager.addModel')}</Heading>
|
||||
<Tabs variant="collapse" height="50%" display="flex" flexDir="column" defaultIndex={defaultIndex}>
|
||||
<Tabs variant="collapse" height="50%" display="flex" flexDir="column" index={index} onChange={onChange}>
|
||||
<TabList>
|
||||
<Tab>{t('modelManager.urlOrLocalPath')}</Tab>
|
||||
<Tab>{t('modelManager.huggingFace')}</Tab>
|
||||
@ -49,4 +49,6 @@ export const InstallModels = () => {
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
InstallModels.displayName = 'InstallModels';
|
||||
|
@ -1,14 +1,14 @@
|
||||
import { Button, Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { useCallback } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
|
||||
import ModelList from './ModelManagerPanel/ModelList';
|
||||
import { ModelListNavigation } from './ModelManagerPanel/ModelListNavigation';
|
||||
|
||||
export const ModelManager = () => {
|
||||
export const ModelManager = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const handleClickAddModel = useCallback(() => {
|
||||
@ -29,4 +29,6 @@ export const ModelManager = () => {
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
ModelManager.displayName = 'ModelManager';
|
||||
|
@ -21,7 +21,8 @@ import { FetchingModelsLoader } from './FetchingModelsLoader';
|
||||
import { ModelListWrapper } from './ModelListWrapper';
|
||||
|
||||
const ModelList = () => {
|
||||
const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2);
|
||||
const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType);
|
||||
const searchTerm = useAppSelector((s) => s.modelmanagerV2.searchTerm);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [mainModels, { isLoading: isLoadingMainModels }] = useMainModels();
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user