Merge branch 'release/invokeai-3-0-alpha' of github.com:invoke-ai/InvokeAI into release/invokeai-3-0-alpha

This commit is contained in:
Lincoln Stein 2023-07-03 14:11:28 -04:00
commit 3937428563
135 changed files with 3655 additions and 2219 deletions

View File

@ -47,7 +47,7 @@ def add_parsers(
commands: list[type], commands: list[type],
command_field: str = "type", command_field: str = "type",
exclude_fields: list[str] = ["id", "type"], exclude_fields: list[str] = ["id", "type"],
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None add_arguments: Union[Callable[[argparse.ArgumentParser], None],None] = None
): ):
"""Adds parsers for each command to the subparsers""" """Adds parsers for each command to the subparsers"""
@ -72,7 +72,7 @@ def add_parsers(
def add_graph_parsers( def add_graph_parsers(
subparsers, subparsers,
graphs: list[LibraryGraph], graphs: list[LibraryGraph],
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
): ):
for graph in graphs: for graph in graphs:
command_parser = subparsers.add_parser(graph.name, help=graph.description) command_parser = subparsers.add_parser(graph.name, help=graph.description)

View File

@ -1,12 +1,11 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import argparse import argparse
import os
import re import re
import shlex import shlex
import sys import sys
import time import time
from typing import Union, get_type_hints from typing import Union, get_type_hints, Optional
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from pydantic.fields import Field from pydantic.fields import Field
@ -348,7 +347,7 @@ def invoke_cli():
# Parse invocation # Parse invocation
command: CliCommand = None # type:ignore command: CliCommand = None # type:ignore
system_graph: LibraryGraph|None = None system_graph: Optional[LibraryGraph] = None
if args['type'] in system_graph_names: if args['type'] in system_graph_names:
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs)) system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id)) invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))

View File

@ -97,6 +97,7 @@ class UIConfig(TypedDict, total=False):
"latents", "latents",
"model", "model",
"control", "control",
"image_collection",
], ],
] ]
tags: List[str] tags: List[str]

View File

@ -4,13 +4,16 @@ from typing import Literal
import numpy as np import numpy as np
from pydantic import Field, validator from pydantic import Field, validator
from invokeai.app.models.image import ImageField
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
InvocationConfig,
InvocationContext, InvocationContext,
BaseInvocationOutput, BaseInvocationOutput,
UIConfig,
) )
@ -22,6 +25,7 @@ class IntCollectionOutput(BaseInvocationOutput):
# Outputs # Outputs
collection: list[int] = Field(default=[], description="The int collection") collection: list[int] = Field(default=[], description="The int collection")
class FloatCollectionOutput(BaseInvocationOutput): class FloatCollectionOutput(BaseInvocationOutput):
"""A collection of floats""" """A collection of floats"""
@ -31,6 +35,18 @@ class FloatCollectionOutput(BaseInvocationOutput):
collection: list[float] = Field(default=[], description="The float collection") collection: list[float] = Field(default=[], description="The float collection")
class ImageCollectionOutput(BaseInvocationOutput):
"""A collection of images"""
type: Literal["image_collection"] = "image_collection"
# Outputs
collection: list[ImageField] = Field(default=[], description="The output images")
class Config:
schema_extra = {"required": ["type", "collection"]}
class RangeInvocation(BaseInvocation): class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step""" """Creates a range of numbers from start to stop with step"""
@ -92,3 +108,27 @@ class RandomRangeInvocation(BaseInvocation):
return IntCollectionOutput( return IntCollectionOutput(
collection=list(rng.integers(low=self.low, high=self.high, size=self.size)) collection=list(rng.integers(low=self.low, high=self.high, size=self.size))
) )
class ImageCollectionInvocation(BaseInvocation):
"""Load a collection of images and provide it as output."""
# fmt: off
type: Literal["image_collection"] = "image_collection"
# Inputs
images: list[ImageField] = Field(
default=[], description="The image collection to load"
)
# fmt: on
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
return ImageCollectionOutput(collection=self.images)
class Config(InvocationConfig):
schema_extra = {
"ui": {
"type_hints": {
"images": "image_collection",
}
},
}

View File

@ -1,6 +1,5 @@
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from contextlib import ExitStack
import re import re
import torch import torch
@ -9,7 +8,7 @@ from .model import ClipField
from ...backend.util.devices import torch_dtype from ...backend.util.devices import torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management import ModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from compel import Compel from compel import Compel

View File

@ -6,7 +6,7 @@ from builtins import float, bool
import cv2 import cv2
import numpy as np import numpy as np
from typing import Literal, Optional, Union, List, Dict from typing import Literal, Optional, Union, List, Dict
from PIL import Image, ImageFilter, ImageOps from PIL import Image
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from ..models.image import ImageField, ImageCategory, ResourceOrigin from ..models.image import ImageField, ImageCategory, ResourceOrigin
@ -422,9 +422,9 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
# Inputs # Inputs
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
h: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `h` parameter") h: Optional[int] = Field(default=512, ge=0, description="Content shuffle `h` parameter")
w: Union[int, None] = Field(default=512, ge=0, description="Content shuffle `w` parameter") w: Optional[int] = Field(default=512, ge=0, description="Content shuffle `w` parameter")
f: Union[int, None] = Field(default=256, ge=0, description="Content shuffle `f` parameter") f: Optional[int] = Field(default=256, ge=0, description="Content shuffle `f` parameter")
# fmt: on # fmt: on
def run_processor(self, image): def run_processor(self, image):

View File

@ -1,11 +1,10 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from functools import partial from functools import partial
from typing import Literal, Optional, Union, get_args from typing import Literal, Optional, get_args
import torch import torch
from diffusers import ControlNetModel from pydantic import Field
from pydantic import BaseModel, Field
from invokeai.app.models.image import (ColorField, ImageCategory, ImageField, from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
ResourceOrigin) ResourceOrigin)
@ -18,7 +17,6 @@ from ..util.step_callback import stable_diffusion_step_callback
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
from .image import ImageOutput from .image import ImageOutput
import re
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from .model import UNetField, VaeField from .model import UNetField, VaeField
@ -76,7 +74,7 @@ class InpaintInvocation(BaseInvocation):
vae: VaeField = Field(default=None, description="Vae model") vae: VaeField = Field(default=None, description="Vae model")
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The input image") image: Optional[ImageField] = Field(description="The input image")
strength: float = Field( strength: float = Field(
default=0.75, gt=0, le=1, description="The strength of the original image" default=0.75, gt=0, le=1, description="The strength of the original image"
) )
@ -86,7 +84,7 @@ class InpaintInvocation(BaseInvocation):
) )
# Inputs # Inputs
mask: Union[ImageField, None] = Field(description="The mask") mask: Optional[ImageField] = Field(description="The mask")
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)") seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
seam_blur: int = Field( seam_blur: int = Field(
default=16, ge=0, description="The seam inpaint blur radius (px)" default=16, ge=0, description="The seam inpaint blur radius (px)"

View File

@ -1,7 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io from typing import Literal, Optional
from typing import Literal, Optional, Union
import numpy import numpy
from PIL import Image, ImageFilter, ImageOps, ImageChops from PIL import Image, ImageFilter, ImageOps, ImageChops
@ -67,7 +66,7 @@ class LoadImageInvocation(BaseInvocation):
type: Literal["load_image"] = "load_image" type: Literal["load_image"] = "load_image"
# Inputs # Inputs
image: Union[ImageField, None] = Field( image: Optional[ImageField] = Field(
default=None, description="The image to load" default=None, description="The image to load"
) )
# fmt: on # fmt: on
@ -87,7 +86,7 @@ class ShowImageInvocation(BaseInvocation):
type: Literal["show_image"] = "show_image" type: Literal["show_image"] = "show_image"
# Inputs # Inputs
image: Union[ImageField, None] = Field( image: Optional[ImageField] = Field(
default=None, description="The image to show" default=None, description="The image to show"
) )
@ -112,7 +111,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_crop"] = "img_crop" type: Literal["img_crop"] = "img_crop"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to crop") image: Optional[ImageField] = Field(default=None, description="The image to crop")
x: int = Field(default=0, description="The left x coordinate of the crop rectangle") x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
y: int = Field(default=0, description="The top y coordinate of the crop rectangle") y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle") width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
@ -150,8 +149,8 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_paste"] = "img_paste" type: Literal["img_paste"] = "img_paste"
# Inputs # Inputs
base_image: Union[ImageField, None] = Field(default=None, description="The base image") base_image: Optional[ImageField] = Field(default=None, description="The base image")
image: Union[ImageField, None] = Field(default=None, description="The image to paste") image: Optional[ImageField] = Field(default=None, description="The image to paste")
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting") mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
x: int = Field(default=0, description="The left x coordinate at which to paste the image") x: int = Field(default=0, description="The left x coordinate at which to paste the image")
y: int = Field(default=0, description="The top y coordinate at which to paste the image") y: int = Field(default=0, description="The top y coordinate at which to paste the image")
@ -203,7 +202,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["tomask"] = "tomask" type: Literal["tomask"] = "tomask"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to create the mask from") image: Optional[ImageField] = Field(default=None, description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask") invert: bool = Field(default=False, description="Whether or not to invert the mask")
# fmt: on # fmt: on
@ -237,8 +236,8 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_mul"] = "img_mul" type: Literal["img_mul"] = "img_mul"
# Inputs # Inputs
image1: Union[ImageField, None] = Field(default=None, description="The first image to multiply") image1: Optional[ImageField] = Field(default=None, description="The first image to multiply")
image2: Union[ImageField, None] = Field(default=None, description="The second image to multiply") image2: Optional[ImageField] = Field(default=None, description="The second image to multiply")
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@ -273,7 +272,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_chan"] = "img_chan" type: Literal["img_chan"] = "img_chan"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to get the channel from") image: Optional[ImageField] = Field(default=None, description="The image to get the channel from")
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get") channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
# fmt: on # fmt: on
@ -308,7 +307,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_conv"] = "img_conv" type: Literal["img_conv"] = "img_conv"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to convert") image: Optional[ImageField] = Field(default=None, description="The image to convert")
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to") mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
# fmt: on # fmt: on
@ -340,7 +339,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_blur"] = "img_blur" type: Literal["img_blur"] = "img_blur"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to blur") image: Optional[ImageField] = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius") radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur") blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
# fmt: on # fmt: on
@ -398,7 +397,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_resize"] = "img_resize" type: Literal["img_resize"] = "img_resize"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to resize") image: Optional[ImageField] = Field(default=None, description="The image to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
@ -437,7 +436,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_scale"] = "img_scale" type: Literal["img_scale"] = "img_scale"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to scale") image: Optional[ImageField] = Field(default=None, description="The image to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the image") scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on # fmt: on
@ -477,7 +476,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_lerp"] = "img_lerp" type: Literal["img_lerp"] = "img_lerp"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to lerp") image: Optional[ImageField] = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value") min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value") max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
# fmt: on # fmt: on
@ -513,7 +512,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_ilerp"] = "img_ilerp" type: Literal["img_ilerp"] = "img_ilerp"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to lerp") image: Optional[ImageField] = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value") min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value") max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
# fmt: on # fmt: on

View File

@ -1,6 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import Literal, Union, get_args from typing import Literal, Optional, get_args
import numpy as np import numpy as np
import math import math
@ -68,7 +68,7 @@ def get_tile_images(image: np.ndarray, width=8, height=8):
def tile_fill_missing( def tile_fill_missing(
im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
) -> Image.Image: ) -> Image.Image:
# Only fill if there's an alpha layer # Only fill if there's an alpha layer
if im.mode != "RGBA": if im.mode != "RGBA":
@ -125,7 +125,7 @@ class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color""" """Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba" type: Literal["infill_rgba"] = "infill_rgba"
image: Union[ImageField, None] = Field( image: Optional[ImageField] = Field(
default=None, description="The image to infill" default=None, description="The image to infill"
) )
color: ColorField = Field( color: ColorField = Field(
@ -162,7 +162,7 @@ class InfillTileInvocation(BaseInvocation):
type: Literal["infill_tile"] = "infill_tile" type: Literal["infill_tile"] = "infill_tile"
image: Union[ImageField, None] = Field( image: Optional[ImageField] = Field(
default=None, description="The image to infill" default=None, description="The image to infill"
) )
tile_size: int = Field(default=32, ge=1, description="The tile size (px)") tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
@ -202,7 +202,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
type: Literal["infill_patchmatch"] = "infill_patchmatch" type: Literal["infill_patchmatch"] = "infill_patchmatch"
image: Union[ImageField, None] = Field( image: Optional[ImageField] = Field(
default=None, description="The image to infill" default=None, description="The image to infill"
) )

View File

@ -1,21 +1,18 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import ExitStack
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
import einops import einops
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
import torch import torch
from diffusers import ControlNetModel, DPMSolverMultistepScheduler from diffusers import ControlNetModel
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ..models.image import ImageCategory, ImageField, ResourceOrigin from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.image_util.seamless import configure_model_padding
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline, ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
@ -546,7 +543,7 @@ class ImageToLatentsInvocation(BaseInvocation):
type: Literal["i2l"] = "i2l" type: Literal["i2l"] = "i2l"
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The image to encode") image: Optional[ImageField] = Field(description="The image to encode")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)") tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")

View File

@ -1,4 +1,4 @@
from typing import Literal, Union from typing import Literal, Optional
from pydantic import Field from pydantic import Field
@ -15,7 +15,7 @@ class RestoreFaceInvocation(BaseInvocation):
type: Literal["restore_face"] = "restore_face" type: Literal["restore_face"] = "restore_face"
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The input image") image: Optional[ImageField] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" ) strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
# fmt: on # fmt: on

View File

@ -1,6 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Union from typing import Literal, Optional
from pydantic import Field from pydantic import Field
@ -16,7 +16,7 @@ class UpscaleInvocation(BaseInvocation):
type: Literal["upscale"] = "upscale" type: Literal["upscale"] = "upscale"
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None) image: Optional[ImageField] = Field(description="The input image", default=None)
strength: float = Field(default=0.75, gt=0, le=1, description="The strength") strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2, 4] = Field(default=2, description="The upscale level") level: Literal[2, 4] = Field(default=2, description="The upscale level")
# fmt: on # fmt: on

View File

@ -1,8 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import sqlite3 import sqlite3
import threading import threading
from typing import Union, cast from typing import Optional, cast
from invokeai.app.services.board_record_storage import BoardRecord
from invokeai.app.services.image_record_storage import OffsetPaginatedResults from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
@ -44,7 +43,7 @@ class BoardImageRecordStorageBase(ABC):
def get_board_for_image( def get_board_for_image(
self, self,
image_name: str, image_name: str,
) -> Union[str, None]: ) -> Optional[str]:
"""Gets an image's board id, if it has one.""" """Gets an image's board id, if it has one."""
pass pass
@ -215,7 +214,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
def get_board_for_image( def get_board_for_image(
self, self,
image_name: str, image_name: str,
) -> Union[str, None]: ) -> Optional[str]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger from logging import Logger
from typing import List, Union from typing import List, Union, Optional
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_record_storage import ( from invokeai.app.services.board_record_storage import (
BoardRecord, BoardRecord,
@ -49,7 +49,7 @@ class BoardImagesServiceABC(ABC):
def get_board_for_image( def get_board_for_image(
self, self,
image_name: str, image_name: str,
) -> Union[str, None]: ) -> Optional[str]:
"""Gets an image's board id, if it has one.""" """Gets an image's board id, if it has one."""
pass pass
@ -126,13 +126,13 @@ class BoardImagesService(BoardImagesServiceABC):
def get_board_for_image( def get_board_for_image(
self, self,
image_name: str, image_name: str,
) -> Union[str, None]: ) -> Optional[str]:
board_id = self._services.board_image_records.get_board_for_image(image_name) board_id = self._services.board_image_records.get_board_for_image(image_name)
return board_id return board_id
def board_record_to_dto( def board_record_to_dto(
board_record: BoardRecord, cover_image_name: str | None, image_count: int board_record: BoardRecord, cover_image_name: Optional[str], image_count: int
) -> BoardDTO: ) -> BoardDTO:
"""Converts a board record to a board DTO.""" """Converts a board record to a board DTO."""
return BoardDTO( return BoardDTO(

View File

@ -1,10 +1,9 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any from typing import Any, Optional
from invokeai.app.models.image import ProgressImage from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
from invokeai.app.models.exceptions import CanceledException
class EventServiceBase: class EventServiceBase:
session_event: str = "session_event" session_event: str = "session_event"
@ -28,7 +27,7 @@ class EventServiceBase:
graph_execution_state_id: str, graph_execution_state_id: str,
node: dict, node: dict,
source_node_id: str, source_node_id: str,
progress_image: ProgressImage | None, progress_image: Optional[ProgressImage],
step: int, step: int,
total_steps: int, total_steps: int,
) -> None: ) -> None:

View File

@ -3,7 +3,6 @@
import copy import copy
import itertools import itertools
import uuid import uuid
from types import NoneType
from typing import ( from typing import (
Annotated, Annotated,
Any, Any,
@ -26,6 +25,8 @@ from ..invocations.baseinvocation import (
InvocationContext, InvocationContext,
) )
# in 3.10 this would be "from types import NoneType"
NoneType = type(None)
class EdgeConnection(BaseModel): class EdgeConnection(BaseModel):
node_id: str = Field(description="The id of the node for this edge connection") node_id: str = Field(description="The id of the node for this edge connection")
@ -60,8 +61,6 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
node_input_field = node_inputs.get(field) or None node_input_field = node_inputs.get(field) or None
return node_input_field return node_input_field
from typing import Optional, Union, List, get_args
def is_union_subtype(t1, t2): def is_union_subtype(t1, t2):
t1_args = get_args(t1) t1_args = get_args(t1)
t2_args = get_args(t2) t2_args = get_args(t2)
@ -846,7 +845,7 @@ class GraphExecutionState(BaseModel):
] ]
} }
def next(self) -> BaseInvocation | None: def next(self) -> Optional[BaseInvocation]:
"""Gets the next node ready to execute.""" """Gets the next node ready to execute."""
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes # TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes

View File

@ -2,13 +2,12 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
from typing import Dict, Optional from typing import Dict, Optional, Union
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from send2trash import send2trash from send2trash import send2trash
from invokeai.app.models.image import ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@ -80,7 +79,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
__cache: Dict[Path, PILImageType] __cache: Dict[Path, PILImageType]
__max_cache_size: int __max_cache_size: int
def __init__(self, output_folder: str | Path): def __init__(self, output_folder: Union[str, Path]):
self.__cache = dict() self.__cache = dict()
self.__cache_ids = Queue() self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config self.__max_cache_size = 10 # TODO: get this from config
@ -164,7 +163,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
return path return path
def validate_path(self, path: str | Path) -> bool: def validate_path(self, path: Union[str, Path]) -> bool:
"""Validates the path given for an image or thumbnail.""" """Validates the path given for an image or thumbnail."""
path = path if isinstance(path, Path) else Path(path) path = path if isinstance(path, Path) else Path(path)
return path.exists() return path.exists()
@ -175,7 +174,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
for folder in folders: for folder in folders:
folder.mkdir(parents=True, exist_ok=True) folder.mkdir(parents=True, exist_ok=True)
def __get_cache(self, image_name: Path) -> PILImageType | None: def __get_cache(self, image_name: Path) -> Optional[PILImageType]:
return None if image_name not in self.__cache else self.__cache[image_name] return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: Path, image: PILImageType): def __set_cache(self, image_name: Path, image: PILImageType):

View File

@ -3,7 +3,6 @@ from datetime import datetime
from typing import Generic, Optional, TypeVar, cast from typing import Generic, Optional, TypeVar, cast
import sqlite3 import sqlite3
import threading import threading
from typing import Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from pydantic.generics import GenericModel from pydantic.generics import GenericModel
@ -116,7 +115,7 @@ class ImageRecordStorageBase(ABC):
pass pass
@abstractmethod @abstractmethod
def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None: def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
"""Gets the most recent image for a board.""" """Gets the most recent image for a board."""
pass pass
@ -208,7 +207,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""" """
) )
def get(self, image_name: str) -> Union[ImageRecord, None]: def get(self, image_name: str) -> Optional[ImageRecord]:
try: try:
self._lock.acquire() self._lock.acquire()
@ -220,7 +219,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
(image_name,), (image_name,),
) )
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone()) result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
raise ImageRecordNotFoundException from e raise ImageRecordNotFoundException from e
@ -475,7 +474,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def get_most_recent_image_for_board( def get_most_recent_image_for_board(
self, board_id: str self, board_id: str
) -> Union[ImageRecord, None]: ) -> Optional[ImageRecord]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -490,7 +489,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
(board_id,), (board_id,),
) )
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone()) result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
finally: finally:
self._lock.release() self._lock.release()
if result is None: if result is None:

View File

@ -370,7 +370,7 @@ class ImageService(ImageServiceABC):
def _get_metadata( def _get_metadata(
self, session_id: Optional[str] = None, node_id: Optional[str] = None self, session_id: Optional[str] = None, node_id: Optional[str] = None
) -> Union[ImageMetadata, None]: ) -> Optional[ImageMetadata]:
"""Get the metadata for a node.""" """Get the metadata for a node."""
metadata = None metadata = None

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from queue import Queue from queue import Queue
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional
class InvocationQueueItem(BaseModel): class InvocationQueueItem(BaseModel):
graph_execution_state_id: str = Field(description="The ID of the graph execution state") graph_execution_state_id: str = Field(description="The ID of the graph execution state")
@ -22,7 +22,7 @@ class InvocationQueueABC(ABC):
pass pass
@abstractmethod @abstractmethod
def put(self, item: InvocationQueueItem | None) -> None: def put(self, item: Optional[InvocationQueueItem]) -> None:
pass pass
@abstractmethod @abstractmethod
@ -57,7 +57,7 @@ class MemoryInvocationQueue(InvocationQueueABC):
return item return item
def put(self, item: InvocationQueueItem | None) -> None: def put(self, item: Optional[InvocationQueueItem]) -> None:
self.__queue.put(item) self.__queue.put(item)
def cancel(self, graph_execution_state_id: str) -> None: def cancel(self, graph_execution_state_id: str) -> None:

View File

@ -1,14 +1,11 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from abc import ABC from abc import ABC
from threading import Event, Thread from typing import Optional
from ..invocations.baseinvocation import InvocationContext
from .graph import Graph, GraphExecutionState from .graph import Graph, GraphExecutionState
from .invocation_queue import InvocationQueueABC, InvocationQueueItem from .invocation_queue import InvocationQueueItem
from .invocation_services import InvocationServices from .invocation_services import InvocationServices
from .item_storage import ItemStorageABC
class Invoker: class Invoker:
"""The invoker, used to execute invocations""" """The invoker, used to execute invocations"""
@ -21,7 +18,7 @@ class Invoker:
def invoke( def invoke(
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
) -> str | None: ) -> Optional[str]:
"""Determines the next node to invoke and enqueues it, preparing if needed. """Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue.""" Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
@ -45,7 +42,7 @@ class Invoker:
return invocation.id return invocation.id
def create_execution_state(self, graph: Graph | None = None) -> GraphExecutionState: def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState:
"""Creates a new execution state for the given graph""" """Creates a new execution state for the given graph"""
new_state = GraphExecutionState(graph=Graph() if graph is None else graph) new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
self.services.graph_execution_manager.set(new_state) self.services.graph_execution_manager.set(new_state)

