Merge branch 'main' into mm-ui

This commit is contained in:
blessedcoolant 2023-07-14 15:46:53 +12:00
commit 16e93c6455
160 changed files with 4147 additions and 4892 deletions

View File

@ -13,7 +13,6 @@ from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -75,7 +74,6 @@ class ApiDependencies:
) )
urls = LocalUrlService() urls = LocalUrlService()
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location) image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService() names = SimpleNameService()
@ -111,7 +109,6 @@ class ApiDependencies:
board_image_record_storage=board_image_record_storage, board_image_record_storage=board_image_record_storage,
image_record_storage=image_record_storage, image_record_storage=image_record_storage,
image_file_storage=image_file_storage, image_file_storage=image_file_storage,
metadata=metadata,
url=urls, url=urls,
logger=logger, logger=logger,
names=names, names=names,

View File

@ -1,18 +1,36 @@
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic import BaseModel from pydantic import BaseModel, Field
from invokeai.backend.image_util.patchmatch import PatchMatch
from invokeai.version import __version__ from invokeai.version import __version__
app_router = APIRouter(prefix="/v1/app", tags=['app']) app_router = APIRouter(prefix="/v1/app", tags=["app"])
class AppVersion(BaseModel): class AppVersion(BaseModel):
"""App Version Response""" """App Version Response"""
version: str
version: str = Field(description="App version")
@app_router.get('/version', operation_id="app_version", class AppConfig(BaseModel):
status_code=200, """App Config Response"""
response_model=AppVersion)
infill_methods: list[str] = Field(description="List of available infill methods")
@app_router.get(
"/version", operation_id="app_version", status_code=200, response_model=AppVersion
)
async def get_version() -> AppVersion: async def get_version() -> AppVersion:
return AppVersion(version=__version__) return AppVersion(version=__version__)
@app_router.get(
"/config", operation_id="get_config", status_code=200, response_model=AppConfig
)
async def get_config() -> AppConfig:
infill_methods = ['tile']
if PatchMatch.patchmatch_available():
infill_methods.append('patchmatch')
return AppConfig(infill_methods=infill_methods)

View File

@ -1,20 +1,19 @@
import io import io
from typing import Optional from typing import Optional
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.routing import APIRouter from fastapi import (Body, HTTPException, Path, Query, Request, Response,
UploadFile)
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image from PIL import Image
from invokeai.app.models.image import (
ImageCategory, from invokeai.app.invocations.metadata import ImageMetadata
ResourceOrigin, from invokeai.app.models.image import ImageCategory, ResourceOrigin
)
from invokeai.app.services.image_record_storage import OffsetPaginatedResults from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import (
ImageDTO,
ImageRecordChanges,
ImageUrlsDTO,
)
from invokeai.app.services.item_storage import PaginatedResults from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.services.models.image_record import (ImageDTO,
ImageRecordChanges,
ImageUrlsDTO)
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -103,23 +102,38 @@ async def update_image(
@images_router.get( @images_router.get(
"/{image_name}/metadata", "/{image_name}",
operation_id="get_image_metadata", operation_id="get_image_dto",
response_model=ImageDTO, response_model=ImageDTO,
) )
async def get_image_metadata( async def get_image_dto(
image_name: str = Path(description="The name of image to get"), image_name: str = Path(description="The name of image to get"),
) -> ImageDTO: ) -> ImageDTO:
"""Gets an image's metadata""" """Gets an image's DTO"""
try: try:
return ApiDependencies.invoker.services.images.get_dto(image_name) return ApiDependencies.invoker.services.images.get_dto(image_name)
except Exception as e: except Exception as e:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@images_router.get(
"/{image_name}/metadata",
operation_id="get_image_metadata",
response_model=ImageMetadata,
)
async def get_image_metadata(
image_name: str = Path(description="The name of image to get"),
) -> ImageMetadata:
"""Gets an image's metadata"""
try:
return ApiDependencies.invoker.services.images.get_metadata(image_name)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get( @images_router.get(
"/{image_name}", "/{image_name}/full",
operation_id="get_image_full", operation_id="get_image_full",
response_class=Response, response_class=Response,
responses={ responses={
@ -208,10 +222,10 @@ async def get_image_urls(
@images_router.get( @images_router.get(
"/", "/",
operation_id="list_images_with_metadata", operation_id="list_image_dtos",
response_model=OffsetPaginatedResults[ImageDTO], response_model=OffsetPaginatedResults[ImageDTO],
) )
async def list_images_with_metadata( async def list_image_dtos(
image_origin: Optional[ResourceOrigin] = Query( image_origin: Optional[ResourceOrigin] = Query(
default=None, description="The origin of images to list" default=None, description="The origin of images to list"
), ),
@ -227,7 +241,7 @@ async def list_images_with_metadata(
offset: int = Query(default=0, description="The page offset"), offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"), limit: int = Query(default=10, description="The number of images per page"),
) -> OffsetPaginatedResults[ImageDTO]: ) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of images""" """Gets a list of image DTOs"""
image_dtos = ApiDependencies.invoker.services.images.get_many( image_dtos = ApiDependencies.invoker.services.images.get_many(
offset, offset,

View File

@ -34,7 +34,6 @@ from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from .services.default_graphs import (default_text_to_image_graph_id, from .services.default_graphs import (default_text_to_image_graph_id,
@ -244,7 +243,6 @@ def invoke_cli():
) )
urls = LocalUrlService() urls = LocalUrlService()
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location) image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService() names = SimpleNameService()
@ -277,7 +275,6 @@ def invoke_cli():
board_image_record_storage=board_image_record_storage, board_image_record_storage=board_image_record_storage,
image_record_storage=image_record_storage, image_record_storage=image_record_storage,
image_file_storage=image_file_storage, image_file_storage=image_file_storage,
metadata=metadata,
url=urls, url=urls,
logger=logger, logger=logger,
names=names, names=names,

View File

@ -154,40 +154,42 @@ class InpaintInvocation(BaseInvocation):
@contextmanager @contextmanager
def load_model_old_way(self, context, scheduler): def load_model_old_way(self, context, scheduler):
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
#unet = unet_info.context.model with vae_info as vae,\
#vae = vae_info.context.model ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
with ExitStack() as stack: device = context.services.model_manager.mgr.cache.execution_device
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] dtype = context.services.model_manager.mgr.cache.precision
with vae_info as vae,\ pipeline = StableDiffusionGeneratorPipeline(
unet_info as unet,\ vae=vae,
ModelPatcher.apply_lora_unet(unet, loras): text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
precision="float16" if dtype == torch.float16 else "float32",
execution_device=device,
)
device = context.services.model_manager.mgr.cache.execution_device yield OldModelInfo(
dtype = context.services.model_manager.mgr.cache.precision name=self.unet.unet.model_name,
hash="<NO-HASH>",
pipeline = StableDiffusionGeneratorPipeline( model=pipeline,
vae=vae, )
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
precision="float16" if dtype == torch.float16 else "float32",
execution_device=device,
)
yield OldModelInfo(
name=self.unet.unet.model_name,
hash="<NO-HASH>",
model=pipeline,
)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = ( image = (
@ -226,21 +228,21 @@ class InpaintInvocation(BaseInvocation):
), # Shorthand for passing all of the parameters above manually ), # Shorthand for passing all of the parameters above manually
) )
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one. # each time it is called. We only need the first one.
generator_output = next(outputs) generator_output = next(outputs)
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=generator_output.image, image=generator_output.image,
image_origin=ResourceOrigin.INTERNAL, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
) )
return ImageOutput( return ImageOutput(
image=ImageField(image_name=image_dto.image_name), image=ImageField(image_name=image_dto.image_name),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )

View File

@ -9,9 +9,9 @@ from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
@ -21,6 +21,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import torch_dtype from ...backend.util.devices import torch_dtype
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext) InvocationConfig, InvocationContext)
from .compel import ConditioningField from .compel import ConditioningField
@ -449,6 +450,8 @@ class LatentsToImageInvocation(BaseInvocation):
tiled: bool = Field( tiled: bool = Field(
default=False, default=False,
description="Decode latents by overlaping tiles(less memory consumption)") description="Decode latents by overlaping tiles(less memory consumption)")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -493,7 +496,8 @@ class LatentsToImageInvocation(BaseInvocation):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
) )
return ImageOutput( return ImageOutput(

View File

@ -0,0 +1,124 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (BaseInvocation,
BaseInvocationOutput,
InvocationContext)
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import (LoRAModelField, MainModelField,
VAEModelField)
class LoRAMetadataField(BaseModel):
"""LoRA metadata for an image generated in InvokeAI."""
lora: LoRAModelField = Field(description="The LoRA model")
weight: float = Field(description="The weight of the LoRA model")
class CoreMetadata(BaseModel):
"""Core generation metadata for an image generated in InvokeAI."""
generation_mode: str = Field(description="The generation mode that output this image",)
positive_prompt: str = Field(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter")
width: int = Field(description="The width parameter")
height: int = Field(description="The height parameter")
seed: int = Field(description="The seed used for noise generation")
rand_device: str = Field(description="The device used for random number generation")
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(description="The number of skipped CLIP layers",)
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
strength: Union[float, None] = Field(
default=None,
description="The strength used for latents-to-latents",
)
init_image: Union[str, None] = Field(
default=None, description="The name of the initial image"
)
vae: Union[VAEModelField, None] = Field(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
class ImageMetadata(BaseModel):
"""An image's generation metadata"""
metadata: Optional[dict] = Field(
default=None,
description="The image's core metadata, if it was created in the Linear or Canvas UI",
)
graph: Optional[dict] = Field(
default=None, description="The graph that created the image"
)
class MetadataAccumulatorOutput(BaseInvocationOutput):
"""The output of the MetadataAccumulator node"""
type: Literal["metadata_accumulator_output"] = "metadata_accumulator_output"
metadata: CoreMetadata = Field(description="The core metadata for the image")
class MetadataAccumulatorInvocation(BaseInvocation):
"""Outputs a Core Metadata Object"""
type: Literal["metadata_accumulator"] = "metadata_accumulator"
generation_mode: str = Field(description="The generation mode that output this image",)
positive_prompt: str = Field(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter")
width: int = Field(description="The width parameter")
height: int = Field(description="The height parameter")
seed: int = Field(description="The seed used for noise generation")
rand_device: str = Field(description="The device used for random number generation")
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(description="The number of skipped CLIP layers",)
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
strength: Union[float, None] = Field(
default=None,
description="The strength used for latents-to-latents",
)
init_image: Union[str, None] = Field(
default=None, description="The name of the initial image"
)
vae: Union[VAEModelField, None] = Field(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
"""Collects and outputs a CoreMetadata object"""
return MetadataAccumulatorOutput(
metadata=CoreMetadata(
generation_mode=self.generation_mode,
positive_prompt=self.positive_prompt,
negative_prompt=self.negative_prompt,
width=self.width,
height=self.height,
seed=self.seed,
rand_device=self.rand_device,
cfg_scale=self.cfg_scale,
steps=self.steps,
scheduler=self.scheduler,
model=self.model,
strength=self.strength,
init_image=self.init_image,
vae=self.vae,
controlnets=self.controlnets,
loras=self.loras,
clip_skip=self.clip_skip,
)
)

View File

@ -1,93 +0,0 @@
from typing import Optional, Union, List
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
class ImageMetadata(BaseModel):
"""
Core generation metadata for an image/tensor generated in InvokeAI.
Also includes any metadata from the image's PNG tEXt chunks.
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
of a given node.
Full metadata may be accessed by querying for the session in the `graph_executions` table.
"""
class Config:
extra = Extra.allow
"""
This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService
won't add any fields that are not already defined, but other a different metadata service
implementation might.
"""
type: Optional[StrictStr] = Field(
default=None,
description="The type of the ancestor node of the image output node.",
)
"""The type of the ancestor node of the image output node."""
positive_conditioning: Optional[StrictStr] = Field(
default=None, description="The positive conditioning."
)
"""The positive conditioning"""
negative_conditioning: Optional[StrictStr] = Field(
default=None, description="The negative conditioning."
)
"""The negative conditioning"""
width: Optional[StrictInt] = Field(
default=None, description="Width of the image/latents in pixels."
)
"""Width of the image/latents in pixels"""
height: Optional[StrictInt] = Field(
default=None, description="Height of the image/latents in pixels."
)
"""Height of the image/latents in pixels"""
seed: Optional[StrictInt] = Field(
default=None, description="The seed used for noise generation."
)
"""The seed used for noise generation"""
# cfg_scale: Optional[StrictFloat] = Field(
# cfg_scale: Union[float, list[float]] = Field(
cfg_scale: Union[StrictFloat, List[StrictFloat]] = Field(
default=None, description="The classifier-free guidance scale."
)
"""The classifier-free guidance scale"""
steps: Optional[StrictInt] = Field(
default=None, description="The number of steps used for inference."
)
"""The number of steps used for inference"""
scheduler: Optional[StrictStr] = Field(
default=None, description="The scheduler used for inference."
)
"""The scheduler used for inference"""
model: Optional[StrictStr] = Field(
default=None, description="The model used for inference."
)
"""The model used for inference"""
strength: Optional[StrictFloat] = Field(
default=None,
description="The strength used for image-to-image/latents-to-latents.",
)
"""The strength used for image-to-image/latents-to-latents."""
latents: Optional[StrictStr] = Field(
default=None, description="The ID of the initial latents."
)
"""The ID of the initial latents"""
vae: Optional[StrictStr] = Field(
default=None, description="The VAE used for decoding."
)
"""The VAE used for decoding"""
unet: Optional[StrictStr] = Field(
default=None, description="The UNet used dor inference."
)
"""The UNet used dor inference"""
clip: Optional[StrictStr] = Field(
default=None, description="The CLIP Encoder used for conditioning."
)
"""The CLIP Encoder used for conditioning"""
extra: Optional[StrictStr] = Field(
default=None,
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
)
"""Uploaded image metadata, extracted from the PNG tEXt chunk."""

View File

@ -1,14 +1,14 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from PIL.Image import Image as PILImageType
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from PIL.Image import Image as PILImageType
from send2trash import send2trash from send2trash import send2trash
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@ -59,7 +59,8 @@ class ImageFileStorageBase(ABC):
self, self,
image: PILImageType, image: PILImageType,
image_name: str, image_name: str,
metadata: Optional[ImageMetadata] = None, metadata: Optional[dict] = None,
graph: Optional[dict] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp.""" """Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@ -110,20 +111,22 @@ class DiskImageFileStorage(ImageFileStorageBase):
self, self,
image: PILImageType, image: PILImageType,
image_name: str, image_name: str,
metadata: Optional[ImageMetadata] = None, metadata: Optional[dict] = None,
graph: Optional[dict] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
try: try:
self.__validate_storage_folders() self.__validate_storage_folders()
image_path = self.get_path(image_name) image_path = self.get_path(image_name)
pnginfo = PngImagePlugin.PngInfo()
if metadata is not None: if metadata is not None:
pnginfo = PngImagePlugin.PngInfo() pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
pnginfo.add_text("invokeai", metadata.json()) if graph is not None:
image.save(image_path, "PNG", pnginfo=pnginfo) pnginfo.add_text("invokeai_graph", json.dumps(graph))
else:
image.save(image_path, "PNG")
image.save(image_path, "PNG", pnginfo=pnginfo)
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True) thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size) thumbnail_image = make_thumbnail(image, thumbnail_size)

View File

@ -1,3 +1,4 @@
import json
import sqlite3 import sqlite3
import threading import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -8,7 +9,6 @@ from pydantic import BaseModel, Field
from pydantic.generics import GenericModel from pydantic.generics import GenericModel
from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
ImageRecord, ImageRecordChanges, deserialize_image_record) ImageRecord, ImageRecordChanges, deserialize_image_record)
@ -48,6 +48,28 @@ class ImageRecordDeleteException(Exception):
super().__init__(message) super().__init__(message)
IMAGE_DTO_COLS = ", ".join(
list(
map(
lambda c: "images." + c,
[
"image_name",
"image_origin",
"image_category",
"width",
"height",
"session_id",
"node_id",
"is_intermediate",
"created_at",
"updated_at",
"deleted_at",
],
)
)
)
class ImageRecordStorageBase(ABC): class ImageRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the image record store.""" """Low-level service responsible for interfacing with the image record store."""
@ -58,6 +80,11 @@ class ImageRecordStorageBase(ABC):
"""Gets an image record.""" """Gets an image record."""
pass pass
@abstractmethod
def get_metadata(self, image_name: str) -> Optional[dict]:
"""Gets an image's metadata'."""
pass
@abstractmethod @abstractmethod
def update( def update(
self, self,
@ -102,7 +129,7 @@ class ImageRecordStorageBase(ABC):
height: int, height: int,
session_id: Optional[str], session_id: Optional[str],
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[ImageMetadata], metadata: Optional[dict],
is_intermediate: bool = False, is_intermediate: bool = False,
) -> datetime: ) -> datetime:
"""Saves an image record.""" """Saves an image record."""
@ -206,7 +233,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._cursor.execute( self._cursor.execute(
f"""--sql f"""--sql
SELECT * FROM images SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?; WHERE image_name = ?;
""", """,
(image_name,), (image_name,),
@ -224,6 +251,28 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result)) return deserialize_image_record(dict(result))
def get_metadata(self, image_name: str) -> Optional[dict]:
try:
self._lock.acquire()
self._cursor.execute(
f"""--sql
SELECT images.metadata FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
if not result or not result[0]:
return None
return json.loads(result[0])
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordNotFoundException from e
finally:
self._lock.release()
def update( def update(
self, self,
image_name: str, image_name: str,
@ -291,8 +340,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
WHERE 1=1 WHERE 1=1
""" """
images_query = """--sql images_query = f"""--sql
SELECT images.* SELECT {IMAGE_DTO_COLS}
FROM images FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1 WHERE 1=1
@ -410,12 +459,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
width: int, width: int,
height: int, height: int,
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[ImageMetadata], metadata: Optional[dict],
is_intermediate: bool = False, is_intermediate: bool = False,
) -> datetime: ) -> datetime:
try: try:
metadata_json = ( metadata_json = (
None if metadata is None else metadata.json(exclude_none=True) None if metadata is None else json.dumps(metadata)
) )
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -465,9 +514,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
finally: finally:
self._lock.release() self._lock.release()
def get_most_recent_image_for_board( def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
self, board_id: str
) -> Optional[ImageRecord]:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(

View File

@ -1,39 +1,30 @@
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger from logging import Logger
from typing import Optional, TYPE_CHECKING, Union from typing import TYPE_CHECKING, Optional
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from invokeai.app.models.image import ( from invokeai.app.invocations.metadata import ImageMetadata
ImageCategory, from invokeai.app.models.image import (ImageCategory,
ResourceOrigin, InvalidImageCategoryException,
InvalidImageCategoryException, InvalidOriginException, ResourceOrigin)
InvalidOriginException, from invokeai.app.services.board_image_record_storage import \
) BoardImageRecordStorageBase
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.services.graph import Graph
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException,
ImageRecordNotFoundException,
ImageRecordSaveException,
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.image_record import (
ImageRecord,
ImageDTO,
ImageRecordChanges,
image_record_to_dto,
)
from invokeai.app.services.image_file_storage import ( from invokeai.app.services.image_file_storage import (
ImageFileDeleteException, ImageFileDeleteException, ImageFileNotFoundException,
ImageFileNotFoundException, ImageFileSaveException, ImageFileStorageBase)
ImageFileSaveException, from invokeai.app.services.image_record_storage import (
ImageFileStorageBase, ImageRecordDeleteException, ImageRecordNotFoundException,
) ImageRecordSaveException, ImageRecordStorageBase, OffsetPaginatedResults)
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.metadata import MetadataServiceBase from invokeai.app.services.models.image_record import (ImageDTO, ImageRecord,
ImageRecordChanges,
image_record_to_dto)
from invokeai.app.services.resource_name import NameServiceBase from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.urls import UrlServiceBase from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState from invokeai.app.services.graph import GraphExecutionState
@ -51,6 +42,7 @@ class ImageServiceABC(ABC):
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
is_intermediate: bool = False, is_intermediate: bool = False,
metadata: Optional[dict] = None,
) -> ImageDTO: ) -> ImageDTO:
"""Creates an image, storing the file and its metadata.""" """Creates an image, storing the file and its metadata."""
pass pass
@ -79,6 +71,11 @@ class ImageServiceABC(ABC):
"""Gets an image DTO.""" """Gets an image DTO."""
pass pass
@abstractmethod
def get_metadata(self, image_name: str) -> ImageMetadata:
"""Gets an image's metadata."""
pass
@abstractmethod @abstractmethod
def get_path(self, image_name: str, thumbnail: bool = False) -> str: def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's path.""" """Gets an image's path."""
@ -124,7 +121,6 @@ class ImageServiceDependencies:
image_records: ImageRecordStorageBase image_records: ImageRecordStorageBase
image_files: ImageFileStorageBase image_files: ImageFileStorageBase
board_image_records: BoardImageRecordStorageBase board_image_records: BoardImageRecordStorageBase
metadata: MetadataServiceBase
urls: UrlServiceBase urls: UrlServiceBase
logger: Logger logger: Logger
names: NameServiceBase names: NameServiceBase
@ -135,7 +131,6 @@ class ImageServiceDependencies:
image_record_storage: ImageRecordStorageBase, image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase, image_file_storage: ImageFileStorageBase,
board_image_record_storage: BoardImageRecordStorageBase, board_image_record_storage: BoardImageRecordStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase, url: UrlServiceBase,
logger: Logger, logger: Logger,
names: NameServiceBase, names: NameServiceBase,
@ -144,7 +139,6 @@ class ImageServiceDependencies:
self.image_records = image_record_storage self.image_records = image_record_storage
self.image_files = image_file_storage self.image_files = image_file_storage
self.board_image_records = board_image_record_storage self.board_image_records = board_image_record_storage
self.metadata = metadata
self.urls = url self.urls = url
self.logger = logger self.logger = logger
self.names = names self.names = names
@ -165,6 +159,7 @@ class ImageService(ImageServiceABC):
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
is_intermediate: bool = False, is_intermediate: bool = False,
metadata: Optional[dict] = None,
) -> ImageDTO: ) -> ImageDTO:
if image_origin not in ResourceOrigin: if image_origin not in ResourceOrigin:
raise InvalidOriginException raise InvalidOriginException
@ -174,7 +169,16 @@ class ImageService(ImageServiceABC):
image_name = self._services.names.create_image_name() image_name = self._services.names.create_image_name()
metadata = self._get_metadata(session_id, node_id) graph = None
if session_id is not None:
session_raw = self._services.graph_execution_manager.get_raw(session_id)
if session_raw is not None:
try:
graph = get_metadata_graph_from_raw_session(session_raw)
except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None
(width, height) = image.size (width, height) = image.size
@ -191,14 +195,12 @@ class ImageService(ImageServiceABC):
is_intermediate=is_intermediate, is_intermediate=is_intermediate,
# Nullable fields # Nullable fields
node_id=node_id, node_id=node_id,
session_id=session_id,
metadata=metadata, metadata=metadata,
session_id=session_id,
) )
self._services.image_files.save( self._services.image_files.save(
image_name=image_name, image_name=image_name, image=image, metadata=metadata, graph=graph
image=image,
metadata=metadata,
) )
image_dto = self.get_dto(image_name) image_dto = self.get_dto(image_name)
@ -268,6 +270,34 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image DTO") self._services.logger.error("Problem getting image DTO")
raise e raise e
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
try:
image_record = self._services.image_records.get(image_name)
if not image_record.session_id:
return ImageMetadata()
session_raw = self._services.graph_execution_manager.get_raw(
image_record.session_id
)
graph = None
if session_raw:
try:
graph = get_metadata_graph_from_raw_session(session_raw)
except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None
metadata = self._services.image_records.get_metadata(image_name)
return ImageMetadata(graph=graph, metadata=metadata)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image DTO")
raise e
def get_path(self, image_name: str, thumbnail: bool = False) -> str: def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try: try:
return self._services.image_files.get_path(image_name, thumbnail) return self._services.image_files.get_path(image_name, thumbnail)
@ -367,15 +397,3 @@ class ImageService(ImageServiceABC):
except Exception as e: except Exception as e:
self._services.logger.error("Problem deleting image records and files") self._services.logger.error("Problem deleting image records and files")
raise e raise e
def _get_metadata(
self, session_id: Optional[str] = None, node_id: Optional[str] = None
) -> Optional[ImageMetadata]:
"""Get the metadata for a node."""
metadata = None
if node_id is not None and session_id is not None:
session = self._services.graph_execution_manager.get(session_id)
metadata = self._services.metadata.create_image_metadata(session, node_id)
return metadata

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Generic, TypeVar from typing import Callable, Generic, Optional, TypeVar
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from pydantic.generics import GenericModel from pydantic.generics import GenericModel
@ -29,14 +29,22 @@ class ItemStorageABC(ABC, Generic[T]):
@abstractmethod @abstractmethod
def get(self, item_id: str) -> T: def get(self, item_id: str) -> T:
"""Gets the item, parsing it into a Pydantic model"""
pass
@abstractmethod
def get_raw(self, item_id: str) -> Optional[str]:
"""Gets the raw item as a string, skipping Pydantic parsing"""
pass pass
@abstractmethod @abstractmethod
def set(self, item: T) -> None: def set(self, item: T) -> None:
"""Sets the item"""
pass pass
@abstractmethod @abstractmethod
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
"""Gets a paginated list of items"""
pass pass
@abstractmethod @abstractmethod

View File

@ -1,142 +0,0 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
import networkx as nx
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.graph import Graph, GraphExecutionState
class MetadataServiceBase(ABC):
"""Handles building metadata for nodes, images, and outputs."""
@abstractmethod
def create_image_metadata(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
"""Builds an ImageMetadata object for a node."""
pass
class CoreMetadataService(MetadataServiceBase):
_ANCESTOR_TYPES = ["t2l", "l2l"]
"""The ancestor types that contain the core metadata"""
_ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"]
"""The core metadata parameters in the ancestor types"""
_NOISE_FIELDS = ["seed", "width", "height"]
"""The core metadata parameters in the noise node"""
def create_image_metadata(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
metadata = self._build_metadata_from_graph(session, node_id)
return metadata
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Optional[str]:
"""
Finds the id of the nearest ancestor (of a valid type) of a given node.
Parameters:
G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must
have the same data as the execution graph.
node_id (str): The ID of the node.
Returns:
str | None: The ID of the nearest ancestor, or None if there are no valid ancestors.
"""
# Retrieve the node from the graph
node = G.nodes[node_id]
# If the node type is one of the core metadata node types, return its id
if node.get("type") in self._ANCESTOR_TYPES:
return node.get("id")
# Else, look for the ancestor in the predecessor nodes
for predecessor in G.predecessors(node_id):
result = self._find_nearest_ancestor(G, predecessor)
if result:
return result
# If there are no valid ancestors, return None
return None
def _get_additional_metadata(
self, graph: Graph, node_id: str
) -> Optional[dict[str, Any]]:
"""
Returns additional metadata for a given node.
Parameters:
graph (Graph): The execution graph.
node_id (str): The ID of the node.
Returns:
dict[str, Any] | None: A dictionary of additional metadata.
"""
metadata = {}
# Iterate over all edges in the graph
for edge in graph.edges:
dest_node_id = edge.destination.node_id
dest_field = edge.destination.field
source_node_dict = graph.nodes[edge.source.node_id].dict()
# If the destination node ID matches the given node ID, gather necessary metadata
if dest_node_id == node_id:
# Prompt
if dest_field == "positive_conditioning":
metadata["positive_conditioning"] = source_node_dict.get("prompt")
# Negative prompt
if dest_field == "negative_conditioning":
metadata["negative_conditioning"] = source_node_dict.get("prompt")
# Seed, width and height
if dest_field == "noise":
for field in self._NOISE_FIELDS:
metadata[field] = source_node_dict.get(field)
return metadata
def _build_metadata_from_graph(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
"""
Builds an ImageMetadata object for a node.
Parameters:
session (GraphExecutionState): The session.
node_id (str): The ID of the node.
Returns:
ImageMetadata: The metadata for the node.
"""
# We need to do all the traversal on the execution graph
graph = session.execution_graph
# Find the nearest `t2l`/`l2l` ancestor of the given node
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
# If no ancestor was found, return an empty ImageMetadata object
if ancestor_id is None:
return ImageMetadata()
ancestor_node = graph.get_node(ancestor_id)
# Grab all the core metadata from the ancestor node
ancestor_metadata = {
param: val
for param, val in ancestor_node.dict().items()
if param in self._ANCESTOR_PARAMS
}
# Get this image's prompts and noise parameters
addl_metadata = self._get_additional_metadata(graph, ancestor_id)
# If additional metadata was found, add it to the main metadata
if addl_metadata is not None:
ancestor_metadata.update(addl_metadata)
return ImageMetadata(**ancestor_metadata)

View File

@ -1,13 +1,14 @@
import datetime import datetime
from typing import Optional, Union from typing import Optional, Union
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
class ImageRecord(BaseModel): class ImageRecord(BaseModel):
"""Deserialized image record.""" """Deserialized image record without metadata."""
image_name: str = Field(description="The unique name of the image.") image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image.""" """The unique name of the image."""
@ -43,11 +44,6 @@ class ImageRecord(BaseModel):
description="The node ID that generated this image, if it is a generated image.", description="The node ID that generated this image, if it is a generated image.",
) )
"""The node ID that generated this image, if it is a generated image.""" """The node ID that generated this image, if it is a generated image."""
metadata: Optional[ImageMetadata] = Field(
default=None,
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
)
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
class ImageRecordChanges(BaseModel, extra=Extra.forbid): class ImageRecordChanges(BaseModel, extra=Extra.forbid):
@ -112,6 +108,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
# Retrieve all the values, setting "reasonable" defaults if they are not present. # Retrieve all the values, setting "reasonable" defaults if they are not present.
# TODO: do we really need to handle default values here? ideally the data is the correct shape...
image_name = image_dict.get("image_name", "unknown") image_name = image_dict.get("image_name", "unknown")
image_origin = ResourceOrigin( image_origin = ResourceOrigin(
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value) image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
@ -128,13 +125,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False) is_intermediate = image_dict.get("is_intermediate", False)
raw_metadata = image_dict.get("metadata")
if raw_metadata is not None:
metadata = ImageMetadata.parse_raw(raw_metadata)
else:
metadata = None
return ImageRecord( return ImageRecord(
image_name=image_name, image_name=image_name,
image_origin=image_origin, image_origin=image_origin,
@ -143,7 +133,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
height=height, height=height,
session_id=session_id, session_id=session_id,
node_id=node_id, node_id=node_id,
metadata=metadata,
created_at=created_at, created_at=created_at,
updated_at=updated_at, updated_at=updated_at,
deleted_at=deleted_at, deleted_at=deleted_at,

View File

@ -1,6 +1,6 @@
import sqlite3 import sqlite3
from threading import Lock from threading import Lock
from typing import Generic, TypeVar, Optional, Union, get_args from typing import Generic, Optional, TypeVar, get_args
from pydantic import BaseModel, parse_raw_as from pydantic import BaseModel, parse_raw_as
@ -78,6 +78,21 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
return self._parse_item(result[0]) return self._parse_item(result[0])
def get_raw(self, id: str) -> Optional[str]:
try:
self._lock.acquire()
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
return result[0]
def delete(self, id: str): def delete(self, id: str):
try: try:
self._lock.acquire() self._lock.acquire()

View File

@ -22,4 +22,4 @@ class LocalUrlService(UrlServiceBase):
if thumbnail: if thumbnail:
return f"{self._base_url}/images/{image_basename}/thumbnail" return f"{self._base_url}/images/{image_basename}/thumbnail"
return f"{self._base_url}/images/{image_basename}" return f"{self._base_url}/images/{image_basename}/full"

View File

@ -0,0 +1,55 @@
import json
from typing import Optional
from pydantic import ValidationError
from invokeai.app.services.graph import Edge
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
"""
Parses raw session string, returning a dict of the graph.
Only the general graph shape is validated; none of the fields are validated.
Any `metadata_accumulator` nodes and edges are removed.
Any validation failure will return None.
"""
graph = json.loads(session_raw).get("graph", None)
# sanity check make sure the graph is at least reasonably shaped
if (
type(graph) is not dict
or "nodes" not in graph
or type(graph["nodes"]) is not dict
or "edges" not in graph
or type(graph["edges"]) is not list
):
# something has gone terribly awry, return an empty dict
return None
try:
# delete the `metadata_accumulator` node
del graph["nodes"]["metadata_accumulator"]
except KeyError:
# no accumulator node, all good
pass
# delete any edges to or from it
for i, edge in enumerate(graph["edges"]):
try:
# try to parse the edge
Edge(**edge)
except ValidationError:
# something has gone terribly awry, return an empty dict
return None
if (
edge["source"]["node_id"] == "metadata_accumulator"
or edge["destination"]["node_id"] == "metadata_accumulator"
):
del graph["edges"][i]
return graph

View File

@ -121,8 +121,8 @@ class ModelInstall(object):
installed_models = self.mgr.list_models() installed_models = self.mgr.list_models()
for md in installed_models: for md in installed_models:
base = md['base_model'] base = md['base_model']
model_type = md['type'] model_type = md['model_type']
name = md['name'] name = md['model_name']
key = ModelManager.create_key(name, base, model_type) key = ModelManager.create_key(name, base, model_type)
if key in model_dict: if key in model_dict:
model_dict[key].installed = True model_dict[key].installed = True

View File

@ -538,9 +538,9 @@ class ModelManager(object):
model_dict = dict( model_dict = dict(
**model_config.dict(exclude_defaults=True), **model_config.dict(exclude_defaults=True),
# OpenAPIModelInfoBase # OpenAPIModelInfoBase
name=cur_model_name, model_name=cur_model_name,
base_model=cur_base_model, base_model=cur_base_model,
type=cur_model_type, model_type=cur_model_type,
) )
models.append(model_dict) models.append(model_dict)

View File

@ -37,9 +37,9 @@ MODEL_CONFIGS = list()
OPENAPI_MODEL_CONFIGS = list() OPENAPI_MODEL_CONFIGS = list()
class OpenAPIModelInfoBase(BaseModel): class OpenAPIModelInfoBase(BaseModel):
name: str model_name: str
base_model: BaseModelType base_model: BaseModelType
type: ModelType model_type: ModelType
for base_model, models in MODEL_CLASSES.items(): for base_model, models in MODEL_CLASSES.items():
@ -56,7 +56,7 @@ for base_model, models in MODEL_CLASSES.items():
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict( api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
__annotations__ = dict( __annotations__ = dict(
type=Literal[model_type.value], model_type=Literal[model_type.value],
), ),
)) ))

View File

@ -127,7 +127,7 @@ class AddsMaskGuidance:
def _t_for_field(self, field_name: str, t): def _t_for_field(self, field_name: str, t):
if field_name == "pred_original_sample": if field_name == "pred_original_sample":
return torch.zeros_like(t, dtype=t.dtype) # it represents t=0 return self.scheduler.timesteps[-1]
return t return t
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor: def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
@ -631,7 +631,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
control_latent_input = torch.cat([unet_latent_input] * 2) control_latent_input = torch.cat([unet_latent_input] * 2)
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings]) encoder_hidden_states = conditioning_data.text_embeddings
else: else:
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings, encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings]) conditioning_data.text_embeddings])