View File

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
from typing import Dict from typing import Dict, Union, Optional
import torch import torch
@ -55,7 +55,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
if name in self.__cache: if name in self.__cache:
del self.__cache[name] del self.__cache[name]
def __get_cache(self, name: str) -> torch.Tensor|None: def __get_cache(self, name: str) -> Optional[torch.Tensor]:
return None if name not in self.__cache else self.__cache[name] return None if name not in self.__cache else self.__cache[name]
def __set_cache(self, name: str, data: torch.Tensor): def __set_cache(self, name: str, data: torch.Tensor):
@ -69,9 +69,9 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
class DiskLatentsStorage(LatentsStorageBase): class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching""" """Stores latents in a folder on disk without caching"""
__output_folder: str | Path __output_folder: Union[str, Path]
def __init__(self, output_folder: str | Path): def __init__(self, output_folder: Union[str, Path]):
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder) self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder.mkdir(parents=True, exist_ok=True) self.__output_folder.mkdir(parents=True, exist_ok=True)

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Union from typing import Any, Optional
import networkx as nx import networkx as nx
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
@ -34,7 +34,7 @@ class CoreMetadataService(MetadataServiceBase):
return metadata return metadata
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Union[str, None]: def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Optional[str]:
""" """
Finds the id of the nearest ancestor (of a valid type) of a given node. Finds the id of the nearest ancestor (of a valid type) of a given node.
@ -65,7 +65,7 @@ class CoreMetadataService(MetadataServiceBase):
def _get_additional_metadata( def _get_additional_metadata(
self, graph: Graph, node_id: str self, graph: Graph, node_id: str
) -> Union[dict[str, Any], None]: ) -> Optional[dict[str, Any]]:
""" """
Returns additional metadata for a given node. Returns additional metadata for a given node.

View File

@ -88,7 +88,7 @@ class ImageUrlsDTO(BaseModel):
class ImageDTO(ImageRecord, ImageUrlsDTO): class ImageDTO(ImageRecord, ImageUrlsDTO):
"""Deserialized image record, enriched for the frontend.""" """Deserialized image record, enriched for the frontend."""
board_id: Union[str, None] = Field( board_id: Optional[str] = Field(
description="The id of the board the image belongs to, if one exists." description="The id of the board the image belongs to, if one exists."
) )
"""The id of the board the image belongs to, if one exists.""" """The id of the board the image belongs to, if one exists."""
@ -96,7 +96,7 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
def image_record_to_dto( def image_record_to_dto(
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None] image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Optional[str]
) -> ImageDTO: ) -> ImageDTO:
"""Converts an image record to an image DTO.""" """Converts an image record to an image DTO."""
return ImageDTO( return ImageDTO(

View File

@ -1,6 +1,6 @@
import sqlite3 import sqlite3
from threading import Lock from threading import Lock
from typing import Generic, TypeVar, Union, get_args from typing import Generic, TypeVar, Optional, Union, get_args
from pydantic import BaseModel, parse_raw_as from pydantic import BaseModel, parse_raw_as
@ -63,7 +63,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._lock.release() self._lock.release()
self._on_changed(item) self._on_changed(item)
def get(self, id: str) -> Union[T, None]: def get(self, id: str) -> Optional[T]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(

View File

@ -21,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from tqdm import trange from tqdm import trange
from typing import Callable, List, Iterator, Optional, Type from typing import Callable, List, Iterator, Optional, Type, Union
from dataclasses import dataclass, field from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
@ -178,7 +178,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
# ------------------------------------ # ------------------------------------
class Img2Img(InvokeAIGenerator): class Img2Img(InvokeAIGenerator):
def generate(self, def generate(self,
init_image: Image.Image | torch.FloatTensor, init_image: Union[Image.Image, torch.FloatTensor],
strength: float=0.75, strength: float=0.75,
**keyword_args **keyword_args
)->Iterator[InvokeAIGeneratorOutput]: )->Iterator[InvokeAIGeneratorOutput]:
@ -195,7 +195,7 @@ class Img2Img(InvokeAIGenerator):
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff # Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
class Inpaint(Img2Img): class Inpaint(Img2Img):
def generate(self, def generate(self,
mask_image: Image.Image | torch.FloatTensor, mask_image: Union[Image.Image, torch.FloatTensor],
# Seam settings - when 0, doesn't fill seam # Seam settings - when 0, doesn't fill seam
seam_size: int = 96, seam_size: int = 96,
seam_blur: int = 16, seam_blur: int = 16,

View File

@ -4,11 +4,10 @@ invokeai.backend.generator.inpaint descends from .generator
from __future__ import annotations from __future__ import annotations
import math import math
from typing import Tuple, Union from typing import Tuple, Union, Optional
import cv2 import cv2
import numpy as np import numpy as np
import PIL
import torch import torch
from PIL import Image, ImageChops, ImageFilter, ImageOps from PIL import Image, ImageChops, ImageFilter, ImageOps
@ -76,7 +75,7 @@ class Inpaint(Img2Img):
return im_patched return im_patched
def tile_fill_missing( def tile_fill_missing(
self, im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
) -> Image.Image: ) -> Image.Image:
# Only fill if there's an alpha layer # Only fill if there's an alpha layer
if im.mode != "RGBA": if im.mode != "RGBA":
@ -203,8 +202,8 @@ class Inpaint(Img2Img):
cfg_scale, cfg_scale,
ddim_eta, ddim_eta,
conditioning, conditioning,
init_image: Image.Image | torch.FloatTensor, init_image: Union[Image.Image, torch.FloatTensor],
mask_image: Image.Image | torch.FloatTensor, mask_image: Union[Image.Image, torch.FloatTensor],
strength: float, strength: float,
mask_blur_radius: int = 8, mask_blur_radius: int = 8,
# Seam settings - when 0, doesn't fill seam # Seam settings - when 0, doesn't fill seam

View File

@ -306,7 +306,6 @@ class ModelManager(object):
and sequential_offload boolean. Note that the default device and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision. type and precision are set up for a CUDA system running at half precision.
""" """
self.config_path = None self.config_path = None
if isinstance(config, (str, Path)): if isinstance(config, (str, Path)):
self.config_path = Path(config) self.config_path = Path(config)
@ -423,7 +422,7 @@ class ModelManager(object):
if submodel_type is not None and hasattr(model_config, submodel_type): if submodel_type is not None and hasattr(model_config, submodel_type):
override_path = getattr(model_config, submodel_type) override_path = getattr(model_config, submodel_type)
if override_path: if override_path:
model_path = override_path model_path = self.app_config.root_path / override_path
model_type = submodel_type model_type = submodel_type
submodel_type = None submodel_type = None
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
@ -431,6 +430,7 @@ class ModelManager(object):
# TODO: path # TODO: path
# TODO: is it accurate to use path as id # TODO: is it accurate to use path as id
dst_convert_path = self._get_model_cache_path(model_path) dst_convert_path = self._get_model_cache_path(model_path)
model_path = model_class.convert_if_required( model_path = model_class.convert_if_required(
base_model=base_model, base_model=base_model,
model_path=str(model_path), # TODO: refactor str/Path types logic model_path=str(model_path), # TODO: refactor str/Path types logic

View File

@ -6,7 +6,7 @@ from dataclasses import dataclass
from diffusers import ModelMixin, ConfigMixin from diffusers import ModelMixin, ConfigMixin
from pathlib import Path from pathlib import Path
from typing import Callable, Literal, Union, Dict from typing import Callable, Literal, Union, Dict, Optional
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from .models import ( from .models import (
@ -64,7 +64,7 @@ class ModelProbe(object):
@classmethod @classmethod
def probe(cls, def probe(cls,
model_path: Path, model_path: Path,
model: Union[Dict, ModelMixin] = None, model: Optional[Union[Dict, ModelMixin]],
prediction_type_helper: Callable[[Path],SchedulerPredictionType] = None)->ModelProbeInfo: prediction_type_helper: Callable[[Path],SchedulerPredictionType] = None)->ModelProbeInfo:
''' '''
Probe the model at model_path and return sufficient information about it Probe the model at model_path and return sufficient information about it

View File

@ -68,7 +68,11 @@ def get_model_config_enums():
enums = list() enums = list()
for model_config in MODEL_CONFIGS: for model_config in MODEL_CONFIGS:
if hasattr(inspect,'get_annotations'):
fields = inspect.get_annotations(model_config) fields = inspect.get_annotations(model_config)
else:
fields = model_config.__annotations__
try: try:
field = fields["model_format"] field = fields["model_format"]
except: except:

View File

@ -7,7 +7,7 @@ import secrets
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
from pydantic import BaseModel, Field from pydantic import Field
import einops import einops
import PIL.Image import PIL.Image
@ -17,12 +17,11 @@ import psutil
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline, StableDiffusionPipeline,
) )
from diffusers.pipelines.controlnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
@ -46,7 +45,7 @@ from .diffusion import (
InvokeAIDiffuserComponent, InvokeAIDiffuserComponent,
PostprocessingSettings, PostprocessingSettings,
) )
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup from .offloading import FullyLoadedModelGroup, ModelGroup
@dataclass @dataclass
class PipelineIntermediateState: class PipelineIntermediateState:
@ -105,7 +104,7 @@ class AddsMaskGuidance:
_debug: Optional[Callable] = None _debug: Optional[Callable] = None
def __call__( def __call__(
self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning
) -> BaseOutput: ) -> BaseOutput:
output_class = step_output.__class__ # We'll create a new one with masked data. output_class = step_output.__class__ # We'll create a new one with masked data.

View File

@ -4,7 +4,7 @@ import warnings
import weakref import weakref
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections.abc import MutableMapping from collections.abc import MutableMapping
from typing import Callable from typing import Callable, Union
import torch import torch
from accelerate.utils import send_to_device from accelerate.utils import send_to_device
@ -117,7 +117,7 @@ class LazilyLoadedModelGroup(ModelGroup):
""" """
_hooks: MutableMapping[torch.nn.Module, RemovableHandle] _hooks: MutableMapping[torch.nn.Module, RemovableHandle]
_current_model_ref: Callable[[], torch.nn.Module | _NoModel] _current_model_ref: Callable[[], Union[torch.nn.Module, _NoModel]]
def __init__(self, execution_device: torch.device): def __init__(self, execution_device: torch.device):
super().__init__(execution_device) super().__init__(execution_device)

View File

@ -4,6 +4,7 @@ from contextlib import nullcontext
import torch import torch
from torch import autocast from torch import autocast
from typing import Union
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
@ -49,7 +50,7 @@ def choose_autocast(precision):
return nullcontext return nullcontext
def normalize_device(device: str | torch.device) -> torch.device: def normalize_device(device: Union[str, torch.device]) -> torch.device:
"""Ensure device has a device index defined, if appropriate.""" """Ensure device has a device index defined, if appropriate."""
device = torch.device(device) device = torch.device(device)
if device.index is None: if device.index is None:

View File

@ -36,6 +36,12 @@ module.exports = {
], ],
'prettier/prettier': ['error', { endOfLine: 'auto' }], 'prettier/prettier': ['error', { endOfLine: 'auto' }],
'@typescript-eslint/ban-ts-comment': 'warn', '@typescript-eslint/ban-ts-comment': 'warn',
'@typescript-eslint/no-empty-interface': [
'error',
{
allowSingleExtends: true,
},
],
}, },
settings: { settings: {
react: { react: {

View File

@ -83,7 +83,7 @@
"konva": "^9.2.0", "konva": "^9.2.0",
"lodash-es": "^4.17.21", "lodash-es": "^4.17.21",
"nanostores": "^0.9.2", "nanostores": "^0.9.2",
"openapi-fetch": "^0.4.0", "openapi-fetch": "0.4.0",
"overlayscrollbars": "^2.2.0", "overlayscrollbars": "^2.2.0",
"overlayscrollbars-react": "^0.5.0", "overlayscrollbars-react": "^0.5.0",
"patch-package": "^7.0.0", "patch-package": "^7.0.0",

View File

@ -52,6 +52,7 @@
"unifiedCanvas": "Unified Canvas", "unifiedCanvas": "Unified Canvas",
"linear": "Linear", "linear": "Linear",
"nodes": "Node Editor", "nodes": "Node Editor",
"batch": "Batch Manager",
"postprocessing": "Post Processing", "postprocessing": "Post Processing",
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.", "nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
"postProcessing": "Post Processing", "postProcessing": "Post Processing",

View File

@ -1,67 +1,40 @@
import { Box, Flex, Grid, Portal } from '@chakra-ui/react'; import { Flex, Grid, Portal } from '@chakra-ui/react';
import { useLogger } from 'app/logging/useLogger'; import { useLogger } from 'app/logging/useLogger';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { PartialAppConfig } from 'app/types/invokeai'; import { PartialAppConfig } from 'app/types/invokeai';
import ImageUploader from 'common/components/ImageUploader'; import ImageUploader from 'common/components/ImageUploader';
import Loading from 'common/components/Loading/Loading';
import GalleryDrawer from 'features/gallery/components/GalleryPanel'; import GalleryDrawer from 'features/gallery/components/GalleryPanel';
import Lightbox from 'features/lightbox/components/Lightbox'; import Lightbox from 'features/lightbox/components/Lightbox';
import SiteHeader from 'features/system/components/SiteHeader'; import SiteHeader from 'features/system/components/SiteHeader';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady';
import { configChanged } from 'features/system/store/configSlice'; import { configChanged } from 'features/system/store/configSlice';
import { languageSelector } from 'features/system/store/systemSelectors'; import { languageSelector } from 'features/system/store/systemSelectors';
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton'; import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons'; import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
import InvokeTabs from 'features/ui/components/InvokeTabs'; import InvokeTabs from 'features/ui/components/InvokeTabs';
import ParametersDrawer from 'features/ui/components/ParametersDrawer'; import ParametersDrawer from 'features/ui/components/ParametersDrawer';
import { AnimatePresence, motion } from 'framer-motion';
import i18n from 'i18n'; import i18n from 'i18n';
import { ReactNode, memo, useCallback, useEffect, useState } from 'react'; import { ReactNode, memo, useEffect } from 'react';
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import GlobalHotkeys from './GlobalHotkeys'; import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster'; import Toaster from './Toaster';
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal'; import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import { useListModelsQuery } from 'services/api/endpoints/models';
import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal'; import DeleteBoardImagesModal from '../../features/gallery/components/Boards/DeleteBoardImagesModal';
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
const DEFAULT_CONFIG = {}; const DEFAULT_CONFIG = {};
interface Props { interface Props {
config?: PartialAppConfig; config?: PartialAppConfig;
headerComponent?: ReactNode; headerComponent?: ReactNode;
setIsReady?: (isReady: boolean) => void;
} }
const App = ({ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
config = DEFAULT_CONFIG,
headerComponent,
setIsReady,
}: Props) => {
const language = useAppSelector(languageSelector); const language = useAppSelector(languageSelector);
const log = useLogger(); const log = useLogger();
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled; const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
const isApplicationReady = useIsApplicationReady();
const { data: pipelineModels } = useListModelsQuery({
model_type: 'main',
});
const { data: controlnetModels } = useListModelsQuery({
model_type: 'controlnet',
});
const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' });
const { data: loraModels } = useListModelsQuery({ model_type: 'lora' });
const { data: embeddingModels } = useListModelsQuery({
model_type: 'embedding',
});
const [loadingOverridden, setLoadingOverridden] = useState(false);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
useEffect(() => { useEffect(() => {
@ -73,27 +46,6 @@ const App = ({
dispatch(configChanged(config)); dispatch(configChanged(config));
}, [dispatch, config, log]); }, [dispatch, config, log]);
const handleOverrideClicked = useCallback(() => {
setLoadingOverridden(true);
}, []);
useEffect(() => {
if (isApplicationReady && setIsReady) {
setIsReady(true);
}
if (isApplicationReady) {
// TODO: This is a jank fix for canvas not filling the screen on first load
setTimeout(() => {
dispatch(requestCanvasRescale());
}, 200);
}
return () => {
setIsReady && setIsReady(false);
};
}, [dispatch, isApplicationReady, setIsReady]);
return ( return (
<> <>
<Grid w="100vw" h="100vh" position="relative" overflow="hidden"> <Grid w="100vw" h="100vh" position="relative" overflow="hidden">
@ -123,33 +75,6 @@ const App = ({
<GalleryDrawer /> <GalleryDrawer />
<ParametersDrawer /> <ParametersDrawer />
<AnimatePresence>
{!isApplicationReady && !loadingOverridden && (
<motion.div
key="loading"
initial={{ opacity: 1 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
transition={{ duration: 0.3 }}
style={{ zIndex: 3 }}
>
<Box position="absolute" top={0} left={0} w="100vw" h="100vh">
<Loading />
</Box>
<Box
onClick={handleOverrideClicked}
position="absolute"
top={0}
right={0}
cursor="pointer"
w="2rem"
h="2rem"
/>
</motion.div>
)}
</AnimatePresence>
<Portal> <Portal>
<FloatingParametersPanelButtons /> <FloatingParametersPanelButtons />
</Portal> </Portal>

View File

@ -0,0 +1,82 @@
import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react';
import { memo } from 'react';
import { TypesafeDraggableData } from './typesafeDnd';
type OverlayDragImageProps = {
dragData: TypesafeDraggableData | null;
};
const BOX_SIZE = 28;
const STYLES: ChakraProps['sx'] = {
w: BOX_SIZE,
h: BOX_SIZE,
maxW: BOX_SIZE,
maxH: BOX_SIZE,
shadow: 'dark-lg',
borderRadius: 'lg',
borderWidth: 2,
borderStyle: 'dashed',
borderColor: 'base.100',
opacity: 0.5,
bg: 'base.800',
color: 'base.50',
_dark: {
borderColor: 'base.200',
bg: 'base.900',
color: 'base.100',
},
};
const DragPreview = (props: OverlayDragImageProps) => {
if (!props.dragData) {
return;
}
if (props.dragData.payloadType === 'IMAGE_DTO') {
return (
<Box
sx={{
position: 'relative',
width: '100%',
height: '100%',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
userSelect: 'none',
cursor: 'none',
}}
>
<Image
sx={{
...STYLES,
}}
src={props.dragData.payload.imageDTO.thumbnail_url}
/>
</Box>
);
}
if (props.dragData.payloadType === 'IMAGE_NAMES') {
return (
<Flex
sx={{
cursor: 'none',
userSelect: 'none',
position: 'relative',
alignItems: 'center',
justifyContent: 'center',
flexDir: 'column',
...STYLES,
}}
>
<Heading>{props.dragData.payload.imageNames.length}</Heading>
<Heading size="sm">Images</Heading>
</Flex>
);
}
return null;
};
export default memo(DragPreview);

View File

@ -1,8 +1,5 @@
import { import {
DndContext,
DragEndEvent,
DragOverlay, DragOverlay,
DragStartEvent,
MouseSensor, MouseSensor,
TouchSensor, TouchSensor,
pointerWithin, pointerWithin,
@ -10,33 +7,45 @@ import {
useSensors, useSensors,
} from '@dnd-kit/core'; } from '@dnd-kit/core';
import { PropsWithChildren, memo, useCallback, useState } from 'react'; import { PropsWithChildren, memo, useCallback, useState } from 'react';
import OverlayDragImage from './OverlayDragImage'; import DragPreview from './DragPreview';
import { ImageDTO } from 'services/api/types';
import { isImageDTO } from 'services/api/guards';
import { snapCenterToCursor } from '@dnd-kit/modifiers'; import { snapCenterToCursor } from '@dnd-kit/modifiers';
import { AnimatePresence, motion } from 'framer-motion'; import { AnimatePresence, motion } from 'framer-motion';
import {
DndContext,
DragEndEvent,
DragStartEvent,
TypesafeDraggableData,
} from './typesafeDnd';
import { useAppDispatch } from 'app/store/storeHooks';
import { imageDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
type ImageDndContextProps = PropsWithChildren; type ImageDndContextProps = PropsWithChildren;
const ImageDndContext = (props: ImageDndContextProps) => { const ImageDndContext = (props: ImageDndContextProps) => {
const [draggedImage, setDraggedImage] = useState<ImageDTO | null>(null); const [activeDragData, setActiveDragData] =
useState<TypesafeDraggableData | null>(null);
const dispatch = useAppDispatch();
const handleDragStart = useCallback((event: DragStartEvent) => { const handleDragStart = useCallback((event: DragStartEvent) => {
const dragData = event.active.data.current; const activeData = event.active.data.current;
if (dragData && 'image' in dragData && isImageDTO(dragData.image)) { if (!activeData) {
setDraggedImage(dragData.image); return;
} }
setActiveDragData(activeData);
}, []); }, []);
const handleDragEnd = useCallback( const handleDragEnd = useCallback(
(event: DragEndEvent) => { (event: DragEndEvent) => {
const handleDrop = event.over?.data.current?.handleDrop; const activeData = event.active.data.current;
if (handleDrop && typeof handleDrop === 'function' && draggedImage) { const overData = event.over?.data.current;
handleDrop(draggedImage); if (!activeData || !overData) {
return;
} }
setDraggedImage(null); dispatch(imageDropped({ overData, activeData }));
setActiveDragData(null);
}, },
[draggedImage] [dispatch]
); );
const mouseSensor = useSensor(MouseSensor, { const mouseSensor = useSensor(MouseSensor, {
@ -46,6 +55,7 @@ const ImageDndContext = (props: ImageDndContextProps) => {
const touchSensor = useSensor(TouchSensor, { const touchSensor = useSensor(TouchSensor, {
activationConstraint: { delay: 150, tolerance: 5 }, activationConstraint: { delay: 150, tolerance: 5 },
}); });
// TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos // TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos
// Alternatively, fix `rectIntersection` collection detection to work with the drag overlay // Alternatively, fix `rectIntersection` collection detection to work with the drag overlay
// (currently the drag element collision rect is not correctly calculated) // (currently the drag element collision rect is not correctly calculated)
@ -63,7 +73,7 @@ const ImageDndContext = (props: ImageDndContextProps) => {
{props.children} {props.children}
<DragOverlay dropAnimation={null} modifiers={[snapCenterToCursor]}> <DragOverlay dropAnimation={null} modifiers={[snapCenterToCursor]}>
<AnimatePresence> <AnimatePresence>
{draggedImage && ( {activeDragData && (
<motion.div <motion.div
layout layout
key="overlay-drag-image" key="overlay-drag-image"
@ -77,7 +87,7 @@ const ImageDndContext = (props: ImageDndContextProps) => {
transition: { duration: 0.1 }, transition: { duration: 0.1 },
}} }}
> >
<OverlayDragImage image={draggedImage} /> <DragPreview dragData={activeDragData} />
</motion.div> </motion.div>
)} )}
</AnimatePresence> </AnimatePresence>

View File

@ -1,36 +0,0 @@
import { Box, Image } from '@chakra-ui/react';
import { memo } from 'react';
import { ImageDTO } from 'services/api/types';
type OverlayDragImageProps = {
image: ImageDTO;
};
const OverlayDragImage = (props: OverlayDragImageProps) => {
return (
<Box
style={{
width: '100%',
height: '100%',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
userSelect: 'none',
cursor: 'grabbing',
opacity: 0.5,
}}
>
<Image
sx={{
maxW: 36,
maxH: 36,
borderRadius: 'base',
shadow: 'dark-lg',
}}
src={props.image.thumbnail_url}
/>
</Box>
);
};
export default memo(OverlayDragImage);

View File

@ -0,0 +1,195 @@
// type-safe dnd from https://github.com/clauderic/dnd-kit/issues/935
import {
Active,
Collision,
DndContextProps,
DndContext as OriginalDndContext,
Over,
Translate,
UseDraggableArguments,
UseDroppableArguments,
useDraggable as useOriginalDraggable,
useDroppable as useOriginalDroppable,
} from '@dnd-kit/core';
import { ImageDTO } from 'services/api/types';
type BaseDropData = {
id: string;
};
export type CurrentImageDropData = BaseDropData & {
actionType: 'SET_CURRENT_IMAGE';
};
export type InitialImageDropData = BaseDropData & {
actionType: 'SET_INITIAL_IMAGE';
};
export type ControlNetDropData = BaseDropData & {
actionType: 'SET_CONTROLNET_IMAGE';
context: {
controlNetId: string;
};
};
export type CanvasInitialImageDropData = BaseDropData & {
actionType: 'SET_CANVAS_INITIAL_IMAGE';
};
export type NodesImageDropData = BaseDropData & {
actionType: 'SET_NODES_IMAGE';
context: {
nodeId: string;
fieldName: string;
};
};
export type NodesMultiImageDropData = BaseDropData & {
actionType: 'SET_MULTI_NODES_IMAGE';
context: { nodeId: string; fieldName: string };
};
export type AddToBatchDropData = BaseDropData & {
actionType: 'ADD_TO_BATCH';
};
export type MoveBoardDropData = BaseDropData & {
actionType: 'MOVE_BOARD';
context: { boardId: string | null };
};
export type TypesafeDroppableData =
| CurrentImageDropData
| InitialImageDropData
| ControlNetDropData
| CanvasInitialImageDropData
| NodesImageDropData
| AddToBatchDropData
| NodesMultiImageDropData
| MoveBoardDropData;
type BaseDragData = {
id: string;
};
export type ImageDraggableData = BaseDragData & {
payloadType: 'IMAGE_DTO';
payload: { imageDTO: ImageDTO };
};
export type ImageNamesDraggableData = BaseDragData & {
payloadType: 'IMAGE_NAMES';
payload: { imageNames: string[] };
};
export type TypesafeDraggableData =
| ImageDraggableData
| ImageNamesDraggableData;
interface UseDroppableTypesafeArguments
extends Omit<UseDroppableArguments, 'data'> {
data?: TypesafeDroppableData;
}
type UseDroppableTypesafeReturnValue = Omit<
ReturnType<typeof useOriginalDroppable>,
'active' | 'over'
> & {
active: TypesafeActive | null;
over: TypesafeOver | null;
};
export function useDroppable(props: UseDroppableTypesafeArguments) {
return useOriginalDroppable(props) as UseDroppableTypesafeReturnValue;
}
interface UseDraggableTypesafeArguments
extends Omit<UseDraggableArguments, 'data'> {
data?: TypesafeDraggableData;
}
type UseDraggableTypesafeReturnValue = Omit<
ReturnType<typeof useOriginalDraggable>,
'active' | 'over'
> & {
active: TypesafeActive | null;
over: TypesafeOver | null;
};
export function useDraggable(props: UseDraggableTypesafeArguments) {
return useOriginalDraggable(props) as UseDraggableTypesafeReturnValue;
}
interface TypesafeActive extends Omit<Active, 'data'> {
data: React.MutableRefObject<TypesafeDraggableData | undefined>;
}
interface TypesafeOver extends Omit<Over, 'data'> {
data: React.MutableRefObject<TypesafeDroppableData | undefined>;
}
export const isValidDrop = (
overData: TypesafeDroppableData | undefined,
active: TypesafeActive | null
) => {
if (!overData || !active?.data.current) {
return false;
}
const { actionType } = overData;
const { payloadType } = active.data.current;
if (overData.id === active.data.current.id) {
return false;
}
switch (actionType) {
case 'SET_CURRENT_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CONTROLNET_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CANVAS_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_MULTI_NODES_IMAGE':
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
case 'ADD_TO_BATCH':
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
case 'MOVE_BOARD':
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
default:
return false;
}
};
interface DragEvent {
activatorEvent: Event;
active: TypesafeActive;
collisions: Collision[] | null;
delta: Translate;
over: TypesafeOver | null;
}
export interface DragStartEvent extends Pick<DragEvent, 'active'> {}
export interface DragMoveEvent extends DragEvent {}
export interface DragOverEvent extends DragMoveEvent {}
export interface DragEndEvent extends DragEvent {}
export interface DragCancelEvent extends DragEndEvent {}
export interface DndContextTypesafeProps
extends Omit<
DndContextProps,
'onDragStart' | 'onDragMove' | 'onDragOver' | 'onDragEnd' | 'onDragCancel'
> {
onDragStart?(event: DragStartEvent): void;
onDragMove?(event: DragMoveEvent): void;
onDragOver?(event: DragOverEvent): void;
onDragEnd?(event: DragEndEvent): void;
onDragCancel?(event: DragCancelEvent): void;
}
export function DndContext(props: DndContextTypesafeProps) {
return <OriginalDndContext {...props} />;
}

View File

@ -7,7 +7,6 @@ import React, {
} from 'react'; } from 'react';
import { Provider } from 'react-redux'; import { Provider } from 'react-redux';
import { store } from 'app/store/store'; import { store } from 'app/store/store';
// import { OpenAPI } from 'services/api/types';
import Loading from '../../common/components/Loading/Loading'; import Loading from '../../common/components/Loading/Loading';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares'; import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
@ -17,11 +16,6 @@ import '../../i18n';
import { socketMiddleware } from 'services/events/middleware'; import { socketMiddleware } from 'services/events/middleware';
import { Middleware } from '@reduxjs/toolkit'; import { Middleware } from '@reduxjs/toolkit';
import ImageDndContext from './ImageDnd/ImageDndContext'; import ImageDndContext from './ImageDnd/ImageDndContext';
import {
DeleteImageContext,
DeleteImageContextProvider,
} from 'app/contexts/DeleteImageContext';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext'; import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext';
import { $authToken, $baseUrl } from 'services/api/client'; import { $authToken, $baseUrl } from 'services/api/client';
import { DeleteBoardImagesContextProvider } from '../contexts/DeleteBoardImagesContext'; import { DeleteBoardImagesContextProvider } from '../contexts/DeleteBoardImagesContext';
@ -34,7 +28,6 @@ interface Props extends PropsWithChildren {
token?: string; token?: string;
config?: PartialAppConfig; config?: PartialAppConfig;
headerComponent?: ReactNode; headerComponent?: ReactNode;
setIsReady?: (isReady: boolean) => void;
middleware?: Middleware[]; middleware?: Middleware[];
} }
@ -43,7 +36,6 @@ const InvokeAIUI = ({
token, token,
config, config,
headerComponent, headerComponent,
setIsReady,
middleware, middleware,
}: Props) => { }: Props) => {
useEffect(() => { useEffect(() => {
@ -85,17 +77,11 @@ const InvokeAIUI = ({
<React.Suspense fallback={<Loading />}> <React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider> <ThemeLocaleProvider>
<ImageDndContext> <ImageDndContext>
<DeleteImageContextProvider>
<AddImageToBoardContextProvider> <AddImageToBoardContextProvider>
<DeleteBoardImagesContextProvider> <DeleteBoardImagesContextProvider>
<App <App config={config} headerComponent={headerComponent} />
config={config}
headerComponent={headerComponent}
setIsReady={setIsReady}
/>
</DeleteBoardImagesContextProvider> </DeleteBoardImagesContextProvider>
</AddImageToBoardContextProvider> </AddImageToBoardContextProvider>
</DeleteImageContextProvider>
</ImageDndContext> </ImageDndContext>
</ThemeLocaleProvider> </ThemeLocaleProvider>
</React.Suspense> </React.Suspense>

View File

@ -5,15 +5,15 @@ import { useDeleteBoardMutation } from '../../services/api/endpoints/boards';
import { defaultSelectorOptions } from '../store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from '../store/util/defaultMemoizeOptions';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { some } from 'lodash-es'; import { some } from 'lodash-es';
import { canvasSelector } from '../../features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { controlNetSelector } from '../../features/controlNet/store/controlNetSlice'; import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
import { selectImagesById } from '../../features/gallery/store/imagesSlice'; import { selectImagesById } from 'features/gallery/store/gallerySlice';
import { nodesSelector } from '../../features/nodes/store/nodesSlice'; import { nodesSelector } from 'features/nodes/store/nodesSlice';
import { generationSelector } from '../../features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { RootState } from '../store/store'; import { RootState } from '../store/store';
import { useAppDispatch, useAppSelector } from '../store/storeHooks'; import { useAppDispatch, useAppSelector } from '../store/storeHooks';
import { ImageUsage } from './DeleteImageContext'; import { ImageUsage } from './DeleteImageContext';
import { requestedBoardImagesDeletion } from '../../features/gallery/store/actions'; import { requestedBoardImagesDeletion } from 'features/gallery/store/actions';
export const selectBoardImagesUsage = createSelector( export const selectBoardImagesUsage = createSelector(
[ [

View File

@ -1,201 +0,0 @@
import { useDisclosure } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { requestedImageDeletion } from 'features/gallery/store/actions';
import { systemSelector } from 'features/system/store/systemSelectors';
import {
PropsWithChildren,
createContext,
useCallback,
useEffect,
useState,
} from 'react';
import { ImageDTO } from 'services/api/types';
import { RootState } from 'app/store/store';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
import { nodesSelector } from 'features/nodes/store/nodesSlice';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { some } from 'lodash-es';
export type ImageUsage = {
isInitialImage: boolean;
isCanvasImage: boolean;
isNodesImage: boolean;
isControlNetImage: boolean;
};
export const selectImageUsage = createSelector(
[
generationSelector,
canvasSelector,
nodesSelector,
controlNetSelector,
(state: RootState, image_name?: string) => image_name,
],
(generation, canvas, nodes, controlNet, image_name) => {
const isInitialImage = generation.initialImage?.imageName === image_name;
const isCanvasImage = canvas.layerState.objects.some(
(obj) => obj.kind === 'image' && obj.imageName === image_name
);
const isNodesImage = nodes.nodes.some((node) => {
return some(
node.data.inputs,
(input) => input.type === 'image' && input.value === image_name
);
});
const isControlNetImage = some(
controlNet.controlNets,
(c) =>
c.controlImage === image_name || c.processedControlImage === image_name
);
const imageUsage: ImageUsage = {
isInitialImage,
isCanvasImage,
isNodesImage,
isControlNetImage,
};
return imageUsage;
},
defaultSelectorOptions
);
type DeleteImageContextValue = {
/**
* Whether the delete image dialog is open.
*/
isOpen: boolean;
/**
* Closes the delete image dialog.
*/
onClose: () => void;
/**
* Opens the delete image dialog and handles all deletion-related checks.
*/
onDelete: (image?: ImageDTO) => void;
/**
* The image pending deletion
*/
image?: ImageDTO;
/**
* The features in which this image is used
*/
imageUsage?: ImageUsage;
/**
* Immediately deletes an image.
*
* You probably don't want to use this - use `onDelete` instead.
*/
onImmediatelyDelete: () => void;
};
export const DeleteImageContext = createContext<DeleteImageContextValue>({
isOpen: false,
onClose: () => undefined,
onImmediatelyDelete: () => undefined,
onDelete: () => undefined,
});
const selector = createSelector(
[systemSelector],
(system) => {
const { isProcessing, isConnected, shouldConfirmOnDelete } = system;
return {
canDeleteImage: isConnected && !isProcessing,
shouldConfirmOnDelete,
};
},
defaultSelectorOptions
);
type Props = PropsWithChildren;
export const DeleteImageContextProvider = (props: Props) => {
const { canDeleteImage, shouldConfirmOnDelete } = useAppSelector(selector);
const [imageToDelete, setImageToDelete] = useState<ImageDTO>();
const dispatch = useAppDispatch();
const { isOpen, onOpen, onClose } = useDisclosure();
// Check where the image to be deleted is used (eg init image, controlnet, etc.)
const imageUsage = useAppSelector((state) =>
selectImageUsage(state, imageToDelete?.image_name)
);
// Clean up after deleting or dismissing the modal
const closeAndClearImageToDelete = useCallback(() => {
setImageToDelete(undefined);
onClose();
}, [onClose]);
// Dispatch the actual deletion action, to be handled by listener middleware
const handleActualDeletion = useCallback(
(image: ImageDTO) => {
dispatch(requestedImageDeletion({ image, imageUsage }));
closeAndClearImageToDelete();
},
[closeAndClearImageToDelete, dispatch, imageUsage]
);
// This is intended to be called by the delete button in the dialog
const onImmediatelyDelete = useCallback(() => {
if (canDeleteImage && imageToDelete) {
handleActualDeletion(imageToDelete);
}
closeAndClearImageToDelete();
}, [
canDeleteImage,
imageToDelete,
closeAndClearImageToDelete,
handleActualDeletion,
]);
const handleGatedDeletion = useCallback(
(image: ImageDTO) => {
if (shouldConfirmOnDelete || some(imageUsage)) {
// If we should confirm on delete, or if the image is in use, open the dialog
onOpen();
} else {
handleActualDeletion(image);
}
},
[imageUsage, shouldConfirmOnDelete, onOpen, handleActualDeletion]
);
// Consumers of the context call this to delete an image
const onDelete = useCallback((image?: ImageDTO) => {
if (!image) {
return;
}
// Set the image to delete, then let the effect call the actual deletion
setImageToDelete(image);
}, []);
useEffect(() => {
// We need to use an effect here to trigger the image usage selector, else we get a stale value
if (imageToDelete) {
handleGatedDeletion(imageToDelete);
}
}, [handleGatedDeletion, imageToDelete]);
return (
<DeleteImageContext.Provider
value={{
isOpen,
image: imageToDelete,
onClose: closeAndClearImageToDelete,
onDelete,
onImmediatelyDelete,
imageUsage,
}}
>
{props.children}
</DeleteImageContext.Provider>
);
};

View File

@ -1,7 +1,6 @@
import { initialCanvasState } from 'features/canvas/store/canvasSlice'; import { initialCanvasState } from 'features/canvas/store/canvasSlice';
import { initialControlNetState } from 'features/controlNet/store/controlNetSlice'; import { initialControlNetState } from 'features/controlNet/store/controlNetSlice';
import { initialGalleryState } from 'features/gallery/store/gallerySlice'; import { initialGalleryState } from 'features/gallery/store/gallerySlice';
import { initialImagesState } from 'features/gallery/store/imagesSlice';
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice'; import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
import { initialNodesState } from 'features/nodes/store/nodesSlice'; import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice';
@ -26,7 +25,6 @@ const initialStates: {
config: initialConfigState, config: initialConfigState,
ui: initialUIState, ui: initialUIState,
hotkeys: initialHotkeysState, hotkeys: initialHotkeysState,
images: initialImagesState,
controlNet: initialControlNetState, controlNet: initialControlNetState,
}; };

View File

@ -72,7 +72,6 @@ import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingA
import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged'; import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged';
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed'; import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess'; import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
import { addUpdateImageUrlsOnConnectListener } from './listeners/updateImageUrlsOnConnect';
import { import {
addImageAddedToBoardFulfilledListener, addImageAddedToBoardFulfilledListener,
addImageAddedToBoardRejectedListener, addImageAddedToBoardRejectedListener,
@ -84,6 +83,9 @@ import {
} from './listeners/imageRemovedFromBoard'; } from './listeners/imageRemovedFromBoard';
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema'; import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesDeleted'; import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesDeleted';
import { addSelectionAddedToBatchListener } from './listeners/selectionAddedToBatch';
import { addImageDroppedListener } from './listeners/imageDropped';
import { addImageToDeleteSelectedListener } from './listeners/imageToDeleteSelected';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -126,6 +128,7 @@ addImageDeletedPendingListener();
addImageDeletedFulfilledListener(); addImageDeletedFulfilledListener();
addImageDeletedRejectedListener(); addImageDeletedRejectedListener();
addRequestedBoardImageDeletionListener(); addRequestedBoardImageDeletionListener();
addImageToDeleteSelectedListener();
// Image metadata // Image metadata
addImageMetadataReceivedFulfilledListener(); addImageMetadataReceivedFulfilledListener();
@ -211,3 +214,9 @@ addBoardIdSelectedListener();
// Node schemas // Node schemas
addReceivedOpenAPISchemaListener(); addReceivedOpenAPISchemaListener();
// Batches
addSelectionAddedToBatchListener();
// DND
addImageDroppedListener();

View File

@ -1,12 +1,14 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { boardIdSelected } from 'features/gallery/store/boardSlice'; import {
import { selectImagesAll } from 'features/gallery/store/imagesSlice'; imageSelected,
selectImagesAll,
boardIdSelected,
} from 'features/gallery/store/gallerySlice';
import { import {
IMAGES_PER_PAGE, IMAGES_PER_PAGE,
receivedPageOfImages, receivedPageOfImages,
} from 'services/api/thunks/image'; } from 'services/api/thunks/image';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { boardsApi } from 'services/api/endpoints/boards'; import { boardsApi } from 'services/api/endpoints/boards';
const moduleLog = log.child({ namespace: 'boards' }); const moduleLog = log.child({ namespace: 'boards' });
@ -28,7 +30,7 @@ export const addBoardIdSelectedListener = () => {
return; return;
} }
const { categories } = state.images; const { categories } = state.gallery;
const filteredImages = allImages.filter((i) => { const filteredImages = allImages.filter((i) => {
const isInCategory = categories.includes(i.image_category); const isInCategory = categories.includes(i.image_category);
@ -47,7 +49,7 @@ export const addBoardIdSelectedListener = () => {
return; return;
} }
dispatch(imageSelected(board.cover_image_name)); dispatch(imageSelected(board.cover_image_name ?? null));
// if we haven't loaded one full page of images from this board, load more // if we haven't loaded one full page of images from this board, load more
if ( if (
@ -77,7 +79,7 @@ export const addBoardIdSelected_changeSelectedImage_listener = () => {
return; return;
} }
const { categories } = state.images; const { categories } = state.gallery;
const filteredImages = selectImagesAll(state).filter((i) => { const filteredImages = selectImagesAll(state).filter((i) => {
const isInCategory = categories.includes(i.image_category); const isInCategory = categories.includes(i.image_category);

View File

@ -1,11 +1,11 @@
import { requestedBoardImagesDeletion } from 'features/gallery/store/actions'; import { requestedBoardImagesDeletion } from 'features/gallery/store/actions';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { import {
imageSelected,
imagesRemoved, imagesRemoved,
selectImagesAll, selectImagesAll,
selectImagesById, selectImagesById,
} from 'features/gallery/store/imagesSlice'; } from 'features/gallery/store/gallerySlice';
import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { clearInitialImage } from 'features/parameters/store/generationSlice';
@ -22,12 +22,15 @@ export const addRequestedBoardImageDeletionListener = () => {
const { board_id } = board; const { board_id } = board;
const state = getState(); const state = getState();
const selectedImage = state.gallery.selectedImage const selectedImageName =
? selectImagesById(state, state.gallery.selectedImage) state.gallery.selection[state.gallery.selection.length - 1];
const selectedImage = selectedImageName
? selectImagesById(state, selectedImageName)
: undefined; : undefined;
if (selectedImage && selectedImage.board_id === board_id) { if (selectedImage && selectedImage.board_id === board_id) {
dispatch(imageSelected()); dispatch(imageSelected(null));
} }
// We need to reset the features where the board images are in use - none of these work if their image(s) don't exist // We need to reset the features where the board images are in use - none of these work if their image(s) don't exist

View File

@ -4,7 +4,7 @@ import { log } from 'app/logging/useLogger';
import { imageUploaded } from 'services/api/thunks/image'; import { imageUploaded } from 'services/api/thunks/image';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { imageUpserted } from 'features/gallery/store/imagesSlice'; import { imageUpserted } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' }); const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });

View File

@ -3,8 +3,8 @@ import { startAppListening } from '..';
import { receivedPageOfImages } from 'services/api/thunks/image'; import { receivedPageOfImages } from 'services/api/thunks/image';
import { import {
imageCategoriesChanged, imageCategoriesChanged,
selectFilteredImagesAsArray, selectFilteredImages,
} from 'features/gallery/store/imagesSlice'; } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'gallery' }); const moduleLog = log.child({ namespace: 'gallery' });
@ -13,7 +13,7 @@ export const addImageCategoriesChangedListener = () => {
actionCreator: imageCategoriesChanged, actionCreator: imageCategoriesChanged,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
const state = getState(); const state = getState();
const filteredImagesCount = selectFilteredImagesAsArray(state).length; const filteredImagesCount = selectFilteredImages(state).length;
if (!filteredImagesCount) { if (!filteredImagesCount) {
dispatch( dispatch(

View File

@ -1,18 +1,21 @@
import { requestedImageDeletion } from 'features/gallery/store/actions';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { imageDeleted } from 'services/api/thunks/image'; import { imageDeleted } from 'services/api/thunks/image';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { clamp } from 'lodash-es'; import { clamp } from 'lodash-es';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { import {
imageSelected,
imageRemoved, imageRemoved,
selectImagesIds, selectImagesIds,
} from 'features/gallery/store/imagesSlice'; } from 'features/gallery/store/gallerySlice';
import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { api } from 'services/api'; import { api } from 'services/api';
import {
imageDeletionConfirmed,
isModalOpenChanged,
} from 'features/imageDeletion/store/imageDeletionSlice';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });
@ -21,16 +24,19 @@ const moduleLog = log.child({ namespace: 'image' });
*/ */
export const addRequestedImageDeletionListener = () => { export const addRequestedImageDeletionListener = () => {
startAppListening({ startAppListening({
actionCreator: requestedImageDeletion, actionCreator: imageDeletionConfirmed,
effect: async (action, { dispatch, getState, condition }) => { effect: async (action, { dispatch, getState, condition }) => {
const { image, imageUsage } = action.payload; const { imageDTO, imageUsage } = action.payload;
const { image_name } = image; dispatch(isModalOpenChanged(false));
const { image_name } = imageDTO;
const state = getState(); const state = getState();
const selectedImage = state.gallery.selectedImage; const lastSelectedImage =
state.gallery.selection[state.gallery.selection.length - 1];
if (selectedImage === image_name) { if (lastSelectedImage === image_name) {
const ids = selectImagesIds(state); const ids = selectImagesIds(state);
const deletedImageIndex = ids.findIndex( const deletedImageIndex = ids.findIndex(
@ -50,7 +56,7 @@ export const addRequestedImageDeletionListener = () => {
if (newSelectedImageId) { if (newSelectedImageId) {
dispatch(imageSelected(newSelectedImageId as string)); dispatch(imageSelected(newSelectedImageId as string));
} else { } else {
dispatch(imageSelected()); dispatch(imageSelected(null));
} }
} }
@ -88,7 +94,7 @@ export const addRequestedImageDeletionListener = () => {
if (wasImageDeleted) { if (wasImageDeleted) {
dispatch( dispatch(
api.util.invalidateTags([{ type: 'Board', id: image.board_id }]) api.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id }])
); );
} }
}, },

View File

@ -0,0 +1,188 @@
import { createAction } from '@reduxjs/toolkit';
import { startAppListening } from '../';
import { log } from 'app/logging/useLogger';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import {
imageAddedToBatch,
imagesAddedToBatch,
} from 'features/batch/store/batchSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {
fieldValueChanged,
imageCollectionFieldValueChanged,
} from 'features/nodes/store/nodesSlice';
import { boardsApi } from 'services/api/endpoints/boards';
import { boardImagesApi } from 'services/api/endpoints/boardImages';
const moduleLog = log.child({ namespace: 'dnd' });
export const imageDropped = createAction<{
overData: TypesafeDroppableData;
activeData: TypesafeDraggableData;
}>('dnd/imageDropped');
export const addImageDroppedListener = () => {
startAppListening({
actionCreator: imageDropped,
effect: (action, { dispatch, getState }) => {
const { activeData, overData } = action.payload;
const { actionType } = overData;
// set current image
if (
actionType === 'SET_CURRENT_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(imageSelected(activeData.payload.imageDTO.image_name));
}
// set initial image
if (
actionType === 'SET_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(initialImageChanged(activeData.payload.imageDTO));
}
// add image to batch
if (
actionType === 'ADD_TO_BATCH' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(imageAddedToBatch(activeData.payload.imageDTO.image_name));
}
// add multiple images to batch
if (
actionType === 'ADD_TO_BATCH' &&
activeData.payloadType === 'IMAGE_NAMES'
) {
dispatch(imagesAddedToBatch(activeData.payload.imageNames));
}
// set control image
if (
actionType === 'SET_CONTROLNET_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { controlNetId } = overData.context;
dispatch(
controlNetImageChanged({
controlImage: activeData.payload.imageDTO.image_name,
controlNetId,
})
);
}
// set canvas image
if (
actionType === 'SET_CANVAS_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(setInitialCanvasImage(activeData.payload.imageDTO));
}
// set nodes image
if (
actionType === 'SET_NODES_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { fieldName, nodeId } = overData.context;
dispatch(
fieldValueChanged({
nodeId,
fieldName,
value: activeData.payload.imageDTO,
})
);
}
// set multiple nodes images (single image handler)
if (
actionType === 'SET_MULTI_NODES_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { fieldName, nodeId } = overData.context;
dispatch(
fieldValueChanged({
nodeId,
fieldName,
value: [activeData.payload.imageDTO],
})
);
}
// set multiple nodes images (multiple images handler)
if (
actionType === 'SET_MULTI_NODES_IMAGE' &&
activeData.payloadType === 'IMAGE_NAMES'
) {
const { fieldName, nodeId } = overData.context;
dispatch(
imageCollectionFieldValueChanged({
nodeId,
fieldName,
value: activeData.payload.imageNames.map((image_name) => ({
image_name,
})),
})
);
}
// remove image from board
// TODO: remove board_id from `removeImageFromBoard()` endpoint
// TODO: handle multiple images
// if (
// actionType === 'MOVE_BOARD' &&
// activeData.payloadType === 'IMAGE_DTO' &&
// activeData.payload.imageDTO &&
// overData.boardId !== null
// ) {
// const { image_name } = activeData.payload.imageDTO;
// dispatch(
// boardImagesApi.endpoints.removeImageFromBoard.initiate({ image_name })
// );
// }
// add image to board
if (
actionType === 'MOVE_BOARD' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO &&
overData.context.boardId
) {
const { image_name } = activeData.payload.imageDTO;
const { boardId } = overData.context;
dispatch(
boardImagesApi.endpoints.addImageToBoard.initiate({
image_name,
board_id: boardId,
})
);
}
// add multiple images to board
// TODO: add endpoint
// if (
// actionType === 'ADD_TO_BATCH' &&
// activeData.payloadType === 'IMAGE_NAMES' &&
// activeData.payload.imageDTONames
// ) {
// dispatch(boardImagesApi.endpoints.addImagesToBoard.intiate({}));
// }
},
});
};

View File

@ -1,7 +1,7 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { imageMetadataReceived, imageUpdated } from 'services/api/thunks/image'; import { imageMetadataReceived, imageUpdated } from 'services/api/thunks/image';
import { imageUpserted } from 'features/gallery/store/imagesSlice'; import { imageUpserted } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });

View File

@ -0,0 +1,40 @@
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import {
imageDeletionConfirmed,
imageToDeleteSelected,
isModalOpenChanged,
selectImageUsage,
} from 'features/imageDeletion/store/imageDeletionSlice';
const moduleLog = log.child({ namespace: 'image' });
export const addImageToDeleteSelectedListener = () => {
startAppListening({
actionCreator: imageToDeleteSelected,
effect: async (action, { dispatch, getState, condition }) => {
const imageDTO = action.payload;
const state = getState();
const { shouldConfirmOnDelete } = state.system;
const imageUsage = selectImageUsage(getState());
if (!imageUsage) {
// should never happen
return;
}
const isImageInUse =
imageUsage.isCanvasImage ||
imageUsage.isInitialImage ||
imageUsage.isControlNetImage ||
imageUsage.isNodesImage;
if (shouldConfirmOnDelete || isImageInUse) {
dispatch(isModalOpenChanged(true));
return;
}
dispatch(imageDeletionConfirmed({ imageDTO, imageUsage }));
},
});
};

View File

@ -2,11 +2,12 @@ import { startAppListening } from '..';
import { imageUploaded } from 'services/api/thunks/image'; import { imageUploaded } from 'services/api/thunks/image';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { imageUpserted } from 'features/gallery/store/imagesSlice'; import { imageUpserted } from 'features/gallery/store/gallerySlice';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { imageAddedToBatch } from 'features/batch/store/batchSlice';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });
@ -70,6 +71,11 @@ export const addImageUploadedFulfilledListener = () => {
dispatch(addToast({ title: 'Image Uploaded', status: 'success' })); dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
return; return;
} }
if (postUploadAction?.type === 'ADD_TO_BATCH') {
dispatch(imageAddedToBatch(image.image_name));
return;
}
}, },
}); });
}; };

View File

@ -1,7 +1,7 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { imageUrlsReceived } from 'services/api/thunks/image'; import { imageUrlsReceived } from 'services/api/thunks/image';
import { imageUpdatedOne } from 'features/gallery/store/imagesSlice'; import { imageUpdatedOne } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });

View File

@ -4,7 +4,7 @@ import { addToast } from 'features/system/store/systemSlice';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { initialImageSelected } from 'features/parameters/store/actions'; import { initialImageSelected } from 'features/parameters/store/actions';
import { makeToast } from 'app/components/Toaster'; import { makeToast } from 'app/components/Toaster';
import { selectImagesById } from 'features/gallery/store/imagesSlice'; import { selectImagesById } from 'features/gallery/store/gallerySlice';
import { isImageDTO } from 'services/api/guards'; import { isImageDTO } from 'services/api/guards';
export const addInitialImageSelectedListener = () => { export const addInitialImageSelectedListener = () => {

View File

@ -2,6 +2,7 @@ import { log } from 'app/logging/useLogger';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { serializeError } from 'serialize-error'; import { serializeError } from 'serialize-error';
import { receivedPageOfImages } from 'services/api/thunks/image'; import { receivedPageOfImages } from 'services/api/thunks/image';
import { imagesApi } from 'services/api/endpoints/images';
const moduleLog = log.child({ namespace: 'gallery' }); const moduleLog = log.child({ namespace: 'gallery' });
@ -9,11 +10,17 @@ export const addReceivedPageOfImagesFulfilledListener = () => {
startAppListening({ startAppListening({
actionCreator: receivedPageOfImages.fulfilled, actionCreator: receivedPageOfImages.fulfilled,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
const page = action.payload; const { items } = action.payload;
moduleLog.debug( moduleLog.debug(
{ data: { payload: action.payload } }, { data: { payload: action.payload } },
`Received ${page.items.length} images` `Received ${items.length} images`
); );
items.forEach((image) => {
dispatch(
imagesApi.util.upsertQueryData('getImageDTO', image.image_name, image)
);
});
}, },
}); });
}; };

View File

@ -0,0 +1,19 @@
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import {
imagesAddedToBatch,
selectionAddedToBatch,
} from 'features/batch/store/batchSlice';
const moduleLog = log.child({ namespace: 'batch' });
export const addSelectionAddedToBatchListener = () => {
startAppListening({
actionCreator: selectionAddedToBatch,
effect: (action, { dispatch, getState }) => {
const { selection } = getState().gallery;
dispatch(imagesAddedToBatch(selection));
},
});
};

View File

@ -14,11 +14,11 @@ export const addSocketConnectedEventListener = () => {
moduleLog.debug({ timestamp }, 'Connected'); moduleLog.debug({ timestamp }, 'Connected');
const { nodes, config, images } = getState(); const { nodes, config, gallery } = getState();
const { disabledTabs } = config; const { disabledTabs } = config;
if (!images.ids.length) { if (!gallery.ids.length) {
dispatch( dispatch(
receivedPageOfImages({ receivedPageOfImages({
categories: ['general'], categories: ['general'],

View File

@ -2,7 +2,7 @@ import { stagingAreaImageSaved } from 'features/canvas/store/actions';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { imageUpdated } from 'services/api/thunks/image'; import { imageUpdated } from 'services/api/thunks/image';
import { imageUpserted } from 'features/gallery/store/imagesSlice'; import { imageUpserted } from 'features/gallery/store/gallerySlice';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
const moduleLog = log.child({ namespace: 'canvas' }); const moduleLog = log.child({ namespace: 'canvas' });

View File

@ -8,7 +8,7 @@ import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
import { forEach, uniqBy } from 'lodash-es'; import { forEach, uniqBy } from 'lodash-es';
import { imageUrlsReceived } from 'services/api/thunks/image'; import { imageUrlsReceived } from 'services/api/thunks/image';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { selectImagesEntities } from 'features/gallery/store/imagesSlice'; import { selectImagesEntities } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'images' }); const moduleLog = log.child({ namespace: 'images' });
@ -36,7 +36,7 @@ const selectAllUsedImages = createSelector(
nodes.nodes.forEach((node) => { nodes.nodes.forEach((node) => {
forEach(node.data.inputs, (input) => { forEach(node.data.inputs, (input) => {
if (input.type === 'image' && input.value) { if (input.type === 'image' && input.value) {
allUsedImages.push(input.value); allUsedImages.push(input.value.image_name);
} }
}); });
}); });

View File

@ -11,18 +11,18 @@ import { rememberEnhancer, rememberReducer } from 'redux-remember';
import canvasReducer from 'features/canvas/store/canvasSlice'; import canvasReducer from 'features/canvas/store/canvasSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice'; import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import galleryReducer from 'features/gallery/store/gallerySlice'; import galleryReducer from 'features/gallery/store/gallerySlice';
import imagesReducer from 'features/gallery/store/imagesSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice'; import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice'; import systemReducer from 'features/system/store/systemSlice';
// import sessionReducer from 'features/system/store/sessionSlice';
import nodesReducer from 'features/nodes/store/nodesSlice'; import nodesReducer from 'features/nodes/store/nodesSlice';
import boardsReducer from 'features/gallery/store/boardSlice'; import boardsReducer from 'features/gallery/store/boardSlice';
import configReducer from 'features/system/store/configSlice'; import configReducer from 'features/system/store/configSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice'; import uiReducer from 'features/ui/store/uiSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice'; import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
import batchReducer from 'features/batch/store/batchSlice';
import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
import { listenerMiddleware } from './middleware/listenerMiddleware'; import { listenerMiddleware } from './middleware/listenerMiddleware';
@ -45,11 +45,11 @@ const allReducers = {
config: configReducer, config: configReducer,
ui: uiReducer, ui: uiReducer,
hotkeys: hotkeysReducer, hotkeys: hotkeysReducer,
images: imagesReducer,
controlNet: controlNetReducer, controlNet: controlNetReducer,
boards: boardsReducer, boards: boardsReducer,
// session: sessionReducer,
dynamicPrompts: dynamicPromptsReducer, dynamicPrompts: dynamicPromptsReducer,
batch: batchReducer,
imageDeletion: imageDeletionReducer,
[api.reducerPath]: api.reducer, [api.reducerPath]: api.reducer,
}; };
@ -68,6 +68,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'ui', 'ui',
'controlNet', 'controlNet',
'dynamicPrompts', 'dynamicPrompts',
'batch',
// 'boards', // 'boards',
// 'hotkeys', // 'hotkeys',
// 'config', // 'config',

View File

@ -15,10 +15,25 @@ export interface IAIButtonProps extends ButtonProps {
} }
const IAIButton = forwardRef((props: IAIButtonProps, forwardedRef) => { const IAIButton = forwardRef((props: IAIButtonProps, forwardedRef) => {
const { children, tooltip = '', tooltipProps, isChecked, ...rest } = props; const {
children,
tooltip = '',
tooltipProps: { placement = 'top', hasArrow = true, ...tooltipProps } = {},
isChecked,
...rest
} = props;
return ( return (
<Tooltip label={tooltip} {...tooltipProps}> <Tooltip
<Button ref={forwardedRef} aria-checked={isChecked} {...rest}> label={tooltip}
placement={placement}
hasArrow={hasArrow}
{...tooltipProps}
>
<Button
ref={forwardedRef}
colorScheme={isChecked ? 'accent' : 'base'}
{...rest}
>
{children} {children}
</Button> </Button>
</Tooltip> </Tooltip>

View File

@ -1,19 +1,20 @@
import { import {
Box,
ChakraProps, ChakraProps,
Flex, Flex,
Icon, Icon,
IconButtonProps,
Image, Image,
useColorMode, useColorMode,
useColorModeValue,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useDraggable, useDroppable } from '@dnd-kit/core';
import { useCombinedRefs } from '@dnd-kit/utilities'; import { useCombinedRefs } from '@dnd-kit/utilities';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback'; import {
IAILoadingImageFallback,
IAINoContentFallback,
} from 'common/components/IAIImageFallback';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay'; import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { AnimatePresence } from 'framer-motion'; import { AnimatePresence } from 'framer-motion';
import { ReactElement, SyntheticEvent } from 'react'; import { MouseEvent, ReactElement, SyntheticEvent } from 'react';
import { memo, useRef } from 'react'; import { memo, useRef } from 'react';
import { FaImage, FaUndo, FaUpload } from 'react-icons/fa'; import { FaImage, FaUndo, FaUpload } from 'react-icons/fa';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
@ -22,81 +23,97 @@ import IAIDropOverlay from './IAIDropOverlay';
import { PostUploadAction } from 'services/api/thunks/image'; import { PostUploadAction } from 'services/api/thunks/image';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton'; import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
import {
TypesafeDraggableData,
TypesafeDroppableData,
isValidDrop,
useDraggable,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
type IAIDndImageProps = { type IAIDndImageProps = {
image: ImageDTO | null | undefined; imageDTO: ImageDTO | undefined;
onDrop: (droppedImage: ImageDTO) => void;
onReset?: () => void;
onError?: (event: SyntheticEvent<HTMLImageElement>) => void; onError?: (event: SyntheticEvent<HTMLImageElement>) => void;
onLoad?: (event: SyntheticEvent<HTMLImageElement>) => void; onLoad?: (event: SyntheticEvent<HTMLImageElement>) => void;
resetIconSize?: IconButtonProps['size']; onClick?: (event: MouseEvent<HTMLDivElement>) => void;
onClickReset?: (event: MouseEvent<HTMLButtonElement>) => void;
withResetIcon?: boolean; withResetIcon?: boolean;
resetIcon?: ReactElement;
resetTooltip?: string;
withMetadataOverlay?: boolean; withMetadataOverlay?: boolean;
isDragDisabled?: boolean; isDragDisabled?: boolean;
isDropDisabled?: boolean; isDropDisabled?: boolean;
isUploadDisabled?: boolean; isUploadDisabled?: boolean;
fallback?: ReactElement;
payloadImage?: ImageDTO | null | undefined;
minSize?: number; minSize?: number;
postUploadAction?: PostUploadAction; postUploadAction?: PostUploadAction;
imageSx?: ChakraProps['sx']; imageSx?: ChakraProps['sx'];
fitContainer?: boolean; fitContainer?: boolean;
droppableData?: TypesafeDroppableData;
draggableData?: TypesafeDraggableData;
dropLabel?: string;
isSelected?: boolean;
thumbnail?: boolean;
noContentFallback?: ReactElement;
}; };
const IAIDndImage = (props: IAIDndImageProps) => { const IAIDndImage = (props: IAIDndImageProps) => {
const { const {
image, imageDTO,
onDrop, onClickReset,
onReset,
onError, onError,
resetIconSize = 'md', onClick,
withResetIcon = false, withResetIcon = false,
withMetadataOverlay = false, withMetadataOverlay = false,
isDropDisabled = false, isDropDisabled = false,
isDragDisabled = false, isDragDisabled = false,
isUploadDisabled = false, isUploadDisabled = false,
fallback = <IAIImageLoadingFallback />,
payloadImage,
minSize = 24, minSize = 24,
postUploadAction, postUploadAction,
imageSx, imageSx,
fitContainer = false, fitContainer = false,
droppableData,
draggableData,
dropLabel,
isSelected = false,
thumbnail = false,
resetTooltip = 'Reset',
resetIcon = <FaUndo />,
noContentFallback = <IAINoContentFallback icon={FaImage} />,
} = props; } = props;
const dndId = useRef(uuidv4());
const { colorMode } = useColorMode(); const { colorMode } = useColorMode();
const { const dndId = useRef(uuidv4());
isOver,
setNodeRef: setDroppableRef,
active: isDropActive,
} = useDroppable({
id: dndId.current,
disabled: isDropDisabled,
data: {
handleDrop: onDrop,
},
});
const { const {
attributes, attributes,
listeners, listeners,
setNodeRef: setDraggableRef, setNodeRef: setDraggableRef,
isDragging, isDragging,
active,
} = useDraggable({ } = useDraggable({
id: dndId.current, id: dndId.current,
data: { disabled: isDragDisabled || !imageDTO,
image: payloadImage ? payloadImage : image, data: draggableData,
},
disabled: isDragDisabled || !image,
}); });
const { isOver, setNodeRef: setDroppableRef } = useDroppable({
id: dndId.current,
disabled: isDropDisabled,
data: droppableData,
});
const setDndRef = useCombinedRefs(setDroppableRef, setDraggableRef);
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({ const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
postUploadAction, postUploadAction,
isDisabled: isUploadDisabled, isDisabled: isUploadDisabled,
}); });
const setNodeRef = useCombinedRefs(setDroppableRef, setDraggableRef); const resetIconShadow = useColorModeValue(
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-600))`,
`drop-shadow(0px 0px 0.1rem var(--invokeai-colors-base-800))`
);
const uploadButtonStyles = isUploadDisabled const uploadButtonStyles = isUploadDisabled
? {} ? {}
@ -117,16 +134,16 @@ const IAIDndImage = (props: IAIDndImageProps) => {
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
position: 'relative', position: 'relative',
minW: minSize, minW: minSize ? minSize : undefined,
minH: minSize, minH: minSize ? minSize : undefined,
userSelect: 'none', userSelect: 'none',
cursor: isDragDisabled || !image ? 'auto' : 'grab', cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer',
}} }}
{...attributes} {...attributes}
{...listeners} {...listeners}
ref={setNodeRef} ref={setDndRef}
> >
{image && ( {imageDTO && (
<Flex <Flex
sx={{ sx={{
w: 'full', w: 'full',
@ -137,42 +154,50 @@ const IAIDndImage = (props: IAIDndImageProps) => {
}} }}
> >
<Image <Image
src={image.image_url} onClick={onClick}
fallback={fallback} src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
fallbackStrategy="beforeLoadOrError"
fallback={<IAILoadingImageFallback image={imageDTO} />}
onError={onError} onError={onError}
objectFit="contain"
draggable={false} draggable={false}
sx={{ sx={{
objectFit: 'contain',
maxW: 'full', maxW: 'full',
maxH: 'full', maxH: 'full',
borderRadius: 'base', borderRadius: 'base',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
...imageSx, ...imageSx,
}} }}
/> />
{withMetadataOverlay && <ImageMetadataOverlay image={image} />} {withMetadataOverlay && <ImageMetadataOverlay image={imageDTO} />}
{onReset && withResetIcon && ( {onClickReset && withResetIcon && (
<Box <IAIIconButton
onClick={onClickReset}
aria-label={resetTooltip}
tooltip={resetTooltip}
icon={resetIcon}
size="sm"
variant="link"
sx={{ sx={{
position: 'absolute', position: 'absolute',
top: 0, top: 1,
right: 0, insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: resetIconShadow,
},
}} }}
>
<IAIIconButton
size={resetIconSize}
tooltip="Reset Image"
aria-label="Reset Image"
icon={<FaUndo />}
onClick={onReset}
/> />
</Box>
)} )}
<AnimatePresence>
{isDropActive && <IAIDropOverlay isOver={isOver} />}
</AnimatePresence>
</Flex> </Flex>
)} )}
{!image && ( {!imageDTO && !isUploadDisabled && (
<> <>
<Flex <Flex
sx={{ sx={{
@ -191,17 +216,20 @@ const IAIDndImage = (props: IAIDndImageProps) => {
> >
<input {...getUploadInputProps()} /> <input {...getUploadInputProps()} />
<Icon <Icon
as={isUploadDisabled ? FaImage : FaUpload} as={FaUpload}
sx={{ sx={{
boxSize: 12, boxSize: 16,
}} }}
/> />
</Flex> </Flex>
<AnimatePresence>
{isDropActive && <IAIDropOverlay isOver={isOver} />}
</AnimatePresence>
</> </>
)} )}
{!imageDTO && isUploadDisabled && noContentFallback}
<AnimatePresence>
{isValidDrop(droppableData, active) && !isDragging && (
<IAIDropOverlay isOver={isOver} label={dropLabel} />
)}
</AnimatePresence>
</Flex> </Flex>
); );
}; };

View File

@ -62,7 +62,7 @@ export const IAIDropOverlay = (props: Props) => {
w: 'full', w: 'full',
h: 'full', h: 'full',
opacity: 1, opacity: 1,
borderWidth: 2, borderWidth: 3,
borderColor: isOver borderColor: isOver
? mode('base.50', 'base.200')(colorMode) ? mode('base.50', 'base.200')(colorMode)
: mode('base.100', 'base.500')(colorMode), : mode('base.100', 'base.500')(colorMode),
@ -78,10 +78,10 @@ export const IAIDropOverlay = (props: Props) => {
sx={{ sx={{
fontSize: '2xl', fontSize: '2xl',
fontWeight: 600, fontWeight: 600,
transform: isOver ? 'scale(1.1)' : 'scale(1)', transform: isOver ? 'scale(1.02)' : 'scale(1)',
color: isOver color: isOver
? mode('base.100', 'base.100')(colorMode) ? mode('base.50', 'base.50')(colorMode)
: mode('base.200', 'base.500')(colorMode), : mode('base.100', 'base.200')(colorMode),
transitionProperty: 'common', transitionProperty: 'common',
transitionDuration: '0.1s', transitionDuration: '0.1s',
}} }}

View File

@ -29,7 +29,7 @@ const IAIIconButton = forwardRef((props: IAIIconButtonProps, forwardedRef) => {
<IconButton <IconButton
ref={forwardedRef} ref={forwardedRef}
role={role} role={role}
aria-checked={isChecked !== undefined ? isChecked : undefined} colorScheme={isChecked ? 'accent' : 'base'}
{...rest} {...rest}
/> />
</Tooltip> </Tooltip>

View File

@ -1,73 +1,82 @@
import { import {
As, As,
ChakraProps,
Flex, Flex,
FlexProps,
Icon, Icon,
IconProps, Skeleton,
Spinner, Spinner,
SpinnerProps, StyleProps,
useColorMode, Text,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { FaImage } from 'react-icons/fa'; import { FaImage } from 'react-icons/fa';
import { mode } from 'theme/util/mode'; import { ImageDTO } from 'services/api/types';
type Props = FlexProps & { type Props = { image: ImageDTO | undefined };
spinnerProps?: SpinnerProps;
}; export const IAILoadingImageFallback = (props: Props) => {
if (props.image) {
return (
<Skeleton
sx={{
w: `${props.image.width}px`,
h: 'auto',
objectFit: 'contain',
aspectRatio: `${props.image.width}/${props.image.height}`,
}}
/>
);
}
export const IAIImageLoadingFallback = (props: Props) => {
const { spinnerProps, ...rest } = props;
const { sx, ...restFlexProps } = rest;
const { colorMode } = useColorMode();
return ( return (
<Flex <Flex
sx={{ sx={{
bg: mode('base.200', 'base.900')(colorMode),
opacity: 0.7, opacity: 0.7,
w: 'full', w: 'full',
h: 'full', h: 'full',
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
borderRadius: 'base', borderRadius: 'base',
...sx, bg: 'base.200',
_dark: {
bg: 'base.900',
},
}} }}
{...restFlexProps}
> >
<Spinner size="xl" {...spinnerProps} /> <Spinner size="xl" />
</Flex> </Flex>
); );
}; };
type IAINoImageFallbackProps = { type IAINoImageFallbackProps = {
flexProps?: FlexProps; label?: string;
iconProps?: IconProps; icon?: As;
as?: As; boxSize?: StyleProps['boxSize'];
sx?: ChakraProps['sx'];
}; };
export const IAINoImageFallback = (props: IAINoImageFallbackProps) => { export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
const { sx: flexSx, ...restFlexProps } = props.flexProps ?? { sx: {} }; const { icon = FaImage, boxSize = 16 } = props;
const { sx: iconSx, ...restIconProps } = props.iconProps ?? { sx: {} };
const { colorMode } = useColorMode();
return ( return (
<Flex <Flex
sx={{ sx={{
bg: mode('base.200', 'base.900')(colorMode),
opacity: 0.7,
w: 'full', w: 'full',
h: 'full', h: 'full',
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',
borderRadius: 'base', borderRadius: 'base',
...flexSx, flexDir: 'column',
gap: 2,
userSelect: 'none',
color: 'base.700',
_dark: {
color: 'base.500',
},
...props.sx,
}} }}
{...restFlexProps}
> >
<Icon <Icon as={icon} boxSize={boxSize} opacity={0.7} />
as={props.as ?? FaImage} {props.label && <Text textAlign="center">{props.label}</Text>}
sx={{ color: mode('base.700', 'base.500')(colorMode), ...iconSx }}
{...restIconProps}
/>
</Flex> </Flex>
); );
}; };

View File

@ -1,4 +1,5 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { validateSeedWeights } from 'common/util/seedWeightPairs'; import { validateSeedWeights } from 'common/util/seedWeightPairs';
@ -7,17 +8,26 @@ import { systemSelector } from 'features/system/store/systemSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
const readinessSelector = createSelector( const readinessSelector = createSelector(
[generationSelector, systemSelector, activeTabNameSelector], [stateSelector, activeTabNameSelector],
(generation, system, activeTabName) => { ({ generation, system, batch }, activeTabName) => {
const { shouldGenerateVariations, seedWeights, initialImage, seed } = const { shouldGenerateVariations, seedWeights, initialImage, seed } =
generation; generation;
const { isProcessing, isConnected } = system; const { isProcessing, isConnected } = system;
const {
isEnabled: isBatchEnabled,
asInitialImage,
imageNames: batchImageNames,
} = batch;
let isReady = true; let isReady = true;
const reasonsWhyNotReady: string[] = []; const reasonsWhyNotReady: string[] = [];
if (activeTabName === 'img2img' && !initialImage) { if (
activeTabName === 'img2img' &&
!initialImage &&
!(asInitialImage && batchImageNames.length > 1)
) {
isReady = false; isReady = false;
reasonsWhyNotReady.push('No initial image selected'); reasonsWhyNotReady.push('No initial image selected');
} }

View File

@ -0,0 +1,67 @@
import {
Flex,
FormControl,
FormLabel,
Heading,
Spacer,
Switch,
Text,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { ChangeEvent, memo, useCallback } from 'react';
import { controlNetToggled } from '../store/batchSlice';
type Props = {
controlNet: ControlNetConfig;
};
const selector = createSelector(
[stateSelector, (state, controlNetId: string) => controlNetId],
(state, controlNetId) => {
const isControlNetEnabled = state.batch.controlNets.includes(controlNetId);
return { isControlNetEnabled };
},
defaultSelectorOptions
);
const BatchControlNet = (props: Props) => {
const dispatch = useAppDispatch();
const { isControlNetEnabled } = useAppSelector((state) =>
selector(state, props.controlNet.controlNetId)
);
const { processorType, model } = props.controlNet;
const handleChangeAsControlNet = useCallback(() => {
dispatch(controlNetToggled(props.controlNet.controlNetId));
}, [dispatch, props.controlNet.controlNetId]);
return (
<Flex
layerStyle="second"
sx={{ flexDir: 'column', gap: 1, p: 4, borderRadius: 'base' }}
>
<Flex sx={{ justifyContent: 'space-between' }}>
<FormControl as={Flex} onClick={handleChangeAsControlNet}>
<FormLabel>
<Heading size="sm">ControlNet</Heading>
</FormLabel>
<Spacer />
<Switch isChecked={isControlNetEnabled} />
</FormControl>
</Flex>
<Text>
<strong>Model:</strong> {model}
</Text>
<Text>
<strong>Processor:</strong> {processorType}
</Text>
</Flex>
);
};
export default memo(BatchControlNet);

View File

@ -0,0 +1,115 @@
import { Box, Icon, Skeleton } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { FaExclamationCircle } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import {
batchImageRangeEndSelected,
batchImageSelected,
batchImageSelectionToggled,
imageRemovedFromBatch,
} from 'features/batch/store/batchSlice';
import IAIDndImage from 'common/components/IAIDndImage';
import { createSelector } from '@reduxjs/toolkit';
import { RootState, stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
const isSelectedSelector = createSelector(
[stateSelector, (state: RootState, imageName: string) => imageName],
(state, imageName) => ({
selection: state.batch.selection,
isSelected: state.batch.selection.includes(imageName),
}),
defaultSelectorOptions
);
type BatchImageProps = {
imageName: string;
};
const BatchImage = (props: BatchImageProps) => {
const {
currentData: imageDTO,
isFetching,
isError,
isSuccess,
} = useGetImageDTOQuery(props.imageName);
const dispatch = useAppDispatch();
const { isSelected, selection } = useAppSelector((state) =>
isSelectedSelector(state, props.imageName)
);
const handleClickRemove = useCallback(() => {
dispatch(imageRemovedFromBatch(props.imageName));
}, [dispatch, props.imageName]);
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (e.shiftKey) {
dispatch(batchImageRangeEndSelected(props.imageName));
} else if (e.ctrlKey || e.metaKey) {
dispatch(batchImageSelectionToggled(props.imageName));
} else {
dispatch(batchImageSelected(props.imageName));
}
},
[dispatch, props.imageName]
);
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (selection.length > 1) {
return {
id: 'batch',
payloadType: 'IMAGE_NAMES',
payload: {
imageNames: selection,
},
};
}
if (imageDTO) {
return {
id: 'batch',
payloadType: 'IMAGE_DTO',
payload: { imageDTO },
};
}
}, [imageDTO, selection]);
if (isError) {
return <Icon as={FaExclamationCircle} />;
}
if (isFetching) {
return (
<Skeleton>
<Box w="full" h="full" aspectRatio="1/1" />
</Skeleton>
);
}
return (
<Box sx={{ position: 'relative', aspectRatio: '1/1' }}>
<IAIDndImage
imageDTO={imageDTO}
draggableData={draggableData}
isDropDisabled={true}
isUploadDisabled={true}
imageSx={{
w: 'full',
h: 'full',
}}
onClick={handleClick}
isSelected={isSelected}
onClickReset={handleClickRemove}
resetTooltip="Remove from batch"
withResetIcon
thumbnail
/>
</Box>
);
};
export default memo(BatchImage);

View File

@ -0,0 +1,31 @@
import { Box } from '@chakra-ui/react';
import BatchImageGrid from './BatchImageGrid';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import {
AddToBatchDropData,
isValidDrop,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
const droppableData: AddToBatchDropData = {
id: 'batch',
actionType: 'ADD_TO_BATCH',
};
const BatchImageContainer = () => {
const { isOver, setNodeRef, active } = useDroppable({
id: 'batch-manager',
data: droppableData,
});
return (
<Box ref={setNodeRef} position="relative" w="full" h="full">
<BatchImageGrid />
{isValidDrop(droppableData, active) && (
<IAIDropOverlay isOver={isOver} label="Add to Batch" />
)}
</Box>
);
};
export default BatchImageContainer;

View File

@ -0,0 +1,54 @@
import { FaImages } from 'react-icons/fa';
import { Grid, GridItem } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import BatchImage from './BatchImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
const selector = createSelector(
stateSelector,
(state) => {
const imageNames = state.batch.imageNames.concat().reverse();
return { imageNames };
},
defaultSelectorOptions
);
const BatchImageGrid = () => {
const { imageNames } = useAppSelector(selector);
if (imageNames.length === 0) {
return (
<IAINoContentFallback
icon={FaImages}
boxSize={16}
label="No images in Batch"
/>
);
}
return (
<Grid
sx={{
position: 'absolute',
flexWrap: 'wrap',
w: 'full',
minH: 0,
maxH: 'full',
overflowY: 'scroll',
gridTemplateColumns: `repeat(auto-fill, minmax(128px, 1fr))`,
}}
>
{imageNames.map((imageName) => (
<GridItem key={imageName} sx={{ p: 1.5 }}>
<BatchImage imageName={imageName} />
</GridItem>
))}
</Grid>
);
};
export default BatchImageGrid;

View File

@ -0,0 +1,103 @@
import { Flex, Heading, Spacer } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useCallback } from 'react';
import IAISwitch from 'common/components/IAISwitch';
import {
asInitialImageToggled,
batchReset,
isEnabledChanged,
} from 'features/batch/store/batchSlice';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import BatchImageContainer from './BatchImageGrid';
import { map } from 'lodash-es';
import BatchControlNet from './BatchControlNet';
const selector = createSelector(
stateSelector,
(state) => {
const { controlNets } = state.controlNet;
const {
imageNames,
asInitialImage,
controlNets: batchControlNets,
isEnabled,
} = state.batch;
return {
imageCount: imageNames.length,
asInitialImage,
controlNets,
batchControlNets,
isEnabled,
};
},
defaultSelectorOptions
);
const BatchManager = () => {
const dispatch = useAppDispatch();
const { imageCount, isEnabled, controlNets, batchControlNets } =
useAppSelector(selector);
const handleResetBatch = useCallback(() => {
dispatch(batchReset());
}, [dispatch]);
const handleToggle = useCallback(() => {
dispatch(isEnabledChanged(!isEnabled));
}, [dispatch, isEnabled]);
const handleChangeAsInitialImage = useCallback(() => {
dispatch(asInitialImageToggled());
}, [dispatch]);
return (
<Flex
sx={{
h: 'full',
w: 'full',
flexDir: 'column',
position: 'relative',
gap: 2,
minW: 0,
}}
>
<Flex sx={{ alignItems: 'center' }}>
<Heading
size={'md'}
sx={{ color: 'base.800', _dark: { color: 'base.200' } }}
>
{imageCount || 'No'} images
</Heading>
<Spacer />
<IAIButton onClick={handleResetBatch}>Reset</IAIButton>
</Flex>
<Flex
sx={{
alignItems: 'center',
flexDir: 'column',
gap: 4,
}}
>
<IAISwitch
label="Use as Initial Image"
onChange={handleChangeAsInitialImage}
/>
{map(controlNets, (controlNet) => {
return (
<BatchControlNet
key={controlNet.controlNetId}
controlNet={controlNet}
/>
);
})}
</Flex>
<BatchImageContainer />
</Flex>
);
};
export default BatchManager;

View File

@ -0,0 +1,142 @@
import { PayloadAction, createAction, createSlice } from '@reduxjs/toolkit';
import { uniq } from 'lodash-es';
import { imageDeleted } from 'services/api/thunks/image';
type BatchState = {
isEnabled: boolean;
imageNames: string[];
asInitialImage: boolean;
controlNets: string[];
selection: string[];
};
export const initialBatchState: BatchState = {
isEnabled: false,
imageNames: [],
asInitialImage: false,
controlNets: [],
selection: [],
};
const batch = createSlice({
name: 'batch',
initialState: initialBatchState,
reducers: {
isEnabledChanged: (state, action: PayloadAction<boolean>) => {
state.isEnabled = action.payload;
},
imageAddedToBatch: (state, action: PayloadAction<string>) => {
state.imageNames = uniq(state.imageNames.concat(action.payload));
},
imagesAddedToBatch: (state, action: PayloadAction<string[]>) => {
state.imageNames = uniq(state.imageNames.concat(action.payload));
},
imageRemovedFromBatch: (state, action: PayloadAction<string>) => {
state.imageNames = state.imageNames.filter(
(imageName) => action.payload !== imageName
);
state.selection = state.selection.filter(
(imageName) => action.payload !== imageName
);
},
imagesRemovedFromBatch: (state, action: PayloadAction<string[]>) => {
state.imageNames = state.imageNames.filter(
(imageName) => !action.payload.includes(imageName)
);
state.selection = state.selection.filter(
(imageName) => !action.payload.includes(imageName)
);
},
batchImageRangeEndSelected: (state, action: PayloadAction<string>) => {
const rangeEndImageName = action.payload;
const lastSelectedImage = state.selection[state.selection.length - 1];
const lastClickedIndex = state.imageNames.findIndex(
(n) => n === lastSelectedImage
);
const currentClickedIndex = state.imageNames.findIndex(
(n) => n === rangeEndImageName
);
if (lastClickedIndex > -1 && currentClickedIndex > -1) {
// We have a valid range!
const start = Math.min(lastClickedIndex, currentClickedIndex);
const end = Math.max(lastClickedIndex, currentClickedIndex);
const imagesToSelect = state.imageNames.slice(start, end + 1);
state.selection = uniq(state.selection.concat(imagesToSelect));
}
},
batchImageSelectionToggled: (state, action: PayloadAction<string>) => {
if (
state.selection.includes(action.payload) &&
state.selection.length > 1
) {
state.selection = state.selection.filter(
(imageName) => imageName !== action.payload
);
} else {
state.selection = uniq(state.selection.concat(action.payload));
}
},
batchImageSelected: (state, action: PayloadAction<string | null>) => {
state.selection = action.payload
? [action.payload]
: [String(state.imageNames[0])];
},
batchReset: (state) => {
state.imageNames = [];
state.selection = [];
},
asInitialImageToggled: (state) => {
state.asInitialImage = !state.asInitialImage;
},
controlNetAddedToBatch: (state, action: PayloadAction<string>) => {
state.controlNets = uniq(state.controlNets.concat(action.payload));
},
controlNetRemovedFromBatch: (state, action: PayloadAction<string>) => {
state.controlNets = state.controlNets.filter(
(controlNetId) => controlNetId !== action.payload
);
},
controlNetToggled: (state, action: PayloadAction<string>) => {
if (state.controlNets.includes(action.payload)) {
state.controlNets = state.controlNets.filter(
(controlNetId) => controlNetId !== action.payload
);
} else {
state.controlNets = uniq(state.controlNets.concat(action.payload));
}
},
},
extraReducers: (builder) => {
builder.addCase(imageDeleted.fulfilled, (state, action) => {
state.imageNames = state.imageNames.filter(
(imageName) => imageName !== action.meta.arg.image_name
);
state.selection = state.selection.filter(
(imageName) => imageName !== action.meta.arg.image_name
);
});
},
});
export const {
isEnabledChanged,
imageAddedToBatch,
imagesAddedToBatch,
imageRemovedFromBatch,
imagesRemovedFromBatch,
asInitialImageToggled,
controlNetAddedToBatch,
controlNetRemovedFromBatch,
batchReset,
controlNetToggled,
batchImageRangeEndSelected,
batchImageSelectionToggled,
batchImageSelected,
} = batch.actions;
export default batch.reducer;
export const selectionAddedToBatch = createAction(
'batch/selectionAddedToBatch'
);

View File

@ -1,4 +1,4 @@
import { memo, useCallback, useState } from 'react'; import { memo, useCallback, useMemo, useState } from 'react';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { import {
ControlNetConfig, ControlNetConfig,
@ -10,11 +10,16 @@ import { Box, Flex, SystemStyleObject } from '@chakra-ui/react';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback'; import { IAILoadingImageFallback } from 'common/components/IAIImageFallback';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { FaUndo } from 'react-icons/fa'; import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
import { PostUploadAction } from 'services/api/thunks/image';
const selector = createSelector( const selector = createSelector(
controlNetSelector, controlNetSelector,
@ -57,22 +62,6 @@ const ControlNetImagePreview = (props: Props) => {
isSuccess: isSuccessProcessedControlImage, isSuccess: isSuccessProcessedControlImage,
} = useGetImageDTOQuery(processedControlImageName ?? skipToken); } = useGetImageDTOQuery(processedControlImageName ?? skipToken);
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
if (controlImageName === droppedImage.image_name) {
return;
}
setIsMouseOverImage(false);
dispatch(
controlNetImageChanged({
controlNetId,
controlImage: droppedImage.image_name,
})
);
},
[controlImageName, controlNetId, dispatch]
);
const handleResetControlImage = useCallback(() => { const handleResetControlImage = useCallback(() => {
dispatch(controlNetImageChanged({ controlNetId, controlImage: null })); dispatch(controlNetImageChanged({ controlNetId, controlImage: null }));
}, [controlNetId, dispatch]); }, [controlNetId, dispatch]);
@ -84,6 +73,31 @@ const ControlNetImagePreview = (props: Props) => {
setIsMouseOverImage(false); setIsMouseOverImage(false);
}, []); }, []);
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (controlImage) {
return {
id: controlNetId,
payloadType: 'IMAGE_DTO',
payload: { imageDTO: controlImage },
};
}
}, [controlImage, controlNetId]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(() => {
if (controlNetId) {
return {
id: controlNetId,
actionType: 'SET_CONTROLNET_IMAGE',
context: { controlNetId },
};
}
}, [controlNetId]);
const postUploadAction = useMemo<PostUploadAction>(
() => ({ type: 'SET_CONTROLNET_IMAGE', controlNetId }),
[controlNetId]
);
const shouldShowProcessedImage = const shouldShowProcessedImage =
controlImage && controlImage &&
processedControlImage && processedControlImage &&
@ -104,14 +118,14 @@ const ControlNetImagePreview = (props: Props) => {
}} }}
> >
<IAIDndImage <IAIDndImage
image={controlImage} draggableData={draggableData}
onDrop={handleDrop} droppableData={droppableData}
imageDTO={controlImage}
isDropDisabled={shouldShowProcessedImage} isDropDisabled={shouldShowProcessedImage}
postUploadAction={{ type: 'SET_CONTROLNET_IMAGE', controlNetId }} onClickReset={handleResetControlImage}
imageSx={{ postUploadAction={postUploadAction}
w: 'full', resetTooltip="Reset Control Image"
h: 'full', withResetIcon={Boolean(controlImage)}
}}
/> />
<Box <Box
sx={{ sx={{
@ -127,14 +141,13 @@ const ControlNetImagePreview = (props: Props) => {
}} }}
> >
<IAIDndImage <IAIDndImage
image={processedControlImage} draggableData={draggableData}
onDrop={handleDrop} droppableData={droppableData}
payloadImage={controlImage} imageDTO={processedControlImage}
isUploadDisabled={true} isUploadDisabled={true}
imageSx={{ onClickReset={handleResetControlImage}
w: 'full', resetTooltip="Reset Control Image"
h: 'full', withResetIcon={Boolean(controlImage)}
}}
/> />
</Box> </Box>
{pendingControlImages.includes(controlNetId) && ( {pendingControlImages.includes(controlNetId) && (
@ -145,27 +158,12 @@ const ControlNetImagePreview = (props: Props) => {
insetInlineStart: 0, insetInlineStart: 0,
w: 'full', w: 'full',
h: 'full', h: 'full',
objectFit: 'contain',
}} }}
> >
<IAIImageLoadingFallback /> <IAILoadingImageFallback image={controlImage} />
</Box> </Box>
)} )}
{controlImage && (
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0 }}>
<IAIIconButton
aria-label="Reset Control Image"
tooltip="Reset Control Image"
size="sm"
onClick={handleResetControlImage}
icon={<FaUndo />}
variant="link"
sx={{
p: 2,
color: 'base.50',
}}
/>
</Flex>
)}
</Flex> </Flex>
); );
}; };

View File

@ -1,16 +1,16 @@
import { Flex, Text, useColorMode } from '@chakra-ui/react'; import { Flex, useColorMode } from '@chakra-ui/react';
import { FaImages } from 'react-icons/fa'; import { FaImages } from 'react-icons/fa';
import { boardIdSelected } from '../../store/boardSlice'; import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import { useDispatch } from 'react-redux'; import { useDispatch } from 'react-redux';
import { IAINoImageFallback } from 'common/components/IAIImageFallback'; import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { AnimatePresence } from 'framer-motion'; import { AnimatePresence } from 'framer-motion';
import { SelectedItemOverlay } from '../SelectedItemOverlay';
import { useCallback } from 'react';
import { ImageDTO } from 'services/api/types';
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
import { useDroppable } from '@dnd-kit/core';
import IAIDropOverlay from 'common/components/IAIDropOverlay'; import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
import {
MoveBoardDropData,
isValidDrop,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => { const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
const dispatch = useDispatch(); const dispatch = useDispatch();
@ -20,31 +20,15 @@ const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
dispatch(boardIdSelected()); dispatch(boardIdSelected());
}; };
const [removeImageFromBoard, { isLoading }] = const droppableData: MoveBoardDropData = {
useRemoveImageFromBoardMutation(); id: 'all-images-board',
actionType: 'MOVE_BOARD',
context: { boardId: null },
};
const handleDrop = useCallback( const { isOver, setNodeRef, active } = useDroppable({
(droppedImage: ImageDTO) => {
if (!droppedImage.board_id) {
return;
}
removeImageFromBoard({
board_id: droppedImage.board_id,
image_name: droppedImage.image_name,
});
},
[removeImageFromBoard]
);
const {
isOver,
setNodeRef,
active: isDropActive,
} = useDroppable({
id: `board_droppable_all_images`, id: `board_droppable_all_images`,
data: { data: droppableData,
handleDrop,
},
}); });
return ( return (
@ -58,10 +42,10 @@ const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
h: 'full', h: 'full',
borderRadius: 'base', borderRadius: 'base',
}} }}
onClick={handleAllImagesBoardClick}
> >
<Flex <Flex
ref={setNodeRef} ref={setNodeRef}
onClick={handleAllImagesBoardClick}
sx={{ sx={{
position: 'relative', position: 'relative',
justifyContent: 'center', justifyContent: 'center',
@ -69,18 +53,30 @@ const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
borderRadius: 'base', borderRadius: 'base',
w: 'full', w: 'full',
aspectRatio: '1/1', aspectRatio: '1/1',
overflow: 'hidden',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
flexShrink: 0,
}} }}
> >
<IAINoImageFallback iconProps={{ boxSize: 8 }} as={FaImages} /> <IAINoContentFallback
boxSize={8}
icon={FaImages}
sx={{
border: '2px solid var(--invokeai-colors-base-200)',
_dark: { border: '2px solid var(--invokeai-colors-base-800)' },
}}
/>
<AnimatePresence> <AnimatePresence>
{isSelected && <SelectedItemOverlay />} {isValidDrop(droppableData, active) && (
</AnimatePresence> <IAIDropOverlay isOver={isOver} />
<AnimatePresence> )}
{isDropActive && <IAIDropOverlay isOver={isOver} />}
</AnimatePresence> </AnimatePresence>
</Flex> </Flex>
<Text <Flex
sx={{ sx={{
h: 'full',
alignItems: 'center',
color: isSelected color: isSelected
? mode('base.900', 'base.50')(colorMode) ? mode('base.900', 'base.50')(colorMode)
: mode('base.700', 'base.200')(colorMode), : mode('base.700', 'base.200')(colorMode),
@ -89,7 +85,7 @@ const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
}} }}
> >
All Images All Images
</Text> </Flex>
</Flex> </Flex>
); );
}; };

View File

@ -2,6 +2,7 @@ import {
Collapse, Collapse,
Flex, Flex,
Grid, Grid,
GridItem,
IconButton, IconButton,
Input, Input,
InputGroup, InputGroup,
@ -10,10 +11,7 @@ import {
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { import { setBoardSearchText } from 'features/gallery/store/boardSlice';
boardsSelector,
setBoardSearchText,
} from 'features/gallery/store/boardSlice';
import { memo, useState } from 'react'; import { memo, useState } from 'react';
import HoverableBoard from './HoverableBoard'; import HoverableBoard from './HoverableBoard';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
@ -21,11 +19,13 @@ import AddBoardButton from './AddBoardButton';
import AllImagesBoard from './AllImagesBoard'; import AllImagesBoard from './AllImagesBoard';
import { CloseIcon } from '@chakra-ui/icons'; import { CloseIcon } from '@chakra-ui/icons';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { stateSelector } from 'app/store/store';
const selector = createSelector( const selector = createSelector(
[boardsSelector], [stateSelector],
(boardsState) => { ({ boards, gallery }) => {
const { selectedBoardId, searchText } = boardsState; const { searchText } = boards;
const { selectedBoardId } = gallery;
return { selectedBoardId, searchText }; return { selectedBoardId, searchText };
}, },
defaultSelectorOptions defaultSelectorOptions
@ -109,20 +109,24 @@ const BoardsList = (props: Props) => {
<Grid <Grid
className="list-container" className="list-container"
sx={{ sx={{
gap: 2, gridTemplateRows: '6.5rem 6.5rem',
gridTemplateRows: '5.5rem 5.5rem',
gridAutoFlow: 'column dense', gridAutoFlow: 'column dense',
gridAutoColumns: '4rem', gridAutoColumns: '5rem',
}} }}
> >
{!searchMode && <AllImagesBoard isSelected={!selectedBoardId} />} {!searchMode && (
<GridItem sx={{ p: 1.5 }}>
<AllImagesBoard isSelected={!selectedBoardId} />
</GridItem>
)}
{filteredBoards && {filteredBoards &&
filteredBoards.map((board) => ( filteredBoards.map((board) => (
<GridItem key={board.board_id} sx={{ p: 1.5 }}>
<HoverableBoard <HoverableBoard
key={board.board_id}
board={board} board={board}
isSelected={selectedBoardId === board.board_id} isSelected={selectedBoardId === board.board_id}
/> />
</GridItem>
))} ))}
</Grid> </Grid>
</OverlayScrollbarsComponent> </OverlayScrollbarsComponent>

View File

@ -15,10 +15,9 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { memo, useCallback, useContext } from 'react'; import { memo, useCallback, useContext } from 'react';
import { FaFolder, FaTrash } from 'react-icons/fa'; import { FaFolder, FaTrash } from 'react-icons/fa';
import { ContextMenu } from 'chakra-ui-contextmenu'; import { ContextMenu } from 'chakra-ui-contextmenu';
import { BoardDTO, ImageDTO } from 'services/api/types'; import { BoardDTO } from 'services/api/types';
import { IAINoImageFallback } from 'common/components/IAIImageFallback'; import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { boardIdSelected } from 'features/gallery/store/boardSlice'; import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import { useAddImageToBoardMutation } from 'services/api/endpoints/boardImages';
import { import {
useDeleteBoardMutation, useDeleteBoardMutation,
useUpdateBoardMutation, useUpdateBoardMutation,
@ -26,12 +25,15 @@ import {
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useDroppable } from '@dnd-kit/core';
import { AnimatePresence } from 'framer-motion'; import { AnimatePresence } from 'framer-motion';
import IAIDropOverlay from 'common/components/IAIDropOverlay'; import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { SelectedItemOverlay } from '../SelectedItemOverlay';
import { DeleteBoardImagesContext } from '../../../../app/contexts/DeleteBoardImagesContext'; import { DeleteBoardImagesContext } from '../../../../app/contexts/DeleteBoardImagesContext';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
import {
MoveBoardDropData,
isValidDrop,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
interface HoverableBoardProps { interface HoverableBoardProps {
board: BoardDTO; board: BoardDTO;
@ -61,9 +63,6 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
const [deleteBoard, { isLoading: isDeleteBoardLoading }] = const [deleteBoard, { isLoading: isDeleteBoardLoading }] =
useDeleteBoardMutation(); useDeleteBoardMutation();
const [addImageToBoard, { isLoading: isAddImageToBoardLoading }] =
useAddImageToBoardMutation();
const handleUpdateBoardName = (newBoardName: string) => { const handleUpdateBoardName = (newBoardName: string) => {
updateBoard({ board_id, changes: { board_name: newBoardName } }); updateBoard({ board_id, changes: { board_name: newBoardName } });
}; };
@ -77,29 +76,19 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
onClickDeleteBoardImages(board); onClickDeleteBoardImages(board);
}, [board, onClickDeleteBoardImages]); }, [board, onClickDeleteBoardImages]);
const handleDrop = useCallback( const droppableData: MoveBoardDropData = {
(droppedImage: ImageDTO) => { id: board_id,
if (droppedImage.board_id === board_id) { actionType: 'MOVE_BOARD',
return; context: { boardId: board_id },
} };
addImageToBoard({ board_id, image_name: droppedImage.image_name });
},
[addImageToBoard, board_id]
);
const { const { isOver, setNodeRef, active } = useDroppable({
isOver,
setNodeRef,
active: isDropActive,
} = useDroppable({
id: `board_droppable_${board_id}`, id: `board_droppable_${board_id}`,
data: { data: droppableData,
handleDrop,
},
}); });
return ( return (
<Box sx={{ touchAction: 'none' }}> <Box sx={{ touchAction: 'none', height: 'full' }}>
<ContextMenu<HTMLDivElement> <ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }} menuProps={{ size: 'sm', isLazy: true }}
renderMenu={() => ( renderMenu={() => (
@ -148,13 +137,25 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
w: 'full', w: 'full',
aspectRatio: '1/1', aspectRatio: '1/1',
overflow: 'hidden', overflow: 'hidden',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
flexShrink: 0,
}} }}
> >
{board.cover_image_name && coverImage?.image_url && ( {board.cover_image_name && coverImage?.image_url && (
<Image src={coverImage?.image_url} draggable={false} /> <Image src={coverImage?.image_url} draggable={false} />
)} )}
{!(board.cover_image_name && coverImage?.image_url) && ( {!(board.cover_image_name && coverImage?.image_url) && (
<IAINoImageFallback iconProps={{ boxSize: 8 }} as={FaFolder} /> <IAINoContentFallback
boxSize={8}
icon={FaFolder}
sx={{
border: '2px solid var(--invokeai-colors-base-200)',
_dark: {
border: '2px solid var(--invokeai-colors-base-800)',
},
}}
/>
)} )}
<Flex <Flex
sx={{ sx={{
@ -167,14 +168,20 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
<Badge variant="solid">{board.image_count}</Badge> <Badge variant="solid">{board.image_count}</Badge>
</Flex> </Flex>
<AnimatePresence> <AnimatePresence>
{isSelected && <SelectedItemOverlay />} {isValidDrop(droppableData, active) && (
</AnimatePresence> <IAIDropOverlay isOver={isOver} />
<AnimatePresence> )}
{isDropActive && <IAIDropOverlay isOver={isOver} />}
</AnimatePresence> </AnimatePresence>
</Flex> </Flex>
<Box sx={{ width: 'full' }}> <Flex
sx={{
width: 'full',
height: 'full',
justifyContent: 'center',
alignItems: 'center',
}}
>
<Editable <Editable
defaultValue={board_name} defaultValue={board_name}
submitOnBlur={false} submitOnBlur={false}
@ -204,7 +211,7 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
}} }}
/> />
</Editable> </Editable>
</Box> </Flex>
</Flex> </Flex>
)} )}
</ContextMenu> </ContextMenu>

View File

@ -38,8 +38,7 @@ import {
FaShare, FaShare,
FaShareAlt, FaShareAlt,
} from 'react-icons/fa'; } from 'react-icons/fa';
import { gallerySelector } from '../store/gallerySelectors'; import { useCallback } from 'react';
import { useCallback, useContext } from 'react';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
@ -49,22 +48,15 @@ import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceR
import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings'; import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { DeleteImageContext } from 'app/contexts/DeleteImageContext'; import { stateSelector } from 'app/store/store';
import { DeleteImageButton } from './DeleteImageModal'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { selectImagesById } from '../store/imagesSlice'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { RootState } from 'app/store/store'; import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
import { DeleteImageButton } from 'features/imageDeletion/components/DeleteImageButton';
const currentImageButtonsSelector = createSelector( const currentImageButtonsSelector = createSelector(
[ [stateSelector, activeTabNameSelector],
(state: RootState) => state, ({ gallery, system, postprocessing, ui, lightbox }, activeTabName) => {
systemSelector,
gallerySelector,
postprocessingSelector,
uiSelector,
lightboxSelector,
activeTabNameSelector,
],
(state, system, gallery, postprocessing, ui, lightbox, activeTabName) => {
const { const {
isProcessing, isProcessing,
isConnected, isConnected,
@ -84,9 +76,7 @@ const currentImageButtonsSelector = createSelector(
shouldShowProgressInViewer, shouldShowProgressInViewer,
} = ui; } = ui;
const imageDTO = selectImagesById(state, gallery.selectedImage ?? ''); const lastSelectedImage = gallery.selection[gallery.selection.length - 1];
const { selectedImage } = gallery;
return { return {
canDeleteImage: isConnected && !isProcessing, canDeleteImage: isConnected && !isProcessing,
@ -97,16 +87,13 @@ const currentImageButtonsSelector = createSelector(
isESRGANAvailable, isESRGANAvailable,
upscalingLevel, upscalingLevel,
facetoolStrength, facetoolStrength,
shouldDisableToolbarButtons: Boolean(progressImage) || !selectedImage, shouldDisableToolbarButtons: Boolean(progressImage) || !lastSelectedImage,
shouldShowImageDetails, shouldShowImageDetails,
activeTabName, activeTabName,
isLightboxOpen, isLightboxOpen,
shouldHidePreview, shouldHidePreview,
image: imageDTO,
seed: imageDTO?.metadata?.seed,
prompt: imageDTO?.metadata?.positive_conditioning,
negativePrompt: imageDTO?.metadata?.negative_conditioning,
shouldShowProgressInViewer, shouldShowProgressInViewer,
lastSelectedImage,
}; };
}, },
{ {
@ -132,7 +119,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
isLightboxOpen, isLightboxOpen,
activeTabName, activeTabName,
shouldHidePreview, shouldHidePreview,
image, lastSelectedImage,
shouldShowProgressInViewer, shouldShowProgressInViewer,
} = useAppSelector(currentImageButtonsSelector); } = useAppSelector(currentImageButtonsSelector);
@ -147,7 +134,9 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const { recallBothPrompts, recallSeed, recallAllParameters } = const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters(); useRecallParameters();
const { onDelete } = useContext(DeleteImageContext); const { currentData: image } = useGetImageDTOQuery(
lastSelectedImage ?? skipToken
);
// const handleCopyImage = useCallback(async () => { // const handleCopyImage = useCallback(async () => {
// if (!image?.url) { // if (!image?.url) {
@ -248,8 +237,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}, []); }, []);
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {
onDelete(image); if (!image) {
}, [image, onDelete]); return;
}
dispatch(imageToDeleteSelected(image));
}, [dispatch, image]);
useHotkeys( useHotkeys(
'Shift+U', 'Shift+U',
@ -371,7 +363,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}} }}
{...props} {...props}
> >
<ButtonGroup isAttached={true}> <ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<IAIPopover <IAIPopover
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
@ -444,11 +436,12 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
} }
isChecked={isLightboxOpen} isChecked={isLightboxOpen}
onClick={handleLightBox} onClick={handleLightBox}
isDisabled={shouldDisableToolbarButtons}
/> />
)} )}
</ButtonGroup> </ButtonGroup>
<ButtonGroup isAttached={true}> <ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<IAIIconButton <IAIIconButton
icon={<FaQuoteRight />} icon={<FaQuoteRight />}
tooltip={`${t('parameters.usePrompt')} (P)`} tooltip={`${t('parameters.usePrompt')} (P)`}
@ -478,7 +471,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
</ButtonGroup> </ButtonGroup>
{(isUpscalingEnabled || isFaceRestoreEnabled) && ( {(isUpscalingEnabled || isFaceRestoreEnabled) && (
<ButtonGroup isAttached={true}> <ButtonGroup
isAttached={true}
isDisabled={shouldDisableToolbarButtons}
>
{isFaceRestoreEnabled && ( {isFaceRestoreEnabled && (
<IAIPopover <IAIPopover
triggerComponent={ triggerComponent={
@ -543,7 +539,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
</ButtonGroup> </ButtonGroup>
)} )}
<ButtonGroup isAttached={true}> <ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<IAIIconButton <IAIIconButton
icon={<FaCode />} icon={<FaCode />}
tooltip={`${t('parameters.info')} (I)`} tooltip={`${t('parameters.info')} (I)`}
@ -553,7 +549,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
/> />
</ButtonGroup> </ButtonGroup>
<ButtonGroup isAttached={true}> <ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<IAIIconButton <IAIIconButton
aria-label={t('settings.displayInProgress')} aria-label={t('settings.displayInProgress')}
tooltip={t('settings.displayInProgress')} tooltip={t('settings.displayInProgress')}
@ -564,7 +560,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
</ButtonGroup> </ButtonGroup>
<ButtonGroup isAttached={true}> <ButtonGroup isAttached={true}>
<DeleteImageButton onClick={handleDelete} /> <DeleteImageButton
onClick={handleDelete}
isDisabled={shouldDisableToolbarButtons}
/>
</ButtonGroup> </ButtonGroup>
</Flex> </Flex>
</> </>

View File

@ -1,29 +1,9 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import { gallerySelector } from '../store/gallerySelectors';
import CurrentImageButtons from './CurrentImageButtons'; import CurrentImageButtons from './CurrentImageButtons';
import CurrentImagePreview from './CurrentImagePreview'; import CurrentImagePreview from './CurrentImagePreview';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const currentImageDisplaySelector = createSelector(
[systemSelector, gallerySelector],
(system, gallery) => {
const { progressImage } = system;
return {
hasSelectedImage: Boolean(gallery.selectedImage),
hasProgressImage: Boolean(progressImage),
};
},
defaultSelectorOptions
);
const CurrentImageDisplay = () => { const CurrentImageDisplay = () => {
const { hasSelectedImage } = useAppSelector(currentImageDisplaySelector);
return ( return (
<Flex <Flex
sx={{ sx={{
@ -36,7 +16,7 @@ const CurrentImageDisplay = () => {
justifyContent: 'center', justifyContent: 'center',
}} }}
> >
{hasSelectedImage && <CurrentImageButtons />} <CurrentImageButtons />
<CurrentImagePreview /> <CurrentImagePreview />
</Flex> </Flex>
); );

View File

@ -1,35 +1,33 @@
import { Box, Flex, Image } from '@chakra-ui/react'; import { Box, Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { gallerySelector } from '../store/gallerySelectors';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer'; import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons'; import NextPrevImageButtons from './NextPrevImageButtons';
import { memo, useCallback } from 'react'; import { memo, useMemo } from 'react';
import { systemSelector } from 'features/system/store/systemSelectors';
import { imageSelected } from '../store/gallerySlice';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import { ImageDTO } from 'services/api/types';
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { stateSelector } from 'app/store/store';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySlice';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd';
export const imagesSelector = createSelector( export const imagesSelector = createSelector(
[uiSelector, gallerySelector, systemSelector], [stateSelector, selectLastSelectedImage],
(ui, gallery, system) => { ({ ui, system }, lastSelectedImage) => {
const { const {
shouldShowImageDetails, shouldShowImageDetails,
shouldHidePreview, shouldHidePreview,
shouldShowProgressInViewer, shouldShowProgressInViewer,
} = ui; } = ui;
const { selectedImage } = gallery;
const { progressImage, shouldAntialiasProgressImage } = system; const { progressImage, shouldAntialiasProgressImage } = system;
return { return {
shouldShowImageDetails, shouldShowImageDetails,
shouldHidePreview, shouldHidePreview,
selectedImage, imageName: lastSelectedImage,
progressImage, progressImage,
shouldShowProgressInViewer, shouldShowProgressInViewer,
shouldAntialiasProgressImage, shouldAntialiasProgressImage,
@ -45,29 +43,35 @@ export const imagesSelector = createSelector(
const CurrentImagePreview = () => { const CurrentImagePreview = () => {
const { const {
shouldShowImageDetails, shouldShowImageDetails,
selectedImage, imageName,
progressImage, progressImage,
shouldShowProgressInViewer, shouldShowProgressInViewer,
shouldAntialiasProgressImage, shouldAntialiasProgressImage,
} = useAppSelector(imagesSelector); } = useAppSelector(imagesSelector);
const { const {
currentData: image, currentData: imageDTO,
isLoading, isLoading,
isError, isError,
isSuccess, isSuccess,
} = useGetImageDTOQuery(selectedImage ?? skipToken); } = useGetImageDTOQuery(imageName ?? skipToken);
const dispatch = useAppDispatch(); const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (imageDTO) {
const handleDrop = useCallback( return {
(droppedImage: ImageDTO) => { id: 'current-image',
if (droppedImage.image_name === image?.image_name) { payloadType: 'IMAGE_DTO',
return; payload: { imageDTO },
};
} }
dispatch(imageSelected(droppedImage.image_name)); }, [imageDTO]);
},
[dispatch, image?.image_name] const droppableData = useMemo<TypesafeDroppableData | undefined>(
() => ({
id: 'current-image',
actionType: 'SET_CURRENT_IMAGE',
}),
[]
); );
return ( return (
@ -98,14 +102,15 @@ const CurrentImagePreview = () => {
/> />
) : ( ) : (
<IAIDndImage <IAIDndImage
image={image} imageDTO={imageDTO}
onDrop={handleDrop} droppableData={droppableData}
fallback={<IAIImageLoadingFallback sx={{ bg: 'none' }} />} draggableData={draggableData}
isUploadDisabled={true} isUploadDisabled={true}
fitContainer fitContainer
dropLabel="Set as Current Image"
/> />
)} )}
{shouldShowImageDetails && image && ( {shouldShowImageDetails && imageDTO && (
<Box <Box
sx={{ sx={{
position: 'absolute', position: 'absolute',
@ -116,10 +121,10 @@ const CurrentImagePreview = () => {
overflow: 'scroll', overflow: 'scroll',
}} }}
> >
<ImageMetadataViewer image={image} /> <ImageMetadataViewer image={imageDTO} />
</Box> </Box>
)} )}
{!shouldShowImageDetails && image && ( {!shouldShowImageDetails && imageDTO && (
<Box <Box
sx={{ sx={{
position: 'absolute', position: 'absolute',

View File

@ -1,166 +0,0 @@
import {
AlertDialog,
AlertDialogBody,
AlertDialogContent,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogOverlay,
Divider,
Flex,
ListItem,
Text,
UnorderedList,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import {
DeleteImageContext,
ImageUsage,
} from 'app/contexts/DeleteImageContext';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { configSelector } from 'features/system/store/configSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import { setShouldConfirmOnDelete } from 'features/system/store/systemSlice';
import { some } from 'lodash-es';
import { ChangeEvent, memo, useCallback, useContext, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
const selector = createSelector(
[systemSelector, configSelector],
(system, config) => {
const { shouldConfirmOnDelete } = system;
const { canRestoreDeletedImagesFromBin } = config;
return {
shouldConfirmOnDelete,
canRestoreDeletedImagesFromBin,
};
},
defaultSelectorOptions
);
const ImageInUseMessage = (props: { imageUsage?: ImageUsage }) => {
const { imageUsage } = props;
if (!imageUsage) {
return null;
}
if (!some(imageUsage)) {
return null;
}
return (
<>
<Text>This image is currently in use in the following features:</Text>
<UnorderedList sx={{ paddingInlineStart: 6 }}>
{imageUsage.isInitialImage && <ListItem>Image to Image</ListItem>}
{imageUsage.isCanvasImage && <ListItem>Unified Canvas</ListItem>}
{imageUsage.isControlNetImage && <ListItem>ControlNet</ListItem>}
{imageUsage.isNodesImage && <ListItem>Node Editor</ListItem>}
</UnorderedList>
<Text>
If you delete this image, those features will immediately be reset.
</Text>
</>
);
};
const DeleteImageModal = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { isOpen, onClose, onImmediatelyDelete, image, imageUsage } =
useContext(DeleteImageContext);
const { shouldConfirmOnDelete, canRestoreDeletedImagesFromBin } =
useAppSelector(selector);
const handleChangeShouldConfirmOnDelete = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldConfirmOnDelete(!e.target.checked)),
[dispatch]
);
const cancelRef = useRef<HTMLButtonElement>(null);
return (
<AlertDialog
isOpen={isOpen}
leastDestructiveRef={cancelRef}
onClose={onClose}
isCentered
>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
{t('gallery.deleteImage')}
</AlertDialogHeader>
<AlertDialogBody>
<Flex direction="column" gap={3}>
<ImageInUseMessage imageUsage={imageUsage} />
<Divider />
<Text>
{canRestoreDeletedImagesFromBin
? t('gallery.deleteImageBin')
: t('gallery.deleteImagePermanent')}
</Text>
<Text>{t('common.areYouSure')}</Text>
<IAISwitch
label={t('common.dontAskMeAgain')}
isChecked={!shouldConfirmOnDelete}
onChange={handleChangeShouldConfirmOnDelete}
/>
</Flex>
</AlertDialogBody>
<AlertDialogFooter>
<IAIButton ref={cancelRef} onClick={onClose}>
Cancel
</IAIButton>
<IAIButton colorScheme="error" onClick={onImmediatelyDelete} ml={3}>
Delete
</IAIButton>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
);
};
export default memo(DeleteImageModal);
const deleteImageButtonsSelector = createSelector(
[systemSelector],
(system) => {
const { isProcessing, isConnected } = system;
return isConnected && !isProcessing;
}
);
type DeleteImageButtonProps = {
onClick: () => void;
};
export const DeleteImageButton = (props: DeleteImageButtonProps) => {
const { onClick } = props;
const { t } = useTranslation();
const canDeleteImage = useAppSelector(deleteImageButtonsSelector);
return (
<IAIIconButton
onClick={onClick}
icon={<FaTrash />}
tooltip={`${t('gallery.deleteImage')} (Del)`}
aria-label={`${t('gallery.deleteImage')} (Del)`}
isDisabled={!canDeleteImage}
colorScheme="error"
/>
);
};

View File

@ -0,0 +1,131 @@
import { Box } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { FaTrash } from 'react-icons/fa';
import { useTranslation } from 'react-i18next';
import { createSelector } from '@reduxjs/toolkit';
import { ImageDTO } from 'services/api/types';
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import ImageContextMenu from './ImageContextMenu';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import {
imageRangeEndSelected,
imageSelected,
imageSelectionToggled,
} from '../store/gallerySlice';
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
export const selector = createSelector(
[stateSelector, (state, { image_name }: ImageDTO) => image_name],
({ gallery }, image_name) => {
const isSelected = gallery.selection.includes(image_name);
const selection = gallery.selection;
return {
isSelected,
selection,
};
},
defaultSelectorOptions
);
interface HoverableImageProps {
imageDTO: ImageDTO;
}
/**
* Gallery image component with delete/use all/use seed buttons on hover.
*/
const GalleryImage = (props: HoverableImageProps) => {
const { isSelected, selection } = useAppSelector((state) =>
selector(state, props.imageDTO)
);
const { imageDTO } = props;
const { image_url, thumbnail_url, image_name } = imageDTO;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (e.shiftKey) {
dispatch(imageRangeEndSelected(props.imageDTO.image_name));
} else if (e.ctrlKey || e.metaKey) {
dispatch(imageSelectionToggled(props.imageDTO.image_name));
} else {
dispatch(imageSelected(props.imageDTO.image_name));
}
},
[dispatch, props.imageDTO.image_name]
);
const handleDelete = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
e.stopPropagation();
if (!imageDTO) {
return;
}
dispatch(imageToDeleteSelected(imageDTO));
},
[dispatch, imageDTO]
);
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (selection.length > 1) {
return {
id: 'gallery-image',
payloadType: 'IMAGE_NAMES',
payload: { imageNames: selection },
};
}
if (imageDTO) {
return {
id: 'gallery-image',
payloadType: 'IMAGE_DTO',
payload: { imageDTO },
};
}
}, [imageDTO, selection]);
return (
<Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}>
<ImageContextMenu image={imageDTO}>
{(ref) => (
<Box
position="relative"
key={image_name}
userSelect="none"
ref={ref}
sx={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
aspectRatio: '1/1',
}}
>
<IAIDndImage
onClick={handleClick}
imageDTO={imageDTO}
draggableData={draggableData}
isSelected={isSelected}
minSize={0}
onClickReset={handleDelete}
resetIcon={<FaTrash />}
resetTooltip="Delete image"
imageSx={{ w: 'full', h: 'full' }}
withResetIcon
isDropDisabled={true}
isUploadDisabled={true}
/>
</Box>
)}
</ImageContextMenu>
</Box>
);
};
export default memo(GalleryImage);

View File

@ -1,371 +0,0 @@
import { Box, Flex, Icon, Image, MenuItem, MenuList } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useContext, useState } from 'react';
import {
FaCheck,
FaExpand,
FaFolder,
FaImage,
FaShare,
FaTrash,
} from 'react-icons/fa';
import { ContextMenu } from 'chakra-ui-contextmenu';
import {
resizeAndScaleCanvas,
setInitialCanvasImage,
} from 'features/canvas/store/canvasSlice';
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useTranslation } from 'react-i18next';
import IAIIconButton from 'common/components/IAIIconButton';
import { ExternalLinkIcon } from '@chakra-ui/icons';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { createSelector } from '@reduxjs/toolkit';
import { systemSelector } from 'features/system/store/systemSelectors';
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { sentImageToCanvas, sentImageToImg2Img } from '../store/actions';
import { useAppToaster } from 'app/components/Toaster';
import { ImageDTO } from 'services/api/types';
import { useDraggable } from '@dnd-kit/core';
import { DeleteImageContext } from 'app/contexts/DeleteImageContext';
import { AddImageToBoardContext } from '../../../app/contexts/AddImageToBoardContext';
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
export const selector = createSelector(
[gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector],
(gallery, system, lightbox, activeTabName) => {
const {
galleryImageObjectFit,
galleryImageMinimumWidth,
shouldUseSingleGalleryColumn,
} = gallery;
const { isLightboxOpen } = lightbox;
const { isConnected, isProcessing, shouldConfirmOnDelete } = system;
return {
canDeleteImage: isConnected && !isProcessing,
shouldConfirmOnDelete,
galleryImageObjectFit,
galleryImageMinimumWidth,
shouldUseSingleGalleryColumn,
activeTabName,
isLightboxOpen,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
interface HoverableImageProps {
image: ImageDTO;
isSelected: boolean;
}
/**
* Gallery image component with delete/use all/use seed buttons on hover.
*/
const HoverableImage = (props: HoverableImageProps) => {
const dispatch = useAppDispatch();
const {
activeTabName,
galleryImageObjectFit,
galleryImageMinimumWidth,
canDeleteImage,
shouldUseSingleGalleryColumn,
} = useAppSelector(selector);
const { image, isSelected } = props;
const { image_url, thumbnail_url, image_name } = image;
const [isHovered, setIsHovered] = useState<boolean>(false);
const toaster = useAppToaster();
const { t } = useTranslation();
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { onDelete } = useContext(DeleteImageContext);
const { onClickAddToBoard } = useContext(AddImageToBoardContext);
const handleDelete = useCallback(() => {
onDelete(image);
}, [image, onDelete]);
const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters();
const { attributes, listeners, setNodeRef } = useDraggable({
id: `galleryImage_${image_name}`,
data: {
image,
},
});
const [removeFromBoard] = useRemoveImageFromBoardMutation();
const handleMouseOver = () => setIsHovered(true);
const handleMouseOut = () => setIsHovered(false);
const handleSelectImage = useCallback(() => {
dispatch(imageSelected(image.image_name));
}, [image, dispatch]);
// Recall parameters handlers
const handleRecallPrompt = useCallback(() => {
recallBothPrompts(
image.metadata?.positive_conditioning,
image.metadata?.negative_conditioning
);
}, [
image.metadata?.negative_conditioning,
image.metadata?.positive_conditioning,
recallBothPrompts,
]);
const handleRecallSeed = useCallback(() => {
recallSeed(image.metadata?.seed);
}, [image, recallSeed]);
const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(image));
}, [dispatch, image]);
// const handleRecallInitialImage = useCallback(() => {
// recallInitialImage(image.metadata.invokeai?.node?.image);
// }, [image, recallInitialImage]);
/**
* TODO: the rest of these
*/
const handleSendToCanvas = () => {
dispatch(sentImageToCanvas());
dispatch(setInitialCanvasImage(image));
dispatch(resizeAndScaleCanvas());
if (activeTabName !== 'unifiedCanvas') {
dispatch(setActiveTab('unifiedCanvas'));
}
toaster({
title: t('toast.sentToUnifiedCanvas'),
status: 'success',
duration: 2500,
isClosable: true,
});
};
const handleUseAllParameters = useCallback(() => {
recallAllParameters(image);
}, [image, recallAllParameters]);
const handleLightBox = () => {
// dispatch(setCurrentImage(image));
// dispatch(setIsLightboxOpen(true));
};
const handleAddToBoard = useCallback(() => {
onClickAddToBoard(image);
}, [image, onClickAddToBoard]);
const handleRemoveFromBoard = useCallback(() => {
if (!image.board_id) {
return;
}
removeFromBoard({ board_id: image.board_id, image_name: image.image_name });
}, [image.board_id, image.image_name, removeFromBoard]);
const handleOpenInNewTab = () => {
window.open(image.image_url, '_blank');
};
return (
<Box
ref={setNodeRef}
{...listeners}
{...attributes}
sx={{ w: 'full', h: 'full', touchAction: 'none' }}
>
<ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }}
renderMenu={() => (
<MenuList sx={{ visibility: 'visible !important' }}>
<MenuItem
icon={<ExternalLinkIcon />}
onClickCapture={handleOpenInNewTab}
>
{t('common.openInNewTab')}
</MenuItem>
{isLightboxEnabled && (
<MenuItem icon={<FaExpand />} onClickCapture={handleLightBox}>
{t('parameters.openInViewer')}
</MenuItem>
)}
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallPrompt}
isDisabled={image?.metadata?.positive_conditioning === undefined}
>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallSeed}
isDisabled={image?.metadata?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
{/* <MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallInitialImage}
isDisabled={image?.metadata?.type !== 'img2img'}
>
{t('parameters.useInitImg')}
</MenuItem> */}
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleUseAllParameters}
isDisabled={
// what should these be
!['t2l', 'l2l', 'inpaint'].includes(
String(image?.metadata?.type)
)
}
>
{t('parameters.useAll')}
</MenuItem>
<MenuItem
icon={<FaShare />}
onClickCapture={handleSendToImageToImage}
id="send-to-img2img"
>
{t('parameters.sendToImg2Img')}
</MenuItem>
{isCanvasEnabled && (
<MenuItem
icon={<FaShare />}
onClickCapture={handleSendToCanvas}
id="send-to-canvas"
>
{t('parameters.sendToUnifiedCanvas')}
</MenuItem>
)}
<MenuItem icon={<FaFolder />} onClickCapture={handleAddToBoard}>
{image.board_id ? 'Change Board' : 'Add to Board'}
</MenuItem>
{image.board_id && (
<MenuItem
icon={<FaFolder />}
onClickCapture={handleRemoveFromBoard}
>
Remove from Board
</MenuItem>
)}
<MenuItem
sx={{ color: 'error.300' }}
icon={<FaTrash />}
onClickCapture={handleDelete}
>
{t('gallery.deleteImage')}
</MenuItem>
</MenuList>
)}
>
{(ref) => (
<Box
position="relative"
key={image_name}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
userSelect="none"
onClick={handleSelectImage}
ref={ref}
sx={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
w: 'full',
h: 'full',
transition: 'transform 0.2s ease-out',
aspectRatio: '1/1',
cursor: 'pointer',
}}
>
<Image
loading="lazy"
objectFit={
shouldUseSingleGalleryColumn ? 'contain' : galleryImageObjectFit
}
draggable={false}
rounded="md"
src={thumbnail_url || image_url}
fallback={<FaImage />}
sx={{
width: '100%',
height: '100%',
maxWidth: '100%',
maxHeight: '100%',
}}
/>
{isSelected && (
<Flex
sx={{
position: 'absolute',
top: '0',
insetInlineStart: '0',
width: '100%',
height: '100%',
alignItems: 'center',
justifyContent: 'center',
pointerEvents: 'none',
}}
>
<Icon
filter={'drop-shadow(0px 0px 1rem black)'}
as={FaCheck}
sx={{
width: '50%',
height: '50%',
maxWidth: '4rem',
maxHeight: '4rem',
fill: 'ok.500',
}}
/>
</Flex>
)}
{isHovered && galleryImageMinimumWidth >= 100 && (
<Box
sx={{
position: 'absolute',
top: 1,
insetInlineEnd: 1,
}}
>
<IAIIconButton
onClickCapture={handleDelete}
aria-label={t('gallery.deleteImage')}
icon={<FaTrash />}
size="xs"
fontSize={14}
isDisabled={!canDeleteImage}
/>
</Box>
)}
</Box>
)}
</ContextMenu>
</Box>
);
};
export default memo(HoverableImage);

View File

@ -0,0 +1,278 @@
import { MenuItem, MenuList } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { memo, useCallback, useContext } from 'react';
import {
FaExpand,
FaFolder,
FaFolderPlus,
FaShare,
FaTrash,
} from 'react-icons/fa';
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
import {
resizeAndScaleCanvas,
setInitialCanvasImage,
} from 'features/canvas/store/canvasSlice';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useTranslation } from 'react-i18next';
import { ExternalLinkIcon } from '@chakra-ui/icons';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { createSelector } from '@reduxjs/toolkit';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { sentImageToCanvas, sentImageToImg2Img } from '../store/actions';
import { useAppToaster } from 'app/components/Toaster';
import { AddImageToBoardContext } from '../../../app/contexts/AddImageToBoardContext';
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
import { ImageDTO } from 'services/api/types';
import { RootState, stateSelector } from 'app/store/store';
import {
imagesAddedToBatch,
selectionAddedToBatch,
} from 'features/batch/store/batchSlice';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
const selector = createSelector(
[stateSelector, (state: RootState, imageDTO: ImageDTO) => imageDTO],
({ gallery, batch }, imageDTO) => {
const selectionCount = gallery.selection.length;
const isInBatch = batch.imageNames.includes(imageDTO.image_name);
return { selectionCount, isInBatch };
},
defaultSelectorOptions
);
type Props = {
image: ImageDTO;
children: ContextMenuProps<HTMLDivElement>['children'];
};
const ImageContextMenu = ({ image, children }: Props) => {
const { selectionCount, isInBatch } = useAppSelector((state) =>
selector(state, image)
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const toaster = useAppToaster();
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { onClickAddToBoard } = useContext(AddImageToBoardContext);
const handleDelete = useCallback(() => {
if (!image) {
return;
}
dispatch(imageToDeleteSelected(image));
}, [dispatch, image]);
const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters();
const [removeFromBoard] = useRemoveImageFromBoardMutation();
// Recall parameters handlers
const handleRecallPrompt = useCallback(() => {
recallBothPrompts(
image.metadata?.positive_conditioning,
image.metadata?.negative_conditioning
);
}, [
image.metadata?.negative_conditioning,
image.metadata?.positive_conditioning,
recallBothPrompts,
]);
const handleRecallSeed = useCallback(() => {
recallSeed(image.metadata?.seed);
}, [image, recallSeed]);
const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(image));
}, [dispatch, image]);
// const handleRecallInitialImage = useCallback(() => {
// recallInitialImage(image.metadata.invokeai?.node?.image);
// }, [image, recallInitialImage]);
const handleSendToCanvas = () => {
dispatch(sentImageToCanvas());
dispatch(setInitialCanvasImage(image));
dispatch(resizeAndScaleCanvas());
dispatch(setActiveTab('unifiedCanvas'));
toaster({
title: t('toast.sentToUnifiedCanvas'),
status: 'success',
duration: 2500,
isClosable: true,
});
};
const handleUseAllParameters = useCallback(() => {
recallAllParameters(image);
}, [image, recallAllParameters]);
const handleLightBox = () => {
// dispatch(setCurrentImage(image));
// dispatch(setIsLightboxOpen(true));
};
const handleAddToBoard = useCallback(() => {
onClickAddToBoard(image);
}, [image, onClickAddToBoard]);
const handleRemoveFromBoard = useCallback(() => {
if (!image.board_id) {
return;
}
removeFromBoard({ board_id: image.board_id, image_name: image.image_name });
}, [image.board_id, image.image_name, removeFromBoard]);
const handleOpenInNewTab = () => {
window.open(image.image_url, '_blank');
};
const handleAddSelectionToBatch = useCallback(() => {
dispatch(selectionAddedToBatch());
}, [dispatch]);
const handleAddToBatch = useCallback(() => {
dispatch(imagesAddedToBatch([image.image_name]));
}, [dispatch, image.image_name]);
return (
<ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }}
renderMenu={() => (
<MenuList sx={{ visibility: 'visible !important' }}>
{selectionCount === 1 ? (
<>
<MenuItem
icon={<ExternalLinkIcon />}
onClickCapture={handleOpenInNewTab}
>
{t('common.openInNewTab')}
</MenuItem>
{isLightboxEnabled && (
<MenuItem icon={<FaExpand />} onClickCapture={handleLightBox}>
{t('parameters.openInViewer')}
</MenuItem>
)}
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallPrompt}
isDisabled={
image?.metadata?.positive_conditioning === undefined
}
>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallSeed}
isDisabled={image?.metadata?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
{/* <MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallInitialImage}
isDisabled={image?.metadata?.type !== 'img2img'}
>
{t('parameters.useInitImg')}
</MenuItem> */}
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleUseAllParameters}
isDisabled={
// what should these be
!['t2l', 'l2l', 'inpaint'].includes(
String(image?.metadata?.type)
)
}
>
{t('parameters.useAll')}
</MenuItem>
<MenuItem
icon={<FaShare />}
onClickCapture={handleSendToImageToImage}
id="send-to-img2img"
>
{t('parameters.sendToImg2Img')}
</MenuItem>
{isCanvasEnabled && (
<MenuItem
icon={<FaShare />}
onClickCapture={handleSendToCanvas}
id="send-to-canvas"
>
{t('parameters.sendToUnifiedCanvas')}
</MenuItem>
)}
{/* <MenuItem
icon={<FaFolder />}
isDisabled={isInBatch}
onClickCapture={handleAddToBatch}
>
Add to Batch
</MenuItem> */}
<MenuItem icon={<FaFolder />} onClickCapture={handleAddToBoard}>
{image.board_id ? 'Change Board' : 'Add to Board'}
</MenuItem>
{image.board_id && (
<MenuItem
icon={<FaFolder />}
onClickCapture={handleRemoveFromBoard}
>
Remove from Board
</MenuItem>
)}
<MenuItem
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
icon={<FaTrash />}
onClickCapture={handleDelete}
>
{t('gallery.deleteImage')}
</MenuItem>
</>
) : (
<>
<MenuItem
isDisabled={true}
icon={<FaFolder />}
onClickCapture={handleAddToBoard}
>
Move Selection to Board
</MenuItem>
{/* <MenuItem
icon={<FaFolderPlus />}
onClickCapture={handleAddSelectionToBatch}
>
Add Selection to Batch
</MenuItem> */}
<MenuItem
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
icon={<FaTrash />}
onClickCapture={handleDelete}
>
Delete Selection
</MenuItem>
</>
)}
</MenuList>
)}
>
{children}
</ContextMenu>
);
};
export default memo(ImageContextMenu);

View File

@ -5,7 +5,7 @@ import {
Flex, Flex,
FlexProps, FlexProps,
Grid, Grid,
Icon, Skeleton,
Text, Text,
VStack, VStack,
forwardRef, forwardRef,
@ -18,12 +18,8 @@ import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover'; import IAIPopover from 'common/components/IAIPopover';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
import { import {
setGalleryImageMinimumWidth, setGalleryImageMinimumWidth,
setGalleryImageObjectFit,
setShouldAutoSwitchToNewImages,
setShouldUseSingleGalleryColumn,
setGalleryView, setGalleryView,
} from 'features/gallery/store/gallerySlice'; } from 'features/gallery/store/gallerySlice';
import { togglePinGalleryPanel } from 'features/ui/store/uiSlice'; import { togglePinGalleryPanel } from 'features/ui/store/uiSlice';
@ -42,77 +38,56 @@ import {
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs'; import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs';
import { FaImage, FaServer, FaWrench } from 'react-icons/fa'; import { FaImage, FaServer, FaWrench } from 'react-icons/fa';
import { MdPhotoLibrary } from 'react-icons/md'; import GalleryImage from './GalleryImage';
import HoverableImage from './HoverableImage';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState, stateSelector } from 'app/store/store';
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso'; import { VirtuosoGrid } from 'react-virtuoso';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { import {
ASSETS_CATEGORIES, ASSETS_CATEGORIES,
IMAGE_CATEGORIES, IMAGE_CATEGORIES,
imageCategoriesChanged, imageCategoriesChanged,
selectImagesAll, shouldAutoSwitchChanged,
} from '../store/imagesSlice'; selectFilteredImages,
} from 'features/gallery/store/gallerySlice';
import { receivedPageOfImages } from 'services/api/thunks/image'; import { receivedPageOfImages } from 'services/api/thunks/image';
import BoardsList from './Boards/BoardsList'; import BoardsList from './Boards/BoardsList';
import { boardsSelector } from '../store/boardSlice';
import { ChevronUpIcon } from '@chakra-ui/icons'; import { ChevronUpIcon } from '@chakra-ui/icons';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
import { ImageDTO } from 'services/api/types';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
const itemSelector = createSelector( const LOADING_IMAGE_ARRAY = Array(20).fill('loading');
[(state: RootState) => state],
(state) => {
const { categories, total: allImagesTotal, isLoading } = state.images;
const { selectedBoardId } = state.boards;
const allImages = selectImagesAll(state); const selector = createSelector(
[stateSelector, selectFilteredImages],
(state, filteredImages) => {
const {
categories,
total: allImagesTotal,
isLoading,
selectedBoardId,
galleryImageMinimumWidth,
galleryView,
shouldAutoSwitch,
} = state.gallery;
const { shouldPinGallery } = state.ui;
const images = allImages.filter((i) => { const images = filteredImages as (ImageDTO | string)[];
const isInCategory = categories.includes(i.image_category);
const isInSelectedBoard = selectedBoardId
? i.board_id === selectedBoardId
: true;
return isInCategory && isInSelectedBoard;
});
return { return {
images, images: isLoading ? images.concat(LOADING_IMAGE_ARRAY) : images,
allImagesTotal, allImagesTotal,
isLoading, isLoading,
categories, categories,
selectedBoardId, selectedBoardId,
};
},
defaultSelectorOptions
);
const mainSelector = createSelector(
[gallerySelector, uiSelector, boardsSelector],
(gallery, ui, boards) => {
const {
galleryImageMinimumWidth,
galleryImageObjectFit,
shouldAutoSwitchToNewImages,
shouldUseSingleGalleryColumn,
selectedImage,
galleryView,
} = gallery;
const { shouldPinGallery } = ui;
return {
shouldPinGallery, shouldPinGallery,
galleryImageMinimumWidth, galleryImageMinimumWidth,
galleryImageObjectFit, shouldAutoSwitch,
shouldAutoSwitchToNewImages,
shouldUseSingleGalleryColumn,
selectedImage,
galleryView, galleryView,
selectedBoardId: boards.selectedBoardId,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
@ -140,17 +115,16 @@ const ImageGalleryContent = () => {
const { colorMode } = useColorMode(); const { colorMode } = useColorMode();
const { const {
images,
isLoading,
allImagesTotal,
categories,
selectedBoardId,
shouldPinGallery, shouldPinGallery,
galleryImageMinimumWidth, galleryImageMinimumWidth,
galleryImageObjectFit, shouldAutoSwitch,
shouldAutoSwitchToNewImages,
shouldUseSingleGalleryColumn,
selectedImage,
galleryView, galleryView,
} = useAppSelector(mainSelector); } = useAppSelector(selector);
const { images, isLoading, allImagesTotal, categories, selectedBoardId } =
useAppSelector(itemSelector);
const { selectedBoard } = useListAllBoardsQuery(undefined, { const { selectedBoard } = useListAllBoardsQuery(undefined, {
selectFromResult: ({ data }) => ({ selectFromResult: ({ data }) => ({
@ -208,12 +182,6 @@ const ImageGalleryContent = () => {
return () => osInstance()?.destroy(); return () => osInstance()?.destroy();
}, [scroller, initialize, osInstance]); }, [scroller, initialize, osInstance]);
const setScrollerRef = useCallback((ref: HTMLElement | Window | null) => {
if (ref instanceof HTMLElement) {
setScroller(ref);
}
}, []);
const handleClickImagesCategory = useCallback(() => { const handleClickImagesCategory = useCallback(() => {
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES)); dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
dispatch(setGalleryView('images')); dispatch(setGalleryView('images'));
@ -314,29 +282,11 @@ const ImageGalleryContent = () => {
withReset withReset
handleReset={() => dispatch(setGalleryImageMinimumWidth(64))} handleReset={() => dispatch(setGalleryImageMinimumWidth(64))}
/> />
<IAISimpleCheckbox
label={t('gallery.maintainAspectRatio')}
isChecked={galleryImageObjectFit === 'contain'}
onChange={() =>
dispatch(
setGalleryImageObjectFit(
galleryImageObjectFit === 'contain' ? 'cover' : 'contain'
)
)
}
/>
<IAISimpleCheckbox <IAISimpleCheckbox
label={t('gallery.autoSwitchNewImages')} label={t('gallery.autoSwitchNewImages')}
isChecked={shouldAutoSwitchToNewImages} isChecked={shouldAutoSwitch}
onChange={(e: ChangeEvent<HTMLInputElement>) => onChange={(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldAutoSwitchToNewImages(e.target.checked)) dispatch(shouldAutoSwitchChanged(e.target.checked))
}
/>
<IAISimpleCheckbox
label={t('gallery.singleColumnLayout')}
isChecked={shouldUseSingleGalleryColumn}
onChange={(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldUseSingleGalleryColumn(e.target.checked))
} }
/> />
</Flex> </Flex>
@ -358,23 +308,6 @@ const ImageGalleryContent = () => {
{images.length || areMoreAvailable ? ( {images.length || areMoreAvailable ? (
<> <>
<Box ref={rootRef} data-overlayscrollbars="" h="100%"> <Box ref={rootRef} data-overlayscrollbars="" h="100%">
{shouldUseSingleGalleryColumn ? (
<Virtuoso
style={{ height: '100%' }}
data={images}
endReached={handleEndReached}
scrollerRef={(ref) => setScrollerRef(ref)}
itemContent={(index, item) => (
<Flex sx={{ pb: 2 }}>
<HoverableImage
key={`${item.image_name}-${item.thumbnail_url}`}
image={item}
isSelected={selectedImage === item?.image_name}
/>
</Flex>
)}
/>
) : (
<VirtuosoGrid <VirtuosoGrid
style={{ height: '100%' }} style={{ height: '100%' }}
data={images} data={images}
@ -384,15 +317,19 @@ const ImageGalleryContent = () => {
List: ListContainer, List: ListContainer,
}} }}
scrollerRef={setScroller} scrollerRef={setScroller}
itemContent={(index, item) => ( itemContent={(index, item) =>
<HoverableImage typeof item === 'string' ? (
<Skeleton
sx={{ w: 'full', h: 'full', aspectRatio: '1/1' }}
/>
) : (
<GalleryImage
key={`${item.image_name}-${item.thumbnail_url}`} key={`${item.image_name}-${item.thumbnail_url}`}
image={item} imageDTO={item}
isSelected={selectedImage === item?.image_name}
/> />
)} )
}
/> />
)}
</Box> </Box>
<IAIButton <IAIButton
onClick={handleLoadMoreImages} onClick={handleLoadMoreImages}
@ -407,27 +344,10 @@ const ImageGalleryContent = () => {
</IAIButton> </IAIButton>
</> </>
) : ( ) : (
<Flex <IAINoContentFallback
sx={{ label={t('gallery.noImagesInGallery')}
flexDirection: 'column', icon={FaImage}
alignItems: 'center',
justifyContent: 'center',
gap: 2,
padding: 8,
h: '100%',
w: '100%',
color: 'base.500',
}}
>
<Icon
as={MdPhotoLibrary}
sx={{
w: 16,
h: 16,
}}
/> />
<Text textAlign="center">{t('gallery.noImagesInGallery')}</Text>
</Flex>
)} )}
</Flex> </Flex>
</VStack> </VStack>
@ -436,7 +356,7 @@ const ImageGalleryContent = () => {
type ItemContainerProps = PropsWithChildren & FlexProps; type ItemContainerProps = PropsWithChildren & FlexProps;
const ItemContainer = forwardRef((props: ItemContainerProps, ref) => ( const ItemContainer = forwardRef((props: ItemContainerProps, ref) => (
<Box className="item-container" ref={ref}> <Box className="item-container" ref={ref} p={1.5}>
{props.children} {props.children}
</Box> </Box>
)); ));
@ -453,8 +373,7 @@ const ListContainer = forwardRef((props: ListContainerProps, ref) => {
className="list-container" className="list-container"
ref={ref} ref={ref}
sx={{ sx={{
gap: 2, gridTemplateColumns: `repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr));`,
gridTemplateColumns: `repeat(auto-fit, minmax(${galleryImageMinimumWidth}px, 1fr));`,
}} }}
> >
{props.children} {props.children}

View File

@ -5,14 +5,13 @@ import { clamp, isEqual } from 'lodash-es';
import { useCallback, useState } from 'react'; import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaAngleLeft, FaAngleRight } from 'react-icons/fa'; import { FaAngleLeft, FaAngleRight } from 'react-icons/fa';
import { gallerySelector } from '../store/gallerySelectors'; import { stateSelector } from 'app/store/store';
import { RootState } from 'app/store/store';
import { imageSelected } from '../store/gallerySlice';
import { useHotkeys } from 'react-hotkeys-hook';
import { import {
selectFilteredImagesAsObject, imageSelected,
selectFilteredImagesIds, selectImagesById,
} from '../store/imagesSlice'; } from 'features/gallery/store/gallerySlice';
import { useHotkeys } from 'react-hotkeys-hook';
import { selectFilteredImages } from 'features/gallery/store/gallerySlice';
const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = { const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = {
height: '100%', height: '100%',
@ -25,45 +24,40 @@ const nextPrevButtonStyles: ChakraProps['sx'] = {
}; };
export const nextPrevImageButtonsSelector = createSelector( export const nextPrevImageButtonsSelector = createSelector(
[ [stateSelector, selectFilteredImages],
(state: RootState) => state, (state, filteredImages) => {
gallerySelector, const lastSelectedImage =
selectFilteredImagesAsObject, state.gallery.selection[state.gallery.selection.length - 1];
selectFilteredImagesIds,
],
(state, gallery, filteredImagesAsObject, filteredImageIds) => {
const { selectedImage } = gallery;
if (!selectedImage) { if (!lastSelectedImage || filteredImages.length === 0) {
return { return {
isOnFirstImage: true, isOnFirstImage: true,
isOnLastImage: true, isOnLastImage: true,
}; };
} }
const currentImageIndex = filteredImageIds.findIndex( const currentImageIndex = filteredImages.findIndex(
(i) => i === selectedImage (i) => i.image_name === lastSelectedImage
); );
const nextImageIndex = clamp( const nextImageIndex = clamp(
currentImageIndex + 1, currentImageIndex + 1,
0, 0,
filteredImageIds.length - 1 filteredImages.length - 1
); );
const prevImageIndex = clamp( const prevImageIndex = clamp(
currentImageIndex - 1, currentImageIndex - 1,
0, 0,
filteredImageIds.length - 1 filteredImages.length - 1
); );
const nextImageId = filteredImageIds[nextImageIndex]; const nextImageId = filteredImages[nextImageIndex].image_name;
const prevImageId = filteredImageIds[prevImageIndex]; const prevImageId = filteredImages[prevImageIndex].image_name;
const nextImage = filteredImagesAsObject[nextImageId]; const nextImage = selectImagesById(state, nextImageId);
const prevImage = filteredImagesAsObject[prevImageId]; const prevImage = selectImagesById(state, prevImageId);
const imagesLength = filteredImageIds.length; const imagesLength = filteredImages.length;
return { return {
isOnFirstImage: currentImageIndex === 0, isOnFirstImage: currentImageIndex === 0,
@ -101,11 +95,11 @@ const NextPrevImageButtons = () => {
}, []); }, []);
const handlePrevImage = useCallback(() => { const handlePrevImage = useCallback(() => {
dispatch(imageSelected(prevImageId)); prevImageId && dispatch(imageSelected(prevImageId));
}, [dispatch, prevImageId]); }, [dispatch, prevImageId]);
const handleNextImage = useCallback(() => { const handleNextImage = useCallback(() => {
dispatch(imageSelected(nextImageId)); nextImageId && dispatch(imageSelected(nextImageId));
}, [dispatch, nextImageId]); }, [dispatch, nextImageId]);
useHotkeys( useHotkeys(

View File

@ -1,40 +0,0 @@
import { useColorMode, useToken } from '@chakra-ui/react';
import { motion } from 'framer-motion';
import { mode } from 'theme/util/mode';
export const SelectedItemOverlay = () => {
const [accent400, accent500] = useToken('colors', [
'accent.400',
'accent.500',
]);
const { colorMode } = useColorMode();
return (
<motion.div
initial={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
exit={{
opacity: 0,
transition: { duration: 0.1 },
}}
style={{
position: 'absolute',
top: 0,
insetInlineStart: 0,
width: '100%',
height: '100%',
boxShadow: `inset 0px 0px 0px 2px ${mode(
accent400,
accent500
)(colorMode)}`,
borderRadius: 'var(--invokeai-radii-base)',
}}
/>
);
};

View File

@ -1,18 +0,0 @@
import { useAppSelector } from 'app/store/storeHooks';
import { selectImagesEntities } from '../store/imagesSlice';
import { useCallback } from 'react';
const useGetImageByName = () => {
const images = useAppSelector(selectImagesEntities);
return useCallback(
(name: string | undefined) => {
if (!name) {
return;
}
return images[name];
},
[images]
);
};
export default useGetImageByName;

View File

@ -1,15 +1,6 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { ImageUsage } from 'app/contexts/DeleteImageContext'; import { ImageUsage } from 'app/contexts/AddImageToBoardContext';
import { ImageDTO, BoardDTO } from 'services/api/types'; import { BoardDTO } from 'services/api/types';
export type RequestedImageDeletionArg = {
image: ImageDTO;
imageUsage: ImageUsage;
};
export const requestedImageDeletion = createAction<RequestedImageDeletionArg>(
'gallery/requestedImageDeletion'
);
export type RequestedBoardImagesDeletionArg = { export type RequestedBoardImagesDeletionArg = {
board: BoardDTO; board: BoardDTO;

View File

@ -1,10 +1,8 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { boardsApi } from 'services/api/endpoints/boards';
type BoardsState = { type BoardsState = {
searchText: string; searchText: string;
selectedBoardId?: string;
updateBoardModalOpen: boolean; updateBoardModalOpen: boolean;
}; };
@ -17,9 +15,6 @@ const boardsSlice = createSlice({
name: 'boards', name: 'boards',
initialState: initialBoardsState, initialState: initialBoardsState,
reducers: { reducers: {
boardIdSelected: (state, action: PayloadAction<string | undefined>) => {
state.selectedBoardId = action.payload;
},
setBoardSearchText: (state, action: PayloadAction<string>) => { setBoardSearchText: (state, action: PayloadAction<string>) => {
state.searchText = action.payload; state.searchText = action.payload;
}, },
@ -27,19 +22,9 @@ const boardsSlice = createSlice({
state.updateBoardModalOpen = action.payload; state.updateBoardModalOpen = action.payload;
}, },
}, },
extraReducers: (builder) => {
builder.addMatcher(
boardsApi.endpoints.deleteBoard.matchFulfilled,
(state, action) => {
if (action.meta.arg.originalArgs === state.selectedBoardId) {
state.selectedBoardId = undefined;
}
}
);
},
}); });
export const { boardIdSelected, setBoardSearchText, setUpdateBoardModalOpen } = export const { setBoardSearchText, setUpdateBoardModalOpen } =
boardsSlice.actions; boardsSlice.actions;
export const boardsSelector = (state: RootState) => state.boards; export const boardsSelector = (state: RootState) => state.boards;

View File

@ -1,8 +1,15 @@
import { GalleryState } from './gallerySlice'; import { initialGalleryState } from './gallerySlice';
/** /**
* Gallery slice persist denylist * Gallery slice persist denylist
*/ */
export const galleryPersistDenylist: (keyof GalleryState)[] = [ export const galleryPersistDenylist: (keyof typeof initialGalleryState)[] = [
'shouldAutoSwitchToNewImages', 'selection',
'entities',
'ids',
'isLoading',
'limit',
'offset',
'selectedBoardId',
'total',
]; ];

View File

@ -1,87 +1,266 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction, Update } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import {
import { imageUpserted } from './imagesSlice'; createEntityAdapter,
createSelector,
createSlice,
} from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { dateComparator } from 'common/util/dateComparator';
import { imageDeletionConfirmed } from 'features/imageDeletion/store/imageDeletionSlice';
import { keyBy, uniq } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards';
import {
imageUrlsReceived,
receivedPageOfImages,
} from 'services/api/thunks/image';
import { ImageCategory, ImageDTO } from 'services/api/types';
type GalleryImageObjectFitType = 'contain' | 'cover'; export const imagesAdapter = createEntityAdapter<ImageDTO>({
selectId: (image) => image.image_name,
sortComparer: (a, b) => dateComparator(b.updated_at, a.updated_at),
});
export interface GalleryState { export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
selectedImage?: string; export const ASSETS_CATEGORIES: ImageCategory[] = [
'control',
'mask',
'user',
'other',
];
type AdditionaGalleryState = {
offset: number;
limit: number;
total: number;
isLoading: boolean;
categories: ImageCategory[];
selectedBoardId?: string;
selection: string[];
shouldAutoSwitch: boolean;
galleryImageMinimumWidth: number; galleryImageMinimumWidth: number;
galleryImageObjectFit: GalleryImageObjectFitType;
shouldAutoSwitchToNewImages: boolean;
shouldUseSingleGalleryColumn: boolean;
galleryView: 'images' | 'assets' | 'boards'; galleryView: 'images' | 'assets' | 'boards';
}
export const initialGalleryState: GalleryState = {
galleryImageMinimumWidth: 64,
galleryImageObjectFit: 'cover',
shouldAutoSwitchToNewImages: true,
shouldUseSingleGalleryColumn: false,
galleryView: 'images',
}; };
export const initialGalleryState =
imagesAdapter.getInitialState<AdditionaGalleryState>({
offset: 0,
limit: 0,
total: 0,
isLoading: true,
categories: IMAGE_CATEGORIES,
selection: [],
shouldAutoSwitch: true,
galleryImageMinimumWidth: 64,
galleryView: 'images',
});
export const gallerySlice = createSlice({ export const gallerySlice = createSlice({
name: 'gallery', name: 'gallery',
initialState: initialGalleryState, initialState: initialGalleryState,
reducers: { reducers: {
imageSelected: (state, action: PayloadAction<string | undefined>) => { imageUpserted: (state, action: PayloadAction<ImageDTO>) => {
state.selectedImage = action.payload; imagesAdapter.upsertOne(state, action.payload);
// TODO: if the user selects an image, disable the auto switch? if (
// state.shouldAutoSwitchToNewImages = false; state.shouldAutoSwitch &&
action.payload.image_category === 'general'
) {
state.selection = [action.payload.image_name];
}
},
imageUpdatedOne: (state, action: PayloadAction<Update<ImageDTO>>) => {
imagesAdapter.updateOne(state, action.payload);
},
imageRemoved: (state, action: PayloadAction<string>) => {
imagesAdapter.removeOne(state, action.payload);
},
imagesRemoved: (state, action: PayloadAction<string[]>) => {
imagesAdapter.removeMany(state, action.payload);
},
imageCategoriesChanged: (state, action: PayloadAction<ImageCategory[]>) => {
state.categories = action.payload;
},
imageRangeEndSelected: (state, action: PayloadAction<string>) => {
const rangeEndImageName = action.payload;
const lastSelectedImage = state.selection[state.selection.length - 1];
const filteredImages = selectFilteredImagesLocal(state);
const lastClickedIndex = filteredImages.findIndex(
(n) => n.image_name === lastSelectedImage
);
const currentClickedIndex = filteredImages.findIndex(
(n) => n.image_name === rangeEndImageName
);
if (lastClickedIndex > -1 && currentClickedIndex > -1) {
// We have a valid range!
const start = Math.min(lastClickedIndex, currentClickedIndex);
const end = Math.max(lastClickedIndex, currentClickedIndex);
const imagesToSelect = filteredImages
.slice(start, end + 1)
.map((i) => i.image_name);
state.selection = uniq(state.selection.concat(imagesToSelect));
}
},
imageSelectionToggled: (state, action: PayloadAction<string>) => {
if (
state.selection.includes(action.payload) &&
state.selection.length > 1
) {
state.selection = state.selection.filter(
(imageName) => imageName !== action.payload
);
} else {
state.selection = uniq(state.selection.concat(action.payload));
}
},
imageSelected: (state, action: PayloadAction<string | null>) => {
state.selection = action.payload
? [action.payload]
: [String(state.ids[0])];
},
shouldAutoSwitchChanged: (state, action: PayloadAction<boolean>) => {
state.shouldAutoSwitch = action.payload;
}, },
setGalleryImageMinimumWidth: (state, action: PayloadAction<number>) => { setGalleryImageMinimumWidth: (state, action: PayloadAction<number>) => {
state.galleryImageMinimumWidth = action.payload; state.galleryImageMinimumWidth = action.payload;
}, },
setGalleryImageObjectFit: (
state,
action: PayloadAction<GalleryImageObjectFitType>
) => {
state.galleryImageObjectFit = action.payload;
},
setShouldAutoSwitchToNewImages: (state, action: PayloadAction<boolean>) => {
state.shouldAutoSwitchToNewImages = action.payload;
},
setShouldUseSingleGalleryColumn: (
state,
action: PayloadAction<boolean>
) => {
state.shouldUseSingleGalleryColumn = action.payload;
},
setGalleryView: ( setGalleryView: (
state, state,
action: PayloadAction<'images' | 'assets' | 'boards'> action: PayloadAction<'images' | 'assets' | 'boards'>
) => { ) => {
state.galleryView = action.payload; state.galleryView = action.payload;
}, },
boardIdSelected: (state, action: PayloadAction<string | undefined>) => {
state.selectedBoardId = action.payload;
},
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
builder.addCase(imageUpserted, (state, action) => { builder.addCase(receivedPageOfImages.pending, (state) => {
if ( state.isLoading = true;
state.shouldAutoSwitchToNewImages &&
action.payload.image_category === 'general'
) {
state.selectedImage = action.payload.image_name;
}
}); });
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { builder.addCase(receivedPageOfImages.rejected, (state) => {
// const { image_name, image_url, thumbnail_url } = action.payload; state.isLoading = false;
});
builder.addCase(receivedPageOfImages.fulfilled, (state, action) => {
state.isLoading = false;
const { board_id, categories, image_origin, is_intermediate } =
action.meta.arg;
// if (state.selectedImage?.image_name === image_name) { const { items, offset, limit, total } = action.payload;
// state.selectedImage.image_url = image_url;
// state.selectedImage.thumbnail_url = thumbnail_url; const transformedItems = items.map((item) => ({
// } ...item,
// }); isSelected: false,
}));
imagesAdapter.upsertMany(state, transformedItems);
if (state.selection.length === 0) {
state.selection = [items[0].image_name];
}
if (!categories?.includes('general') || board_id) {
// need to skip updating the total images count if the images recieved were for a specific board
// TODO: this doesn't work when on the Asset tab/category...
return;
}
state.offset = offset;
state.limit = limit;
state.total = total;
});
builder.addCase(imageDeletionConfirmed, (state, action) => {
// Image deleted
const { image_name } = action.payload.imageDTO;
imagesAdapter.removeOne(state, image_name);
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload;
imagesAdapter.updateOne(state, {
id: image_name,
changes: { image_url, thumbnail_url },
});
});
builder.addMatcher(
boardsApi.endpoints.deleteBoard.matchFulfilled,
(state, action) => {
if (action.meta.arg.originalArgs === state.selectedBoardId) {
state.selectedBoardId = undefined;
}
}
);
}, },
}); });
export const { export const {
selectAll: selectImagesAll,
selectById: selectImagesById,
selectEntities: selectImagesEntities,
selectIds: selectImagesIds,
selectTotal: selectImagesTotal,
} = imagesAdapter.getSelectors<RootState>((state) => state.gallery);
export const {
imageUpserted,
imageUpdatedOne,
imageRemoved,
imagesRemoved,
imageCategoriesChanged,
imageRangeEndSelected,
imageSelectionToggled,
imageSelected, imageSelected,
shouldAutoSwitchChanged,
setGalleryImageMinimumWidth, setGalleryImageMinimumWidth,
setGalleryImageObjectFit,
setShouldAutoSwitchToNewImages,
setShouldUseSingleGalleryColumn,
setGalleryView, setGalleryView,
boardIdSelected,
} = gallerySlice.actions; } = gallerySlice.actions;
export default gallerySlice.reducer; export default gallerySlice.reducer;
export const selectFilteredImagesLocal = createSelector(
(state: typeof initialGalleryState) => state,
(galleryState) => {
const allImages = imagesAdapter.getSelectors().selectAll(galleryState);
const { categories, selectedBoardId } = galleryState;
const filteredImages = allImages.filter((i) => {
const isInCategory = categories.includes(i.image_category);
const isInSelectedBoard = selectedBoardId
? i.board_id === selectedBoardId
: true;
return isInCategory && isInSelectedBoard;
});
return filteredImages;
}
);
export const selectFilteredImages = createSelector(
(state: RootState) => state,
(state) => {
return selectFilteredImagesLocal(state.gallery);
},
defaultSelectorOptions
);
export const selectFilteredImagesAsObject = createSelector(
selectFilteredImages,
(filteredImages) => keyBy(filteredImages, 'image_name')
);
export const selectFilteredImagesIds = createSelector(
selectFilteredImages,
(filteredImages) => filteredImages.map((i) => i.image_name)
);
export const selectLastSelectedImage = createSelector(
(state: RootState) => state,
(state) => state.gallery.selection[state.gallery.selection.length - 1],
defaultSelectorOptions
);

View File

@ -1,182 +0,0 @@
import {
PayloadAction,
Update,
createEntityAdapter,
createSelector,
createSlice,
} from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { ImageCategory, ImageDTO } from 'services/api/types';
import { dateComparator } from 'common/util/dateComparator';
import { keyBy } from 'lodash-es';
import {
imageDeleted,
imageUrlsReceived,
receivedPageOfImages,
} from 'services/api/thunks/image';
export const imagesAdapter = createEntityAdapter<ImageDTO>({
selectId: (image) => image.image_name,
sortComparer: (a, b) => dateComparator(b.updated_at, a.updated_at),
});
export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
export const ASSETS_CATEGORIES: ImageCategory[] = [
'control',
'mask',
'user',
'other',
];
type AdditionaImagesState = {
offset: number;
limit: number;
total: number;
isLoading: boolean;
categories: ImageCategory[];
};
export const initialImagesState =
imagesAdapter.getInitialState<AdditionaImagesState>({
offset: 0,
limit: 0,
total: 0,
isLoading: false,
categories: IMAGE_CATEGORIES,
});
export type ImagesState = typeof initialImagesState;
const imagesSlice = createSlice({
name: 'images',
initialState: initialImagesState,
reducers: {
imageUpserted: (state, action: PayloadAction<ImageDTO>) => {
imagesAdapter.upsertOne(state, action.payload);
},
imageUpdatedOne: (state, action: PayloadAction<Update<ImageDTO>>) => {
imagesAdapter.updateOne(state, action.payload);
},
imageRemoved: (state, action: PayloadAction<string>) => {
imagesAdapter.removeOne(state, action.payload);
},
imagesRemoved: (state, action: PayloadAction<string[]>) => {
imagesAdapter.removeMany(state, action.payload);
},
imageCategoriesChanged: (state, action: PayloadAction<ImageCategory[]>) => {
state.categories = action.payload;
},
},
extraReducers: (builder) => {
builder.addCase(receivedPageOfImages.pending, (state) => {
state.isLoading = true;
});
builder.addCase(receivedPageOfImages.rejected, (state) => {
state.isLoading = false;
});
builder.addCase(receivedPageOfImages.fulfilled, (state, action) => {
state.isLoading = false;
const { board_id, categories, image_origin, is_intermediate } =
action.meta.arg;
const { items, offset, limit, total } = action.payload;
imagesAdapter.upsertMany(state, items);
if (!categories?.includes('general') || board_id) {
// need to skip updating the total images count if the images recieved were for a specific board
// TODO: this doesn't work when on the Asset tab/category...
return;
}
state.offset = offset;
state.limit = limit;
state.total = total;
});
builder.addCase(imageDeleted.pending, (state, action) => {
// Image deleted
const { image_name } = action.meta.arg;
imagesAdapter.removeOne(state, image_name);
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload;
imagesAdapter.updateOne(state, {
id: image_name,
changes: { image_url, thumbnail_url },
});
});
},
});
export const {
selectAll: selectImagesAll,
selectById: selectImagesById,
selectEntities: selectImagesEntities,
selectIds: selectImagesIds,
selectTotal: selectImagesTotal,
} = imagesAdapter.getSelectors<RootState>((state) => state.images);
export const {
imageUpserted,
imageUpdatedOne,
imageRemoved,
imagesRemoved,
imageCategoriesChanged,
} = imagesSlice.actions;
export default imagesSlice.reducer;
export const selectFilteredImagesAsArray = createSelector(
(state: RootState) => state,
(state) => {
const {
images: { categories },
} = state;
return selectImagesAll(state).filter((i) =>
categories.includes(i.image_category)
);
}
);
export const selectFilteredImagesAsObject = createSelector(
(state: RootState) => state,
(state) => {
const {
images: { categories },
} = state;
return keyBy(
selectImagesAll(state).filter((i) =>
categories.includes(i.image_category)
),
'image_name'
);
}
);
export const selectFilteredImagesIds = createSelector(
(state: RootState) => state,
(state) => {
const {
images: { categories },
} = state;
return selectImagesAll(state)
.filter((i) => categories.includes(i.image_category))
.map((i) => i.image_name);
}
);
// export const selectImageById = createSelector(
// (state: RootState, imageId) => state,
// (state) => {
// const {
// images: { categories },
// } = state;
// return selectImagesAll(state)
// .filter((i) => categories.includes(i.image_category))
// .map((i) => i.image_name);
// }
// );

View File

@ -0,0 +1,37 @@
import { IconButtonProps } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
const deleteImageButtonsSelector = createSelector(
[stateSelector],
({ system }) => {
const { isProcessing, isConnected } = system;
return isConnected && !isProcessing;
}
);
type DeleteImageButtonProps = Omit<IconButtonProps, 'aria-label'> & {
onClick: () => void;
};
export const DeleteImageButton = (props: DeleteImageButtonProps) => {
const { onClick, isDisabled } = props;
const { t } = useTranslation();
const canDeleteImage = useAppSelector(deleteImageButtonsSelector);
return (
<IAIIconButton
onClick={onClick}
icon={<FaTrash />}
tooltip={`${t('gallery.deleteImage')} (Del)`}
aria-label={`${t('gallery.deleteImage')} (Del)`}
isDisabled={isDisabled || !canDeleteImage}
colorScheme="error"
/>
);
};

View File

@ -0,0 +1,122 @@
import {
AlertDialog,
AlertDialogBody,
AlertDialogContent,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogOverlay,
Divider,
Flex,
Text,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAISwitch from 'common/components/IAISwitch';
import { setShouldConfirmOnDelete } from 'features/system/store/systemSlice';
import { ChangeEvent, memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import ImageUsageMessage from './ImageUsageMessage';
import { stateSelector } from 'app/store/store';
import {
imageDeletionConfirmed,
imageToDeleteCleared,
selectImageUsage,
} from '../store/imageDeletionSlice';
const selector = createSelector(
[stateSelector, selectImageUsage],
({ system, config, imageDeletion }, imageUsage) => {
const { shouldConfirmOnDelete } = system;
const { canRestoreDeletedImagesFromBin } = config;
const { imageToDelete, isModalOpen } = imageDeletion;
return {
shouldConfirmOnDelete,
canRestoreDeletedImagesFromBin,
imageToDelete,
imageUsage,
isModalOpen,
};
},
defaultSelectorOptions
);
const DeleteImageModal = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const {
shouldConfirmOnDelete,
canRestoreDeletedImagesFromBin,
imageToDelete,
imageUsage,
isModalOpen,
} = useAppSelector(selector);
const handleChangeShouldConfirmOnDelete = useCallback(
(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldConfirmOnDelete(!e.target.checked)),
[dispatch]
);
const handleClose = useCallback(() => {
dispatch(imageToDeleteCleared());
}, [dispatch]);
const handleDelete = useCallback(() => {
if (!imageToDelete || !imageUsage) {
return;
}
dispatch(imageToDeleteCleared());
dispatch(imageDeletionConfirmed({ imageDTO: imageToDelete, imageUsage }));
}, [dispatch, imageToDelete, imageUsage]);
const cancelRef = useRef<HTMLButtonElement>(null);
return (
<AlertDialog
isOpen={isModalOpen}
onClose={handleClose}
leastDestructiveRef={cancelRef}
isCentered
>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
{t('gallery.deleteImage')}
</AlertDialogHeader>
<AlertDialogBody>
<Flex direction="column" gap={3}>
<ImageUsageMessage imageUsage={imageUsage} />
<Divider />
<Text>
{canRestoreDeletedImagesFromBin
? t('gallery.deleteImageBin')
: t('gallery.deleteImagePermanent')}
</Text>
<Text>{t('common.areYouSure')}</Text>
<IAISwitch
label={t('common.dontAskMeAgain')}
isChecked={!shouldConfirmOnDelete}
onChange={handleChangeShouldConfirmOnDelete}
/>
</Flex>
</AlertDialogBody>
<AlertDialogFooter>
<IAIButton ref={cancelRef} onClick={handleClose}>
Cancel
</IAIButton>
<IAIButton colorScheme="error" onClick={handleDelete} ml={3}>
Delete
</IAIButton>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
);
};
export default memo(DeleteImageModal);

View File

@ -0,0 +1,33 @@
import { some } from 'lodash-es';
import { memo } from 'react';
import { ImageUsage } from '../store/imageDeletionSlice';
import { ListItem, Text, UnorderedList } from '@chakra-ui/react';
const ImageUsageMessage = (props: { imageUsage?: ImageUsage }) => {
const { imageUsage } = props;
if (!imageUsage) {
return null;
}
if (!some(imageUsage)) {
return null;
}
return (
<>
<Text>This image is currently in use in the following features:</Text>
<UnorderedList sx={{ paddingInlineStart: 6 }}>
{imageUsage.isInitialImage && <ListItem>Image to Image</ListItem>}
{imageUsage.isCanvasImage && <ListItem>Unified Canvas</ListItem>}
{imageUsage.isControlNetImage && <ListItem>ControlNet</ListItem>}
{imageUsage.isNodesImage && <ListItem>Node Editor</ListItem>}
</UnorderedList>
<Text>
If you delete this image, those features will immediately be reset.
</Text>
</>
);
};
export default memo(ImageUsageMessage);

View File

@ -0,0 +1,99 @@
import {
PayloadAction,
createAction,
createSelector,
createSlice,
} from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { some } from 'lodash-es';
import { ImageDTO } from 'services/api/types';
type DeleteImageState = {
imageToDelete: ImageDTO | null;
isModalOpen: boolean;
};
export const initialDeleteImageState: DeleteImageState = {
imageToDelete: null,
isModalOpen: false,
};
const imageDeletion = createSlice({
name: 'imageDeletion',
initialState: initialDeleteImageState,
reducers: {
isModalOpenChanged: (state, action: PayloadAction<boolean>) => {
state.isModalOpen = action.payload;
},
imageToDeleteSelected: (state, action: PayloadAction<ImageDTO>) => {
state.imageToDelete = action.payload;
},
imageToDeleteCleared: (state) => {
state.imageToDelete = null;
},
},
});
export const {
isModalOpenChanged,
imageToDeleteSelected,
imageToDeleteCleared,
} = imageDeletion.actions;
export default imageDeletion.reducer;
export type ImageUsage = {
isInitialImage: boolean;
isCanvasImage: boolean;
isNodesImage: boolean;
isControlNetImage: boolean;
};
export const selectImageUsage = createSelector(
[(state: RootState) => state],
({ imageDeletion, generation, canvas, nodes, controlNet }) => {
const { imageToDelete } = imageDeletion;
if (!imageToDelete) {
return;
}
const { image_name } = imageToDelete;
const isInitialImage = generation.initialImage?.imageName === image_name;
const isCanvasImage = canvas.layerState.objects.some(
(obj) => obj.kind === 'image' && obj.imageName === image_name
);
const isNodesImage = nodes.nodes.some((node) => {
return some(
node.data.inputs,
(input) =>
input.type === 'image' && input.value?.image_name === image_name
);
});
const isControlNetImage = some(
controlNet.controlNets,
(c) =>
c.controlImage === image_name || c.processedControlImage === image_name
);
const imageUsage: ImageUsage = {
isInitialImage,
isCanvasImage,
isNodesImage,
isControlNetImage,
};
return imageUsage;
},
defaultSelectorOptions
);
export const imageDeletionConfirmed = createAction<{
imageDTO: ImageDTO;
imageUsage: ImageUsage;
}>('imageDeletion/imageDeletionConfirmed');

View File

@ -16,6 +16,7 @@ import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
import StringInputFieldComponent from './fields/StringInputFieldComponent'; import StringInputFieldComponent from './fields/StringInputFieldComponent';
import ColorInputFieldComponent from './fields/ColorInputFieldComponent'; import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
import ItemInputFieldComponent from './fields/ItemInputFieldComponent'; import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
type InputFieldComponentProps = { type InputFieldComponentProps = {
nodeId: string; nodeId: string;
@ -191,6 +192,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
); );
} }
if (type === 'image_collection' && template.type === 'image_collection') {
return (
<ImageCollectionInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
return <Box p={2}>Unknown field type: {type}</Box>; return <Box p={2}>Unknown field type: {type}</Box>;
}; };

Some files were not shown because too many files have changed in this diff Show More