View File

@ -36,6 +36,7 @@ module.exports = {
], ],
'prettier/prettier': ['error', { endOfLine: 'auto' }], 'prettier/prettier': ['error', { endOfLine: 'auto' }],
'@typescript-eslint/ban-ts-comment': 'warn', '@typescript-eslint/ban-ts-comment': 'warn',
'@typescript-eslint/no-explicit-any': 'warn',
'@typescript-eslint/no-empty-interface': [ '@typescript-eslint/no-empty-interface': [
'error', 'error',
{ {

View File

@ -108,6 +108,7 @@
"roarr": "^7.15.0", "roarr": "^7.15.0",
"serialize-error": "^11.0.0", "serialize-error": "^11.0.0",
"socket.io-client": "^4.7.0", "socket.io-client": "^4.7.0",
"use-debounce": "^9.0.4",
"use-image": "^1.1.1", "use-image": "^1.1.1",
"uuid": "^9.0.0", "uuid": "^9.0.0",
"zod": "^3.21.4" "zod": "^3.21.4"

View File

@ -102,7 +102,8 @@
"openInNewTab": "Open in New Tab", "openInNewTab": "Open in New Tab",
"dontAskMeAgain": "Don't ask me again", "dontAskMeAgain": "Don't ask me again",
"areYouSure": "Are you sure?", "areYouSure": "Are you sure?",
"imagePrompt": "Image Prompt" "imagePrompt": "Image Prompt",
"clearNodes": "Are you sure you want to clear all nodes?"
}, },
"gallery": { "gallery": {
"generations": "Generations", "generations": "Generations",
@ -118,7 +119,7 @@
"pinGallery": "Pin Gallery", "pinGallery": "Pin Gallery",
"allImagesLoaded": "All Images Loaded", "allImagesLoaded": "All Images Loaded",
"loadMore": "Load More", "loadMore": "Load More",
"noImagesInGallery": "No Images In Gallery", "noImagesInGallery": "No Images to Display",
"deleteImage": "Delete Image", "deleteImage": "Delete Image",
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.", "deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
"deleteImagePermanent": "Deleted images cannot be restored.", "deleteImagePermanent": "Deleted images cannot be restored.",
@ -600,7 +601,8 @@
"initialImageNotSetDesc": "Could not load initial image", "initialImageNotSetDesc": "Could not load initial image",
"nodesSaved": "Nodes Saved", "nodesSaved": "Nodes Saved",
"nodesLoaded": "Nodes Loaded", "nodesLoaded": "Nodes Loaded",
"nodesLoadedFailed": "Failed To Load Nodes" "nodesLoadedFailed": "Failed To Load Nodes",
"nodesCleared": "Nodes Cleared"
}, },
"tooltip": { "tooltip": {
"feature": { "feature": {
@ -685,6 +687,7 @@
"nodes": { "nodes": {
"reloadSchema": "Reload Schema", "reloadSchema": "Reload Schema",
"saveNodes": "Save Nodes", "saveNodes": "Save Nodes",
"loadNodes": "Load Nodes" "loadNodes": "Load Nodes",
"clearNodes": "Clear Nodes"
} }
} }

View File

@ -6,9 +6,7 @@ import { PartialAppConfig } from 'app/types/invokeai';
import ImageUploader from 'common/components/ImageUploader'; import ImageUploader from 'common/components/ImageUploader';
import GalleryDrawer from 'features/gallery/components/GalleryPanel'; import GalleryDrawer from 'features/gallery/components/GalleryPanel';
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal'; import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
import Lightbox from 'features/lightbox/components/Lightbox';
import SiteHeader from 'features/system/components/SiteHeader'; import SiteHeader from 'features/system/components/SiteHeader';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { configChanged } from 'features/system/store/configSlice'; import { configChanged } from 'features/system/store/configSlice';
import { languageSelector } from 'features/system/store/systemSelectors'; import { languageSelector } from 'features/system/store/systemSelectors';
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton'; import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
@ -34,8 +32,6 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
const log = useLogger(); const log = useLogger();
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
useEffect(() => { useEffect(() => {
@ -54,7 +50,6 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
return ( return (
<> <>
<Grid w="100vw" h="100vh" position="relative" overflow="hidden"> <Grid w="100vw" h="100vh" position="relative" overflow="hidden">
{isLightboxEnabled && <Lightbox />}
<ImageUploader> <ImageUploader>
<Grid <Grid
sx={{ sx={{

View File

@ -1,8 +1,4 @@
import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react'; import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { memo } from 'react'; import { memo } from 'react';
import { TypesafeDraggableData } from './typesafeDnd'; import { TypesafeDraggableData } from './typesafeDnd';
@ -32,24 +28,7 @@ const STYLES: ChakraProps['sx'] = {
}, },
}; };
const selector = createSelector(
stateSelector,
(state) => {
const gallerySelectionCount = state.gallery.selection.length;
const batchSelectionCount = state.batch.selection.length;
return {
gallerySelectionCount,
batchSelectionCount,
};
},
defaultSelectorOptions
);
const DragPreview = (props: OverlayDragImageProps) => { const DragPreview = (props: OverlayDragImageProps) => {
const { gallerySelectionCount, batchSelectionCount } =
useAppSelector(selector);
if (!props.dragData) { if (!props.dragData) {
return; return;
} }
@ -82,7 +61,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
); );
} }
if (props.dragData.payloadType === 'BATCH_SELECTION') { if (props.dragData.payloadType === 'IMAGE_NAMES') {
return ( return (
<Flex <Flex
sx={{ sx={{
@ -95,26 +74,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
...STYLES, ...STYLES,
}} }}
> >
<Heading>{batchSelectionCount}</Heading> <Heading>{props.dragData.payload.image_names.length}</Heading>
<Heading size="sm">Images</Heading>
</Flex>
);
}
if (props.dragData.payloadType === 'GALLERY_SELECTION') {
return (
<Flex
sx={{
cursor: 'none',
userSelect: 'none',
position: 'relative',
alignItems: 'center',
justifyContent: 'center',
flexDir: 'column',
...STYLES,
}}
>
<Heading>{gallerySelectionCount}</Heading>
<Heading size="sm">Images</Heading> <Heading size="sm">Images</Heading>
</Flex> </Flex>
); );

View File

@ -6,18 +6,18 @@ import {
useSensor, useSensor,
useSensors, useSensors,
} from '@dnd-kit/core'; } from '@dnd-kit/core';
import { snapCenterToCursor } from '@dnd-kit/modifiers';
import { dndDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
import { useAppDispatch } from 'app/store/storeHooks';
import { AnimatePresence, motion } from 'framer-motion';
import { PropsWithChildren, memo, useCallback, useState } from 'react'; import { PropsWithChildren, memo, useCallback, useState } from 'react';
import DragPreview from './DragPreview'; import DragPreview from './DragPreview';
import { snapCenterToCursor } from '@dnd-kit/modifiers';
import { AnimatePresence, motion } from 'framer-motion';
import { import {
DndContext, DndContext,
DragEndEvent, DragEndEvent,
DragStartEvent, DragStartEvent,
TypesafeDraggableData, TypesafeDraggableData,
} from './typesafeDnd'; } from './typesafeDnd';
import { useAppDispatch } from 'app/store/storeHooks';
import { imageDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped';
type ImageDndContextProps = PropsWithChildren; type ImageDndContextProps = PropsWithChildren;
@ -42,18 +42,18 @@ const ImageDndContext = (props: ImageDndContextProps) => {
if (!activeData || !overData) { if (!activeData || !overData) {
return; return;
} }
dispatch(imageDropped({ overData, activeData })); dispatch(dndDropped({ overData, activeData }));
setActiveDragData(null); setActiveDragData(null);
}, },
[dispatch] [dispatch]
); );
const mouseSensor = useSensor(MouseSensor, { const mouseSensor = useSensor(MouseSensor, {
activationConstraint: { delay: 150, tolerance: 5 }, activationConstraint: { distance: 10 },
}); });
const touchSensor = useSensor(TouchSensor, { const touchSensor = useSensor(TouchSensor, {
activationConstraint: { delay: 150, tolerance: 5 }, activationConstraint: { distance: 10 },
}); });
// TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos // TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos

View File

@ -77,18 +77,14 @@ export type ImageDraggableData = BaseDragData & {
payload: { imageDTO: ImageDTO }; payload: { imageDTO: ImageDTO };
}; };
export type GallerySelectionDraggableData = BaseDragData & { export type ImageNamesDraggableData = BaseDragData & {
payloadType: 'GALLERY_SELECTION'; payloadType: 'IMAGE_NAMES';
}; payload: { image_names: string[] };
export type BatchSelectionDraggableData = BaseDragData & {
payloadType: 'BATCH_SELECTION';
}; };
export type TypesafeDraggableData = export type TypesafeDraggableData =
| ImageDraggableData | ImageDraggableData
| GallerySelectionDraggableData | ImageNamesDraggableData;
| BatchSelectionDraggableData;
interface UseDroppableTypesafeArguments interface UseDroppableTypesafeArguments
extends Omit<UseDroppableArguments, 'data'> { extends Omit<UseDroppableArguments, 'data'> {
@ -159,13 +155,11 @@ export const isValidDrop = (
case 'SET_NODES_IMAGE': case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO'; return payloadType === 'IMAGE_DTO';
case 'SET_MULTI_NODES_IMAGE': case 'SET_MULTI_NODES_IMAGE':
return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION'; return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
case 'ADD_TO_BATCH': case 'ADD_TO_BATCH':
return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION'; return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
case 'MOVE_BOARD': case 'MOVE_BOARD':
return ( return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION' || 'BATCH_SELECTION'
);
default: default:
return false; return false;
} }

View File

@ -1,67 +0,0 @@
// import { createAction } from '@reduxjs/toolkit';
// import * as InvokeAI from 'app/types/invokeai';
// import { GalleryCategory } from 'features/gallery/store/gallerySlice';
// import { InvokeTabName } from 'features/ui/store/tabMap';
// /**
// * We can't use redux-toolkit's createSlice() to make these actions,
// * because they have no associated reducer. They only exist to dispatch
// * requests to the server via socketio. These actions will be handled
// * by the middleware.
// */
// export const generateImage = createAction<InvokeTabName>(
// 'socketio/generateImage'
// );
// export const runESRGAN = createAction<InvokeAI._Image>('socketio/runESRGAN');
// export const runFacetool = createAction<InvokeAI._Image>(
// 'socketio/runFacetool'
// );
// export const deleteImage = createAction<InvokeAI._Image>(
// 'socketio/deleteImage'
// );
// export const requestImages = createAction<GalleryCategory>(
// 'socketio/requestImages'
// );
// export const requestNewImages = createAction<GalleryCategory>(
// 'socketio/requestNewImages'
// );
// export const cancelProcessing = createAction<undefined>(
// 'socketio/cancelProcessing'
// );
// export const requestSystemConfig = createAction<undefined>(
// 'socketio/requestSystemConfig'
// );
// export const searchForModels = createAction<string>('socketio/searchForModels');
// export const addNewModel = createAction<
// InvokeAI.InvokeModelConfigProps | InvokeAI.InvokeDiffusersModelConfigProps
// >('socketio/addNewModel');
// export const deleteModel = createAction<string>('socketio/deleteModel');
// export const convertToDiffusers =
// createAction<InvokeAI.InvokeModelConversionProps>(
// 'socketio/convertToDiffusers'
// );
// export const mergeDiffusersModels =
// createAction<InvokeAI.InvokeModelMergingProps>(
// 'socketio/mergeDiffusersModels'
// );
// export const requestModelChange = createAction<string>(
// 'socketio/requestModelChange'
// );
// export const saveStagingAreaImageToGallery = createAction<string>(
// 'socketio/saveStagingAreaImageToGallery'
// );
// export const emptyTempFolder = createAction<undefined>(
// 'socketio/requestEmptyTempFolder'
// );
export default {};

View File

@ -1,209 +0,0 @@
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai';
import type { RootState } from 'app/store/store';
import {
frontendToBackendParameters,
FrontendToBackendParametersConfig,
} from 'common/util/parameterTranslation';
import dateFormat from 'dateformat';
import {
GalleryCategory,
GalleryState,
removeImage,
} from 'features/gallery/store/gallerySlice';
import {
generationRequested,
modelChangeRequested,
modelConvertRequested,
modelMergingRequested,
setIsProcessing,
} from 'features/system/store/systemSlice';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { Socket } from 'socket.io-client';
/**
* Returns an object containing all functions which use `socketio.emit()`.
* i.e. those which make server requests.
*/
const makeSocketIOEmitters = (
store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
socketio: Socket
) => {
// We need to dispatch actions to redux and get pieces of state from the store.
const { dispatch, getState } = store;
return {
emitGenerateImage: (generationMode: InvokeTabName) => {
dispatch(setIsProcessing(true));
const state: RootState = getState();
const {
generation: generationState,
postprocessing: postprocessingState,
system: systemState,
canvas: canvasState,
} = state;
const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
{
generationMode,
generationState,
postprocessingState,
canvasState,
systemState,
};
dispatch(generationRequested());
const { generationParameters, esrganParameters, facetoolParameters } =
frontendToBackendParameters(frontendToBackendParametersConfig);
socketio.emit(
'generateImage',
generationParameters,
esrganParameters,
facetoolParameters
);
// we need to truncate the init_mask base64 else it takes up the whole log
// TODO: handle maintaining masks for reproducibility in future
if (generationParameters.init_mask) {
generationParameters.init_mask = generationParameters.init_mask
.substr(0, 64)
.concat('...');
}
if (generationParameters.init_img) {
generationParameters.init_img = generationParameters.init_img
.substr(0, 64)
.concat('...');
}
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image generation requested: ${JSON.stringify({
...generationParameters,
...esrganParameters,
...facetoolParameters,
})}`,
})
);
},
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
const {
postprocessing: {
upscalingLevel,
upscalingDenoising,
upscalingStrength,
},
} = getState();
const esrganParameters = {
upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
};
socketio.emit('runPostprocessing', imageToProcess, {
type: 'esrgan',
...esrganParameters,
});
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `ESRGAN upscale requested: ${JSON.stringify({
file: imageToProcess.url,
...esrganParameters,
})}`,
})
);
},
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
const {
postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
} = getState();
const facetoolParameters: Record<string, unknown> = {
facetool_strength: facetoolStrength,
};
if (facetoolType === 'codeformer') {
facetoolParameters.codeformer_fidelity = codeformerFidelity;
}
socketio.emit('runPostprocessing', imageToProcess, {
type: facetoolType,
...facetoolParameters,
});
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
{
file: imageToProcess.url,
...facetoolParameters,
}
)}`,
})
);
},
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
const { url, uuid, category, thumbnail } = imageToDelete;
dispatch(removeImage(imageToDelete));
socketio.emit('deleteImage', url, thumbnail, uuid, category);
},
emitRequestImages: (category: GalleryCategory) => {
const gallery: GalleryState = getState().gallery;
const { earliest_mtime } = gallery.categories[category];
socketio.emit('requestImages', category, earliest_mtime);
},
emitRequestNewImages: (category: GalleryCategory) => {
const gallery: GalleryState = getState().gallery;
const { latest_mtime } = gallery.categories[category];
socketio.emit('requestLatestImages', category, latest_mtime);
},
emitCancelProcessing: () => {
socketio.emit('cancel');
},
emitRequestSystemConfig: () => {
socketio.emit('requestSystemConfig');
},
emitSearchForModels: (modelFolder: string) => {
socketio.emit('searchForModels', modelFolder);
},
emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
socketio.emit('addNewModel', modelConfig);
},
emitDeleteModel: (modelName: string) => {
socketio.emit('deleteModel', modelName);
},
emitConvertToDiffusers: (
modelToConvert: InvokeAI.InvokeModelConversionProps
) => {
dispatch(modelConvertRequested());
socketio.emit('convertToDiffusers', modelToConvert);
},
emitMergeDiffusersModels: (
modelMergeInfo: InvokeAI.InvokeModelMergingProps
) => {
dispatch(modelMergingRequested());
socketio.emit('mergeDiffusersModels', modelMergeInfo);
},
emitRequestModelChange: (modelName: string) => {
dispatch(modelChangeRequested());
socketio.emit('requestModelChange', modelName);
},
emitSaveStagingAreaImageToGallery: (url: string) => {
socketio.emit('requestSaveStagingAreaImageToGallery', url);
},
emitRequestEmptyTempFolder: () => {
socketio.emit('requestEmptyTempFolder');
},
};
};
export default makeSocketIOEmitters;
export default {};

View File

@ -1,502 +0,0 @@
// import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
// import dateFormat from 'dateformat';
// import i18n from 'i18n';
// import { v4 as uuidv4 } from 'uuid';
// import * as InvokeAI from 'app/types/invokeai';
// import {
// addToast,
// errorOccurred,
// processingCanceled,
// setCurrentStatus,
// setFoundModels,
// setIsCancelable,
// setIsConnected,
// setIsProcessing,
// setModelList,
// setSearchFolder,
// setSystemConfig,
// setSystemStatus,
// } from 'features/system/store/systemSlice';
// import {
// addGalleryImages,
// addImage,
// clearIntermediateImage,
// GalleryState,
// removeImage,
// setIntermediateImage,
// } from 'features/gallery/store/gallerySlice';
// import type { RootState } from 'app/store/store';
// import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
// import {
// clearInitialImage,
// initialImageSelected,
// setInfillMethod,
// // setInitialImage,
// setMaskPath,
// } from 'features/parameters/store/generationSlice';
// import { tabMap } from 'features/ui/store/tabMap';
// import {
// requestImages,
// requestNewImages,
// requestSystemConfig,
// } from './actions';
// /**
// * Returns an object containing listener callbacks for socketio events.
// * TODO: This file is large, but simple. Should it be split up further?
// */
// const makeSocketIOListeners = (
// store: MiddlewareAPI<Dispatch<AnyAction>, RootState>
// ) => {
// const { dispatch, getState } = store;
// return {
// /**
// * Callback to run when we receive a 'connect' event.
// */
// onConnect: () => {
// try {
// dispatch(setIsConnected(true));
// dispatch(setCurrentStatus(i18n.t('common.statusConnected')));
// dispatch(requestSystemConfig());
// const gallery: GalleryState = getState().gallery;
// if (gallery.categories.result.latest_mtime) {
// dispatch(requestNewImages('result'));
// } else {
// dispatch(requestImages('result'));
// }
// if (gallery.categories.user.latest_mtime) {
// dispatch(requestNewImages('user'));
// } else {
// dispatch(requestImages('user'));
// }
// } catch (e) {
// console.error(e);
// }
// },
// /**
// * Callback to run when we receive a 'disconnect' event.
// */
// onDisconnect: () => {
// try {
// dispatch(setIsConnected(false));
// dispatch(setCurrentStatus(i18n.t('common.statusDisconnected')));
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Disconnected from server`,
// level: 'warning',
// })
// );
// } catch (e) {
// console.error(e);
// }
// },
// /**
// * Callback to run when we receive a 'generationResult' event.
// */
// onGenerationResult: (data: InvokeAI.ImageResultResponse) => {
// try {
// const state = getState();
// const { activeTab } = state.ui;
// const { shouldLoopback } = state.postprocessing;
// const { boundingBox: _, generationMode, ...rest } = data;
// const newImage = {
// uuid: uuidv4(),
// ...rest,
// };
// if (['txt2img', 'img2img'].includes(generationMode)) {
// dispatch(
// addImage({
// category: 'result',
// image: { ...newImage, category: 'result' },
// })
// );
// }
// if (generationMode === 'unifiedCanvas' && data.boundingBox) {
// const { boundingBox } = data;
// dispatch(
// addImageToStagingArea({
// image: { ...newImage, category: 'temp' },
// boundingBox,
// })
// );
// if (state.canvas.shouldAutoSave) {
// dispatch(
// addImage({
// image: { ...newImage, category: 'result' },
// category: 'result',
// })
// );
// }
// }
// // TODO: fix
// // if (shouldLoopback) {
// // const activeTabName = tabMap[activeTab];
// // switch (activeTabName) {
// // case 'img2img': {
// // dispatch(initialImageSelected(newImage.uuid));
// // // dispatch(setInitialImage(newImage));
// // break;
// // }
// // }
// // }
// dispatch(clearIntermediateImage());
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Image generated: ${data.url}`,
// })
// );
// } catch (e) {
// console.error(e);
// }
// },
// /**
// * Callback to run when we receive a 'intermediateResult' event.
// */
// onIntermediateResult: (data: InvokeAI.ImageResultResponse) => {
// try {
// dispatch(
// setIntermediateImage({
// uuid: uuidv4(),
// ...data,
// category: 'result',
// })
// );
// if (!data.isBase64) {
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Intermediate image generated: ${data.url}`,
// })
// );
// }
// } catch (e) {
// console.error(e);
// }
// },
// /**
// * Callback to run when we receive an 'esrganResult' event.
// */
// onPostprocessingResult: (data: InvokeAI.ImageResultResponse) => {
// try {
// dispatch(
// addImage({
// category: 'result',
// image: {
// uuid: uuidv4(),
// ...data,
// category: 'result',
// },
// })
// );
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Postprocessed: ${data.url}`,
// })
// );
// } catch (e) {
// console.error(e);
// }
// },
// /**
// * Callback to run when we receive a 'progressUpdate' event.
// * TODO: Add additional progress phases
// */
// onProgressUpdate: (data: InvokeAI.SystemStatus) => {
// try {
// dispatch(setIsProcessing(true));
// dispatch(setSystemStatus(data));
// } catch (e) {
// console.error(e);
// }
// },
// /**
// * Callback to run when we receive a 'progressUpdate' event.
// */
// onError: (data: InvokeAI.ErrorResponse) => {
// const { message, additionalData } = data;
// if (additionalData) {
// // TODO: handle more data than short message
// }
// try {
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Server error: ${message}`,
// level: 'error',
// })
// );
// dispatch(errorOccurred());
// dispatch(clearIntermediateImage());
// } catch (e) {
// console.error(e);
// }
// },
// /**
// * Callback to run when we receive a 'galleryImages' event.
// */
// onGalleryImages: (data: InvokeAI.GalleryImagesResponse) => {
// const { images, areMoreImagesAvailable, category } = data;
// /**
// * the logic here ideally would be in the reducer but we have a side effect:
// * generating a uuid. so the logic needs to be here, outside redux.
// */
// // Generate a UUID for each image
// const preparedImages = images.map((image): InvokeAI._Image => {
// return {
// uuid: uuidv4(),
// ...image,
// };
// });
// dispatch(
// addGalleryImages({
// images: preparedImages,
// areMoreImagesAvailable,
// category,
// })
// );
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Loaded ${images.length} images`,
// })
// );
// },
// /**
// * Callback to run when we receive a 'processingCanceled' event.
// */
// onProcessingCanceled: () => {
// dispatch(processingCanceled());
// const { intermediateImage } = getState().gallery;
// if (intermediateImage) {
// if (!intermediateImage.isBase64) {
// dispatch(
// addImage({
// category: 'result',
// image: intermediateImage,
// })
// );
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Intermediate image saved: ${intermediateImage.url}`,
// })
// );
// }
// dispatch(clearIntermediateImage());
// }
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Processing canceled`,
// level: 'warning',
// })
// );
// },
// /**
// * Callback to run when we receive a 'imageDeleted' event.
// */
// onImageDeleted: (data: InvokeAI.ImageDeletedResponse) => {
// const { url } = data;
// // remove image from gallery
// dispatch(removeImage(data));
// // remove references to image in options
// const {
// generation: { initialImage, maskPath },
// } = getState();
// if (
// initialImage === url ||
// (initialImage as InvokeAI._Image)?.url === url
// ) {
// dispatch(clearInitialImage());
// }
// if (maskPath === url) {
// dispatch(setMaskPath(''));
// }
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Image deleted: ${url}`,
// })
// );
// },
// onSystemConfig: (data: InvokeAI.SystemConfig) => {
// dispatch(setSystemConfig(data));
// if (!data.infill_methods.includes('patchmatch')) {
// dispatch(setInfillMethod(data.infill_methods[0]));
// }
// },
// onFoundModels: (data: InvokeAI.FoundModelResponse) => {
// const { search_folder, found_models } = data;
// dispatch(setSearchFolder(search_folder));
// dispatch(setFoundModels(found_models));
// },
// onNewModelAdded: (data: InvokeAI.ModelAddedResponse) => {
// const { new_model_name, model_list, update } = data;
// dispatch(setModelList(model_list));
// dispatch(setIsProcessing(false));
// dispatch(setCurrentStatus(i18n.t('modelManager.modelAdded')));
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Model Added: ${new_model_name}`,
// level: 'info',
// })
// );
// dispatch(
// addToast({
// title: !update
// ? `${i18n.t('modelManager.modelAdded')}: ${new_model_name}`
// : `${i18n.t('modelManager.modelUpdated')}: ${new_model_name}`,
// status: 'success',
// duration: 2500,
// isClosable: true,
// })
// );
// },
// onModelDeleted: (data: InvokeAI.ModelDeletedResponse) => {
// const { deleted_model_name, model_list } = data;
// dispatch(setModelList(model_list));
// dispatch(setIsProcessing(false));
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `${i18n.t(
// 'modelManager.modelAdded'
// )}: ${deleted_model_name}`,
// level: 'info',
// })
// );
// dispatch(
// addToast({
// title: `${i18n.t(
// 'modelManager.modelEntryDeleted'
// )}: ${deleted_model_name}`,
// status: 'success',
// duration: 2500,
// isClosable: true,
// })
// );
// },
// onModelConverted: (data: InvokeAI.ModelConvertedResponse) => {
// const { converted_model_name, model_list } = data;
// dispatch(setModelList(model_list));
// dispatch(setCurrentStatus(i18n.t('common.statusModelConverted')));
// dispatch(setIsProcessing(false));
// dispatch(setIsCancelable(true));
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Model converted: ${converted_model_name}`,
// level: 'info',
// })
// );
// dispatch(
// addToast({
// title: `${i18n.t(
// 'modelManager.modelConverted'
// )}: ${converted_model_name}`,
// status: 'success',
// duration: 2500,
// isClosable: true,
// })
// );
// },
// onModelsMerged: (data: InvokeAI.ModelsMergedResponse) => {
// const { merged_models, merged_model_name, model_list } = data;
// dispatch(setModelList(model_list));
// dispatch(setCurrentStatus(i18n.t('common.statusMergedModels')));
// dispatch(setIsProcessing(false));
// dispatch(setIsCancelable(true));
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Models merged: ${merged_models}`,
// level: 'info',
// })
// );
// dispatch(
// addToast({
// title: `${i18n.t('modelManager.modelsMerged')}: ${merged_model_name}`,
// status: 'success',
// duration: 2500,
// isClosable: true,
// })
// );
// },
// onModelChanged: (data: InvokeAI.ModelChangeResponse) => {
// const { model_name, model_list } = data;
// dispatch(setModelList(model_list));
// dispatch(setCurrentStatus(i18n.t('common.statusModelChanged')));
// dispatch(setIsProcessing(false));
// dispatch(setIsCancelable(true));
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Model changed: ${model_name}`,
// level: 'info',
// })
// );
// },
// onModelChangeFailed: (data: InvokeAI.ModelChangeResponse) => {
// const { model_name, model_list } = data;
// dispatch(setModelList(model_list));
// dispatch(setIsProcessing(false));
// dispatch(setIsCancelable(true));
// dispatch(errorOccurred());
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Model change failed: ${model_name}`,
// level: 'error',
// })
// );
// },
// onTempFolderEmptied: () => {
// dispatch(
// addToast({
// title: i18n.t('toast.tempFoldersEmptied'),
// status: 'success',
// duration: 2500,
// isClosable: true,
// })
// );
// },
// };
// };
// export default makeSocketIOListeners;
export default {};

View File

@ -1,248 +0,0 @@
// import { Middleware } from '@reduxjs/toolkit';
// import { io } from 'socket.io-client';
// import makeSocketIOEmitters from './emitters';
// import makeSocketIOListeners from './listeners';
// import * as InvokeAI from 'app/types/invokeai';
// /**
// * Creates a socketio middleware to handle communication with server.
// *
// * Special `socketio/actionName` actions are created in actions.ts and
// * exported for use by the application, which treats them like any old
// * action, using `dispatch` to dispatch them.
// *
// * These actions are intercepted here, where `socketio.emit()` calls are
// * made on their behalf - see `emitters.ts`. The emitter functions
// * are the outbound communication to the server.
// *
// * Listeners are also established here - see `listeners.ts`. The listener
// * functions receive communication from the server and usually dispatch
// * some new action to handle whatever data was sent from the server.
// */
// export const socketioMiddleware = () => {
// const { origin } = new URL(window.location.href);
// const socketio = io(origin, {
// timeout: 60000,
// path: `${window.location.pathname}socket.io`,
// });
// socketio.disconnect();
// let areListenersSet = false;
// const middleware: Middleware = (store) => (next) => (action) => {
// const {
// onConnect,
// onDisconnect,
// onError,
// onPostprocessingResult,
// onGenerationResult,
// onIntermediateResult,
// onProgressUpdate,
// onGalleryImages,
// onProcessingCanceled,
// onImageDeleted,
// onSystemConfig,
// onModelChanged,
// onFoundModels,
// onNewModelAdded,
// onModelDeleted,
// onModelConverted,
// onModelsMerged,
// onModelChangeFailed,
// onTempFolderEmptied,
// } = makeSocketIOListeners(store);
// const {
// emitGenerateImage,
// emitRunESRGAN,
// emitRunFacetool,
// emitDeleteImage,
// emitRequestImages,
// emitRequestNewImages,
// emitCancelProcessing,
// emitRequestSystemConfig,
// emitSearchForModels,
// emitAddNewModel,
// emitDeleteModel,
// emitConvertToDiffusers,
// emitMergeDiffusersModels,
// emitRequestModelChange,
// emitSaveStagingAreaImageToGallery,
// emitRequestEmptyTempFolder,
// } = makeSocketIOEmitters(store, socketio);
// /**
// * If this is the first time the middleware has been called (e.g. during store setup),
// * initialize all our socket.io listeners.
// */
// if (!areListenersSet) {
// socketio.on('connect', () => onConnect());
// socketio.on('disconnect', () => onDisconnect());
// socketio.on('error', (data: InvokeAI.ErrorResponse) => onError(data));
// socketio.on('generationResult', (data: InvokeAI.ImageResultResponse) =>
// onGenerationResult(data)
// );
// socketio.on(
// 'postprocessingResult',
// (data: InvokeAI.ImageResultResponse) => onPostprocessingResult(data)
// );
// socketio.on('intermediateResult', (data: InvokeAI.ImageResultResponse) =>
// onIntermediateResult(data)
// );
// socketio.on('progressUpdate', (data: InvokeAI.SystemStatus) =>
// onProgressUpdate(data)
// );
// socketio.on('galleryImages', (data: InvokeAI.GalleryImagesResponse) =>
// onGalleryImages(data)
// );
// socketio.on('processingCanceled', () => {
// onProcessingCanceled();
// });
// socketio.on('imageDeleted', (data: InvokeAI.ImageDeletedResponse) => {
// onImageDeleted(data);
// });
// socketio.on('systemConfig', (data: InvokeAI.SystemConfig) => {
// onSystemConfig(data);
// });
// socketio.on('foundModels', (data: InvokeAI.FoundModelResponse) => {
// onFoundModels(data);
// });
// socketio.on('newModelAdded', (data: InvokeAI.ModelAddedResponse) => {
// onNewModelAdded(data);
// });
// socketio.on('modelDeleted', (data: InvokeAI.ModelDeletedResponse) => {
// onModelDeleted(data);
// });
// socketio.on('modelConverted', (data: InvokeAI.ModelConvertedResponse) => {
// onModelConverted(data);
// });
// socketio.on('modelsMerged', (data: InvokeAI.ModelsMergedResponse) => {
// onModelsMerged(data);
// });
// socketio.on('modelChanged', (data: InvokeAI.ModelChangeResponse) => {
// onModelChanged(data);
// });
// socketio.on('modelChangeFailed', (data: InvokeAI.ModelChangeResponse) => {
// onModelChangeFailed(data);
// });
// socketio.on('tempFolderEmptied', () => {
// onTempFolderEmptied();
// });
// areListenersSet = true;
// }
// /**
// * Handle redux actions caught by middleware.
// */
// switch (action.type) {
// case 'socketio/generateImage': {
// emitGenerateImage(action.payload);
// break;
// }
// case 'socketio/runESRGAN': {
// emitRunESRGAN(action.payload);
// break;
// }
// case 'socketio/runFacetool': {
// emitRunFacetool(action.payload);
// break;
// }
// case 'socketio/deleteImage': {
// emitDeleteImage(action.payload);
// break;
// }
// case 'socketio/requestImages': {
// emitRequestImages(action.payload);
// break;
// }
// case 'socketio/requestNewImages': {
// emitRequestNewImages(action.payload);
// break;
// }
// case 'socketio/cancelProcessing': {
// emitCancelProcessing();
// break;
// }
// case 'socketio/requestSystemConfig': {
// emitRequestSystemConfig();
// break;
// }
// case 'socketio/searchForModels': {
// emitSearchForModels(action.payload);
// break;
// }
// case 'socketio/addNewModel': {
// emitAddNewModel(action.payload);
// break;
// }
// case 'socketio/deleteModel': {
// emitDeleteModel(action.payload);
// break;
// }
// case 'socketio/convertToDiffusers': {
// emitConvertToDiffusers(action.payload);
// break;
// }
// case 'socketio/mergeDiffusersModels': {
// emitMergeDiffusersModels(action.payload);
// break;
// }
// case 'socketio/requestModelChange': {
// emitRequestModelChange(action.payload);
// break;
// }
// case 'socketio/saveStagingAreaImageToGallery': {
// emitSaveStagingAreaImageToGallery(action.payload);
// break;
// }
// case 'socketio/requestEmptyTempFolder': {
// emitRequestEmptyTempFolder();
// break;
// }
// }
// next(action);
// };
// return middleware;
// };
export default {};

View File

@ -1,7 +1,6 @@
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist'; import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
import { controlNetDenylist } from 'features/controlNet/store/controlNetDenylist'; import { controlNetDenylist } from 'features/controlNet/store/controlNetDenylist';
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist'; import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist'; import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist'; import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist'; import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
@ -16,7 +15,6 @@ const serializationDenylist: {
canvas: canvasPersistDenylist, canvas: canvasPersistDenylist,
gallery: galleryPersistDenylist, gallery: galleryPersistDenylist,
generation: generationPersistDenylist, generation: generationPersistDenylist,
lightbox: lightboxPersistDenylist,
nodes: nodesPersistDenylist, nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist, postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist, system: systemPersistDenylist,

View File

@ -1,7 +1,6 @@
import { initialCanvasState } from 'features/canvas/store/canvasSlice'; import { initialCanvasState } from 'features/canvas/store/canvasSlice';
import { initialControlNetState } from 'features/controlNet/store/controlNetSlice'; import { initialControlNetState } from 'features/controlNet/store/controlNetSlice';
import { initialGalleryState } from 'features/gallery/store/gallerySlice'; import { initialGalleryState } from 'features/gallery/store/gallerySlice';
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
import { initialNodesState } from 'features/nodes/store/nodesSlice'; import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice'; import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
@ -18,7 +17,6 @@ const initialStates: {
canvas: initialCanvasState, canvas: initialCanvasState,
gallery: initialGalleryState, gallery: initialGalleryState,
generation: initialGenerationState, generation: initialGenerationState,
lightbox: initialLightboxState,
nodes: initialNodesState, nodes: initialNodesState,
postprocessing: initialPostprocessingState, postprocessing: initialPostprocessingState,
system: initialSystemState, system: initialSystemState,

View File

@ -1,4 +1,8 @@
/**
* This is a list of actions that should be excluded in the Redux DevTools.
*/
export const actionsDenylist = [ export const actionsDenylist = [
// very spammy canvas actions
'canvas/setCursorPosition', 'canvas/setCursorPosition',
'canvas/setStageCoordinates', 'canvas/setStageCoordinates',
'canvas/setStageScale', 'canvas/setStageScale',
@ -7,7 +11,11 @@ export const actionsDenylist = [
'canvas/setBoundingBoxDimensions', 'canvas/setBoundingBoxDimensions',
'canvas/setIsDrawing', 'canvas/setIsDrawing',
'canvas/addPointToCurrentLine', 'canvas/addPointToCurrentLine',
// bazillions during generation
'socket/socketGeneratorProgress', 'socket/socketGeneratorProgress',
'socket/appSocketGeneratorProgress', 'socket/appSocketGeneratorProgress',
// every time user presses shift
'hotkeys/shiftKeyPressed', 'hotkeys/shiftKeyPressed',
// this happens after every state change
'@@REMEMBER_PERSISTED',
]; ];

View File

@ -8,6 +8,7 @@ import {
import type { AppDispatch, RootState } from '../../store'; import type { AppDispatch, RootState } from '../../store';
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener'; import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
import { addAppConfigReceivedListener } from './listeners/appConfigReceived';
import { addAppStartedListener } from './listeners/appStarted'; import { addAppStartedListener } from './listeners/appStarted';
import { addBoardIdSelectedListener } from './listeners/boardIdSelected'; import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesDeleted'; import { addRequestedBoardImageDeletionListener } from './listeners/boardImagesDeleted';
@ -51,12 +52,12 @@ import {
} from './listeners/imageUrlsReceived'; } from './listeners/imageUrlsReceived';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected'; import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
import { addModelSelectedListener } from './listeners/modelSelected'; import { addModelSelectedListener } from './listeners/modelSelected';
import { addModelsLoadedListener } from './listeners/modelsLoaded';
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema'; import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
import { import {
addReceivedPageOfImagesFulfilledListener, addReceivedPageOfImagesFulfilledListener,
addReceivedPageOfImagesRejectedListener, addReceivedPageOfImagesRejectedListener,
} from './listeners/receivedPageOfImages'; } from './listeners/receivedPageOfImages';
import { addSelectionAddedToBatchListener } from './listeners/selectionAddedToBatch';
import { import {
addSessionCanceledFulfilledListener, addSessionCanceledFulfilledListener,
addSessionCanceledPendingListener, addSessionCanceledPendingListener,
@ -213,9 +214,6 @@ addBoardIdSelectedListener();
// Node schemas // Node schemas
addReceivedOpenAPISchemaListener(); addReceivedOpenAPISchemaListener();
// Batches
addSelectionAddedToBatchListener();
// DND // DND
addImageDroppedListener(); addImageDroppedListener();
@ -224,3 +222,5 @@ addModelSelectedListener();
// app startup // app startup
addAppStartedListener(); addAppStartedListener();
addModelsLoadedListener();
addAppConfigReceivedListener();

View File

@ -0,0 +1,17 @@
import { setInfillMethod } from 'features/parameters/store/generationSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import { startAppListening } from '..';
export const addAppConfigReceivedListener = () => {
startAppListening({
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
const { infill_methods } = action.payload;
const infillMethod = getState().generation.infillMethod;
if (!infill_methods.includes(infillMethod)) {
dispatch(setInfillMethod(infill_methods[0]));
}
},
});
};

View File

@ -1,5 +1,7 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { import {
ASSETS_CATEGORIES,
IMAGE_CATEGORIES,
INITIAL_IMAGE_LIMIT, INITIAL_IMAGE_LIMIT,
isLoadingChanged, isLoadingChanged,
} from 'features/gallery/store/gallerySlice'; } from 'features/gallery/store/gallerySlice';
@ -20,7 +22,7 @@ export const addAppStartedListener = () => {
// fill up the gallery tab with images // fill up the gallery tab with images
await dispatch( await dispatch(
receivedPageOfImages({ receivedPageOfImages({
categories: ['general'], categories: IMAGE_CATEGORIES,
is_intermediate: false, is_intermediate: false,
offset: 0, offset: 0,
limit: INITIAL_IMAGE_LIMIT, limit: INITIAL_IMAGE_LIMIT,
@ -30,7 +32,7 @@ export const addAppStartedListener = () => {
// fill up the assets tab with images // fill up the assets tab with images
await dispatch( await dispatch(
receivedPageOfImages({ receivedPageOfImages({
categories: ['control', 'mask', 'user', 'other'], categories: ASSETS_CATEGORIES,
is_intermediate: false, is_intermediate: false,
offset: 0, offset: 0,
limit: INITIAL_IMAGE_LIMIT, limit: INITIAL_IMAGE_LIMIT,

View File

@ -1,15 +1,18 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..'; import { selectFilteredImages } from 'features/gallery/store/gallerySelectors';
import { import {
ASSETS_CATEGORIES,
IMAGE_CATEGORIES,
boardIdSelected,
imageSelected, imageSelected,
selectImagesAll, selectImagesAll,
boardIdSelected,
} from 'features/gallery/store/gallerySlice'; } from 'features/gallery/store/gallerySlice';
import { boardsApi } from 'services/api/endpoints/boards';
import { import {
IMAGES_PER_PAGE, IMAGES_PER_PAGE,
receivedPageOfImages, receivedPageOfImages,
} from 'services/api/thunks/image'; } from 'services/api/thunks/image';
import { boardsApi } from 'services/api/endpoints/boards'; import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'boards' }); const moduleLog = log.child({ namespace: 'boards' });
@ -24,19 +27,24 @@ export const addBoardIdSelectedListener = () => {
const state = getState(); const state = getState();
const allImages = selectImagesAll(state); const allImages = selectImagesAll(state);
if (!board_id) { if (board_id === 'all') {
// a board was unselected // Selected all images
dispatch(imageSelected(allImages[0]?.image_name)); dispatch(imageSelected(allImages[0]?.image_name ?? null));
return; return;
} }
const { categories } = state.gallery; if (board_id === 'batch') {
// Selected the batch
dispatch(imageSelected(state.gallery.batchImageNames[0] ?? null));
return;
}
const filteredImages = allImages.filter((i) => { const filteredImages = selectFilteredImages(state);
const isInCategory = categories.includes(i.image_category);
const isInSelectedBoard = board_id ? i.board_id === board_id : true; const categories =
return isInCategory && isInSelectedBoard; state.gallery.galleryView === 'images'
}); ? IMAGE_CATEGORIES
: ASSETS_CATEGORIES;
// get the board from the cache // get the board from the cache
const { data: boards } = const { data: boards } =
@ -45,7 +53,7 @@ export const addBoardIdSelectedListener = () => {
if (!board) { if (!board) {
// can't find the board in cache... // can't find the board in cache...
dispatch(imageSelected(allImages[0]?.image_name)); dispatch(boardIdSelected('all'));
return; return;
} }
@ -63,48 +71,3 @@ export const addBoardIdSelectedListener = () => {
}, },
}); });
}; };
export const addBoardIdSelected_changeSelectedImage_listener = () => {
startAppListening({
actionCreator: boardIdSelected,
effect: (action, { getState, dispatch }) => {
const board_id = action.payload;
const state = getState();
// we need to check if we need to fetch more images
if (!board_id) {
// a board was unselected - we don't need to do anything
return;
}
const { categories } = state.gallery;
const filteredImages = selectImagesAll(state).filter((i) => {
const isInCategory = categories.includes(i.image_category);
const isInSelectedBoard = board_id ? i.board_id === board_id : true;
return isInCategory && isInSelectedBoard;
});
// get the board from the cache
const { data: boards } =
boardsApi.endpoints.listAllBoards.select()(state);
const board = boards?.find((b) => b.board_id === board_id);
if (!board) {
// can't find the board in cache...
return;
}
// if we haven't loaded one full page of images from this board, load more
if (
filteredImages.length < board.image_count &&
filteredImages.length < IMAGES_PER_PAGE
) {
dispatch(
receivedPageOfImages({ categories, board_id, is_intermediate: false })
);
}
},
});
};

View File

@ -1,13 +1,13 @@
import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/api/thunks/image';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { controlNetImageProcessed } from 'features/controlNet/store/actions'; import { controlNetImageProcessed } from 'features/controlNet/store/actions';
import { Graph } from 'services/api/types';
import { sessionCreated } from 'services/api/thunks/session';
import { sessionReadyToInvoke } from 'features/system/store/actions';
import { socketInvocationComplete } from 'services/events/actions';
import { isImageOutput } from 'services/api/guards';
import { controlNetProcessedImageChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetProcessedImageChanged } from 'features/controlNet/store/controlNetSlice';
import { sessionReadyToInvoke } from 'features/system/store/actions';
import { isImageOutput } from 'services/api/guards';
import { imageDTOReceived } from 'services/api/thunks/image';
import { sessionCreated } from 'services/api/thunks/session';
import { Graph } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'controlNet' }); const moduleLog = log.child({ namespace: 'controlNet' });
@ -63,10 +63,8 @@ export const addControlNetImageProcessedListener = () => {
// Wait for the ImageDTO to be received // Wait for the ImageDTO to be received
const [imageMetadataReceivedAction] = await take( const [imageMetadataReceivedAction] = await take(
( (action): action is ReturnType<typeof imageDTOReceived.fulfilled> =>
action imageDTOReceived.fulfilled.match(action) &&
): action is ReturnType<typeof imageMetadataReceived.fulfilled> =>
imageMetadataReceived.fulfilled.match(action) &&
action.payload.image_name === image_name action.payload.image_name === image_name
); );
const processedControlImage = imageMetadataReceivedAction.payload; const processedControlImage = imageMetadataReceivedAction.payload;

View File

@ -1,7 +1,6 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/api/thunks/image';
import { boardImagesApi } from 'services/api/endpoints/boardImages'; import { boardImagesApi } from 'services/api/endpoints/boardImages';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'boards' }); const moduleLog = log.child({ namespace: 'boards' });
@ -15,12 +14,6 @@ export const addImageAddedToBoardFulfilledListener = () => {
{ data: { board_id, image_name } }, { data: { board_id, image_name } },
'Image added to board' 'Image added to board'
); );
dispatch(
imageMetadataReceived({
image_name,
})
);
}, },
}); });
}; };

View File

@ -1,10 +1,10 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { selectNextImageToSelect } from 'features/gallery/store/gallerySelectors';
import { import {
imageRemoved, imageRemoved,
imageSelected, imageSelected,
selectFilteredImages,
} from 'features/gallery/store/gallerySlice'; } from 'features/gallery/store/gallerySlice';
import { import {
imageDeletionConfirmed, imageDeletionConfirmed,
@ -12,7 +12,6 @@ import {
} from 'features/imageDeletion/store/imageDeletionSlice'; } from 'features/imageDeletion/store/imageDeletionSlice';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { clamp } from 'lodash-es';
import { api } from 'services/api'; import { api } from 'services/api';
import { imageDeleted } from 'services/api/thunks/image'; import { imageDeleted } from 'services/api/thunks/image';
import { startAppListening } from '..'; import { startAppListening } from '..';
@ -37,26 +36,10 @@ export const addRequestedImageDeletionListener = () => {
state.gallery.selection[state.gallery.selection.length - 1]; state.gallery.selection[state.gallery.selection.length - 1];
if (lastSelectedImage === image_name) { if (lastSelectedImage === image_name) {
const filteredImages = selectFilteredImages(state); const newSelectedImageId = selectNextImageToSelect(state, image_name);
const ids = filteredImages.map((i) => i.image_name);
const deletedImageIndex = ids.findIndex(
(result) => result.toString() === image_name
);
const filteredIds = ids.filter((id) => id.toString() !== image_name);
const newSelectedImageIndex = clamp(
deletedImageIndex,
0,
filteredIds.length - 1
);
const newSelectedImageId = filteredIds[newSelectedImageIndex];
if (newSelectedImageId) { if (newSelectedImageId) {
dispatch(imageSelected(newSelectedImageId as string)); dispatch(imageSelected(newSelectedImageId));
} else { } else {
dispatch(imageSelected(null)); dispatch(imageSelected(null));
} }

View File

@ -4,13 +4,12 @@ import {
TypesafeDroppableData, TypesafeDroppableData,
} from 'app/components/ImageDnd/typesafeDnd'; } from 'app/components/ImageDnd/typesafeDnd';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import {
imageAddedToBatch,
imagesAddedToBatch,
} from 'features/batch/store/batchSlice';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice'; import {
imageSelected,
imagesAddedToBatch,
} from 'features/gallery/store/gallerySlice';
import { import {
fieldValueChanged, fieldValueChanged,
imageCollectionFieldValueChanged, imageCollectionFieldValueChanged,
@ -21,57 +20,66 @@ import { startAppListening } from '../';
const moduleLog = log.child({ namespace: 'dnd' }); const moduleLog = log.child({ namespace: 'dnd' });
export const imageDropped = createAction<{ export const dndDropped = createAction<{
overData: TypesafeDroppableData; overData: TypesafeDroppableData;
activeData: TypesafeDraggableData; activeData: TypesafeDraggableData;
}>('dnd/imageDropped'); }>('dnd/dndDropped');
export const addImageDroppedListener = () => { export const addImageDroppedListener = () => {
startAppListening({ startAppListening({
actionCreator: imageDropped, actionCreator: dndDropped,
effect: (action, { dispatch, getState }) => { effect: async (action, { dispatch, getState, take }) => {
const { activeData, overData } = action.payload; const { activeData, overData } = action.payload;
const { actionType } = overData;
const state = getState(); const state = getState();
moduleLog.debug(
{ data: { activeData, overData } },
'Image or selection dropped'
);
// set current image // set current image
if ( if (
actionType === 'SET_CURRENT_IMAGE' && overData.actionType === 'SET_CURRENT_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' && activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO activeData.payload.imageDTO
) { ) {
dispatch(imageSelected(activeData.payload.imageDTO.image_name)); dispatch(imageSelected(activeData.payload.imageDTO.image_name));
return;
} }
// set initial image // set initial image
if ( if (
actionType === 'SET_INITIAL_IMAGE' && overData.actionType === 'SET_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' && activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO activeData.payload.imageDTO
) { ) {
dispatch(initialImageChanged(activeData.payload.imageDTO)); dispatch(initialImageChanged(activeData.payload.imageDTO));
return;
} }
// add image to batch // add image to batch
if ( if (
actionType === 'ADD_TO_BATCH' && overData.actionType === 'ADD_TO_BATCH' &&
activeData.payloadType === 'IMAGE_DTO' && activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO activeData.payload.imageDTO
) { ) {
dispatch(imageAddedToBatch(activeData.payload.imageDTO.image_name)); dispatch(imagesAddedToBatch([activeData.payload.imageDTO.image_name]));
return;
} }
// add multiple images to batch // add multiple images to batch
if ( if (
actionType === 'ADD_TO_BATCH' && overData.actionType === 'ADD_TO_BATCH' &&
activeData.payloadType === 'GALLERY_SELECTION' activeData.payloadType === 'IMAGE_NAMES'
) { ) {
dispatch(imagesAddedToBatch(state.gallery.selection)); dispatch(imagesAddedToBatch(activeData.payload.image_names));
return;
} }
// set control image // set control image
if ( if (
actionType === 'SET_CONTROLNET_IMAGE' && overData.actionType === 'SET_CONTROLNET_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' && activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO activeData.payload.imageDTO
) { ) {
@ -82,20 +90,22 @@ export const addImageDroppedListener = () => {
controlNetId, controlNetId,
}) })
); );
return;
} }
// set canvas image // set canvas image
if ( if (
actionType === 'SET_CANVAS_INITIAL_IMAGE' && overData.actionType === 'SET_CANVAS_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' && activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO activeData.payload.imageDTO
) { ) {
dispatch(setInitialCanvasImage(activeData.payload.imageDTO)); dispatch(setInitialCanvasImage(activeData.payload.imageDTO));
return;
} }
// set nodes image // set nodes image
if ( if (
actionType === 'SET_NODES_IMAGE' && overData.actionType === 'SET_NODES_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' && activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO activeData.payload.imageDTO
) { ) {
@ -107,11 +117,12 @@ export const addImageDroppedListener = () => {
value: activeData.payload.imageDTO, value: activeData.payload.imageDTO,
}) })
); );
return;
} }
// set multiple nodes images (single image handler) // set multiple nodes images (single image handler)
if ( if (
actionType === 'SET_MULTI_NODES_IMAGE' && overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' && activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO activeData.payload.imageDTO
) { ) {
@ -123,43 +134,30 @@ export const addImageDroppedListener = () => {
value: [activeData.payload.imageDTO], value: [activeData.payload.imageDTO],
}) })
); );
return;
} }
// set multiple nodes images (multiple images handler) // set multiple nodes images (multiple images handler)
if ( if (
actionType === 'SET_MULTI_NODES_IMAGE' && overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
activeData.payloadType === 'GALLERY_SELECTION' activeData.payloadType === 'IMAGE_NAMES'
) { ) {
const { fieldName, nodeId } = overData.context; const { fieldName, nodeId } = overData.context;
dispatch( dispatch(
imageCollectionFieldValueChanged({ imageCollectionFieldValueChanged({
nodeId, nodeId,
fieldName, fieldName,
value: state.gallery.selection.map((image_name) => ({ value: activeData.payload.image_names.map((image_name) => ({
image_name, image_name,
})), })),
}) })
); );
return;
} }
// remove image from board
// TODO: remove board_id from `removeImageFromBoard()` endpoint
// TODO: handle multiple images
// if (
// actionType === 'MOVE_BOARD' &&
// activeData.payloadType === 'IMAGE_DTO' &&
// activeData.payload.imageDTO &&
// overData.boardId !== null
// ) {
// const { image_name } = activeData.payload.imageDTO;
// dispatch(
// boardImagesApi.endpoints.removeImageFromBoard.initiate({ image_name })
// );
// }
// add image to board // add image to board
if ( if (
actionType === 'MOVE_BOARD' && overData.actionType === 'MOVE_BOARD' &&
activeData.payloadType === 'IMAGE_DTO' && activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO && activeData.payload.imageDTO &&
overData.context.boardId overData.context.boardId
@ -172,17 +170,89 @@ export const addImageDroppedListener = () => {
board_id: boardId, board_id: boardId,
}) })
); );
return;
} }
// add multiple images to board // remove image from board
// TODO: add endpoint if (
// if ( overData.actionType === 'MOVE_BOARD' &&
// actionType === 'ADD_TO_BATCH' && activeData.payloadType === 'IMAGE_DTO' &&
// activeData.payloadType === 'IMAGE_NAMES' && activeData.payload.imageDTO &&
// activeData.payload.imageDTONames overData.context.boardId === null
// ) { ) {
// dispatch(boardImagesApi.endpoints.addImagesToBoard.intiate({})); const { image_name, board_id } = activeData.payload.imageDTO;
// } if (board_id) {
dispatch(
boardImagesApi.endpoints.removeImageFromBoard.initiate({
image_name,
board_id,
})
);
}
return;
}
// add gallery selection to board
if (
overData.actionType === 'MOVE_BOARD' &&
activeData.payloadType === 'IMAGE_NAMES' &&
overData.context.boardId
) {
console.log('adding gallery selection to board');
const board_id = overData.context.boardId;
dispatch(
boardImagesApi.endpoints.addManyBoardImages.initiate({
board_id,
image_names: activeData.payload.image_names,
})
);
return;
}
// remove gallery selection from board
if (
overData.actionType === 'MOVE_BOARD' &&
activeData.payloadType === 'IMAGE_NAMES' &&
overData.context.boardId === null
) {
console.log('removing gallery selection to board');
dispatch(
boardImagesApi.endpoints.deleteManyBoardImages.initiate({
image_names: activeData.payload.image_names,
})
);
return;
}
// add batch selection to board
if (
overData.actionType === 'MOVE_BOARD' &&
activeData.payloadType === 'IMAGE_NAMES' &&
overData.context.boardId
) {
const board_id = overData.context.boardId;
dispatch(
boardImagesApi.endpoints.addManyBoardImages.initiate({
board_id,
image_names: activeData.payload.image_names,
})
);
return;
}
// remove batch selection from board
if (
overData.actionType === 'MOVE_BOARD' &&
activeData.payloadType === 'IMAGE_NAMES' &&
overData.context.boardId === null
) {
dispatch(
boardImagesApi.endpoints.deleteManyBoardImages.initiate({
image_names: activeData.payload.image_names,
})
);
return;
}
}, },
}); });
}; };

View File

@ -1,13 +1,13 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { imageMetadataReceived, imageUpdated } from 'services/api/thunks/image';
import { imageUpserted } from 'features/gallery/store/gallerySlice'; import { imageUpserted } from 'features/gallery/store/gallerySlice';
import { imageDTOReceived, imageUpdated } from 'services/api/thunks/image';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });
export const addImageMetadataReceivedFulfilledListener = () => { export const addImageMetadataReceivedFulfilledListener = () => {
startAppListening({ startAppListening({
actionCreator: imageMetadataReceived.fulfilled, actionCreator: imageDTOReceived.fulfilled,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
const image = action.payload; const image = action.payload;
@ -40,7 +40,7 @@ export const addImageMetadataReceivedFulfilledListener = () => {
export const addImageMetadataReceivedRejectedListener = () => { export const addImageMetadataReceivedRejectedListener = () => {
startAppListening({ startAppListening({
actionCreator: imageMetadataReceived.rejected, actionCreator: imageDTOReceived.rejected,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
moduleLog.debug( moduleLog.debug(
{ data: { image: action.meta.arg } }, { data: { image: action.meta.arg } },

View File

@ -1,7 +1,6 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/api/thunks/image';
import { boardImagesApi } from 'services/api/endpoints/boardImages'; import { boardImagesApi } from 'services/api/endpoints/boardImages';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'boards' }); const moduleLog = log.child({ namespace: 'boards' });
@ -15,12 +14,6 @@ export const addImageRemovedFromBoardFulfilledListener = () => {
{ data: { board_id, image_name } }, { data: { board_id, image_name } },
'Image added to board' 'Image added to board'
); );
dispatch(
imageMetadataReceived({
image_name,
})
);
}, },
}); });
}; };

View File

@ -1,13 +1,15 @@
import { startAppListening } from '..';
import { imageUploaded } from 'services/api/thunks/image';
import { addToast } from 'features/system/store/systemSlice';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { imageUpserted } from 'features/gallery/store/gallerySlice';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice'; import {
imageUpserted,
imagesAddedToBatch,
} from 'features/gallery/store/gallerySlice';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { imageAddedToBatch } from 'features/batch/store/batchSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice';
import { imageUploaded } from 'services/api/thunks/image';
import { startAppListening } from '..';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });
@ -73,7 +75,7 @@ export const addImageUploadedFulfilledListener = () => {
} }
if (postUploadAction?.type === 'ADD_TO_BATCH') { if (postUploadAction?.type === 'ADD_TO_BATCH') {
dispatch(imageAddedToBatch(image.image_name)); dispatch(imagesAddedToBatch([image.image_name]));
return; return;
} }
}, },

View File

@ -14,7 +14,7 @@ export const addModelSelectedListener = () => {
actionCreator: modelSelected, actionCreator: modelSelected,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
const state = getState(); const state = getState();
const [base_model, type, name] = action.payload.split('/'); const { base_model, model_name } = action.payload;
if (state.generation.model?.base_model !== base_model) { if (state.generation.model?.base_model !== base_model) {
dispatch( dispatch(
@ -30,11 +30,7 @@ export const addModelSelectedListener = () => {
// TODO: controlnet cleared // TODO: controlnet cleared
} }
const newModel = zMainModel.parse({ const newModel = zMainModel.parse(action.payload);
id: action.payload,
base_model,
name,
});
dispatch(modelChanged(newModel)); dispatch(modelChanged(newModel));
}, },

View File

@ -0,0 +1,42 @@
import { modelChanged } from 'features/parameters/store/generationSlice';
import { some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..';
export const addModelsLoadedListener = () => {
startAppListening({
matcher: modelsApi.endpoints.getMainModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const currentModel = getState().generation.model;
const isCurrentModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === currentModel?.model_name &&
m?.base_model === currentModel?.base_model
);
if (isCurrentModelAvailable) {
return;
}
const firstModelId = action.payload.ids[0];
const firstModel = action.payload.entities[firstModelId];
if (!firstModel) {
// No models loaded at all
dispatch(modelChanged(null));
return;
}
dispatch(
modelChanged({
base_model: firstModel.base_model,
model_name: firstModel.model_name,
})
);
},
});
};

View File

@ -1,19 +0,0 @@
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import {
imagesAddedToBatch,
selectionAddedToBatch,
} from 'features/batch/store/batchSlice';
const moduleLog = log.child({ namespace: 'batch' });
export const addSelectionAddedToBatchListener = () => {
startAppListening({
actionCreator: selectionAddedToBatch,
effect: (action, { dispatch, getState }) => {
const { selection } = getState().gallery;
dispatch(imagesAddedToBatch(selection));
},
});
};

View File

@ -30,6 +30,7 @@ export const addSessionCreatedRejectedListener = () => {
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
if (action.payload) { if (action.payload) {
const { arg, error } = action.payload; const { arg, error } = action.payload;
const stringifiedError = JSON.stringify(error);
moduleLog.error( moduleLog.error(
{ {
data: { data: {
@ -37,7 +38,7 @@ export const addSessionCreatedRejectedListener = () => {
error: serializeError(error), error: serializeError(error),
}, },
}, },
`Problem creating session` `Problem creating session: ${stringifiedError}`
); );
} }
}, },

View File

@ -33,6 +33,7 @@ export const addSessionInvokedRejectedListener = () => {
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
if (action.payload) { if (action.payload) {
const { arg, error } = action.payload; const { arg, error } = action.payload;
const stringifiedError = JSON.stringify(error);
moduleLog.error( moduleLog.error(
{ {
data: { data: {
@ -40,7 +41,7 @@ export const addSessionInvokedRejectedListener = () => {
error: serializeError(error), error: serializeError(error),
}, },
}, },
`Problem invoking session` `Problem invoking session: ${stringifiedError}`
); );
} }
}, },

View File

@ -1,15 +1,15 @@
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import { progressImageSet } from 'features/system/store/systemSlice';
import { boardImagesApi } from 'services/api/endpoints/boardImages';
import { isImageOutput } from 'services/api/guards';
import { imageDTOReceived } from 'services/api/thunks/image';
import { sessionCanceled } from 'services/api/thunks/session';
import { import {
appSocketInvocationComplete, appSocketInvocationComplete,
socketInvocationComplete, socketInvocationComplete,
} from 'services/events/actions'; } from 'services/events/actions';
import { imageMetadataReceived } from 'services/api/thunks/image'; import { startAppListening } from '../..';
import { sessionCanceled } from 'services/api/thunks/session';
import { isImageOutput } from 'services/api/guards';
import { progressImageSet } from 'features/system/store/systemSlice';
import { boardImagesApi } from 'services/api/endpoints/boardImages';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
const nodeDenylist = ['dataURL_image']; const nodeDenylist = ['dataURL_image'];
@ -42,13 +42,13 @@ export const addInvocationCompleteEventListener = () => {
// Get its metadata // Get its metadata
dispatch( dispatch(
imageMetadataReceived({ imageDTOReceived({
image_name, image_name,
}) })
); );
const [{ payload: imageDTO }] = await take( const [{ payload: imageDTO }] = await take(
imageMetadataReceived.fulfilled.match imageDTOReceived.fulfilled.match
); );
// Handle canvas image // Handle canvas image

View File

@ -13,7 +13,7 @@ export const addInvocationErrorEventListener = () => {
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
moduleLog.error( moduleLog.error(
action.payload, action.payload,
`Invocation error (${action.payload.data.node.type})` `Invocation error (${action.payload.data.node.type}): ${action.payload.data.error}`
); );
dispatch(appSocketInvocationError(action.payload)); dispatch(appSocketInvocationError(action.payload));
}, },

View File

@ -1,6 +1,7 @@
import { import {
AnyAction, AnyAction,
ThunkDispatch, ThunkDispatch,
autoBatchEnhancer,
combineReducers, combineReducers,
configureStore, configureStore,
} from '@reduxjs/toolkit'; } from '@reduxjs/toolkit';
@ -8,14 +9,12 @@ import {
import dynamicMiddlewares from 'redux-dynamic-middlewares'; import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { rememberEnhancer, rememberReducer } from 'redux-remember'; import { rememberEnhancer, rememberReducer } from 'redux-remember';
import batchReducer from 'features/batch/store/batchSlice';
import canvasReducer from 'features/canvas/store/canvasSlice'; import canvasReducer from 'features/canvas/store/canvasSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice'; import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice'; import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
import boardsReducer from 'features/gallery/store/boardSlice'; import boardsReducer from 'features/gallery/store/boardSlice';
import galleryReducer from 'features/gallery/store/gallerySlice'; import galleryReducer from 'features/gallery/store/gallerySlice';
import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice'; import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import loraReducer from 'features/lora/store/loraSlice'; import loraReducer from 'features/lora/store/loraSlice';
import nodesReducer from 'features/nodes/store/nodesSlice'; import nodesReducer from 'features/nodes/store/nodesSlice';
import generationReducer from 'features/parameters/store/generationSlice'; import generationReducer from 'features/parameters/store/generationSlice';
@ -39,7 +38,6 @@ const allReducers = {
canvas: canvasReducer, canvas: canvasReducer,
gallery: galleryReducer, gallery: galleryReducer,
generation: generationReducer, generation: generationReducer,
lightbox: lightboxReducer,
nodes: nodesReducer, nodes: nodesReducer,
postprocessing: postprocessingReducer, postprocessing: postprocessingReducer,
system: systemReducer, system: systemReducer,
@ -49,7 +47,6 @@ const allReducers = {
controlNet: controlNetReducer, controlNet: controlNetReducer,
boards: boardsReducer, boards: boardsReducer,
dynamicPrompts: dynamicPromptsReducer, dynamicPrompts: dynamicPromptsReducer,
batch: batchReducer,
imageDeletion: imageDeletionReducer, imageDeletion: imageDeletionReducer,
lora: loraReducer, lora: loraReducer,
[api.reducerPath]: api.reducer, [api.reducerPath]: api.reducer,
@ -63,30 +60,29 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'canvas', 'canvas',
'gallery', 'gallery',
'generation', 'generation',
'lightbox',
'nodes', 'nodes',
'postprocessing', 'postprocessing',
'system', 'system',
'ui', 'ui',
'controlNet', 'controlNet',
'dynamicPrompts', 'dynamicPrompts',
'batch',
'lora', 'lora',
// 'boards',
// 'hotkeys',
// 'config',
]; ];
export const store = configureStore({ export const store = configureStore({
reducer: rememberedRootReducer, reducer: rememberedRootReducer,
enhancers: [ enhancers: (existingEnhancers) => {
rememberEnhancer(window.localStorage, rememberedKeys, { return existingEnhancers
persistDebounce: 300, .concat(
serialize, rememberEnhancer(window.localStorage, rememberedKeys, {
unserialize, persistDebounce: 300,
prefix: LOCALSTORAGE_PREFIX, serialize,
}), unserialize,
], prefix: LOCALSTORAGE_PREFIX,
})
)
.concat(autoBatchEnhancer());
},
middleware: (getDefaultMiddleware) => middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({ getDefaultMiddleware({
immutableCheck: false, immutableCheck: false,
@ -96,10 +92,26 @@ export const store = configureStore({
.concat(dynamicMiddlewares) .concat(dynamicMiddlewares)
.prepend(listenerMiddleware.middleware), .prepend(listenerMiddleware.middleware),
devTools: { devTools: {
actionsDenylist,
actionSanitizer, actionSanitizer,
stateSanitizer, stateSanitizer,
trace: true, trace: true,
predicate: (state, action) => {
// TODO: hook up to the log level param in system slice
// manually type state, cannot type the arg
// const typedState = state as ReturnType<typeof rootReducer>;
if (action.type.startsWith('api/')) {
// don't log api actions, with manual cache updates they are extremely noisy
return false;
}
if (actionsDenylist.includes(action.type)) {
// don't log other noisy actions
return false;
}
return true;
},
}, },
}); });

View File

@ -94,7 +94,8 @@ export type AppFeature =
| 'bugLink' | 'bugLink'
| 'localization' | 'localization'
| 'consoleLogging' | 'consoleLogging'
| 'dynamicPrompting'; | 'dynamicPrompting'
| 'batches';
/** /**
* A disable-able Stable Diffusion feature * A disable-able Stable Diffusion feature

View File

@ -6,30 +6,21 @@ import {
useColorMode, useColorMode,
useColorModeValue, useColorModeValue,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useCombinedRefs } from '@dnd-kit/utilities';
import IAIIconButton from 'common/components/IAIIconButton';
import {
IAILoadingImageFallback,
IAINoContentFallback,
} from 'common/components/IAIImageFallback';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { AnimatePresence } from 'framer-motion';
import { MouseEvent, ReactElement, SyntheticEvent } from 'react';
import { memo, useRef } from 'react';
import { FaImage, FaUndo, FaUpload } from 'react-icons/fa';
import { ImageDTO } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
import IAIDropOverlay from './IAIDropOverlay';
import { PostUploadAction } from 'services/api/thunks/image';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { mode } from 'theme/util/mode';
import { import {
TypesafeDraggableData, TypesafeDraggableData,
TypesafeDroppableData, TypesafeDroppableData,
isValidDrop,
useDraggable,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd'; } from 'app/components/ImageDnd/typesafeDnd';
import IAIIconButton from 'common/components/IAIIconButton';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { MouseEvent, ReactElement, SyntheticEvent, memo } from 'react';
import { FaImage, FaUndo, FaUpload } from 'react-icons/fa';
import { PostUploadAction } from 'services/api/thunks/image';
import { ImageDTO } from 'services/api/types';
import { mode } from 'theme/util/mode';
import IAIDraggable from './IAIDraggable';
import IAIDroppable from './IAIDroppable';
type IAIDndImageProps = { type IAIDndImageProps = {
imageDTO: ImageDTO | undefined; imageDTO: ImageDTO | undefined;
@ -83,28 +74,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
const { colorMode } = useColorMode(); const { colorMode } = useColorMode();
const dndId = useRef(uuidv4());
const {
attributes,
listeners,
setNodeRef: setDraggableRef,
isDragging,
active,
} = useDraggable({
id: dndId.current,
disabled: isDragDisabled || !imageDTO,
data: draggableData,
});
const { isOver, setNodeRef: setDroppableRef } = useDroppable({
id: dndId.current,
disabled: isDropDisabled,
data: droppableData,
});
const setDndRef = useCombinedRefs(setDroppableRef, setDraggableRef);
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({ const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
postUploadAction, postUploadAction,
isDisabled: isUploadDisabled, isDisabled: isUploadDisabled,
@ -139,9 +108,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
userSelect: 'none', userSelect: 'none',
cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer', cursor: isDragDisabled || !imageDTO ? 'default' : 'pointer',
}} }}
{...attributes}
{...listeners}
ref={setDndRef}
> >
{imageDTO && ( {imageDTO && (
<Flex <Flex
@ -154,10 +120,13 @@ const IAIDndImage = (props: IAIDndImageProps) => {
}} }}
> >
<Image <Image
onClick={onClick}
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url} src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
fallbackStrategy="beforeLoadOrError" fallbackStrategy="beforeLoadOrError"
fallback={<IAILoadingImageFallback image={imageDTO} />} // If we fall back to thumbnail, it feels much snappier than the skeleton...
fallbackSrc={imageDTO.thumbnail_url}
// fallback={<IAILoadingImageFallback image={imageDTO} />}
width={imageDTO.width}
height={imageDTO.height}
onError={onError} onError={onError}
draggable={false} draggable={false}
sx={{ sx={{
@ -171,30 +140,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
}} }}
/> />
{withMetadataOverlay && <ImageMetadataOverlay image={imageDTO} />} {withMetadataOverlay && <ImageMetadataOverlay image={imageDTO} />}
{onClickReset && withResetIcon && (
<IAIIconButton
onClick={onClickReset}
aria-label={resetTooltip}
tooltip={resetTooltip}
icon={resetIcon}
size="sm"
variant="link"
sx={{
position: 'absolute',
top: 1,
insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: resetIconShadow,
},
}}
/>
)}
</Flex> </Flex>
)} )}
{!imageDTO && !isUploadDisabled && ( {!imageDTO && !isUploadDisabled && (
@ -225,11 +170,42 @@ const IAIDndImage = (props: IAIDndImageProps) => {
</> </>
)} )}
{!imageDTO && isUploadDisabled && noContentFallback} {!imageDTO && isUploadDisabled && noContentFallback}
<AnimatePresence> <IAIDroppable
{isValidDrop(droppableData, active) && !isDragging && ( data={droppableData}
<IAIDropOverlay isOver={isOver} label={dropLabel} /> disabled={isDropDisabled}
)} dropLabel={dropLabel}
</AnimatePresence> />
{imageDTO && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}
onClick={onClick}
/>
)}
{onClickReset && withResetIcon && imageDTO && (
<IAIIconButton
onClick={onClickReset}
aria-label={resetTooltip}
tooltip={resetTooltip}
icon={resetIcon}
size="sm"
variant="link"
sx={{
position: 'absolute',
top: 1,
insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: resetIconShadow,
},
}}
/>
)}
</Flex> </Flex>
); );
}; };

View File

@ -0,0 +1,40 @@
import { Box } from '@chakra-ui/react';
import {
TypesafeDraggableData,
useDraggable,
} from 'app/components/ImageDnd/typesafeDnd';
import { MouseEvent, memo, useRef } from 'react';
import { v4 as uuidv4 } from 'uuid';
type IAIDraggableProps = {
disabled?: boolean;
data?: TypesafeDraggableData;
onClick?: (event: MouseEvent<HTMLDivElement>) => void;
};
const IAIDraggable = (props: IAIDraggableProps) => {
const { data, disabled, onClick } = props;
const dndId = useRef(uuidv4());
const { attributes, listeners, setNodeRef } = useDraggable({
id: dndId.current,
disabled,
data,
});
return (
<Box
onClick={onClick}
ref={setNodeRef}
position="absolute"
w="full"
h="full"
top={0}
insetInlineStart={0}
{...attributes}
{...listeners}
/>
);
};
export default memo(IAIDraggable);

View File

@ -0,0 +1,47 @@
import { Box } from '@chakra-ui/react';
import {
TypesafeDroppableData,
isValidDrop,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
import { AnimatePresence } from 'framer-motion';
import { memo, useRef } from 'react';
import { v4 as uuidv4 } from 'uuid';
import IAIDropOverlay from './IAIDropOverlay';
type IAIDroppableProps = {
dropLabel?: string;
disabled?: boolean;
data?: TypesafeDroppableData;
};
const IAIDroppable = (props: IAIDroppableProps) => {
const { dropLabel, data, disabled } = props;
const dndId = useRef(uuidv4());
const { isOver, setNodeRef, active } = useDroppable({
id: dndId.current,
disabled,
data,
});
return (
<Box
ref={setNodeRef}
position="absolute"
top={0}
insetInlineStart={0}
w="full"
h="full"
pointerEvents="none"
>
<AnimatePresence>
{isValidDrop(data, active) && (
<IAIDropOverlay isOver={isOver} label={dropLabel} />
)}
</AnimatePresence>
</Box>
);
};
export default memo(IAIDroppable);

View File

@ -0,0 +1,42 @@
import { Box, Flex, Icon } from '@chakra-ui/react';
import { FaExclamation } from 'react-icons/fa';
const IAIErrorLoadingImageFallback = () => {
return (
<Box
sx={{
position: 'relative',
height: 'full',
width: 'full',
'::before': {
content: "''",
display: 'block',
pt: '100%',
},
}}
>
<Flex
sx={{
position: 'absolute',
top: 0,
insetInlineStart: 0,
height: 'full',
width: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
bg: 'base.100',
color: 'base.500',
_dark: {
color: 'base.700',
bg: 'base.850',
},
}}
>
<Icon as={FaExclamation} boxSize={16} opacity={0.7} />
</Flex>
</Box>
);
};
export default IAIErrorLoadingImageFallback;

View File

@ -0,0 +1,30 @@
import { Box, Skeleton } from '@chakra-ui/react';
const IAIFillSkeleton = () => {
return (
<Skeleton
sx={{
position: 'relative',
height: 'full',
width: 'full',
'::before': {
content: "''",
display: 'block',
pt: '100%',
},
}}
>
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineStart: 0,
height: 'full',
width: 'full',
}}
/>
</Skeleton>
);
};
export default IAIFillSkeleton;

View File

@ -3,36 +3,22 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { validateSeedWeights } from 'common/util/seedWeightPairs'; import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { import { modelsApi } from '../../services/api/endpoints/models';
modelsApi,
useGetMainModelsQuery,
} from '../../services/api/endpoints/models';
const readinessSelector = createSelector( const readinessSelector = createSelector(
[stateSelector, activeTabNameSelector], [stateSelector, activeTabNameSelector],
(state, activeTabName) => { (state, activeTabName) => {
const { generation, system, batch } = state; const { generation, system } = state;
const { shouldGenerateVariations, seedWeights, initialImage, seed } = const { shouldGenerateVariations, seedWeights, initialImage, seed } =
generation; generation;
const { isProcessing, isConnected } = system; const { isProcessing, isConnected } = system;
const {
isEnabled: isBatchEnabled,
asInitialImage,
imageNames: batchImageNames,
} = batch;
let isReady = true; let isReady = true;
const reasonsWhyNotReady: string[] = []; const reasonsWhyNotReady: string[] = [];
if ( if (activeTabName === 'img2img' && !initialImage) {
activeTabName === 'img2img' &&
!initialImage &&
!(asInitialImage && batchImageNames.length > 1)
) {
isReady = false; isReady = false;
reasonsWhyNotReady.push('No initial image selected'); reasonsWhyNotReady.push('No initial image selected');
} }

View File

@ -1,67 +0,0 @@
import {
Flex,
FormControl,
FormLabel,
Heading,
Spacer,
Switch,
Text,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { ChangeEvent, memo, useCallback } from 'react';
import { controlNetToggled } from '../store/batchSlice';
type Props = {
controlNet: ControlNetConfig;
};
const selector = createSelector(
[stateSelector, (state, controlNetId: string) => controlNetId],
(state, controlNetId) => {
const isControlNetEnabled = state.batch.controlNets.includes(controlNetId);
return { isControlNetEnabled };
},
defaultSelectorOptions
);
const BatchControlNet = (props: Props) => {
const dispatch = useAppDispatch();
const { isControlNetEnabled } = useAppSelector((state) =>
selector(state, props.controlNet.controlNetId)
);
const { processorType, model } = props.controlNet;
const handleChangeAsControlNet = useCallback(() => {
dispatch(controlNetToggled(props.controlNet.controlNetId));
}, [dispatch, props.controlNet.controlNetId]);
return (
<Flex
layerStyle="second"
sx={{ flexDir: 'column', gap: 1, p: 4, borderRadius: 'base' }}
>
<Flex sx={{ justifyContent: 'space-between' }}>
<FormControl as={Flex} onClick={handleChangeAsControlNet}>
<FormLabel>
<Heading size="sm">ControlNet</Heading>
</FormLabel>
<Spacer />
<Switch isChecked={isControlNetEnabled} />
</FormControl>
</Flex>
<Text>
<strong>Model:</strong> {model}
</Text>
<Text>
<strong>Processor:</strong> {processorType}
</Text>
</Flex>
);
};
export default memo(BatchControlNet);

View File

@ -1,116 +0,0 @@
import { Box, Icon, Skeleton } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import {
batchImageRangeEndSelected,
batchImageSelected,
batchImageSelectionToggled,
imageRemovedFromBatch,
} from 'features/batch/store/batchSlice';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { FaExclamationCircle } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
const makeSelector = (image_name: string) =>
createSelector(
[stateSelector],
(state) => ({
selectionCount: state.batch.selection.length,
isSelected: state.batch.selection.includes(image_name),
}),
defaultSelectorOptions
);
type BatchImageProps = {
imageName: string;
};
const BatchImage = (props: BatchImageProps) => {
const {
currentData: imageDTO,
isFetching,
isError,
isSuccess,
} = useGetImageDTOQuery(props.imageName);
const dispatch = useAppDispatch();
const selector = useMemo(
() => makeSelector(props.imageName),
[props.imageName]
);
const { isSelected, selectionCount } = useAppSelector(selector);
const handleClickRemove = useCallback(() => {
dispatch(imageRemovedFromBatch(props.imageName));
}, [dispatch, props.imageName]);
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (e.shiftKey) {
dispatch(batchImageRangeEndSelected(props.imageName));
} else if (e.ctrlKey || e.metaKey) {
dispatch(batchImageSelectionToggled(props.imageName));
} else {
dispatch(batchImageSelected(props.imageName));
}
},
[dispatch, props.imageName]
);
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (selectionCount > 1) {
return {
id: 'batch',
payloadType: 'BATCH_SELECTION',
};
}
if (imageDTO) {
return {
id: 'batch',
payloadType: 'IMAGE_DTO',
payload: { imageDTO },
};
}
}, [imageDTO, selectionCount]);
if (isError) {
return <Icon as={FaExclamationCircle} />;
}
if (isFetching) {
return (
<Skeleton>
<Box w="full" h="full" aspectRatio="1/1" />
</Skeleton>
);
}
return (
<Box sx={{ position: 'relative', aspectRatio: '1/1' }}>
<IAIDndImage
imageDTO={imageDTO}
draggableData={draggableData}
isDropDisabled={true}
isUploadDisabled={true}
imageSx={{
w: 'full',
h: 'full',
}}
onClick={handleClick}
isSelected={isSelected}
onClickReset={handleClickRemove}
resetTooltip="Remove from batch"
withResetIcon
thumbnail
/>
</Box>
);
};
export default memo(BatchImage);

View File

@ -1,31 +0,0 @@
import { Box } from '@chakra-ui/react';
import BatchImageGrid from './BatchImageGrid';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import {
AddToBatchDropData,
isValidDrop,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
const droppableData: AddToBatchDropData = {
id: 'batch',
actionType: 'ADD_TO_BATCH',
};
const BatchImageContainer = () => {
const { isOver, setNodeRef, active } = useDroppable({
id: 'batch-manager',
data: droppableData,
});
return (
<Box ref={setNodeRef} position="relative" w="full" h="full">
<BatchImageGrid />
{isValidDrop(droppableData, active) && (
<IAIDropOverlay isOver={isOver} label="Add to Batch" />
)}
</Box>
);
};
export default BatchImageContainer;

View File

@ -1,54 +0,0 @@
import { FaImages } from 'react-icons/fa';
import { Grid, GridItem } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import BatchImage from './BatchImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
const selector = createSelector(
stateSelector,
(state) => {
const imageNames = state.batch.imageNames.concat().reverse();
return { imageNames };
},
defaultSelectorOptions
);
const BatchImageGrid = () => {
const { imageNames } = useAppSelector(selector);
if (imageNames.length === 0) {
return (
<IAINoContentFallback
icon={FaImages}
boxSize={16}
label="No images in Batch"
/>
);
}
return (
<Grid
sx={{
position: 'absolute',
flexWrap: 'wrap',
w: 'full',
minH: 0,
maxH: 'full',
overflowY: 'scroll',
gridTemplateColumns: `repeat(auto-fill, minmax(128px, 1fr))`,
}}
>
{imageNames.map((imageName) => (
<GridItem key={imageName} sx={{ p: 1.5 }}>
<BatchImage imageName={imageName} />
</GridItem>
))}
</Grid>
);
};
export default BatchImageGrid;

View File

@ -1,103 +0,0 @@
import { Flex, Heading, Spacer } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useCallback } from 'react';
import IAISwitch from 'common/components/IAISwitch';
import {
asInitialImageToggled,
batchReset,
isEnabledChanged,
} from 'features/batch/store/batchSlice';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import BatchImageContainer from './BatchImageGrid';
import { map } from 'lodash-es';
import BatchControlNet from './BatchControlNet';
const selector = createSelector(
stateSelector,
(state) => {
const { controlNets } = state.controlNet;
const {
imageNames,
asInitialImage,
controlNets: batchControlNets,
isEnabled,
} = state.batch;
return {
imageCount: imageNames.length,
asInitialImage,
controlNets,
batchControlNets,
isEnabled,
};
},
defaultSelectorOptions
);
const BatchManager = () => {
const dispatch = useAppDispatch();
const { imageCount, isEnabled, controlNets, batchControlNets } =
useAppSelector(selector);
const handleResetBatch = useCallback(() => {
dispatch(batchReset());
}, [dispatch]);
const handleToggle = useCallback(() => {
dispatch(isEnabledChanged(!isEnabled));
}, [dispatch, isEnabled]);
const handleChangeAsInitialImage = useCallback(() => {
dispatch(asInitialImageToggled());
}, [dispatch]);
return (
<Flex
sx={{
h: 'full',
w: 'full',
flexDir: 'column',
position: 'relative',
gap: 2,
minW: 0,
}}
>
<Flex sx={{ alignItems: 'center' }}>
<Heading
size={'md'}
sx={{ color: 'base.800', _dark: { color: 'base.200' } }}
>
{imageCount || 'No'} images
</Heading>
<Spacer />
<IAIButton onClick={handleResetBatch}>Reset</IAIButton>
</Flex>
<Flex
sx={{
alignItems: 'center',
flexDir: 'column',
gap: 4,
}}
>
<IAISwitch
label="Use as Initial Image"
onChange={handleChangeAsInitialImage}
/>
{map(controlNets, (controlNet) => {
return (
<BatchControlNet
key={controlNet.controlNetId}
controlNet={controlNet}
/>
);
})}
</Flex>
<BatchImageContainer />
</Flex>
);
};
export default BatchManager;

View File

@ -1,142 +0,0 @@
import { PayloadAction, createAction, createSlice } from '@reduxjs/toolkit';
import { uniq } from 'lodash-es';
import { imageDeleted } from 'services/api/thunks/image';
type BatchState = {
isEnabled: boolean;
imageNames: string[];
asInitialImage: boolean;
controlNets: string[];
selection: string[];
};
export const initialBatchState: BatchState = {
isEnabled: false,
imageNames: [],
asInitialImage: false,
controlNets: [],
selection: [],
};
const batch = createSlice({
name: 'batch',
initialState: initialBatchState,
reducers: {
isEnabledChanged: (state, action: PayloadAction<boolean>) => {
state.isEnabled = action.payload;
},
imageAddedToBatch: (state, action: PayloadAction<string>) => {
state.imageNames = uniq(state.imageNames.concat(action.payload));
},
imagesAddedToBatch: (state, action: PayloadAction<string[]>) => {
state.imageNames = uniq(state.imageNames.concat(action.payload));
},
imageRemovedFromBatch: (state, action: PayloadAction<string>) => {
state.imageNames = state.imageNames.filter(
(imageName) => action.payload !== imageName
);
state.selection = state.selection.filter(
(imageName) => action.payload !== imageName
);
},
imagesRemovedFromBatch: (state, action: PayloadAction<string[]>) => {
state.imageNames = state.imageNames.filter(
(imageName) => !action.payload.includes(imageName)
);
state.selection = state.selection.filter(
(imageName) => !action.payload.includes(imageName)
);
},
batchImageRangeEndSelected: (state, action: PayloadAction<string>) => {
const rangeEndImageName = action.payload;
const lastSelectedImage = state.selection[state.selection.length - 1];
const lastClickedIndex = state.imageNames.findIndex(
(n) => n === lastSelectedImage
);
const currentClickedIndex = state.imageNames.findIndex(
(n) => n === rangeEndImageName
);
if (lastClickedIndex > -1 && currentClickedIndex > -1) {
// We have a valid range!
const start = Math.min(lastClickedIndex, currentClickedIndex);
const end = Math.max(lastClickedIndex, currentClickedIndex);
const imagesToSelect = state.imageNames.slice(start, end + 1);
state.selection = uniq(state.selection.concat(imagesToSelect));
}
},
batchImageSelectionToggled: (state, action: PayloadAction<string>) => {
if (
state.selection.includes(action.payload) &&
state.selection.length > 1
) {
state.selection = state.selection.filter(
(imageName) => imageName !== action.payload
);
} else {
state.selection = uniq(state.selection.concat(action.payload));
}
},
batchImageSelected: (state, action: PayloadAction<string | null>) => {
state.selection = action.payload
? [action.payload]
: [String(state.imageNames[0])];
},
batchReset: (state) => {
state.imageNames = [];
state.selection = [];
},
asInitialImageToggled: (state) => {
state.asInitialImage = !state.asInitialImage;
},
controlNetAddedToBatch: (state, action: PayloadAction<string>) => {
state.controlNets = uniq(state.controlNets.concat(action.payload));
},
controlNetRemovedFromBatch: (state, action: PayloadAction<string>) => {
state.controlNets = state.controlNets.filter(
(controlNetId) => controlNetId !== action.payload
);
},
controlNetToggled: (state, action: PayloadAction<string>) => {
if (state.controlNets.includes(action.payload)) {
state.controlNets = state.controlNets.filter(
(controlNetId) => controlNetId !== action.payload
);
} else {
state.controlNets = uniq(state.controlNets.concat(action.payload));
}
},
},
extraReducers: (builder) => {
builder.addCase(imageDeleted.fulfilled, (state, action) => {
state.imageNames = state.imageNames.filter(
(imageName) => imageName !== action.meta.arg.image_name
);
state.selection = state.selection.filter(
(imageName) => imageName !== action.meta.arg.image_name
);
});
},
});
export const {
isEnabledChanged,
imageAddedToBatch,
imagesAddedToBatch,
imageRemovedFromBatch,
imagesRemovedFromBatch,
asInitialImageToggled,
controlNetAddedToBatch,
controlNetRemovedFromBatch,
batchReset,
controlNetToggled,
batchImageRangeEndSelected,
batchImageSelectionToggled,
batchImageSelected,
} = batch.actions;
export default batch.reducer;
export const selectionAddedToBatch = createAction(
'batch/selectionAddedToBatch'
);

View File

@ -47,8 +47,8 @@ const ParamEmbeddingPopover = (props: Props) => {
const disabled = currentMainModel?.base_model !== embedding.base_model; const disabled = currentMainModel?.base_model !== embedding.base_model;
data.push({ data.push({
value: embedding.name, value: embedding.model_name,
label: embedding.name, label: embedding.model_name,
group: MODEL_TYPE_MAP[embedding.base_model], group: MODEL_TYPE_MAP[embedding.base_model],
disabled, disabled,
tooltip: disabled tooltip: disabled

View File

@ -1,93 +0,0 @@
import { Flex, useColorMode } from '@chakra-ui/react';
import { FaImages } from 'react-icons/fa';
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import { useDispatch } from 'react-redux';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { AnimatePresence } from 'framer-motion';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { mode } from 'theme/util/mode';
import {
MoveBoardDropData,
isValidDrop,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
const dispatch = useDispatch();
const { colorMode } = useColorMode();
const handleAllImagesBoardClick = () => {
dispatch(boardIdSelected());
};
const droppableData: MoveBoardDropData = {
id: 'all-images-board',
actionType: 'MOVE_BOARD',
context: { boardId: null },
};
const { isOver, setNodeRef, active } = useDroppable({
id: `board_droppable_all_images`,
data: droppableData,
});
return (
<Flex
sx={{
flexDir: 'column',
justifyContent: 'space-between',
alignItems: 'center',
cursor: 'pointer',
w: 'full',
h: 'full',
borderRadius: 'base',
}}
>
<Flex
ref={setNodeRef}
onClick={handleAllImagesBoardClick}
sx={{
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
w: 'full',
aspectRatio: '1/1',
overflow: 'hidden',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
flexShrink: 0,
}}
>
<IAINoContentFallback
boxSize={8}
icon={FaImages}
sx={{
border: '2px solid var(--invokeai-colors-base-200)',
_dark: { border: '2px solid var(--invokeai-colors-base-800)' },
}}
/>
<AnimatePresence>
{isValidDrop(droppableData, active) && (
<IAIDropOverlay isOver={isOver} />
)}
</AnimatePresence>
</Flex>
<Flex
sx={{
h: 'full',
alignItems: 'center',
color: isSelected
? mode('base.900', 'base.50')(colorMode)
: mode('base.700', 'base.200')(colorMode),
fontWeight: isSelected ? 600 : undefined,
fontSize: 'xs',
}}
>
All Images
</Flex>
</Flex>
);
};
export default AllImagesBoard;

View File

@ -0,0 +1,31 @@
import { MoveBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import { FaImages } from 'react-icons/fa';
import { useDispatch } from 'react-redux';
import GenericBoard from './GenericBoard';
const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
const dispatch = useDispatch();
const handleAllImagesBoardClick = () => {
dispatch(boardIdSelected('all'));
};
const droppableData: MoveBoardDropData = {
id: 'all-images-board',
actionType: 'MOVE_BOARD',
context: { boardId: null },
};
return (
<GenericBoard
droppableData={droppableData}
onClick={handleAllImagesBoardClick}
isSelected={isSelected}
icon={FaImages}
label="All Images"
/>
);
};
export default AllImagesBoard;

View File

@ -0,0 +1,42 @@
import { createSelector } from '@reduxjs/toolkit';
import { AddToBatchDropData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import { useCallback } from 'react';
import { FaLayerGroup } from 'react-icons/fa';
import { useDispatch } from 'react-redux';
import GenericBoard from './GenericBoard';
const selector = createSelector(stateSelector, (state) => {
return {
count: state.gallery.batchImageNames.length,
};
});
const BatchBoard = ({ isSelected }: { isSelected: boolean }) => {
const dispatch = useDispatch();
const { count } = useAppSelector(selector);
const handleBatchBoardClick = useCallback(() => {
dispatch(boardIdSelected('batch'));
}, [dispatch]);
const droppableData: AddToBatchDropData = {
id: 'batch-board',
actionType: 'ADD_TO_BATCH',
};
return (
<GenericBoard
droppableData={droppableData}
onClick={handleBatchBoardClick}
isSelected={isSelected}
icon={FaLayerGroup}
label="Batch"
badgeCount={count}
/>
);
};
export default BatchBoard;

View File

@ -1,3 +1,4 @@
import { CloseIcon } from '@chakra-ui/icons';
import { import {
Collapse, Collapse,
Flex, Flex,
@ -9,17 +10,18 @@ import {
InputRightElement, InputRightElement,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { setBoardSearchText } from 'features/gallery/store/boardSlice'; import { setBoardSearchText } from 'features/gallery/store/boardSlice';
import { memo, useState } from 'react';
import HoverableBoard from './HoverableBoard';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useState } from 'react';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import AddBoardButton from './AddBoardButton'; import AddBoardButton from './AddBoardButton';
import AllImagesBoard from './AllImagesBoard'; import AllImagesBoard from './AllImagesBoard';
import { CloseIcon } from '@chakra-ui/icons'; import BatchBoard from './BatchBoard';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; import GalleryBoard from './GalleryBoard';
import { stateSelector } from 'app/store/store'; import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
const selector = createSelector( const selector = createSelector(
[stateSelector], [stateSelector],
@ -42,6 +44,8 @@ const BoardsList = (props: Props) => {
const { data: boards } = useListAllBoardsQuery(); const { data: boards } = useListAllBoardsQuery();
const isBatchEnabled = useFeatureStatus('batches').isFeatureEnabled;
const filteredBoards = searchText const filteredBoards = searchText
? boards?.filter((board) => ? boards?.filter((board) =>
board.board_name.toLowerCase().includes(searchText.toLowerCase()) board.board_name.toLowerCase().includes(searchText.toLowerCase())
@ -115,14 +119,21 @@ const BoardsList = (props: Props) => {
}} }}
> >
{!searchMode && ( {!searchMode && (
<GridItem sx={{ p: 1.5 }}> <>
<AllImagesBoard isSelected={!selectedBoardId} /> <GridItem sx={{ p: 1.5 }}>
</GridItem> <AllImagesBoard isSelected={selectedBoardId === 'all'} />
</GridItem>
{isBatchEnabled && (
<GridItem sx={{ p: 1.5 }}>
<BatchBoard isSelected={selectedBoardId === 'batch'} />
</GridItem>
)}
</>
)} )}
{filteredBoards && {filteredBoards &&
filteredBoards.map((board) => ( filteredBoards.map((board) => (
<GridItem key={board.board_id} sx={{ p: 1.5 }}> <GridItem key={board.board_id} sx={{ p: 1.5 }}>
<HoverableBoard <GalleryBoard
board={board} board={board}
isSelected={selectedBoardId === board.board_id} isSelected={selectedBoardId === board.board_id}
/> />

View File

@ -12,35 +12,31 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { memo, useCallback, useContext } from 'react';
import { FaFolder, FaTrash } from 'react-icons/fa';
import { ContextMenu } from 'chakra-ui-contextmenu'; import { ContextMenu } from 'chakra-ui-contextmenu';
import { BoardDTO } from 'services/api/types';
import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { boardIdSelected } from 'features/gallery/store/gallerySlice'; import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useContext, useMemo } from 'react';
import { FaFolder, FaImages, FaTrash } from 'react-icons/fa';
import { import {
useDeleteBoardMutation, useDeleteBoardMutation,
useUpdateBoardMutation, useUpdateBoardMutation,
} from 'services/api/endpoints/boards'; } from 'services/api/endpoints/boards';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { BoardDTO } from 'services/api/types';
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { AnimatePresence } from 'framer-motion'; import { MoveBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
import IAIDropOverlay from 'common/components/IAIDropOverlay'; // import { boardAddedToBatch } from 'app/store/middleware/listenerMiddleware/listeners/addBoardToBatch';
import { DeleteBoardImagesContext } from '../../../../app/contexts/DeleteBoardImagesContext'; import IAIDroppable from 'common/components/IAIDroppable';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
import { import { DeleteBoardImagesContext } from '../../../../../app/contexts/DeleteBoardImagesContext';
MoveBoardDropData,
isValidDrop,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
interface HoverableBoardProps { interface GalleryBoardProps {
board: BoardDTO; board: BoardDTO;
isSelected: boolean; isSelected: boolean;
} }
const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => { const GalleryBoard = memo(({ board, isSelected }: GalleryBoardProps) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { currentData: coverImage } = useGetImageDTOQuery( const { currentData: coverImage } = useGetImageDTOQuery(
@ -71,21 +67,22 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
deleteBoard(board_id); deleteBoard(board_id);
}, [board_id, deleteBoard]); }, [board_id, deleteBoard]);
const handleAddBoardToBatch = useCallback(() => {
// dispatch(boardAddedToBatch({ board_id }));
}, []);
const handleDeleteBoardAndImages = useCallback(() => { const handleDeleteBoardAndImages = useCallback(() => {
console.log({ board });
onClickDeleteBoardImages(board); onClickDeleteBoardImages(board);
}, [board, onClickDeleteBoardImages]); }, [board, onClickDeleteBoardImages]);
const droppableData: MoveBoardDropData = { const droppableData: MoveBoardDropData = useMemo(
id: board_id, () => ({
actionType: 'MOVE_BOARD', id: board_id,
context: { boardId: board_id }, actionType: 'MOVE_BOARD',
}; context: { boardId: board_id },
}),
const { isOver, setNodeRef, active } = useDroppable({ [board_id]
id: `board_droppable_${board_id}`, );
data: droppableData,
});
return ( return (
<Box sx={{ touchAction: 'none', height: 'full' }}> <Box sx={{ touchAction: 'none', height: 'full' }}>
@ -94,16 +91,25 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
renderMenu={() => ( renderMenu={() => (
<MenuList sx={{ visibility: 'visible !important' }}> <MenuList sx={{ visibility: 'visible !important' }}>
{board.image_count > 0 && ( {board.image_count > 0 && (
<MenuItem <>
sx={{ color: 'error.300' }} <MenuItem
icon={<FaTrash />} isDisabled={!board.image_count}
onClickCapture={handleDeleteBoardAndImages} icon={<FaImages />}
> onClickCapture={handleAddBoardToBatch}
Delete Board and Images >
</MenuItem> Add Board to Batch
</MenuItem>
<MenuItem
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
icon={<FaTrash />}
onClickCapture={handleDeleteBoardAndImages}
>
Delete Board and Images
</MenuItem>
</>
)} )}
<MenuItem <MenuItem
sx={{ color: mode('error.700', 'error.300')(colorMode) }} sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
icon={<FaTrash />} icon={<FaTrash />}
onClickCapture={handleDeleteBoard} onClickCapture={handleDeleteBoard}
> >
@ -127,7 +133,6 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
}} }}
> >
<Flex <Flex
ref={setNodeRef}
onClick={handleSelectBoard} onClick={handleSelectBoard}
sx={{ sx={{
position: 'relative', position: 'relative',
@ -167,11 +172,7 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
> >
<Badge variant="solid">{board.image_count}</Badge> <Badge variant="solid">{board.image_count}</Badge>
</Flex> </Flex>
<AnimatePresence> <IAIDroppable data={droppableData} />
{isValidDrop(droppableData, active) && (
<IAIDropOverlay isOver={isOver} />
)}
</AnimatePresence>
</Flex> </Flex>
<Flex <Flex
@ -219,6 +220,6 @@ const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
); );
}); });
HoverableBoard.displayName = 'HoverableBoard'; GalleryBoard.displayName = 'HoverableBoard';
export default HoverableBoard; export default GalleryBoard;

View File

@ -0,0 +1,83 @@
import { As, Badge, Flex } from '@chakra-ui/react';
import { TypesafeDroppableData } from 'app/components/ImageDnd/typesafeDnd';
import IAIDroppable from 'common/components/IAIDroppable';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
type GenericBoardProps = {
droppableData: TypesafeDroppableData;
onClick: () => void;
isSelected: boolean;
icon: As;
label: string;
badgeCount?: number;
};
const GenericBoard = (props: GenericBoardProps) => {
const { droppableData, onClick, isSelected, icon, label, badgeCount } = props;
return (
<Flex
sx={{
flexDir: 'column',
justifyContent: 'space-between',
alignItems: 'center',
cursor: 'pointer',
w: 'full',
h: 'full',
borderRadius: 'base',
}}
>
<Flex
onClick={onClick}
sx={{
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
w: 'full',
aspectRatio: '1/1',
overflow: 'hidden',
shadow: isSelected ? 'selected.light' : undefined,
_dark: { shadow: isSelected ? 'selected.dark' : undefined },
flexShrink: 0,
}}
>
<IAINoContentFallback
boxSize={8}
icon={icon}
sx={{
border: '2px solid var(--invokeai-colors-base-200)',
_dark: { border: '2px solid var(--invokeai-colors-base-800)' },
}}
/>
<Flex
sx={{
position: 'absolute',
insetInlineEnd: 0,
top: 0,
p: 1,
}}
>
{badgeCount !== undefined && (
<Badge variant="solid">{badgeCount}</Badge>
)}
</Flex>
<IAIDroppable data={droppableData} />
</Flex>
<Flex
sx={{
h: 'full',
alignItems: 'center',
fontWeight: isSelected ? 600 : undefined,
fontSize: 'xs',
color: isSelected ? 'base.900' : 'base.700',
_dark: { color: isSelected ? 'base.50' : 'base.200' },
}}
>
{label}
</Flex>
</Flex>
);
};
export default GenericBoard;

View File

@ -8,8 +8,6 @@ import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover'; import IAIPopover from 'common/components/IAIPopover';
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
@ -36,7 +34,6 @@ import {
FaCode, FaCode,
FaCopy, FaCopy,
FaDownload, FaDownload,
FaExpand,
FaExpandArrowsAlt, FaExpandArrowsAlt,
FaGrinStars, FaGrinStars,
FaHourglassHalf, FaHourglassHalf,
@ -45,12 +42,16 @@ import {
FaShare, FaShare,
FaShareAlt, FaShareAlt,
} from 'react-icons/fa'; } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import {
import { sentImageToCanvas, sentImageToImg2Img } from '../store/actions'; useGetImageDTOQuery,
useGetImageMetadataQuery,
} from 'services/api/endpoints/images';
import { useDebounce } from 'use-debounce';
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
const currentImageButtonsSelector = createSelector( const currentImageButtonsSelector = createSelector(
[stateSelector, activeTabNameSelector], [stateSelector, activeTabNameSelector],
({ gallery, system, postprocessing, ui, lightbox }, activeTabName) => { ({ gallery, system, postprocessing, ui }, activeTabName) => {
const { const {
isProcessing, isProcessing,
isConnected, isConnected,
@ -62,8 +63,6 @@ const currentImageButtonsSelector = createSelector(
const { upscalingLevel, facetoolStrength } = postprocessing; const { upscalingLevel, facetoolStrength } = postprocessing;
const { isLightboxOpen } = lightbox;
const { const {
shouldShowImageDetails, shouldShowImageDetails,
shouldHidePreview, shouldHidePreview,
@ -84,7 +83,6 @@ const currentImageButtonsSelector = createSelector(
shouldDisableToolbarButtons: Boolean(progressImage) || !lastSelectedImage, shouldDisableToolbarButtons: Boolean(progressImage) || !lastSelectedImage,
shouldShowImageDetails, shouldShowImageDetails,
activeTabName, activeTabName,
isLightboxOpen,
shouldHidePreview, shouldHidePreview,
shouldShowProgressInViewer, shouldShowProgressInViewer,
lastSelectedImage, lastSelectedImage,
@ -110,14 +108,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
facetoolStrength, facetoolStrength,
shouldDisableToolbarButtons, shouldDisableToolbarButtons,
shouldShowImageDetails, shouldShowImageDetails,
isLightboxOpen,
activeTabName, activeTabName,
shouldHidePreview,
lastSelectedImage, lastSelectedImage,
shouldShowProgressInViewer, shouldShowProgressInViewer,
} = useAppSelector(currentImageButtonsSelector); } = useAppSelector(currentImageButtonsSelector);
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled; const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled; const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
const isFaceRestoreEnabled = useFeatureStatus('faceRestore').isFeatureEnabled; const isFaceRestoreEnabled = useFeatureStatus('faceRestore').isFeatureEnabled;
@ -128,33 +123,22 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const { recallBothPrompts, recallSeed, recallAllParameters } = const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters(); useRecallParameters();
const { currentData: image } = useGetImageDTOQuery( const [debouncedMetadataQueryArg, debounceState] = useDebounce(
lastSelectedImage,
500
);
const { currentData: image, isFetching } = useGetImageDTOQuery(
lastSelectedImage ?? skipToken lastSelectedImage ?? skipToken
); );
// const handleCopyImage = useCallback(async () => { const { currentData: metadataData } = useGetImageMetadataQuery(
// if (!image?.url) { debounceState.isPending()
// return; ? skipToken
// } : debouncedMetadataQueryArg ?? skipToken
);
// const url = getUrl(image.url); const metadata = metadataData?.metadata;
// if (!url) {
// return;
// }
// const blob = await fetch(url).then((res) => res.blob());
// const data = [new ClipboardItem({ [blob.type]: blob })];
// await navigator.clipboard.write(data);
// toast({
// title: t('toast.imageCopied'),
// status: 'success',
// duration: 2500,
// isClosable: true,
// });
// }, [getUrl, t, image?.url, toast]);
const handleCopyImageLink = useCallback(() => { const handleCopyImageLink = useCallback(() => {
const getImageUrl = () => { const getImageUrl = () => {
@ -193,29 +177,26 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}, [toaster, t, image]); }, [toaster, t, image]);
const handleClickUseAllParameters = useCallback(() => { const handleClickUseAllParameters = useCallback(() => {
recallAllParameters(image); recallAllParameters(metadata);
}, [image, recallAllParameters]); }, [metadata, recallAllParameters]);
useHotkeys( useHotkeys(
'a', 'a',
() => { () => {
handleClickUseAllParameters; handleClickUseAllParameters;
}, },
[image, recallAllParameters] [metadata, recallAllParameters]
); );
const handleUseSeed = useCallback(() => { const handleUseSeed = useCallback(() => {
recallSeed(image?.metadata?.seed); recallSeed(metadata?.seed);
}, [image, recallSeed]); }, [metadata?.seed, recallSeed]);
useHotkeys('s', handleUseSeed, [image]); useHotkeys('s', handleUseSeed, [image]);
const handleUsePrompt = useCallback(() => { const handleUsePrompt = useCallback(() => {
recallBothPrompts( recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt);
image?.metadata?.positive_conditioning, }, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]);
image?.metadata?.negative_conditioning
);
}, [image, recallBothPrompts]);
useHotkeys('p', handleUsePrompt, [image]); useHotkeys('p', handleUsePrompt, [image]);
@ -304,7 +285,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const handleSendToCanvas = useCallback(() => { const handleSendToCanvas = useCallback(() => {
if (!image) return; if (!image) return;
dispatch(sentImageToCanvas()); dispatch(sentImageToCanvas());
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
dispatch(setInitialCanvasImage(image)); dispatch(setInitialCanvasImage(image));
dispatch(requestCanvasRescale()); dispatch(requestCanvasRescale());
@ -319,7 +299,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
duration: 2500, duration: 2500,
isClosable: true, isClosable: true,
}); });
}, [image, isLightboxOpen, dispatch, activeTabName, toaster, t]); }, [image, dispatch, activeTabName, toaster, t]);
useHotkeys( useHotkeys(
'i', 'i',
@ -342,10 +322,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer)); dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer));
}, [dispatch, shouldShowProgressInViewer]); }, [dispatch, shouldShowProgressInViewer]);
const handleLightBox = useCallback(() => {
dispatch(setIsLightboxOpen(!isLightboxOpen));
}, [dispatch, isLightboxOpen]);
return ( return (
<> <>
<Flex <Flex
@ -415,24 +391,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
</Link> </Link>
</Flex> </Flex>
</IAIPopover> </IAIPopover>
{isLightboxEnabled && (
<IAIIconButton
icon={<FaExpand />}
tooltip={
!isLightboxOpen
? `${t('parameters.openInViewer')} (Z)`
: `${t('parameters.closeViewer')} (Z)`
}
aria-label={
!isLightboxOpen
? `${t('parameters.openInViewer')} (Z)`
: `${t('parameters.closeViewer')} (Z)`
}
isChecked={isLightboxOpen}
onClick={handleLightBox}
isDisabled={shouldDisableToolbarButtons}
/>
)}
</ButtonGroup> </ButtonGroup>
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}> <ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
@ -440,7 +398,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaQuoteRight />} icon={<FaQuoteRight />}
tooltip={`${t('parameters.usePrompt')} (P)`} tooltip={`${t('parameters.usePrompt')} (P)`}
aria-label={`${t('parameters.usePrompt')} (P)`} aria-label={`${t('parameters.usePrompt')} (P)`}
isDisabled={!image?.metadata?.positive_conditioning} isDisabled={!metadata?.positive_prompt}
onClick={handleUsePrompt} onClick={handleUsePrompt}
/> />
@ -448,7 +406,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaSeedling />} icon={<FaSeedling />}
tooltip={`${t('parameters.useSeed')} (S)`} tooltip={`${t('parameters.useSeed')} (S)`}
aria-label={`${t('parameters.useSeed')} (S)`} aria-label={`${t('parameters.useSeed')} (S)`}
isDisabled={!image?.metadata?.seed} isDisabled={!metadata?.seed}
onClick={handleUseSeed} onClick={handleUseSeed}
/> />
@ -456,10 +414,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaAsterisk />} icon={<FaAsterisk />}
tooltip={`${t('parameters.useAll')} (A)`} tooltip={`${t('parameters.useAll')} (A)`}
aria-label={`${t('parameters.useAll')} (A)`} aria-label={`${t('parameters.useAll')} (A)`}
isDisabled={ isDisabled={!metadata}
// not sure what this list should be
!['t2l', 'l2l', 'inpaint'].includes(String(image?.metadata?.type))
}
onClick={handleClickUseAllParameters} onClick={handleClickUseAllParameters}
/> />
</ButtonGroup> </ButtonGroup>

View File

@ -8,12 +8,15 @@ import {
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySlice'; import { useNextPrevImage } from 'features/gallery/hooks/useNextPrevImage';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { AnimatePresence, motion } from 'framer-motion';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { memo, useMemo } from 'react'; import { memo, useCallback, useMemo, useRef, useState } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer'; import ImageMetadataViewer from '../ImageMetadataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons'; import NextPrevImageButtons from '../NextPrevImageButtons';
export const imagesSelector = createSelector( export const imagesSelector = createSelector(
[stateSelector, selectLastSelectedImage], [stateSelector, selectLastSelectedImage],
@ -49,6 +52,45 @@ const CurrentImagePreview = () => {
shouldAntialiasProgressImage, shouldAntialiasProgressImage,
} = useAppSelector(imagesSelector); } = useAppSelector(imagesSelector);
const {
handlePrevImage,
handleNextImage,
prevImageId,
nextImageId,
isOnLastImage,
handleLoadMoreImages,
areMoreImagesAvailable,
isFetching,
} = useNextPrevImage();
useHotkeys(
'left',
() => {
handlePrevImage();
},
[prevImageId]
);
useHotkeys(
'right',
() => {
if (isOnLastImage && areMoreImagesAvailable && !isFetching) {
handleLoadMoreImages();
return;
}
if (!isOnLastImage) {
handleNextImage();
}
},
[
nextImageId,
isOnLastImage,
areMoreImagesAvailable,
handleLoadMoreImages,
isFetching,
]
);
const { const {
currentData: imageDTO, currentData: imageDTO,
isLoading, isLoading,
@ -74,8 +116,27 @@ const CurrentImagePreview = () => {
[] []
); );
// Show and hide the next/prev buttons on mouse move
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] =
useState<boolean>(false);
const timeoutId = useRef(0);
const handleMouseOver = useCallback(() => {
setShouldShowNextPrevButtons(true);
window.clearTimeout(timeoutId.current);
}, []);
const handleMouseOut = useCallback(() => {
timeoutId.current = window.setTimeout(() => {
setShouldShowNextPrevButtons(false);
}, 500);
}, []);
return ( return (
<Flex <Flex
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
sx={{ sx={{
width: 'full', width: 'full',
height: 'full', height: 'full',
@ -118,25 +179,38 @@ const CurrentImagePreview = () => {
width: 'full', width: 'full',
height: 'full', height: 'full',
borderRadius: 'base', borderRadius: 'base',
overflow: 'scroll',
}} }}
> >
<ImageMetadataViewer image={imageDTO} /> <ImageMetadataViewer image={imageDTO} />
</Box> </Box>
)} )}
{!shouldShowImageDetails && imageDTO && ( <AnimatePresence>
<Box {!shouldShowImageDetails && imageDTO && shouldShowNextPrevButtons && (
sx={{ <motion.div
position: 'absolute', key="nextPrevButtons"
top: '0', initial={{
width: 'full', opacity: 0,
height: 'full', }}
pointerEvents: 'none', animate={{
}} opacity: 1,
> transition: { duration: 0.1 },
<NextPrevImageButtons /> }}
</Box> exit={{
)} opacity: 0,
transition: { duration: 0.1 },
}}
style={{
position: 'absolute',
top: '0',
width: '100%',
height: '100%',
pointerEvents: 'none',
}}
>
<NextPrevImageButtons />
</motion.div>
)}
</AnimatePresence>
</Flex> </Flex>
); );
}; };

View File

@ -5,32 +5,23 @@ import { setGalleryImageMinimumWidth } from 'features/gallery/store/gallerySlice
import { clamp, isEqual } from 'lodash-es'; import { clamp, isEqual } from 'lodash-es';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import './ImageGallery.css';
import ImageGalleryContent from './ImageGalleryContent';
import ResizableDrawer from 'features/ui/components/common/ResizableDrawer/ResizableDrawer';
import { setShouldShowGallery } from 'features/ui/store/uiSlice';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import ResizableDrawer from 'features/ui/components/common/ResizableDrawer/ResizableDrawer';
import { import {
activeTabNameSelector, activeTabNameSelector,
uiSelector, uiSelector,
} from 'features/ui/store/uiSelectors'; } from 'features/ui/store/uiSelectors';
import { isStagingSelector } from 'features/canvas/store/canvasSelectors'; import { setShouldShowGallery } from 'features/ui/store/uiSlice';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
import { memo } from 'react'; import { memo } from 'react';
import ImageGalleryContent from './ImageGalleryContent';
const selector = createSelector( const selector = createSelector(
[ [activeTabNameSelector, uiSelector, gallerySelector, isStagingSelector],
activeTabNameSelector, (activeTabName, ui, gallery, isStaging) => {
uiSelector,
gallerySelector,
isStagingSelector,
lightboxSelector,
],
(activeTabName, ui, gallery, isStaging, lightbox) => {
const { shouldPinGallery, shouldShowGallery } = ui; const { shouldPinGallery, shouldShowGallery } = ui;
const { galleryImageMinimumWidth } = gallery; const { galleryImageMinimumWidth } = gallery;
const { isLightboxOpen } = lightbox;
return { return {
activeTabName, activeTabName,
@ -39,7 +30,6 @@ const selector = createSelector(
shouldShowGallery, shouldShowGallery,
galleryImageMinimumWidth, galleryImageMinimumWidth,
isResizable: activeTabName !== 'unifiedCanvas', isResizable: activeTabName !== 'unifiedCanvas',
isLightboxOpen,
}; };
}, },
{ {
@ -58,7 +48,6 @@ const GalleryDrawer = () => {
// activeTabName, // activeTabName,
// isStaging, // isStaging,
// isResizable, // isResizable,
// isLightboxOpen,
} = useAppSelector(selector); } = useAppSelector(selector);
const handleCloseGallery = () => { const handleCloseGallery = () => {

View File

@ -1,273 +0,0 @@
import { ExternalLinkIcon } from '@chakra-ui/icons';
import { MenuItem, MenuList } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
import {
imagesAddedToBatch,
selectionAddedToBatch,
} from 'features/batch/store/batchSlice';
import {
resizeAndScaleCanvas,
setInitialCanvasImage,
} from 'features/canvas/store/canvasSlice';
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { memo, useCallback, useContext, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaExpand, FaFolder, FaShare, FaTrash } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
import { ImageDTO } from 'services/api/types';
import { AddImageToBoardContext } from '../../../app/contexts/AddImageToBoardContext';
import { sentImageToCanvas, sentImageToImg2Img } from '../store/actions';
type Props = {
image: ImageDTO;
children: ContextMenuProps<HTMLDivElement>['children'];
};
const ImageContextMenu = ({ image, children }: Props) => {
const selector = useMemo(
() =>
createSelector(
[stateSelector],
({ gallery, batch }) => {
const selectionCount = gallery.selection.length;
const isInBatch = batch.imageNames.includes(image.image_name);
return { selectionCount, isInBatch };
},
defaultSelectorOptions
),
[image.image_name]
);
const { selectionCount, isInBatch } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const toaster = useAppToaster();
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { onClickAddToBoard } = useContext(AddImageToBoardContext);
const handleDelete = useCallback(() => {
if (!image) {
return;
}
dispatch(imageToDeleteSelected(image));
}, [dispatch, image]);
const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters();
const [removeFromBoard] = useRemoveImageFromBoardMutation();
// Recall parameters handlers
const handleRecallPrompt = useCallback(() => {
recallBothPrompts(
image.metadata?.positive_conditioning,
image.metadata?.negative_conditioning
);
}, [
image.metadata?.negative_conditioning,
image.metadata?.positive_conditioning,
recallBothPrompts,
]);
const handleRecallSeed = useCallback(() => {
recallSeed(image.metadata?.seed);
}, [image, recallSeed]);
const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(image));
}, [dispatch, image]);
// const handleRecallInitialImage = useCallback(() => {
// recallInitialImage(image.metadata.invokeai?.node?.image);
// }, [image, recallInitialImage]);
const handleSendToCanvas = () => {
dispatch(sentImageToCanvas());
dispatch(setInitialCanvasImage(image));
dispatch(resizeAndScaleCanvas());
dispatch(setActiveTab('unifiedCanvas'));
toaster({
title: t('toast.sentToUnifiedCanvas'),
status: 'success',
duration: 2500,
isClosable: true,
});
};
const handleUseAllParameters = useCallback(() => {
recallAllParameters(image);
}, [image, recallAllParameters]);
const handleLightBox = () => {
// dispatch(setCurrentImage(image));
// dispatch(setIsLightboxOpen(true));
};
const handleAddToBoard = useCallback(() => {
onClickAddToBoard(image);
}, [image, onClickAddToBoard]);
const handleRemoveFromBoard = useCallback(() => {
if (!image.board_id) {
return;
}
removeFromBoard({ board_id: image.board_id, image_name: image.image_name });
}, [image.board_id, image.image_name, removeFromBoard]);
const handleOpenInNewTab = () => {
window.open(image.image_url, '_blank');
};
const handleAddSelectionToBatch = useCallback(() => {
dispatch(selectionAddedToBatch());
}, [dispatch]);
const handleAddToBatch = useCallback(() => {
dispatch(imagesAddedToBatch([image.image_name]));
}, [dispatch, image.image_name]);
return (
<ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }}
renderMenu={() => (
<MenuList sx={{ visibility: 'visible !important' }}>
{selectionCount === 1 ? (
<>
<MenuItem
icon={<ExternalLinkIcon />}
onClickCapture={handleOpenInNewTab}
>
{t('common.openInNewTab')}
</MenuItem>
{isLightboxEnabled && (
<MenuItem icon={<FaExpand />} onClickCapture={handleLightBox}>
{t('parameters.openInViewer')}
</MenuItem>
)}
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallPrompt}
isDisabled={
image?.metadata?.positive_conditioning === undefined
}
>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallSeed}
isDisabled={image?.metadata?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
{/* <MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallInitialImage}
isDisabled={image?.metadata?.type !== 'img2img'}
>
{t('parameters.useInitImg')}
</MenuItem> */}
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleUseAllParameters}
isDisabled={
// what should these be
!['t2l', 'l2l', 'inpaint'].includes(
String(image?.metadata?.type)
)
}
>
{t('parameters.useAll')}
</MenuItem>
<MenuItem
icon={<FaShare />}
onClickCapture={handleSendToImageToImage}
id="send-to-img2img"
>
{t('parameters.sendToImg2Img')}
</MenuItem>
{isCanvasEnabled && (
<MenuItem
icon={<FaShare />}
onClickCapture={handleSendToCanvas}
id="send-to-canvas"
>
{t('parameters.sendToUnifiedCanvas')}
</MenuItem>
)}
{/* <MenuItem
icon={<FaFolder />}
isDisabled={isInBatch}
onClickCapture={handleAddToBatch}
>
Add to Batch
</MenuItem> */}
<MenuItem icon={<FaFolder />} onClickCapture={handleAddToBoard}>
{image.board_id ? 'Change Board' : 'Add to Board'}
</MenuItem>
{image.board_id && (
<MenuItem
icon={<FaFolder />}
onClickCapture={handleRemoveFromBoard}
>
Remove from Board
</MenuItem>
)}
<MenuItem
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
icon={<FaTrash />}
onClickCapture={handleDelete}
>
{t('gallery.deleteImage')}
</MenuItem>
</>
) : (
<>
<MenuItem
isDisabled={true}
icon={<FaFolder />}
onClickCapture={handleAddToBoard}
>
Move Selection to Board
</MenuItem>
{/* <MenuItem
icon={<FaFolderPlus />}
onClickCapture={handleAddSelectionToBatch}
>
Add Selection to Batch
</MenuItem> */}
<MenuItem
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
icon={<FaTrash />}
onClickCapture={handleDelete}
>
Delete Selection
</MenuItem>
</>
)}
</MenuList>
)}
>
{children}
</ContextMenu>
);
};
export default memo(ImageContextMenu);

View File

@ -0,0 +1,52 @@
import { MenuList } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu';
import { memo, useMemo } from 'react';
import { ImageDTO } from 'services/api/types';
import MultipleSelectionMenuItems from './MultipleSelectionMenuItems';
import SingleSelectionMenuItems from './SingleSelectionMenuItems';
type Props = {
imageDTO: ImageDTO;
children: ContextMenuProps<HTMLDivElement>['children'];
};
const ImageContextMenu = ({ imageDTO, children }: Props) => {
const selector = useMemo(
() =>
createSelector(
[stateSelector],
({ gallery }) => {
const selectionCount = gallery.selection.length;
return { selectionCount };
},
defaultSelectorOptions
),
[]
);
const { selectionCount } = useAppSelector(selector);
return (
<ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }}
renderMenu={() => (
<MenuList sx={{ visibility: 'visible !important' }}>
{selectionCount === 1 ? (
<SingleSelectionMenuItems imageDTO={imageDTO} />
) : (
<MultipleSelectionMenuItems />
)}
</MenuList>
)}
>
{children}
</ContextMenu>
);
};
export default memo(ImageContextMenu);

View File

@ -0,0 +1,40 @@
import { MenuItem } from '@chakra-ui/react';
import { useCallback } from 'react';
import { FaFolder, FaFolderPlus, FaTrash } from 'react-icons/fa';
const MultipleSelectionMenuItems = () => {
const handleAddSelectionToBoard = useCallback(() => {
// TODO: add selection to board
}, []);
const handleDeleteSelection = useCallback(() => {
// TODO: delete all selected images
}, []);
const handleAddSelectionToBatch = useCallback(() => {
// TODO: add selection to batch
}, []);
return (
<>
<MenuItem icon={<FaFolder />} onClickCapture={handleAddSelectionToBoard}>
Move Selection to Board
</MenuItem>
<MenuItem
icon={<FaFolderPlus />}
onClickCapture={handleAddSelectionToBatch}
>
Add Selection to Batch
</MenuItem>
<MenuItem
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
icon={<FaTrash />}
onClickCapture={handleDeleteSelection}
>
Delete Selection
</MenuItem>
</>
);
};
export default MultipleSelectionMenuItems;

View File

@ -0,0 +1,207 @@
import { ExternalLinkIcon } from '@chakra-ui/icons';
import { MenuItem } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import {
resizeAndScaleCanvas,
setInitialCanvasImage,
} from 'features/canvas/store/canvasSlice';
import { imagesAddedToBatch } from 'features/gallery/store/gallerySlice';
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { memo, useCallback, useContext, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaFolder, FaShare, FaTrash } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { useRemoveImageFromBoardMutation } from 'services/api/endpoints/boardImages';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
type SingleSelectionMenuItemsProps = {
imageDTO: ImageDTO;
};
const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const { imageDTO } = props;
const selector = useMemo(
() =>
createSelector(
[stateSelector],
({ gallery }) => {
const isInBatch = gallery.batchImageNames.includes(
imageDTO.image_name
);
return { isInBatch };
},
defaultSelectorOptions
),
[imageDTO.image_name]
);
const { isInBatch } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const toaster = useAppToaster();
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const isBatchEnabled = useFeatureStatus('batches').isFeatureEnabled;
const { onClickAddToBoard } = useContext(AddImageToBoardContext);
const { currentData } = useGetImageMetadataQuery(imageDTO.image_name);
const metadata = currentData?.metadata;
const handleDelete = useCallback(() => {
if (!imageDTO) {
return;
}
dispatch(imageToDeleteSelected(imageDTO));
}, [dispatch, imageDTO]);
const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters();
const [removeFromBoard] = useRemoveImageFromBoardMutation();
// Recall parameters handlers
const handleRecallPrompt = useCallback(() => {
recallBothPrompts(metadata?.positive_prompt, metadata?.negative_prompt);
}, [metadata?.negative_prompt, metadata?.positive_prompt, recallBothPrompts]);
const handleRecallSeed = useCallback(() => {
recallSeed(metadata?.seed);
}, [metadata?.seed, recallSeed]);
const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(imageDTO));
}, [dispatch, imageDTO]);
const handleSendToCanvas = useCallback(() => {
dispatch(sentImageToCanvas());
dispatch(setInitialCanvasImage(imageDTO));
dispatch(resizeAndScaleCanvas());
dispatch(setActiveTab('unifiedCanvas'));
toaster({
title: t('toast.sentToUnifiedCanvas'),
status: 'success',
duration: 2500,
isClosable: true,
});
}, [dispatch, imageDTO, t, toaster]);
const handleUseAllParameters = useCallback(() => {
console.log(metadata);
recallAllParameters(metadata);
}, [metadata, recallAllParameters]);
const handleAddToBoard = useCallback(() => {
onClickAddToBoard(imageDTO);
}, [imageDTO, onClickAddToBoard]);
const handleRemoveFromBoard = useCallback(() => {
if (!imageDTO.board_id) {
return;
}
removeFromBoard({
board_id: imageDTO.board_id,
image_name: imageDTO.image_name,
});
}, [imageDTO.board_id, imageDTO.image_name, removeFromBoard]);
const handleOpenInNewTab = useCallback(() => {
window.open(imageDTO.image_url, '_blank');
}, [imageDTO.image_url]);
const handleAddToBatch = useCallback(() => {
dispatch(imagesAddedToBatch([imageDTO.image_name]));
}, [dispatch, imageDTO.image_name]);
return (
<>
<MenuItem icon={<ExternalLinkIcon />} onClickCapture={handleOpenInNewTab}>
{t('common.openInNewTab')}
</MenuItem>
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallPrompt}
isDisabled={
metadata?.positive_prompt === undefined &&
metadata?.negative_prompt === undefined
}
>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallSeed}
isDisabled={metadata?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleUseAllParameters}
isDisabled={!metadata}
>
{t('parameters.useAll')}
</MenuItem>
<MenuItem
icon={<FaShare />}
onClickCapture={handleSendToImageToImage}
id="send-to-img2img"
>
{t('parameters.sendToImg2Img')}
</MenuItem>
{isCanvasEnabled && (
<MenuItem
icon={<FaShare />}
onClickCapture={handleSendToCanvas}
id="send-to-canvas"
>
{t('parameters.sendToUnifiedCanvas')}
</MenuItem>
)}
{isBatchEnabled && (
<MenuItem
icon={<FaFolder />}
isDisabled={isInBatch}
onClickCapture={handleAddToBatch}
>
Add to Batch
</MenuItem>
)}
<MenuItem icon={<FaFolder />} onClickCapture={handleAddToBoard}>
{imageDTO.board_id ? 'Change Board' : 'Add to Board'}
</MenuItem>
{imageDTO.board_id && (
<MenuItem icon={<FaFolder />} onClickCapture={handleRemoveFromBoard}>
Remove from Board
</MenuItem>
)}
<MenuItem
sx={{ color: 'error.600', _dark: { color: 'error.300' } }}
icon={<FaTrash />}
onClickCapture={handleDelete}
>
{t('gallery.deleteImage')}
</MenuItem>
</>
);
};
export default memo(SingleSelectionMenuItems);

View File

@ -1,35 +0,0 @@
.ltr-image-gallery-css-transition-enter {
transform: translateX(150%);
}
.ltr-image-gallery-css-transition-enter-active {
transform: translateX(0);
transition: all 120ms ease-out;
}
.ltr-image-gallery-css-transition-exit {
transform: translateX(0);
}
.ltr-image-gallery-css-transition-exit-active {
transform: translateX(150%);
transition: all 120ms ease-out;
}
.rtl-image-gallery-css-transition-enter {
transform: translateX(-150%);
}
.rtl-image-gallery-css-transition-enter-active {
transform: translateX(0);
transition: all 120ms ease-out;
}
.rtl-image-gallery-css-transition-exit {
transform: translateX(0);
}
.rtl-image-gallery-css-transition-exit-active {
transform: translateX(-150%);
transition: all 120ms ease-out;
}

View File

@ -19,7 +19,7 @@ import {
} from 'features/gallery/store/gallerySlice'; } from 'features/gallery/store/gallerySlice';
import { togglePinGalleryPanel } from 'features/ui/store/uiSlice'; import { togglePinGalleryPanel } from 'features/ui/store/uiSlice';
import { ChangeEvent, memo, useCallback, useRef } from 'react'; import { ChangeEvent, memo, useCallback, useMemo, useRef } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs'; import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs';
import { FaImage, FaServer, FaWrench } from 'react-icons/fa'; import { FaImage, FaServer, FaWrench } from 'react-icons/fa';
@ -29,16 +29,12 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { import { shouldAutoSwitchChanged } from 'features/gallery/store/gallerySlice';
ASSETS_CATEGORIES,
IMAGE_CATEGORIES,
imageCategoriesChanged,
shouldAutoSwitchChanged,
} from 'features/gallery/store/gallerySlice';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { mode } from 'theme/util/mode'; import { mode } from 'theme/util/mode';
import BoardsList from './Boards/BoardsList'; import BoardsList from './Boards/BoardsList/BoardsList';
import ImageGalleryGrid from './ImageGalleryGrid'; import BatchImageGrid from './ImageGrid/BatchImageGrid';
import GalleryImageGrid from './ImageGrid/GalleryImageGrid';
const selector = createSelector( const selector = createSelector(
[stateSelector], [stateSelector],
@ -66,6 +62,7 @@ const ImageGalleryContent = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const resizeObserverRef = useRef<HTMLDivElement>(null); const resizeObserverRef = useRef<HTMLDivElement>(null);
const galleryGridRef = useRef<HTMLDivElement>(null);
const { colorMode } = useColorMode(); const { colorMode } = useColorMode();
@ -83,6 +80,16 @@ const ImageGalleryContent = () => {
}), }),
}); });
const boardTitle = useMemo(() => {
if (selectedBoardId === 'batch') {
return 'Batch';
}
if (selectedBoard) {
return selectedBoard.board_name;
}
return 'All Images';
}, [selectedBoard, selectedBoardId]);
const { isOpen: isBoardListOpen, onToggle } = useDisclosure(); const { isOpen: isBoardListOpen, onToggle } = useDisclosure();
const handleChangeGalleryImageMinimumWidth = (v: number) => { const handleChangeGalleryImageMinimumWidth = (v: number) => {
@ -95,12 +102,10 @@ const ImageGalleryContent = () => {
}; };
const handleClickImagesCategory = useCallback(() => { const handleClickImagesCategory = useCallback(() => {
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
dispatch(setGalleryView('images')); dispatch(setGalleryView('images'));
}, [dispatch]); }, [dispatch]);
const handleClickAssetsCategory = useCallback(() => { const handleClickAssetsCategory = useCallback(() => {
dispatch(imageCategoriesChanged(ASSETS_CATEGORIES));
dispatch(setGalleryView('assets')); dispatch(setGalleryView('assets'));
}, [dispatch]); }, [dispatch]);
@ -163,7 +168,7 @@ const ImageGalleryContent = () => {
fontWeight: 600, fontWeight: 600,
}} }}
> >
{selectedBoard ? selectedBoard.board_name : 'All Images'} {boardTitle}
</Text> </Text>
<ChevronUpIcon <ChevronUpIcon
sx={{ sx={{
@ -216,8 +221,12 @@ const ImageGalleryContent = () => {
<BoardsList isOpen={isBoardListOpen} /> <BoardsList isOpen={isBoardListOpen} />
</Box> </Box>
</Box> </Box>
<Flex direction="column" gap={2} h="full" w="full"> <Flex ref={galleryGridRef} direction="column" gap={2} h="full" w="full">
<ImageGalleryGrid /> {selectedBoardId === 'batch' ? (
<BatchImageGrid />
) : (
<GalleryImageGrid />
)}
</Flex> </Flex>
</VStack> </VStack>
); );

View File

@ -1,233 +0,0 @@
import {
Box,
Flex,
FlexProps,
Grid,
Skeleton,
Spinner,
forwardRef,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import { IMAGE_LIMIT } from 'features/gallery/store/gallerySlice';
import { useOverlayScrollbars } from 'overlayscrollbars-react';
import {
PropsWithChildren,
memo,
useCallback,
useEffect,
useMemo,
useRef,
useState,
} from 'react';
import { useTranslation } from 'react-i18next';
import { FaImage } from 'react-icons/fa';
import GalleryImage from './GalleryImage';
import { createSelector } from '@reduxjs/toolkit';
import { RootState, stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { selectFilteredImages } from 'features/gallery/store/gallerySlice';
import { VirtuosoGrid } from 'react-virtuoso';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { receivedPageOfImages } from 'services/api/thunks/image';
import { ImageDTO } from 'services/api/types';
const selector = createSelector(
[stateSelector, selectFilteredImages],
(state, filteredImages) => {
const {
categories,
total: allImagesTotal,
isLoading,
isFetching,
selectedBoardId,
} = state.gallery;
let images = filteredImages as (ImageDTO | 'loading')[];
if (!isLoading && isFetching) {
// loading, not not the initial load
images = images.concat(Array(IMAGE_LIMIT).fill('loading'));
}
return {
images,
allImagesTotal,
isLoading,
isFetching,
categories,
selectedBoardId,
};
},
defaultSelectorOptions
);
const ImageGalleryGrid = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const rootRef = useRef(null);
const [scroller, setScroller] = useState<HTMLElement | null>(null);
const [initialize, osInstance] = useOverlayScrollbars({
defer: true,
options: {
scrollbars: {
visibility: 'auto',
autoHide: 'leave',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},
overflow: { x: 'hidden' },
},
});
const {
images,
isLoading,
isFetching,
allImagesTotal,
categories,
selectedBoardId,
} = useAppSelector(selector);
const { selectedBoard } = useListAllBoardsQuery(undefined, {
selectFromResult: ({ data }) => ({
selectedBoard: data?.find((b) => b.board_id === selectedBoardId),
}),
});
const filteredImagesTotal = useMemo(
() => selectedBoard?.image_count ?? allImagesTotal,
[allImagesTotal, selectedBoard?.image_count]
);
const areMoreAvailable = useMemo(() => {
return images.length < filteredImagesTotal;
}, [images.length, filteredImagesTotal]);
const handleLoadMoreImages = useCallback(() => {
dispatch(
receivedPageOfImages({
categories,
board_id: selectedBoardId,
is_intermediate: false,
})
);
}, [categories, dispatch, selectedBoardId]);
const handleEndReached = useMemo(() => {
if (areMoreAvailable && !isLoading) {
return handleLoadMoreImages;
}
return undefined;
}, [areMoreAvailable, handleLoadMoreImages, isLoading]);
useEffect(() => {
const { current: root } = rootRef;
if (scroller && root) {
initialize({
target: root,
elements: {
viewport: scroller,
},
});
}
return () => osInstance()?.destroy();
}, [scroller, initialize, osInstance]);
if (isLoading) {
return (
<Flex
sx={{
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
}}
>
<Spinner
size="xl"
sx={{ color: 'base.300', _dark: { color: 'base.700' } }}
/>
</Flex>
);
}
if (images.length) {
return (
<>
<Box ref={rootRef} data-overlayscrollbars="" h="100%">
<VirtuosoGrid
style={{ height: '100%' }}
data={images}
endReached={handleEndReached}
components={{
Item: ItemContainer,
List: ListContainer,
}}
scrollerRef={setScroller}
itemContent={(index, item) =>
typeof item === 'string' ? (
<Skeleton sx={{ w: 'full', h: 'full', aspectRatio: '1/1' }} />
) : (
<GalleryImage
key={`${item.image_name}-${item.thumbnail_url}`}
imageDTO={item}
/>
)
}
/>
</Box>
<IAIButton
onClick={handleLoadMoreImages}
isDisabled={!areMoreAvailable}
isLoading={isFetching}
loadingText="Loading"
flexShrink={0}
>
{areMoreAvailable
? t('gallery.loadMore')
: t('gallery.allImagesLoaded')}
</IAIButton>
</>
);
}
return (
<IAINoContentFallback
label={t('gallery.noImagesInGallery')}
icon={FaImage}
/>
);
};
type ItemContainerProps = PropsWithChildren & FlexProps;
const ItemContainer = forwardRef((props: ItemContainerProps, ref) => (
<Box className="item-container" ref={ref} p={1.5}>
{props.children}
</Box>
));
type ListContainerProps = PropsWithChildren & FlexProps;
const ListContainer = forwardRef((props: ListContainerProps, ref) => {
const galleryImageMinimumWidth = useAppSelector(
(state: RootState) => state.gallery.galleryImageMinimumWidth
);
return (
<Grid
{...props}
className="list-container"
ref={ref}
sx={{
gridTemplateColumns: `repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr));`,
}}
>
{props.children}
</Grid>
);
});
export default memo(ImageGalleryGrid);

View File

@ -0,0 +1,128 @@
import { Box } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIErrorLoadingImageFallback from 'common/components/IAIErrorLoadingImageFallback';
import IAIFillSkeleton from 'common/components/IAIFillSkeleton';
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import {
imageRangeEndSelected,
imageSelected,
imageSelectionToggled,
imagesRemovedFromBatch,
} from 'features/gallery/store/gallerySlice';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
const makeSelector = (image_name: string) =>
createSelector(
[stateSelector],
(state) => ({
selectionCount: state.gallery.selection.length,
selection: state.gallery.selection,
isSelected: state.gallery.selection.includes(image_name),
}),
defaultSelectorOptions
);
type BatchImageProps = {
imageName: string;
};
const BatchImage = (props: BatchImageProps) => {
const dispatch = useAppDispatch();
const { imageName } = props;
const {
currentData: imageDTO,
isLoading,
isError,
isSuccess,
} = useGetImageDTOQuery(imageName);
const selector = useMemo(() => makeSelector(imageName), [imageName]);
const { isSelected, selectionCount, selection } = useAppSelector(selector);
const handleClickRemove = useCallback(() => {
dispatch(imagesRemovedFromBatch([imageName]));
}, [dispatch, imageName]);
const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (e.shiftKey) {
dispatch(imageRangeEndSelected(imageName));
} else if (e.ctrlKey || e.metaKey) {
dispatch(imageSelectionToggled(imageName));
} else {
dispatch(imageSelected(imageName));
}
},
[dispatch, imageName]
);
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (selectionCount > 1) {
return {
id: 'batch',
payloadType: 'IMAGE_NAMES',
payload: { image_names: selection },
};
}
if (imageDTO) {
return {
id: 'batch',
payloadType: 'IMAGE_DTO',
payload: { imageDTO },
};
}
}, [imageDTO, selection, selectionCount]);
if (isLoading) {
return <IAIFillSkeleton />;
}
if (isError || !imageDTO) {
return <IAIErrorLoadingImageFallback />;
}
return (
<Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}>
<ImageContextMenu imageDTO={imageDTO}>
{(ref) => (
<Box
position="relative"
key={imageName}
userSelect="none"
ref={ref}
sx={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
aspectRatio: '1/1',
}}
>
<IAIDndImage
onClick={handleClick}
imageDTO={imageDTO}
draggableData={draggableData}
isSelected={isSelected}
minSize={0}
onClickReset={handleClickRemove}
isDropDisabled={true}
imageSx={{ w: 'full', h: 'full' }}
isUploadDisabled={true}
resetTooltip="Remove from batch"
withResetIcon
thumbnail
/>
</Box>
)}
</ImageContextMenu>
</Box>
);
};
export default memo(BatchImage);

View File

@ -0,0 +1,87 @@
import { Box } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import { useOverlayScrollbars } from 'overlayscrollbars-react';
import { memo, useEffect, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { FaImage } from 'react-icons/fa';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { VirtuosoGrid } from 'react-virtuoso';
import BatchImage from './BatchImage';
import ItemContainer from './ImageGridItemContainer';
import ListContainer from './ImageGridListContainer';
const selector = createSelector(
[stateSelector],
(state) => {
return {
imageNames: state.gallery.batchImageNames,
};
},
defaultSelectorOptions
);
const BatchImageGrid = () => {
const { t } = useTranslation();
const rootRef = useRef(null);
const [scroller, setScroller] = useState<HTMLElement | null>(null);
const [initialize, osInstance] = useOverlayScrollbars({
defer: true,
options: {
scrollbars: {
visibility: 'auto',
autoHide: 'leave',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},
overflow: { x: 'hidden' },
},
});
const { imageNames } = useAppSelector(selector);
useEffect(() => {
const { current: root } = rootRef;
if (scroller && root) {
initialize({
target: root,
elements: {
viewport: scroller,
},
});
}
return () => osInstance()?.destroy();
}, [scroller, initialize, osInstance]);
if (imageNames.length) {
return (
<Box ref={rootRef} data-overlayscrollbars="" h="100%">
<VirtuosoGrid
style={{ height: '100%' }}
data={imageNames}
components={{
Item: ItemContainer,
List: ListContainer,
}}
scrollerRef={setScroller}
itemContent={(index, imageName) => (
<BatchImage key={imageName} imageName={imageName} />
)}
/>
</Box>
);
}
return (
<IAINoContentFallback
label={t('gallery.noImagesInGallery')}
icon={FaImage}
/>
);
};
export default memo(BatchImageGrid);

View File

@ -1,69 +1,57 @@
import { Box } from '@chakra-ui/react'; import { Box, Spinner } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd'; import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice'; import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
import { ImageDTO } from 'services/api/types';
import { import {
imageRangeEndSelected, imageRangeEndSelected,
imageSelected, imageSelected,
imageSelectionToggled, imageSelectionToggled,
} from '../store/gallerySlice'; } from 'features/gallery/store/gallerySlice';
import ImageContextMenu from './ImageContextMenu'; import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
export const makeSelector = (image_name: string) => export const makeSelector = (image_name: string) =>
createSelector( createSelector(
[stateSelector], [stateSelector],
({ gallery }) => { ({ gallery }) => ({
const isSelected = gallery.selection.includes(image_name); isSelected: gallery.selection.includes(image_name),
const selectionCount = gallery.selection.length; selectionCount: gallery.selection.length,
selection: gallery.selection,
return { }),
isSelected,
selectionCount,
};
},
defaultSelectorOptions defaultSelectorOptions
); );
interface HoverableImageProps { interface HoverableImageProps {
imageDTO: ImageDTO; imageName: string;
} }
/**
* Gallery image component with delete/use all/use seed buttons on hover.
*/
const GalleryImage = (props: HoverableImageProps) => { const GalleryImage = (props: HoverableImageProps) => {
const { imageDTO } = props;
const { image_url, thumbnail_url, image_name } = imageDTO;
const localSelector = useMemo(() => makeSelector(image_name), [image_name]);
const { isSelected, selectionCount } = useAppSelector(localSelector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { imageName } = props;
const { currentData: imageDTO } = useGetImageDTOQuery(imageName);
const localSelector = useMemo(() => makeSelector(imageName), [imageName]);
const { t } = useTranslation(); const { isSelected, selectionCount, selection } =
useAppSelector(localSelector);
const handleClick = useCallback( const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => { (e: MouseEvent<HTMLDivElement>) => {
// multiselect disabled for now // disable multiselect for now
// if (e.shiftKey) { // if (e.shiftKey) {
// dispatch(imageRangeEndSelected(props.imageDTO.image_name)); // dispatch(imageRangeEndSelected(imageName));
// } else if (e.ctrlKey || e.metaKey) { // } else if (e.ctrlKey || e.metaKey) {
// dispatch(imageSelectionToggled(props.imageDTO.image_name)); // dispatch(imageSelectionToggled(imageName));
// } else { // } else {
// dispatch(imageSelected(props.imageDTO.image_name)); // dispatch(imageSelected(imageName));
// } // }
dispatch(imageSelected(props.imageDTO.image_name)); dispatch(imageSelected(imageName));
}, },
[dispatch, props.imageDTO.image_name] [dispatch, imageName]
); );
const handleDelete = useCallback( const handleDelete = useCallback(
@ -81,7 +69,8 @@ const GalleryImage = (props: HoverableImageProps) => {
if (selectionCount > 1) { if (selectionCount > 1) {
return { return {
id: 'gallery-image', id: 'gallery-image',
payloadType: 'GALLERY_SELECTION', payloadType: 'IMAGE_NAMES',
payload: { image_names: selection },
}; };
} }
@ -92,15 +81,19 @@ const GalleryImage = (props: HoverableImageProps) => {
payload: { imageDTO }, payload: { imageDTO },
}; };
} }
}, [imageDTO, selectionCount]); }, [imageDTO, selection, selectionCount]);
if (!imageDTO) {
return <Spinner />;
}
return ( return (
<Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}> <Box sx={{ w: 'full', h: 'full', touchAction: 'none' }}>
<ImageContextMenu image={imageDTO}> <ImageContextMenu imageDTO={imageDTO}>
{(ref) => ( {(ref) => (
<Box <Box
position="relative" position="relative"
key={image_name} key={imageName}
userSelect="none" userSelect="none"
ref={ref} ref={ref}
sx={{ sx={{
@ -117,13 +110,13 @@ const GalleryImage = (props: HoverableImageProps) => {
isSelected={isSelected} isSelected={isSelected}
minSize={0} minSize={0}
onClickReset={handleDelete} onClickReset={handleDelete}
resetIcon={<FaTrash />}
resetTooltip="Delete image"
imageSx={{ w: 'full', h: 'full' }} imageSx={{ w: 'full', h: 'full' }}
// withResetIcon // removed bc it's too easy to accidentally delete images
isDropDisabled={true} isDropDisabled={true}
isUploadDisabled={true} isUploadDisabled={true}
thumbnail={true} thumbnail={true}
// resetIcon={<FaTrash />}
// resetTooltip="Delete image"
// withResetIcon // removed bc it's too easy to accidentally delete images
/> />
</Box> </Box>
)} )}

View File

@ -0,0 +1,204 @@
import { Box } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import { useOverlayScrollbars } from 'overlayscrollbars-react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { FaImage } from 'react-icons/fa';
import GalleryImage from './GalleryImage';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import {
ASSETS_CATEGORIES,
IMAGE_CATEGORIES,
IMAGE_LIMIT,
selectImagesAll,
} from 'features/gallery//store/gallerySlice';
import { selectFilteredImages } from 'features/gallery/store/gallerySelectors';
import { VirtuosoGrid } from 'react-virtuoso';
import { receivedPageOfImages } from 'services/api/thunks/image';
import ImageGridItemContainer from './ImageGridItemContainer';
import ImageGridListContainer from './ImageGridListContainer';
import { useListBoardImagesQuery } from '../../../../services/api/endpoints/boardImages';
const selector = createSelector(
[stateSelector, selectFilteredImages],
(state, filteredImages) => {
const {
galleryImageMinimumWidth,
selectedBoardId,
galleryView,
total,
isLoading,
} = state.gallery;
return {
imageNames: filteredImages.map((i) => i.image_name),
total,
selectedBoardId,
galleryView,
galleryImageMinimumWidth,
isLoading,
};
},
defaultSelectorOptions
);
const GalleryImageGrid = () => {
const { t } = useTranslation();
const rootRef = useRef<HTMLDivElement>(null);
const emptyGalleryRef = useRef<HTMLDivElement>(null);
const [scroller, setScroller] = useState<HTMLElement | null>(null);
const [initialize, osInstance] = useOverlayScrollbars({
defer: true,
options: {
scrollbars: {
visibility: 'auto',
autoHide: 'leave',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},
overflow: { x: 'hidden' },
},
});
const [didInitialFetch, setDidInitialFetch] = useState(false);
const dispatch = useAppDispatch();
const {
galleryImageMinimumWidth,
imageNames: imageNamesAll, //all images names loaded on main tab,
total: totalAll,
selectedBoardId,
galleryView,
isLoading: isLoadingAll,
} = useAppSelector(selector);
const { data: imagesForBoard, isLoading: isLoadingImagesForBoard } =
useListBoardImagesQuery(
{ board_id: selectedBoardId },
{ skip: selectedBoardId === 'all' }
);
const imageNames = useMemo(() => {
if (selectedBoardId === 'all') {
return imageNamesAll; // already sorted by images/uploads in gallery selector
} else {
const categories =
galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES;
const imageList = (imagesForBoard?.items || []).filter((img) =>
categories.includes(img.image_category)
);
return imageList.map((img) => img.image_name);
}
}, [selectedBoardId, galleryView, imagesForBoard, imageNamesAll]);
const areMoreAvailable = useMemo(() => {
return selectedBoardId === 'all' ? totalAll > imageNamesAll.length : false;
}, [selectedBoardId, imageNamesAll.length, totalAll]);
const isLoading = useMemo(() => {
return selectedBoardId === 'all' ? isLoadingAll : isLoadingImagesForBoard;
}, [selectedBoardId, isLoadingAll, isLoadingImagesForBoard]);
const handleLoadMoreImages = useCallback(() => {
dispatch(
receivedPageOfImages({
categories:
galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES,
is_intermediate: false,
offset: imageNames.length,
limit: IMAGE_LIMIT,
})
);
}, [dispatch, imageNames.length, galleryView]);
const handleEndReached = useMemo(() => {
if (areMoreAvailable) {
return handleLoadMoreImages;
}
return undefined;
}, [areMoreAvailable, handleLoadMoreImages]);
// useEffect(() => {
// if (!didInitialFetch) {
// return;
// }
// // rough, conservative calculation of how many images fit in the gallery
// // TODO: this gets an incorrect value on first load...
// const galleryHeight = rootRef.current?.clientHeight ?? 0;
// const galleryWidth = rootRef.current?.clientHeight ?? 0;
// const rows = galleryHeight / galleryImageMinimumWidth;
// const columns = galleryWidth / galleryImageMinimumWidth;
// const imagesToLoad = Math.ceil(rows * columns);
// setDidInitialFetch(true);
// // load up that many images
// dispatch(
// receivedPageOfImages({
// offset: 0,
// limit: 10,
// })
// );
// }, [
// didInitialFetch,
// dispatch,
// galleryImageMinimumWidth,
// galleryView,
// selectedBoardId,
// ]);
if (!isLoading && imageNames.length === 0) {
return (
<Box ref={emptyGalleryRef} sx={{ w: 'full', h: 'full' }}>
<IAINoContentFallback
label={t('gallery.noImagesInGallery')}
icon={FaImage}
/>
</Box>
);
}
console.log({ selectedBoardId });
if (status !== 'rejected') {
return (
<>
<Box ref={rootRef} data-overlayscrollbars="" h="100%">
<VirtuosoGrid
style={{ height: '100%' }}
data={imageNames}
components={{
Item: ImageGridItemContainer,
List: ImageGridListContainer,
}}
scrollerRef={setScroller}
itemContent={(index, imageName) => (
<GalleryImage key={imageName} imageName={imageName} />
)}
/>
</Box>
<IAIButton
onClick={handleLoadMoreImages}
isDisabled={!areMoreAvailable}
isLoading={status === 'pending'}
loadingText="Loading"
flexShrink={0}
>
{areMoreAvailable
? t('gallery.loadMore')
: t('gallery.allImagesLoaded')}
</IAIButton>
</>
);
}
};
export default memo(GalleryImageGrid);

View File

@ -0,0 +1,11 @@
import { Box, FlexProps, forwardRef } from '@chakra-ui/react';
import { PropsWithChildren } from 'react';
type ItemContainerProps = PropsWithChildren & FlexProps;
const ItemContainer = forwardRef((props: ItemContainerProps, ref) => (
<Box className="item-container" ref={ref} p={1.5}>
{props.children}
</Box>
));
export default ItemContainer;

View File

@ -0,0 +1,26 @@
import { FlexProps, Grid, forwardRef } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { PropsWithChildren } from 'react';
type ListContainerProps = PropsWithChildren & FlexProps;
const ListContainer = forwardRef((props: ListContainerProps, ref) => {
const galleryImageMinimumWidth = useAppSelector(
(state: RootState) => state.gallery.galleryImageMinimumWidth
);
return (
<Grid
{...props}
className="list-container"
ref={ref}
sx={{
gridTemplateColumns: `repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr));`,
}}
>
{props.children}
</Grid>
);
});
export default ListContainer;

View File

@ -1,330 +0,0 @@
import { ExternalLinkIcon } from '@chakra-ui/icons';
import {
Box,
Center,
Flex,
IconButton,
Link,
Text,
Tooltip,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { ImageDTO } from 'services/api/types';
type MetadataItemProps = {
isLink?: boolean;
label: string;
onClick?: () => void;
value: number | string | boolean;
labelPosition?: string;
withCopy?: boolean;
};
/**
* Component to display an individual metadata item or parameter.
*/
const MetadataItem = ({
label,
value,
onClick,
isLink,
labelPosition,
withCopy = false,
}: MetadataItemProps) => {
const { t } = useTranslation();
if (!value) {
return null;
}
return (
<Flex gap={2}>
{onClick && (
<Tooltip label={`Recall ${label}`}>
<IconButton
aria-label={t('accessibility.useThisParameter')}
icon={<IoArrowUndoCircleOutline />}
size="xs"
variant="ghost"
fontSize={20}
onClick={onClick}
/>
</Tooltip>
)}
{withCopy && (
<Tooltip label={`Copy ${label}`}>
<IconButton
aria-label={`Copy ${label}`}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(value.toString())}
/>
</Tooltip>
)}
<Flex direction={labelPosition ? 'column' : 'row'}>
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
{label}:
</Text>
{isLink ? (
<Link href={value.toString()} isExternal wordBreak="break-all">
{value.toString()} <ExternalLinkIcon mx="2px" />
</Link>
) : (
<Text overflowY="scroll" wordBreak="break-all">
{value.toString()}
</Text>
)}
</Flex>
</Flex>
);
};
type ImageMetadataViewerProps = {
image: ImageDTO;
};
/**
* Image metadata viewer overlays currently selected image and provides
* access to any of its metadata for use in processing.
*/
const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch();
const {
recallBothPrompts,
recallPositivePrompt,
recallNegativePrompt,
recallSeed,
recallInitialImage,
recallCfgScale,
recallModel,
recallScheduler,
recallSteps,
recallWidth,
recallHeight,
recallStrength,
recallAllParameters,
} = useRecallParameters();
useHotkeys('esc', () => {
dispatch(setShouldShowImageDetails(false));
});
const sessionId = image?.session_id;
const metadata = image?.metadata;
const { t } = useTranslation();
const metadataJSON = JSON.stringify(image, null, 2);
return (
<Flex
sx={{
padding: 4,
gap: 1,
flexDirection: 'column',
width: 'full',
height: 'full',
backdropFilter: 'blur(20px)',
bg: 'whiteAlpha.600',
_dark: {
bg: 'blackAlpha.600',
},
overflow: 'scroll',
}}
>
<Flex gap={2}>
<Text fontWeight="semibold">File:</Text>
<Link href={image.image_url} isExternal maxW="calc(100% - 3rem)">
{image.image_name}
<ExternalLinkIcon mx="2px" />
</Link>
</Flex>
{metadata && Object.keys(metadata).length > 0 ? (
<>
{metadata.type && (
<MetadataItem label="Invocation type" value={metadata.type} />
)}
{sessionId && <MetadataItem label="Session ID" value={sessionId} />}
{metadata.positive_conditioning && (
<MetadataItem
label="Positive Prompt"
labelPosition="top"
value={metadata.positive_conditioning}
onClick={() =>
recallPositivePrompt(metadata.positive_conditioning)
}
/>
)}
{metadata.negative_conditioning && (
<MetadataItem
label="Negative Prompt"
labelPosition="top"
value={metadata.negative_conditioning}
onClick={() =>
recallNegativePrompt(metadata.negative_conditioning)
}
/>
)}
{metadata.seed !== undefined && (
<MetadataItem
label="Seed"
value={metadata.seed}
onClick={() => recallSeed(metadata.seed)}
/>
)}
{metadata.model !== undefined && (
<MetadataItem
label="Model"
value={metadata.model}
onClick={() => recallModel(metadata.model)}
/>
)}
{metadata.width && (
<MetadataItem
label="Width"
value={metadata.width}
onClick={() => recallWidth(metadata.width)}
/>
)}
{metadata.height && (
<MetadataItem
label="Height"
value={metadata.height}
onClick={() => recallHeight(metadata.height)}
/>
)}
{/* {metadata.threshold !== undefined && (
<MetadataItem
label="Noise Threshold"
value={metadata.threshold}
onClick={() => dispatch(setThreshold(Number(metadata.threshold)))}
/>
)}
{metadata.perlin !== undefined && (
<MetadataItem
label="Perlin Noise"
value={metadata.perlin}
onClick={() => dispatch(setPerlin(Number(metadata.perlin)))}
/>
)} */}
{metadata.scheduler && (
<MetadataItem
label="Scheduler"
value={metadata.scheduler}
onClick={() => recallScheduler(metadata.scheduler)}
/>
)}
{metadata.steps && (
<MetadataItem
label="Steps"
value={metadata.steps}
onClick={() => recallSteps(metadata.steps)}
/>
)}
{metadata.cfg_scale !== undefined && (
<MetadataItem
label="CFG scale"
value={metadata.cfg_scale}
onClick={() => recallCfgScale(metadata.cfg_scale)}
/>
)}
{/* {metadata.variations && metadata.variations.length > 0 && (
<MetadataItem
label="Seed-weight pairs"
value={seedWeightsToString(metadata.variations)}
onClick={() =>
dispatch(
setSeedWeights(seedWeightsToString(metadata.variations))
)
}
/>
)}
{metadata.seamless && (
<MetadataItem
label="Seamless"
value={metadata.seamless}
onClick={() => dispatch(setSeamless(metadata.seamless))}
/>
)}
{metadata.hires_fix && (
<MetadataItem
label="High Resolution Optimization"
value={metadata.hires_fix}
onClick={() => dispatch(setHiresFix(metadata.hires_fix))}
/>
)} */}
{/* {init_image_path && (
<MetadataItem
label="Initial image"
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)} */}
{metadata.strength && (
<MetadataItem
label="Image to image strength"
value={metadata.strength}
onClick={() => recallStrength(metadata.strength)}
/>
)}
{/* {metadata.fit && (
<MetadataItem
label="Image to image fit"
value={metadata.fit}
onClick={() => dispatch(setShouldFitToWidthHeight(metadata.fit))}
/>
)} */}
</>
) : (
<Center width="100%" pt={10}>
<Text fontSize="lg" fontWeight="semibold">
No metadata available
</Text>
</Center>
)}
<Flex gap={2} direction="column" overflow="auto">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<OverlayScrollbarsComponent defer>
<Box
sx={{
padding: 4,
borderRadius: 'base',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
w: 'full',
}}
>
<pre>{metadataJSON}</pre>
</Box>
</OverlayScrollbarsComponent>
</Flex>
</Flex>
);
};
export default memo(ImageMetadataViewer);

View File

@ -0,0 +1,212 @@
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { useCallback } from 'react';
import { UnsafeImageMetadata } from 'services/api/endpoints/images';
import ImageMetadataItem from './ImageMetadataItem';
type Props = {
metadata?: UnsafeImageMetadata['metadata'];
};
const ImageMetadataActions = (props: Props) => {
const { metadata } = props;
const {
recallBothPrompts,
recallPositivePrompt,
recallNegativePrompt,
recallSeed,
recallInitialImage,
recallCfgScale,
recallModel,
recallScheduler,
recallSteps,
recallWidth,
recallHeight,
recallStrength,
recallAllParameters,
} = useRecallParameters();
const handleRecallPositivePrompt = useCallback(() => {
recallPositivePrompt(metadata?.positive_prompt);
}, [metadata?.positive_prompt, recallPositivePrompt]);
const handleRecallNegativePrompt = useCallback(() => {
recallNegativePrompt(metadata?.negative_prompt);
}, [metadata?.negative_prompt, recallNegativePrompt]);
const handleRecallSeed = useCallback(() => {
recallSeed(metadata?.seed);
}, [metadata?.seed, recallSeed]);
const handleRecallModel = useCallback(() => {
recallModel(metadata?.model);
}, [metadata?.model, recallModel]);
const handleRecallWidth = useCallback(() => {
recallWidth(metadata?.width);
}, [metadata?.width, recallWidth]);
const handleRecallHeight = useCallback(() => {
recallHeight(metadata?.height);
}, [metadata?.height, recallHeight]);
const handleRecallScheduler = useCallback(() => {
recallScheduler(metadata?.scheduler);
}, [metadata?.scheduler, recallScheduler]);
const handleRecallSteps = useCallback(() => {
recallSteps(metadata?.steps);
}, [metadata?.steps, recallSteps]);
const handleRecallCfgScale = useCallback(() => {
recallCfgScale(metadata?.cfg_scale);
}, [metadata?.cfg_scale, recallCfgScale]);
const handleRecallStrength = useCallback(() => {
recallStrength(metadata?.strength);
}, [metadata?.strength, recallStrength]);
if (!metadata || Object.keys(metadata).length === 0) {
return null;
}
return (
<>
{metadata.generation_mode && (
<ImageMetadataItem
label="Generation Mode"
value={metadata.generation_mode}
/>
)}
{metadata.positive_prompt && (
<ImageMetadataItem
label="Positive Prompt"
labelPosition="top"
value={metadata.positive_prompt}
onClick={handleRecallPositivePrompt}
/>
)}
{metadata.negative_prompt && (
<ImageMetadataItem
label="Negative Prompt"
labelPosition="top"
value={metadata.negative_prompt}
onClick={handleRecallNegativePrompt}
/>
)}
{metadata.seed !== undefined && (
<ImageMetadataItem
label="Seed"
value={metadata.seed}
onClick={handleRecallSeed}
/>
)}
{metadata.model !== undefined && (
<ImageMetadataItem
label="Model"
value={metadata.model.model_name}
onClick={handleRecallModel}
/>
)}
{metadata.width && (
<ImageMetadataItem
label="Width"
value={metadata.width}
onClick={handleRecallWidth}
/>
)}
{metadata.height && (
<ImageMetadataItem
label="Height"
value={metadata.height}
onClick={handleRecallHeight}
/>
)}
{/* {metadata.threshold !== undefined && (
<MetadataItem
label="Noise Threshold"
value={metadata.threshold}
onClick={() => dispatch(setThreshold(Number(metadata.threshold)))}
/>
)}
{metadata.perlin !== undefined && (
<MetadataItem
label="Perlin Noise"
value={metadata.perlin}
onClick={() => dispatch(setPerlin(Number(metadata.perlin)))}
/>
)} */}
{metadata.scheduler && (
<ImageMetadataItem
label="Scheduler"
value={metadata.scheduler}
onClick={handleRecallScheduler}
/>
)}
{metadata.steps && (
<ImageMetadataItem
label="Steps"
value={metadata.steps}
onClick={handleRecallSteps}
/>
)}
{metadata.cfg_scale !== undefined && (
<ImageMetadataItem
label="CFG scale"
value={metadata.cfg_scale}
onClick={handleRecallCfgScale}
/>
)}
{/* {metadata.variations && metadata.variations.length > 0 && (
<MetadataItem
label="Seed-weight pairs"
value={seedWeightsToString(metadata.variations)}
onClick={() =>
dispatch(
setSeedWeights(seedWeightsToString(metadata.variations))
)
}
/>
)}
{metadata.seamless && (
<MetadataItem
label="Seamless"
value={metadata.seamless}
onClick={() => dispatch(setSeamless(metadata.seamless))}
/>
)}
{metadata.hires_fix && (
<MetadataItem
label="High Resolution Optimization"
value={metadata.hires_fix}
onClick={() => dispatch(setHiresFix(metadata.hires_fix))}
/>
)} */}
{/* {init_image_path && (
<MetadataItem
label="Initial image"
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)} */}
{metadata.strength && (
<ImageMetadataItem
label="Image to image strength"
value={metadata.strength}
onClick={handleRecallStrength}
/>
)}
{/* {metadata.fit && (
<MetadataItem
label="Image to image fit"
value={metadata.fit}
onClick={() => dispatch(setShouldFitToWidthHeight(metadata.fit))}
/>
)} */}
</>
);
};
export default ImageMetadataActions;

View File

@ -0,0 +1,77 @@
import { ExternalLinkIcon } from '@chakra-ui/icons';
import { Flex, IconButton, Link, Text, Tooltip } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
type MetadataItemProps = {
isLink?: boolean;
label: string;
onClick?: () => void;
value: number | string | boolean;
labelPosition?: string;
withCopy?: boolean;
};
/**
* Component to display an individual metadata item or parameter.
*/
const ImageMetadataItem = ({
label,
value,
onClick,
isLink,
labelPosition,
withCopy = false,
}: MetadataItemProps) => {
const { t } = useTranslation();
if (!value) {
return null;
}
return (
<Flex gap={2}>
{onClick && (
<Tooltip label={`Recall ${label}`}>
<IconButton
aria-label={t('accessibility.useThisParameter')}
icon={<IoArrowUndoCircleOutline />}
size="xs"
variant="ghost"
fontSize={20}
onClick={onClick}
/>
</Tooltip>
)}
{withCopy && (
<Tooltip label={`Copy ${label}`}>
<IconButton
aria-label={`Copy ${label}`}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(value.toString())}
/>
</Tooltip>
)}
<Flex direction={labelPosition ? 'column' : 'row'}>
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
{label}:
</Text>
{isLink ? (
<Link href={value.toString()} isExternal wordBreak="break-all">
{value.toString()} <ExternalLinkIcon mx="2px" />
</Link>
) : (
<Text overflowY="scroll" wordBreak="break-all">
{value.toString()}
</Text>
)}
</Flex>
</Flex>
);
};
export default ImageMetadataItem;

View File

@ -0,0 +1,70 @@
import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { useMemo } from 'react';
import { FaCopy } from 'react-icons/fa';
type Props = {
copyTooltip: string;
jsonObject: object;
};
const ImageMetadataJSON = (props: Props) => {
const { copyTooltip, jsonObject } = props;
const jsonString = useMemo(
() => JSON.stringify(jsonObject, null, 2),
[jsonObject]
);
return (
<Flex
sx={{
borderRadius: 'base',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
flexGrow: 1,
w: 'full',
h: 'full',
position: 'relative',
}}
>
<Box
sx={{
position: 'absolute',
top: 0,
left: 0,
right: 0,
bottom: 0,
overflow: 'auto',
p: 4,
}}
>
<OverlayScrollbarsComponent
defer
style={{ height: '100%', width: '100%' }}
options={{
scrollbars: {
visibility: 'auto',
autoHide: 'move',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},
}}
>
<pre>{jsonString}</pre>
</OverlayScrollbarsComponent>
</Box>
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0, p: 2 }}>
<Tooltip label={copyTooltip}>
<IconButton
aria-label={copyTooltip}
icon={<FaCopy />}
variant="ghost"
onClick={() => navigator.clipboard.writeText(jsonString)}
/>
</Tooltip>
</Flex>
</Flex>
);
};
export default ImageMetadataJSON;

View File

@ -0,0 +1,138 @@
import { ExternalLinkIcon } from '@chakra-ui/icons';
import {
Flex,
Link,
Tab,
TabList,
TabPanel,
TabPanels,
Tabs,
Text,
} from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { memo, useMemo } from 'react';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { useDebounce } from 'use-debounce';
import ImageMetadataActions from './ImageMetadataActions';
import ImageMetadataJSON from './ImageMetadataJSON';
type ImageMetadataViewerProps = {
image: ImageDTO;
};
const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
// TODO: fix hotkeys
// const dispatch = useAppDispatch();
// useHotkeys('esc', () => {
// dispatch(setShouldShowImageDetails(false));
// });
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
image.image_name,
500
);
const { currentData } = useGetImageMetadataQuery(
debounceState.isPending()
? skipToken
: debouncedMetadataQueryArg ?? skipToken
);
const metadata = currentData?.metadata;
const graph = currentData?.graph;
const tabData = useMemo(() => {
const _tabData: { label: string; data: object; copyTooltip: string }[] = [];
if (metadata) {
_tabData.push({
label: 'Core Metadata',
data: metadata,
copyTooltip: 'Copy Core Metadata JSON',
});
}
if (image) {
_tabData.push({
label: 'Image Details',
data: image,
copyTooltip: 'Copy Image Details JSON',
});
}
if (graph) {
_tabData.push({
label: 'Graph',
data: graph,
copyTooltip: 'Copy Graph JSON',
});
}
return _tabData;
}, [metadata, graph, image]);
return (
<Flex
sx={{
padding: 4,
gap: 1,
flexDirection: 'column',
width: 'full',
height: 'full',
backdropFilter: 'blur(20px)',
bg: 'baseAlpha.200',
_dark: {
bg: 'blackAlpha.600',
},
borderRadius: 'base',
position: 'absolute',
overflow: 'hidden',
}}
>
<Flex gap={2}>
<Text fontWeight="semibold">File:</Text>
<Link href={image.image_url} isExternal maxW="calc(100% - 3rem)">
{image.image_name}
<ExternalLinkIcon mx="2px" />
</Link>
</Flex>
<ImageMetadataActions metadata={metadata} />
<Tabs
variant="line"
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
>
<TabList>
{tabData.map((tab) => (
<Tab
key={tab.label}
sx={{
borderTopRadius: 'base',
}}
>
<Text sx={{ color: 'base.700', _dark: { color: 'base.300' } }}>
{tab.label}
</Text>
</Tab>
))}
</TabList>
<TabPanels sx={{ w: 'full', h: 'full' }}>
{tabData.map((tab) => (
<TabPanel
key={tab.label}
sx={{ w: 'full', h: 'full', p: 0, pt: 4 }}
>
<ImageMetadataJSON
jsonObject={tab.data}
copyTooltip={tab.copyTooltip}
/>
</TabPanel>
))}
</TabPanels>
</Tabs>
</Flex>
);
};
export default memo(ImageMetadataViewer);

View File

@ -1,171 +1,44 @@
import { ChakraProps, Flex, Grid, IconButton, Spinner } from '@chakra-ui/react'; import { Box, ChakraProps, Flex, IconButton, Spinner } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { memo } from 'react';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
imageSelected,
selectFilteredImages,
selectImagesById,
} from 'features/gallery/store/gallerySlice';
import { clamp, isEqual } from 'lodash-es';
import { memo, useCallback, useState } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaAngleDoubleRight, FaAngleLeft, FaAngleRight } from 'react-icons/fa'; import { FaAngleDoubleRight, FaAngleLeft, FaAngleRight } from 'react-icons/fa';
import { receivedPageOfImages } from 'services/api/thunks/image'; import { useNextPrevImage } from '../hooks/useNextPrevImage';
const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = {
height: '100%',
width: '15%',
alignItems: 'center',
pointerEvents: 'auto',
};
const nextPrevButtonStyles: ChakraProps['sx'] = { const nextPrevButtonStyles: ChakraProps['sx'] = {
color: 'base.100', color: 'base.100',
pointerEvents: 'auto',
}; };
export const nextPrevImageButtonsSelector = createSelector(
[stateSelector, selectFilteredImages],
(state, filteredImages) => {
const { total, isFetching } = state.gallery;
const lastSelectedImage =
state.gallery.selection[state.gallery.selection.length - 1];
if (!lastSelectedImage || filteredImages.length === 0) {
return {
isOnFirstImage: true,
isOnLastImage: true,
};
}
const currentImageIndex = filteredImages.findIndex(
(i) => i.image_name === lastSelectedImage
);
const nextImageIndex = clamp(
currentImageIndex + 1,
0,
filteredImages.length - 1
);
const prevImageIndex = clamp(
currentImageIndex - 1,
0,
filteredImages.length - 1
);
const nextImageId = filteredImages[nextImageIndex].image_name;
const prevImageId = filteredImages[prevImageIndex].image_name;
const nextImage = selectImagesById(state, nextImageId);
const prevImage = selectImagesById(state, prevImageId);
const imagesLength = filteredImages.length;
return {
isOnFirstImage: currentImageIndex === 0,
isOnLastImage:
!isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1,
areMoreImagesAvailable: total > imagesLength,
isFetching,
nextImage,
prevImage,
nextImageId,
prevImageId,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const NextPrevImageButtons = () => { const NextPrevImageButtons = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { const {
handlePrevImage,
handleNextImage,
isOnFirstImage, isOnFirstImage,
isOnLastImage, isOnLastImage,
nextImageId, handleLoadMoreImages,
prevImageId,
areMoreImagesAvailable, areMoreImagesAvailable,
isFetching, isFetching,
} = useAppSelector(nextPrevImageButtonsSelector); } = useNextPrevImage();
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] =
useState<boolean>(false);
const handleCurrentImagePreviewMouseOver = useCallback(() => {
setShouldShowNextPrevButtons(true);
}, []);
const handleCurrentImagePreviewMouseOut = useCallback(() => {
setShouldShowNextPrevButtons(false);
}, []);
const handlePrevImage = useCallback(() => {
prevImageId && dispatch(imageSelected(prevImageId));
}, [dispatch, prevImageId]);
const handleNextImage = useCallback(() => {
nextImageId && dispatch(imageSelected(nextImageId));
}, [dispatch, nextImageId]);
const handleLoadMoreImages = useCallback(() => {
dispatch(
receivedPageOfImages({
is_intermediate: false,
})
);
}, [dispatch]);
useHotkeys(
'left',
() => {
handlePrevImage();
},
[prevImageId]
);
useHotkeys(
'right',
() => {
if (isOnLastImage && areMoreImagesAvailable && !isFetching) {
handleLoadMoreImages();
return;
}
if (!isOnLastImage) {
handleNextImage();
}
},
[
nextImageId,
isOnLastImage,
areMoreImagesAvailable,
handleLoadMoreImages,
isFetching,
]
);
return ( return (
<Flex <Box
sx={{ sx={{
justifyContent: 'space-between', position: 'relative',
height: '100%', height: '100%',
width: '100%', width: '100%',
pointerEvents: 'none',
}} }}
> >
<Grid <Box
sx={{ sx={{
...nextPrevButtonTriggerAreaStyles, pos: 'absolute',
justifyContent: 'flex-start', top: '50%',
transform: 'translate(0, -50%)',
insetInlineStart: 0,
}} }}
onMouseOver={handleCurrentImagePreviewMouseOver}
onMouseOut={handleCurrentImagePreviewMouseOut}
> >
{shouldShowNextPrevButtons && !isOnFirstImage && ( {!isOnFirstImage && (
<IconButton <IconButton
aria-label={t('accessibility.previousImage')} aria-label={t('accessibility.previousImage')}
icon={<FaAngleLeft size={64} />} icon={<FaAngleLeft size={64} />}
@ -175,16 +48,16 @@ const NextPrevImageButtons = () => {
sx={nextPrevButtonStyles} sx={nextPrevButtonStyles}
/> />
)} )}
</Grid> </Box>
<Grid <Box
sx={{ sx={{
...nextPrevButtonTriggerAreaStyles, pos: 'absolute',
justifyContent: 'flex-end', top: '50%',
transform: 'translate(0, -50%)',
insetInlineEnd: 0,
}} }}
onMouseOver={handleCurrentImagePreviewMouseOver}
onMouseOut={handleCurrentImagePreviewMouseOut}
> >
{shouldShowNextPrevButtons && !isOnLastImage && ( {!isOnLastImage && (
<IconButton <IconButton
aria-label={t('accessibility.nextImage')} aria-label={t('accessibility.nextImage')}
icon={<FaAngleRight size={64} />} icon={<FaAngleRight size={64} />}
@ -194,36 +67,30 @@ const NextPrevImageButtons = () => {
sx={nextPrevButtonStyles} sx={nextPrevButtonStyles}
/> />
)} )}
{shouldShowNextPrevButtons && {isOnLastImage && areMoreImagesAvailable && !isFetching && (
isOnLastImage && <IconButton
areMoreImagesAvailable && aria-label={t('accessibility.loadMore')}
!isFetching && ( icon={<FaAngleDoubleRight size={64} />}
<IconButton variant="unstyled"
aria-label={t('accessibility.loadMore')} onClick={handleLoadMoreImages}
icon={<FaAngleDoubleRight size={64} />} boxSize={16}
variant="unstyled" sx={nextPrevButtonStyles}
onClick={handleLoadMoreImages} />
boxSize={16} )}
sx={nextPrevButtonStyles} {isOnLastImage && areMoreImagesAvailable && isFetching && (
/> <Flex
)} sx={{
{shouldShowNextPrevButtons && w: 16,
isOnLastImage && h: 16,
areMoreImagesAvailable && alignItems: 'center',
isFetching && ( justifyContent: 'center',
<Flex }}
sx={{ >
w: 16, <Spinner opacity={0.5} size="xl" />
h: 16, </Flex>
alignItems: 'center', )}
justifyContent: 'center', </Box>
}} </Box>
>
<Spinner opacity={0.5} size="xl" />
</Flex>
)}
</Grid>
</Flex>
); );
}; };

View File

@ -0,0 +1,108 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
imageSelected,
selectImagesById,
} from 'features/gallery/store/gallerySlice';
import { clamp, isEqual } from 'lodash-es';
import { useCallback } from 'react';
import { receivedPageOfImages } from 'services/api/thunks/image';
import { selectFilteredImages } from '../store/gallerySelectors';
export const nextPrevImageButtonsSelector = createSelector(
[stateSelector, selectFilteredImages],
(state, filteredImages) => {
const { total, isFetching } = state.gallery;
const lastSelectedImage =
state.gallery.selection[state.gallery.selection.length - 1];
if (!lastSelectedImage || filteredImages.length === 0) {
return {
isOnFirstImage: true,
isOnLastImage: true,
};
}
const currentImageIndex = filteredImages.findIndex(
(i) => i.image_name === lastSelectedImage
);
const nextImageIndex = clamp(
currentImageIndex + 1,
0,
filteredImages.length - 1
);
const prevImageIndex = clamp(
currentImageIndex - 1,
0,
filteredImages.length - 1
);
const nextImageId = filteredImages[nextImageIndex].image_name;
const prevImageId = filteredImages[prevImageIndex].image_name;
const nextImage = selectImagesById(state, nextImageId);
const prevImage = selectImagesById(state, prevImageId);
const imagesLength = filteredImages.length;
return {
isOnFirstImage: currentImageIndex === 0,
isOnLastImage:
!isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1,
areMoreImagesAvailable: total > imagesLength,
isFetching,
nextImage,
prevImage,
nextImageId,
prevImageId,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export const useNextPrevImage = () => {
const dispatch = useAppDispatch();
const {
isOnFirstImage,
isOnLastImage,
nextImageId,
prevImageId,
areMoreImagesAvailable,
isFetching,
} = useAppSelector(nextPrevImageButtonsSelector);
const handlePrevImage = useCallback(() => {
prevImageId && dispatch(imageSelected(prevImageId));
}, [dispatch, prevImageId]);
const handleNextImage = useCallback(() => {
nextImageId && dispatch(imageSelected(nextImageId));
}, [dispatch, nextImageId]);
const handleLoadMoreImages = useCallback(() => {
dispatch(
receivedPageOfImages({
is_intermediate: false,
})
);
}, [dispatch]);
return {
handlePrevImage,
handleNextImage,
isOnFirstImage,
isOnLastImage,
nextImageId,
prevImageId,
areMoreImagesAvailable,
handleLoadMoreImages,
isFetching,
};
};

Some files were not shown because too many files have changed in this diff Show More