resolve conflicts with main

This commit is contained in:
Lincoln Stein 2023-05-11 00:19:20 -04:00
commit 8ad8c5c67a
303 changed files with 6191 additions and 3575 deletions

View File

@ -83,7 +83,7 @@ async def get_thumbnail(
status_code=201,
)
async def upload_image(
file: UploadFile, request: Request, response: Response
file: UploadFile, image_type: ImageType, request: Request, response: Response
) -> ImageResponse:
if not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
@ -99,21 +99,21 @@ async def upload_image(
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
saved_image = ApiDependencies.invoker.services.images.save(
ImageType.UPLOAD, filename, img
image_type, filename, img
)
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
image_url = ApiDependencies.invoker.services.images.get_uri(
ImageType.UPLOAD, saved_image.image_name
image_type, saved_image.image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
ImageType.UPLOAD, saved_image.image_name, True
image_type, saved_image.image_name, True
)
res = ImageResponse(
image_type=ImageType.UPLOAD,
image_type=image_type,
image_name=saved_image.image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,

View File

@ -3,12 +3,12 @@
from typing import Literal, Optional
import numpy as np
import numpy.random
from pydantic import Field
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import (
BaseInvocation,
InvocationConfig,
InvocationContext,
BaseInvocationOutput,
)
@ -50,11 +50,11 @@ class RandomRangeInvocation(BaseInvocation):
default=np.iinfo(np.int32).max, description="The exclusive high value"
)
size: int = Field(default=1, description="The number of values to generate")
seed: Optional[int] = Field(
seed: int = Field(
ge=0,
le=np.iinfo(np.int32).max,
description="The seed for the RNG",
default_factory=lambda: numpy.random.randint(0, np.iinfo(np.int32).max),
le=SEED_MAX,
description="The seed for the RNG (omit for random)",
default_factory=get_random_seed,
)
def invoke(self, context: InvocationContext) -> IntCollectionOutput:

View File

@ -1,15 +1,17 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from functools import partial
from typing import Literal, Optional, Union
from typing import Literal, Optional, Union, get_args
import numpy as np
from torch import Tensor
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.image import ColorField, ImageField, ImageType
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
@ -17,7 +19,8 @@ from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.step_callback import stable_diffusion_step_callback
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = 'patchmatch' if 'patchmatch' in get_args(INFILL_METHODS) else 'tile'
class SDImageInvocation(BaseModel):
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
@ -44,15 +47,13 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
# fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching?
@ -148,7 +149,6 @@ class ImageToImageInvocation(TextToImageInvocation):
self.image.image_type, self.image.image_name
)
)
mask = None
if self.fit:
image = image.resize((self.width, self.height))
@ -165,7 +165,6 @@ class ImageToImageInvocation(TextToImageInvocation):
outputs = Img2Img(model).generate(
prompt=self.prompt,
init_image=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
@ -197,7 +196,6 @@ class ImageToImageInvocation(TextToImageInvocation):
image=result_image,
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
@ -205,6 +203,17 @@ class InpaintInvocation(ImageToImageInvocation):
# Inputs
mask: Union[ImageField, None] = Field(description="The mask")
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
seam_strength: float = Field(
default=0.75, gt=0, le=1, description="The seam inpaint strength"
)
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
infill_method: INFILL_METHODS = Field(default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)")
inpaint_width: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)")
inpaint_height: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)")
inpaint_fill: Optional[ColorField] = Field(default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color")
inpaint_replace: float = Field(
default=0.0,
ge=0.0,

View File

@ -1,5 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from typing import Literal, Optional
import numpy
@ -37,9 +38,7 @@ class ImageOutput(BaseInvocationOutput):
# fmt: on
class Config:
schema_extra = {
"required": ["type", "image", "width", "height", "mode"]
}
schema_extra = {"required": ["type", "image", "width", "height"]}
def build_image_output(
@ -54,7 +53,6 @@ def build_image_output(
image=image_field,
width=image.width,
height=image.height,
mode=image.mode,
)
@ -151,7 +149,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_crop, metadata)
return build_image_output(
image_type=image_type,
@ -209,7 +207,7 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, new_image, metadata)
return build_image_output(
image_type=image_type,

View File

@ -0,0 +1,233 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional, Union, get_args
import numpy as np
import math
from PIL import Image, ImageOps
from pydantic import Field
from invokeai.app.invocations.image import ImageOutput, build_image_output
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ColorField, ImageField, ImageType
from .baseinvocation import (
BaseInvocation,
InvocationContext,
)
def infill_methods() -> list[str]:
methods = [
"tile",
"solid",
]
if PatchMatch.patchmatch_available():
methods.insert(0, "patchmatch")
return methods
INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = (
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
)
def infill_patchmatch(im: Image.Image) -> Image.Image:
if im.mode != "RGBA":
return im
# Skip patchmatch if patchmatch isn't available
if not PatchMatch.patchmatch_available():
return im
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
im_patched_np = PatchMatch.inpaint(
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
)
im_patched = Image.fromarray(im_patched_np, mode="RGB")
return im_patched
def get_tile_images(image: np.ndarray, width=8, height=8):
_nrows, _ncols, depth = image.shape
_strides = image.strides
nrows, _m = divmod(_nrows, height)
ncols, _n = divmod(_ncols, width)
if _m != 0 or _n != 0:
return None
return np.lib.stride_tricks.as_strided(
np.ravel(image),
shape=(nrows, ncols, height, width, depth),
strides=(height * _strides[0], width * _strides[1], *_strides),
writeable=False,
)
def tile_fill_missing(
im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
) -> Image.Image:
# Only fill if there's an alpha layer
if im.mode != "RGBA":
return im
a = np.asarray(im, dtype=np.uint8)
tile_size_tuple = (tile_size, tile_size)
# Get the image as tiles of a specified size
tiles = get_tile_images(a, *tile_size_tuple).copy()
# Get the mask as tiles
tiles_mask = tiles[:, :, :, :, 3]
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
tmask_shape = tiles_mask.shape
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
tiles_mask = tiles_mask > 0
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
# Get RGB tiles in single array and filter by the mask
tshape = tiles.shape
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
filtered_tiles = tiles_all[tiles_mask]
if len(filtered_tiles) == 0:
return im
# Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask == False).sum()
rng = np.random.default_rng(seed=seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
]
# Convert back to an image
tiles_all = tiles_all.reshape(tshape)
tiles_all = tiles_all.swapaxes(1, 2)
st = tiles_all.reshape(
(
math.prod(tiles_all.shape[0:2]),
math.prod(tiles_all.shape[2:4]),
tiles_all.shape[4],
)
)
si = Image.fromarray(st, mode="RGBA")
return si
class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
color: Optional[ColorField] = Field(
default=ColorField(r=127, g=127, b=127, a=255),
description="The color to use to infill",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image)
infilled.paste(image, (0, 0), image.split()[-1])
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
)
class InfillTileInvocation(BaseInvocation):
"""Infills transparent areas of an image with tiles of the image"""
type: Literal["infill_tile"] = "infill_tile"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
seed: int = Field(
ge=0,
le=SEED_MAX,
description="The seed to use for tile generation (omit for random)",
default_factory=get_random_seed,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
infilled = tile_fill_missing(
image.copy(), seed=self.seed, tile_size=self.tile_size
)
infilled.paste(image, (0, 0), image.split()[-1])
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
)
class InfillPatchMatchInvocation(BaseInvocation):
"""Infills transparent areas of an image using the PatchMatch algorithm"""
type: Literal["infill_patchmatch"] = "infill_patchmatch"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
if PatchMatch.patchmatch_available():
infilled = infill_patchmatch(image.copy())
else:
raise ValueError("PatchMatch is not available on this system")
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
)

View File

@ -1,11 +1,13 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import random
from typing import Literal, Optional
from typing import Literal, Optional, Union
import einops
from pydantic import BaseModel, Field
import torch
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback
@ -13,7 +15,8 @@ from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_storage import ImageType
@ -102,17 +105,13 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
return x
def random_seed():
return random.randint(0, np.iinfo(np.uint32).max)
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed)
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
@ -150,10 +149,9 @@ class TextToLatentsInvocation(BaseInvocation):
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
# fmt: on
# Schema customisation
@ -260,6 +258,10 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
type: Literal["l2l"] = "l2l"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
@ -271,10 +273,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
},
}
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
@ -433,3 +431,47 @@ class ScaleLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name))
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
type: Literal["i2l"] = "i2l"
# Inputs
image: Union[ImageField, None] = Field(description="The image to encode")
model: str = Field(default="", description="The model to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
"type_hints": {"model": "model"},
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)
model: StableDiffusionGeneratorPipeline = model_info["model"]
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = model.non_noised_latents_from_image(
image_tensor,
device=model._model_group.device_for(model.unet),
dtype=model.unet.dtype,
)
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, latents)
return LatentsOutput(latents=LatentsField(latents_name=name))

View File

@ -1,13 +1,14 @@
from invokeai.backend.model_management.model_manager import ModelManager
from invokeai.backend.model_management.model_manager_service import ModelManagerService, SDModelType
def choose_model(model_manager: ModelManager, model_name: str):
def choose_model(model_manager: ModelManagerService, model_name: str, model_type: SDModelType=SDModelType.diffusers):
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
logger = model_manager.logger
if model_manager.valid_model(model_name):
model = model_manager.get_model(model_name)
if model_name and not model_manager.valid_model(model_name, model_type):
default_model_name = model_manager.default_model()
logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{default_model_name}\' instead.")
model = model_manager.get_model()
else:
model = model_manager.get_model(model_manager.default_model())
logger.warning(f"'{model_name}' is not a valid model name. Using default model \'{model.name}\' instead.")
model = model_manager.get_model(model_name, model_type)
return model

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Optional
from typing import Optional, Tuple
from pydantic import BaseModel, Field
@ -27,3 +27,13 @@ class ImageField(BaseModel):
class Config:
schema_extra = {"required": ["image_type", "image_name"]}
class ColorField(BaseModel):
r: int = Field(ge=0, le=255, description="The red component")
g: int = Field(ge=0, le=255, description="The green component")
b: int = Field(ge=0, le=255, description="The blue component")
a: int = Field(ge=0, le=255, description="The alpha component")
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)

View File

@ -51,7 +51,7 @@ def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[Li
graphs: list[LibraryGraph] = list()
text_to_image = graph_library.get(default_text_to_image_graph_id)
# text_to_image = graph_library.get(default_text_to_image_graph_id)
# TODO: Check if the graph is the same as the default one, and if not, update it
#if text_to_image is None:

View File

@ -20,9 +20,18 @@ class MetadataLatentsField(TypedDict):
latents_name: str
class MetadataColorField(TypedDict):
"""Pydantic-less ColorField, used for metadata parsing"""
r: int
g: int
b: int
a: int
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
NodeMetadata = Dict[
str, str | int | float | bool | MetadataImageField | MetadataLatentsField
str, None | str | int | float | bool | MetadataImageField | MetadataLatentsField | MetadataColorField
]

View File

@ -243,12 +243,14 @@ class ModelManagerService(ModelManagerServiceBase):
submodel,
)
def valid_model(self, *args, **kwargs) -> bool:
def valid_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> bool:
"""
Given a model name, returns True if it is a valid
identifier.
"""
return self.mgr.valid_model(*args, **kwargs)
return self.mgr.valid_model(
model_name,
model_type)
def default_model(self) -> Union[str,None]:
"""

View File

@ -1,5 +1,13 @@
import datetime
import numpy as np
def get_timestamp():
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
SEED_MAX = np.iinfo(np.int32).max
def get_random_seed():
return np.random.randint(0, SEED_MAX)

View File

@ -226,10 +226,10 @@ class Inpaint(Img2Img):
def generate(self,
mask_image: Image.Image | torch.FloatTensor,
# Seam settings - when 0, doesn't fill seam
seam_size: int = 0,
seam_blur: int = 0,
seam_size: int = 96,
seam_blur: int = 16,
seam_strength: float = 0.7,
seam_steps: int = 10,
seam_steps: int = 30,
tile_size: int = 32,
inpaint_replace=False,
infill_method=None,

View File

@ -4,6 +4,7 @@ invokeai.backend.generator.inpaint descends from .generator
from __future__ import annotations
import math
from typing import Tuple, Union
import cv2
import numpy as np
@ -59,7 +60,7 @@ class Inpaint(Img2Img):
writeable=False,
)
def infill_patchmatch(self, im: Image.Image) -> Image:
def infill_patchmatch(self, im: Image.Image) -> Image.Image:
if im.mode != "RGBA":
return im
@ -75,18 +76,18 @@ class Inpaint(Img2Img):
return im_patched
def tile_fill_missing(
self, im: Image.Image, tile_size: int = 16, seed: int = None
) -> Image:
self, im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
) -> Image.Image:
# Only fill if there's an alpha layer
if im.mode != "RGBA":
return im
a = np.asarray(im, dtype=np.uint8)
tile_size = (tile_size, tile_size)
tile_size_tuple = (tile_size, tile_size)
# Get the image as tiles of a specified size
tiles = self.get_tile_images(a, *tile_size).copy()
tiles = self.get_tile_images(a, *tile_size_tuple).copy()
# Get the mask as tiles
tiles_mask = tiles[:, :, :, :, 3]
@ -127,7 +128,9 @@ class Inpaint(Img2Img):
return si
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
def mask_edge(
self, mask: Image.Image, edge_size: int, edge_blur: int
) -> Image.Image:
npimg = np.asarray(mask, dtype=np.uint8)
# Detect any partially transparent regions
@ -206,15 +209,15 @@ class Inpaint(Img2Img):
cfg_scale,
ddim_eta,
conditioning,
init_image: PIL.Image.Image | torch.FloatTensor,
mask_image: PIL.Image.Image | torch.FloatTensor,
init_image: Image.Image | torch.FloatTensor,
mask_image: Image.Image | torch.FloatTensor,
strength: float,
mask_blur_radius: int = 8,
# Seam settings - when 0, doesn't fill seam
seam_size: int = 0,
seam_blur: int = 0,
seam_size: int = 96,
seam_blur: int = 16,
seam_strength: float = 0.7,
seam_steps: int = 10,
seam_steps: int = 30,
tile_size: int = 32,
step_callback=None,
inpaint_replace=False,
@ -222,7 +225,7 @@ class Inpaint(Img2Img):
infill_method=None,
inpaint_width=None,
inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
inpaint_fill: Tuple[int, int, int, int] = (0x7F, 0x7F, 0x7F, 0xFF),
attention_maps_callback=None,
**kwargs,
):
@ -239,7 +242,7 @@ class Inpaint(Img2Img):
self.inpaint_width = inpaint_width
self.inpaint_height = inpaint_height
if isinstance(init_image, PIL.Image.Image):
if isinstance(init_image, Image.Image):
self.pil_image = init_image.copy()
# Do infill
@ -250,8 +253,8 @@ class Inpaint(Img2Img):
self.pil_image.copy(), seed=self.seed, tile_size=tile_size
)
elif infill_method == "solid":
solid_bg = PIL.Image.new("RGBA", init_image.size, inpaint_fill)
init_filled = PIL.Image.alpha_composite(solid_bg, init_image)
solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
init_filled = Image.alpha_composite(solid_bg, init_image)
else:
raise ValueError(
f"Non-supported infill type {infill_method}", infill_method
@ -269,7 +272,7 @@ class Inpaint(Img2Img):
# Create init tensor
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
if isinstance(mask_image, PIL.Image.Image):
if isinstance(mask_image, Image.Image):
self.pil_mask = mask_image.copy()
debug_image(
mask_image,

View File

@ -1,13 +0,0 @@
{
"plugins": [
[
"transform-imports",
{
"lodash": {
"transform": "lodash/${member}",
"preventFullImport": true
}
}
]
]
}

View File

@ -34,4 +34,8 @@ stats.html
!.yarn/plugins
!.yarn/releases
!.yarn/sdks
!.yarn/versions
!.yarn/versions
# Yalc
.yalc
yalc.lock

View File

@ -21,7 +21,6 @@
"scripts": {
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build",
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
@ -90,6 +89,7 @@
"react-konva": "^18.2.7",
"react-konva-utils": "^1.0.4",
"react-redux": "^8.0.5",
"react-resizable-panels": "^0.0.42",
"react-rnd": "^10.4.1",
"react-transition-group": "^4.4.5",
"react-use": "^17.4.0",
@ -99,6 +99,7 @@
"redux-deep-persist": "^1.0.7",
"redux-dynamic-middlewares": "^2.2.0",
"redux-persist": "^6.0.0",
"redux-remember": "^3.3.1",
"roarr": "^7.15.0",
"serialize-error": "^11.0.0",
"socket.io-client": "^4.6.0",
@ -118,6 +119,7 @@
"@types/node": "^18.16.2",
"@types/react": "^18.2.0",
"@types/react-dom": "^18.2.1",
"@types/react-redux": "^7.1.25",
"@types/react-transition-group": "^4.4.5",
"@types/uuid": "^9.0.0",
"@typescript-eslint/eslint-plugin": "^5.59.1",

View File

@ -54,7 +54,7 @@
"img2img": "Image To Image",
"unifiedCanvas": "Unified Canvas",
"linear": "Linear",
"nodes": "Nodes",
"nodes": "Node Editor",
"postprocessing": "Post Processing",
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
"postProcessing": "Post Processing",
@ -102,7 +102,8 @@
"generate": "Generate",
"openInNewTab": "Open in New Tab",
"dontAskMeAgain": "Don't ask me again",
"areYouSure": "Are you sure?"
"areYouSure": "Are you sure?",
"imagePrompt": "Image Prompt"
},
"gallery": {
"generations": "Generations",
@ -453,9 +454,10 @@
"seed": "Seed",
"imageToImage": "Image to Image",
"randomizeSeed": "Randomize Seed",
"shuffle": "Shuffle",
"shuffle": "Shuffle Seed",
"noiseThreshold": "Noise Threshold",
"perlinNoise": "Perlin Noise",
"noiseSettings": "Noise",
"variations": "Variations",
"variationAmount": "Variation Amount",
"seedWeights": "Seed Weights",
@ -470,6 +472,8 @@
"scale": "Scale",
"otherOptions": "Other Options",
"seamlessTiling": "Seamless Tiling",
"seamlessXAxis": "X Axis",
"seamlessYAxis": "Y Axis",
"hiresOptim": "High Res Optimization",
"hiresStrength": "High Res Strength",
"imageFit": "Fit Initial Image To Output Size",
@ -527,7 +531,8 @@
"useCanvasBeta": "Use Canvas Beta Layout",
"enableImageDebugging": "Enable Image Debugging",
"useSlidersForAll": "Use Sliders For All Options",
"autoShowProgress": "Auto Show Progress Images",
"showProgressInViewer": "Show Progress Images in Viewer",
"antialiasProgressImages": "Antialias Progress Images",
"resetWebUI": "Reset Web UI",
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
@ -549,8 +554,9 @@
"downloadImageStarted": "Image Download Started",
"imageCopied": "Image Copied",
"imageLinkCopied": "Image Link Copied",
"problemCopyingImageLink": "Unable to Copy Image Link",
"imageNotLoaded": "No Image Loaded",
"imageNotLoadedDesc": "No image found to send to image to image module",
"imageNotLoadedDesc": "Could not find image",
"imageSavedToGallery": "Image Saved to Gallery",
"canvasMerged": "Canvas Merged",
"sentToImageToImage": "Sent To Image To Image",
@ -645,7 +651,8 @@
"betaClear": "Clear",
"betaDarkenOutside": "Darken Outside",
"betaLimitToBox": "Limit To Box",
"betaPreserveMasked": "Preserve Masked"
"betaPreserveMasked": "Preserve Masked",
"antialiasing": "Antialiasing"
},
"ui": {
"showProgressImages": "Show Progress Images",

View File

@ -1,6 +1,6 @@
import ImageUploader from 'common/components/ImageUploader';
import ProgressBar from 'features/system/components/ProgressBar';
import SiteHeader from 'features/system/components/SiteHeader';
import ProgressBar from 'features/system/components/ProgressBar';
import InvokeTabs from 'features/ui/components/InvokeTabs';
import useToastWatcher from 'features/system/hooks/useToastWatcher';
@ -9,7 +9,7 @@ import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton'
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
import GalleryDrawer from 'features/gallery/components/ImageGalleryPanel';
import Lightbox from 'features/lightbox/components/Lightbox';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
@ -27,7 +27,8 @@ import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import { configChanged } from 'features/system/store/configSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useLogger } from 'app/logging/useLogger';
import ProgressImagePreview from 'features/parameters/components/ProgressImagePreview';
import ProgressImagePreview from 'features/parameters/components/_ProgressImagePreview';
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
const DEFAULT_CONFIG = {};
@ -84,11 +85,13 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
flexDir={{ base: 'column', xl: 'row' }}
>
<InvokeTabs />
<ImageGalleryPanel />
</Flex>
</Grid>
</ImageUploader>
<GalleryDrawer />
<ParametersDrawer />
<AnimatePresence>
{!isApplicationReady && !loadingOverridden && (
<motion.div
@ -121,7 +124,6 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
<Portal>
<FloatingGalleryButton />
</Portal>
<ProgressImagePreview />
</Grid>
);
};

View File

@ -1,8 +1,6 @@
import React, { lazy, memo, PropsWithChildren, useEffect } from 'react';
import { Provider } from 'react-redux';
import { PersistGate } from 'redux-persist/integration/react';
import { store } from 'app/store/store';
import { persistor } from '../store/persistor';
import { OpenAPI } from 'services/api';
import '@fontsource/inter/100.css';
import '@fontsource/inter/200.css';
@ -57,13 +55,11 @@ const InvokeAIUI = ({ apiUrl, token, config, children }: Props) => {
return (
<React.StrictMode>
<Provider store={store}>
<PersistGate loading={<Loading />} persistor={persistor}>
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<App config={config}>{children}</App>
</ThemeLocaleProvider>
</React.Suspense>
</PersistGate>
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<App config={config}>{children}</App>
</ThemeLocaleProvider>
</React.Suspense>
</Provider>
</React.StrictMode>
);

View File

@ -1,26 +1,20 @@
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { initialCanvasImageSelector } from 'features/canvas/store/canvasSelectors';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es';
export const readinessSelector = createSelector(
[
generationSelector,
systemSelector,
initialCanvasImageSelector,
activeTabNameSelector,
],
(generation, system, initialCanvasImage, activeTabName) => {
[generationSelector, systemSelector, activeTabNameSelector],
(generation, system, activeTabName) => {
const {
prompt,
shouldGenerateVariations,
seedWeights,
initialImage,
seed,
isImageToImageEnabled,
} = generation;
const { isProcessing, isConnected } = system;
@ -34,7 +28,7 @@ export const readinessSelector = createSelector(
reasonsWhyNotReady.push('Missing prompt');
}
if (isImageToImageEnabled && !initialImage) {
if (activeTabName === 'img2img' && !initialImage) {
isReady = false;
reasonsWhyNotReady.push('No initial image selected');
}
@ -64,10 +58,5 @@ export const readinessSelector = createSelector(
// All good
return { isReady, reasonsWhyNotReady };
},
{
memoizeOptions: {
equalityCheck: isEqual,
resultEqualityCheck: isEqual,
},
}
defaultSelectorOptions
);

View File

@ -1,209 +1,209 @@
// import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
// import * as InvokeAI from 'app/types/invokeai';
// import type { RootState } from 'app/store/store';
// import {
// frontendToBackendParameters,
// FrontendToBackendParametersConfig,
// } from 'common/util/parameterTranslation';
// import dateFormat from 'dateformat';
// import {
// GalleryCategory,
// GalleryState,
// removeImage,
// } from 'features/gallery/store/gallerySlice';
// import {
// generationRequested,
// modelChangeRequested,
// modelConvertRequested,
// modelMergingRequested,
// setIsProcessing,
// } from 'features/system/store/systemSlice';
// import { InvokeTabName } from 'features/ui/store/tabMap';
// import { Socket } from 'socket.io-client';
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai';
import type { RootState } from 'app/store/store';
import {
frontendToBackendParameters,
FrontendToBackendParametersConfig,
} from 'common/util/parameterTranslation';
import dateFormat from 'dateformat';
import {
GalleryCategory,
GalleryState,
removeImage,
} from 'features/gallery/store/gallerySlice';
import {
generationRequested,
modelChangeRequested,
modelConvertRequested,
modelMergingRequested,
setIsProcessing,
} from 'features/system/store/systemSlice';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { Socket } from 'socket.io-client';
// /**
// * Returns an object containing all functions which use `socketio.emit()`.
// * i.e. those which make server requests.
// */
// const makeSocketIOEmitters = (
// store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
// socketio: Socket
// ) => {
// // We need to dispatch actions to redux and get pieces of state from the store.
// const { dispatch, getState } = store;
/**
* Returns an object containing all functions which use `socketio.emit()`.
* i.e. those which make server requests.
*/
const makeSocketIOEmitters = (
store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
socketio: Socket
) => {
// We need to dispatch actions to redux and get pieces of state from the store.
const { dispatch, getState } = store;
// return {
// emitGenerateImage: (generationMode: InvokeTabName) => {
// dispatch(setIsProcessing(true));
return {
emitGenerateImage: (generationMode: InvokeTabName) => {
dispatch(setIsProcessing(true));
// const state: RootState = getState();
const state: RootState = getState();
// const {
// generation: generationState,
// postprocessing: postprocessingState,
// system: systemState,
// canvas: canvasState,
// } = state;
const {
generation: generationState,
postprocessing: postprocessingState,
system: systemState,
canvas: canvasState,
} = state;
// const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
// {
// generationMode,
// generationState,
// postprocessingState,
// canvasState,
// systemState,
// };
const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
{
generationMode,
generationState,
postprocessingState,
canvasState,
systemState,
};
// dispatch(generationRequested());
dispatch(generationRequested());
// const { generationParameters, esrganParameters, facetoolParameters } =
// frontendToBackendParameters(frontendToBackendParametersConfig);
const { generationParameters, esrganParameters, facetoolParameters } =
frontendToBackendParameters(frontendToBackendParametersConfig);
// socketio.emit(
// 'generateImage',
// generationParameters,
// esrganParameters,
// facetoolParameters
// );
socketio.emit(
'generateImage',
generationParameters,
esrganParameters,
facetoolParameters
);
// // we need to truncate the init_mask base64 else it takes up the whole log
// // TODO: handle maintaining masks for reproducibility in future
// if (generationParameters.init_mask) {
// generationParameters.init_mask = generationParameters.init_mask
// .substr(0, 64)
// .concat('...');
// }
// if (generationParameters.init_img) {
// generationParameters.init_img = generationParameters.init_img
// .substr(0, 64)
// .concat('...');
// }
// we need to truncate the init_mask base64 else it takes up the whole log
// TODO: handle maintaining masks for reproducibility in future
if (generationParameters.init_mask) {
generationParameters.init_mask = generationParameters.init_mask
.substr(0, 64)
.concat('...');
}
if (generationParameters.init_img) {
generationParameters.init_img = generationParameters.init_img
.substr(0, 64)
.concat('...');
}
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Image generation requested: ${JSON.stringify({
// ...generationParameters,
// ...esrganParameters,
// ...facetoolParameters,
// })}`,
// })
// );
// },
// emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
// dispatch(setIsProcessing(true));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image generation requested: ${JSON.stringify({
...generationParameters,
...esrganParameters,
...facetoolParameters,
})}`,
})
);
},
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
// const {
// postprocessing: {
// upscalingLevel,
// upscalingDenoising,
// upscalingStrength,
// },
// } = getState();
const {
postprocessing: {
upscalingLevel,
upscalingDenoising,
upscalingStrength,
},
} = getState();
// const esrganParameters = {
// upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
// };
// socketio.emit('runPostprocessing', imageToProcess, {
// type: 'esrgan',
// ...esrganParameters,
// });
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `ESRGAN upscale requested: ${JSON.stringify({
// file: imageToProcess.url,
// ...esrganParameters,
// })}`,
// })
// );
// },
// emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
// dispatch(setIsProcessing(true));
const esrganParameters = {
upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
};
socketio.emit('runPostprocessing', imageToProcess, {
type: 'esrgan',
...esrganParameters,
});
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `ESRGAN upscale requested: ${JSON.stringify({
file: imageToProcess.url,
...esrganParameters,
})}`,
})
);
},
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
// const {
// postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
// } = getState();
const {
postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
} = getState();
// const facetoolParameters: Record<string, unknown> = {
// facetool_strength: facetoolStrength,
// };
const facetoolParameters: Record<string, unknown> = {
facetool_strength: facetoolStrength,
};
// if (facetoolType === 'codeformer') {
// facetoolParameters.codeformer_fidelity = codeformerFidelity;
// }
if (facetoolType === 'codeformer') {
facetoolParameters.codeformer_fidelity = codeformerFidelity;
}
// socketio.emit('runPostprocessing', imageToProcess, {
// type: facetoolType,
// ...facetoolParameters,
// });
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
// {
// file: imageToProcess.url,
// ...facetoolParameters,
// }
// )}`,
// })
// );
// },
// emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
// const { url, uuid, category, thumbnail } = imageToDelete;
// dispatch(removeImage(imageToDelete));
// socketio.emit('deleteImage', url, thumbnail, uuid, category);
// },
// emitRequestImages: (category: GalleryCategory) => {
// const gallery: GalleryState = getState().gallery;
// const { earliest_mtime } = gallery.categories[category];
// socketio.emit('requestImages', category, earliest_mtime);
// },
// emitRequestNewImages: (category: GalleryCategory) => {
// const gallery: GalleryState = getState().gallery;
// const { latest_mtime } = gallery.categories[category];
// socketio.emit('requestLatestImages', category, latest_mtime);
// },
// emitCancelProcessing: () => {
// socketio.emit('cancel');
// },
// emitRequestSystemConfig: () => {
// socketio.emit('requestSystemConfig');
// },
// emitSearchForModels: (modelFolder: string) => {
// socketio.emit('searchForModels', modelFolder);
// },
// emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
// socketio.emit('addNewModel', modelConfig);
// },
// emitDeleteModel: (modelName: string) => {
// socketio.emit('deleteModel', modelName);
// },
// emitConvertToDiffusers: (
// modelToConvert: InvokeAI.InvokeModelConversionProps
// ) => {
// dispatch(modelConvertRequested());
// socketio.emit('convertToDiffusers', modelToConvert);
// },
// emitMergeDiffusersModels: (
// modelMergeInfo: InvokeAI.InvokeModelMergingProps
// ) => {
// dispatch(modelMergingRequested());
// socketio.emit('mergeDiffusersModels', modelMergeInfo);
// },
// emitRequestModelChange: (modelName: string) => {
// dispatch(modelChangeRequested());
// socketio.emit('requestModelChange', modelName);
// },
// emitSaveStagingAreaImageToGallery: (url: string) => {
// socketio.emit('requestSaveStagingAreaImageToGallery', url);
// },
// emitRequestEmptyTempFolder: () => {
// socketio.emit('requestEmptyTempFolder');
// },
// };
// };
socketio.emit('runPostprocessing', imageToProcess, {
type: facetoolType,
...facetoolParameters,
});
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
{
file: imageToProcess.url,
...facetoolParameters,
}
)}`,
})
);
},
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
const { url, uuid, category, thumbnail } = imageToDelete;
dispatch(removeImage(imageToDelete));
socketio.emit('deleteImage', url, thumbnail, uuid, category);
},
emitRequestImages: (category: GalleryCategory) => {
const gallery: GalleryState = getState().gallery;
const { earliest_mtime } = gallery.categories[category];
socketio.emit('requestImages', category, earliest_mtime);
},
emitRequestNewImages: (category: GalleryCategory) => {
const gallery: GalleryState = getState().gallery;
const { latest_mtime } = gallery.categories[category];
socketio.emit('requestLatestImages', category, latest_mtime);
},
emitCancelProcessing: () => {
socketio.emit('cancel');
},
emitRequestSystemConfig: () => {
socketio.emit('requestSystemConfig');
},
emitSearchForModels: (modelFolder: string) => {
socketio.emit('searchForModels', modelFolder);
},
emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
socketio.emit('addNewModel', modelConfig);
},
emitDeleteModel: (modelName: string) => {
socketio.emit('deleteModel', modelName);
},
emitConvertToDiffusers: (
modelToConvert: InvokeAI.InvokeModelConversionProps
) => {
dispatch(modelConvertRequested());
socketio.emit('convertToDiffusers', modelToConvert);
},
emitMergeDiffusersModels: (
modelMergeInfo: InvokeAI.InvokeModelMergingProps
) => {
dispatch(modelMergingRequested());
socketio.emit('mergeDiffusersModels', modelMergeInfo);
},
emitRequestModelChange: (modelName: string) => {
dispatch(modelChangeRequested());
socketio.emit('requestModelChange', modelName);
},
emitSaveStagingAreaImageToGallery: (url: string) => {
socketio.emit('requestSaveStagingAreaImageToGallery', url);
},
emitRequestEmptyTempFolder: () => {
socketio.emit('requestEmptyTempFolder');
},
};
};
// export default makeSocketIOEmitters;
export default makeSocketIOEmitters;
export default {};

View File

@ -0,0 +1,4 @@
import { createAction } from '@reduxjs/toolkit';
import { InvokeTabName } from 'features/ui/store/tabMap';
export const userInvoked = createAction<InvokeTabName>('app/userInvoked');

View File

@ -0,0 +1,8 @@
export const LOCALSTORAGE_KEYS = [
'chakra-ui-color-mode',
'i18nextLng',
'ROARR_FILTER',
'ROARR_LOG',
];
export const LOCALSTORAGE_PREFIX = '@@invokeai-';

View File

@ -0,0 +1,36 @@
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
import { resultsPersistDenylist } from 'features/gallery/store/resultsPersistDenylist';
import { uploadsPersistDenylist } from 'features/gallery/store/uploadsPersistDenylist';
import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
import { omit } from 'lodash-es';
import { SerializeFunction } from 'redux-remember';
const serializationDenylist: {
[key: string]: string[];
} = {
canvas: canvasPersistDenylist,
gallery: galleryPersistDenylist,
generation: generationPersistDenylist,
lightbox: lightboxPersistDenylist,
models: modelsPersistDenylist,
nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist,
results: resultsPersistDenylist,
system: systemPersistDenylist,
// config: configPersistDenyList,
ui: uiPersistDenylist,
uploads: uploadsPersistDenylist,
// hotkeys: hotkeysPersistDenylist,
};
export const serialize: SerializeFunction = (data, key) => {
const result = omit(data, serializationDenylist[key]);
return JSON.stringify(result);
};

View File

@ -0,0 +1,38 @@
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
import { initialResultsState } from 'features/gallery/store/resultsSlice';
import { initialUploadsState } from 'features/gallery/store/uploadsSlice';
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
import { initialConfigState } from 'features/system/store/configSlice';
import { initialModelsState } from 'features/system/store/modelSlice';
import { initialSystemState } from 'features/system/store/systemSlice';
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
import { initialUIState } from 'features/ui/store/uiSlice';
import { defaultsDeep } from 'lodash-es';
import { UnserializeFunction } from 'redux-remember';
const initialStates: {
[key: string]: any;
} = {
canvas: initialCanvasState,
gallery: initialGalleryState,
generation: initialGenerationState,
lightbox: initialLightboxState,
models: initialModelsState,
nodes: initialNodesState,
postprocessing: initialPostprocessingState,
results: initialResultsState,
system: initialSystemState,
config: initialConfigState,
ui: initialUIState,
uploads: initialUploadsState,
hotkeys: initialHotkeysState,
};
export const unserialize: UnserializeFunction = (data, key) => {
const result = defaultsDeep(JSON.parse(data), initialStates[key]);
return result;
};

View File

@ -0,0 +1,30 @@
import { AnyAction } from '@reduxjs/toolkit';
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { forEach } from 'lodash-es';
import { Graph } from 'services/api';
export const actionSanitizer = <A extends AnyAction>(action: A): A => {
if (isAnyGraphBuilt(action)) {
if (action.payload.nodes) {
const sanitizedNodes: Graph['nodes'] = {};
// Sanitize nodes as needed
forEach(action.payload.nodes, (node, key) => {
// Don't log the whole freaking dataURL
if (node.type === 'dataURL_image') {
const { dataURL, ...rest } = node;
sanitizedNodes[key] = { ...rest, dataURL: '<dataURL>' };
} else {
sanitizedNodes[key] = { ...node };
}
});
return {
...action,
payload: { ...action.payload, nodes: sanitizedNodes },
};
}
}
return action;
};

View File

@ -0,0 +1,11 @@
export const actionsDenylist = [
'canvas/setCursorPosition',
'canvas/setStageCoordinates',
'canvas/setStageScale',
'canvas/setIsDrawing',
'canvas/setBoundingBoxCoordinates',
'canvas/setBoundingBoxDimensions',
'canvas/setIsDrawing',
'canvas/addPointToCurrentLine',
'socket/generatorProgress',
];

View File

@ -0,0 +1,3 @@
export const stateSanitizer = <S>(state: S): S => {
return state;
};

View File

@ -0,0 +1,45 @@
import {
createListenerMiddleware,
addListener,
ListenerEffect,
AnyAction,
} from '@reduxjs/toolkit';
import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit';
import type { RootState, AppDispatch } from '../../store';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
import { addImageResultReceivedListener } from './listeners/invocationComplete';
import { addImageUploadedListener } from './listeners/imageUploaded';
import { addRequestedImageDeletionListener } from './listeners/imageDeleted';
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
export const startAppListening =
listenerMiddleware.startListening as AppStartListening;
export const addAppListener = addListener as TypedAddListener<
RootState,
AppDispatch
>;
export type AppListenerEffect = ListenerEffect<
AnyAction,
RootState,
AppDispatch
>;
addImageUploadedListener();
addInitialImageSelectedListener();
addImageResultReceivedListener();
addRequestedImageDeletionListener();
addUserInvokedCanvasListener();
addUserInvokedNodesListener();
addUserInvokedTextToImageListener();
addUserInvokedImageToImageListener();

View File

@ -0,0 +1,31 @@
import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { startAppListening } from '..';
import {
canvasSessionIdChanged,
stagingAreaInitialized,
} from 'features/canvas/store/canvasSlice';
import { sessionInvoked } from 'services/thunks/session';
export const addCanvasGraphBuiltListener = () =>
startAppListening({
actionCreator: canvasGraphBuilt,
effect: async (action, { dispatch, getState, take }) => {
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
const { sessionId } = meta.arg;
const state = getState();
if (!state.canvas.layerState.stagingArea.boundingBox) {
dispatch(
stagingAreaInitialized({
sessionId,
boundingBox: {
...state.canvas.boundingBoxCoordinates,
...state.canvas.boundingBoxDimensions,
},
})
);
}
dispatch(canvasSessionIdChanged(sessionId));
},
});

View File

@ -0,0 +1,59 @@
import { requestedImageDeletion } from 'features/gallery/store/actions';
import { startAppListening } from '..';
import { imageDeleted } from 'services/thunks/image';
import { log } from 'app/logging/useLogger';
import { clamp } from 'lodash-es';
import { imageSelected } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
export const addRequestedImageDeletionListener = () => {
startAppListening({
actionCreator: requestedImageDeletion,
effect: (action, { dispatch, getState }) => {
const image = action.payload;
if (!image) {
moduleLog.warn('No image provided');
return;
}
const { name, type } = image;
if (type !== 'uploads' && type !== 'results') {
moduleLog.warn({ data: image }, `Invalid image type ${type}`);
return;
}
const selectedImageName = getState().gallery.selectedImage?.name;
if (selectedImageName === name) {
const allIds = getState()[type].ids;
const allEntities = getState()[type].entities;
const deletedImageIndex = allIds.findIndex(
(result) => result.toString() === name
);
const filteredIds = allIds.filter((id) => id.toString() !== name);
const newSelectedImageIndex = clamp(
deletedImageIndex,
0,
filteredIds.length - 1
);
const newSelectedImageId = filteredIds[newSelectedImageIndex];
const newSelectedImage = allEntities[newSelectedImageId];
if (newSelectedImageId) {
dispatch(imageSelected(newSelectedImage));
} else {
dispatch(imageSelected());
}
}
dispatch(imageDeleted({ imageName: name, imageType: type }));
},
});
};

View File

@ -0,0 +1,22 @@
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
import { startAppListening } from '..';
import { uploadAdded } from 'features/gallery/store/uploadsSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imageUploaded } from 'services/thunks/image';
export const addImageUploadedListener = () => {
startAppListening({
actionCreator: imageUploaded.fulfilled,
effect: (action, { dispatch, getState }) => {
const { response } = action.payload;
const state = getState();
const image = deserializeImageResponse(response);
dispatch(uploadAdded(image));
if (state.gallery.shouldAutoSwitchToNewImages) {
dispatch(imageSelected(image));
}
},
});
};

View File

@ -0,0 +1,54 @@
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { Image, isInvokeAIImage } from 'app/types/invokeai';
import { selectResultsById } from 'features/gallery/store/resultsSlice';
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
import { makeToast } from 'features/system/hooks/useToastWatcher';
import { t } from 'i18next';
import { addToast } from 'features/system/store/systemSlice';
import { startAppListening } from '..';
import { initialImageSelected } from 'features/parameters/store/actions';
export const addInitialImageSelectedListener = () => {
startAppListening({
actionCreator: initialImageSelected,
effect: (action, { getState, dispatch }) => {
if (!action.payload) {
dispatch(
addToast(
makeToast({ title: t('toast.imageNotLoadedDesc'), status: 'error' })
)
);
return;
}
if (isInvokeAIImage(action.payload)) {
dispatch(initialImageChanged(action.payload));
dispatch(addToast(makeToast(t('toast.sentToImageToImage'))));
return;
}
const { name, type } = action.payload;
let image: Image | undefined;
const state = getState();
if (type === 'results') {
image = selectResultsById(state, name);
} else if (type === 'uploads') {
image = selectUploadsById(state, name);
}
if (!image) {
dispatch(
addToast(
makeToast({ title: t('toast.imageNotLoadedDesc'), status: 'error' })
)
);
return;
}
dispatch(initialImageChanged(image));
dispatch(addToast(makeToast(t('toast.sentToImageToImage'))));
},
});
};

View File

@ -0,0 +1,88 @@
import { invocationComplete } from 'services/events/actions';
import { isImageOutput } from 'services/types/guards';
import {
buildImageUrls,
extractTimestampFromImageName,
} from 'services/util/deserializeImageField';
import { Image } from 'app/types/invokeai';
import { resultAdded } from 'features/gallery/store/resultsSlice';
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
import { startAppListening } from '..';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
const nodeDenylist = ['dataURL_image'];
export const addImageResultReceivedListener = () => {
startAppListening({
predicate: (action) => {
if (
invocationComplete.match(action) &&
isImageOutput(action.payload.data.result)
) {
return true;
}
return false;
},
effect: (action, { getState, dispatch }) => {
if (!invocationComplete.match(action)) {
return;
}
const { data, shouldFetchImages } = action.payload;
const { result, node, graph_execution_state_id } = data;
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
const name = result.image.image_name;
const type = result.image.image_type;
const state = getState();
// if we need to refetch, set URLs to placeholder for now
const { url, thumbnail } = shouldFetchImages
? { url: '', thumbnail: '' }
: buildImageUrls(type, name);
const timestamp = extractTimestampFromImageName(name);
const image: Image = {
name,
type,
url,
thumbnail,
metadata: {
created: timestamp,
width: result.width,
height: result.height,
invokeai: {
session_id: graph_execution_state_id,
...(node ? { node } : {}),
},
},
};
dispatch(resultAdded(image));
if (state.gallery.shouldAutoSwitchToNewImages) {
dispatch(imageSelected(image));
}
if (state.config.shouldFetchImages) {
dispatch(imageReceived({ imageName: name, imageType: type }));
dispatch(
thumbnailReceived({
thumbnailName: name,
thumbnailType: type,
})
);
}
if (
graph_execution_state_id ===
state.canvas.layerState.stagingArea.sessionId
) {
dispatch(addImageToStagingArea(image));
}
}
},
});
};

View File

@ -0,0 +1,126 @@
import { startAppListening } from '..';
import { sessionCreated, sessionInvoked } from 'services/thunks/session';
import { buildCanvasGraphAndBlobs } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
import { log } from 'app/logging/useLogger';
import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { imageUploaded } from 'services/thunks/image';
import { v4 as uuidv4 } from 'uuid';
import { Graph } from 'services/api';
import {
canvasSessionIdChanged,
stagingAreaInitialized,
} from 'features/canvas/store/canvasSlice';
import { userInvoked } from 'app/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
export const addUserInvokedCanvasListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'unifiedCanvas',
effect: async (action, { getState, dispatch, take }) => {
const state = getState();
const data = await buildCanvasGraphAndBlobs(state);
if (!data) {
moduleLog.error('Problem building graph');
return;
}
const {
rangeNode,
iterateNode,
baseNode,
edges,
baseBlob,
maskBlob,
generationMode,
} = data;
const baseFilename = `${uuidv4()}.png`;
const maskFilename = `${uuidv4()}.png`;
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
},
})
);
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
const [{ payload: basePayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === baseFilename
);
const { image_name: baseName, image_type: baseType } =
basePayload.response;
baseNode.image = {
image_name: baseName,
image_type: baseType,
};
}
if (baseNode.type === 'inpaint') {
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
},
})
);
const [{ payload: maskPayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === maskFilename
);
const { image_name: maskName, image_type: maskType } =
maskPayload.response;
baseNode.mask = {
image_name: maskName,
image_type: maskType,
};
}
// Assemble!
const nodes: Graph['nodes'] = {
[rangeNode.id]: rangeNode,
[iterateNode.id]: iterateNode,
[baseNode.id]: baseNode,
};
const graph = { nodes, edges };
dispatch(canvasGraphBuilt(graph));
moduleLog({ data: graph }, 'Canvas graph built');
dispatch(sessionCreated({ graph }));
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
const { sessionId } = meta.arg;
if (!state.canvas.layerState.stagingArea.boundingBox) {
dispatch(
stagingAreaInitialized({
sessionId,
boundingBox: {
...state.canvas.boundingBoxCoordinates,
...state.canvas.boundingBoxDimensions,
},
})
);
}
dispatch(canvasSessionIdChanged(sessionId));
},
});
};

View File

@ -0,0 +1,24 @@
import { startAppListening } from '..';
import { buildImageToImageGraph } from 'features/nodes/util/graphBuilders/buildImageToImageGraph';
import { sessionCreated } from 'services/thunks/session';
import { log } from 'app/logging/useLogger';
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
export const addUserInvokedImageToImageListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'img2img',
effect: (action, { getState, dispatch }) => {
const state = getState();
const graph = buildImageToImageGraph(state);
dispatch(imageToImageGraphBuilt(graph));
moduleLog({ data: graph }, 'Image to Image graph built');
dispatch(sessionCreated({ graph }));
},
});
};

View File

@ -0,0 +1,24 @@
import { startAppListening } from '..';
import { sessionCreated } from 'services/thunks/session';
import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph';
import { log } from 'app/logging/useLogger';
import { nodesGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
export const addUserInvokedNodesListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'nodes',
effect: (action, { getState, dispatch }) => {
const state = getState();
const graph = buildNodesGraph(state);
dispatch(nodesGraphBuilt(graph));
moduleLog({ data: graph }, 'Nodes graph built');
dispatch(sessionCreated({ graph }));
},
});
};

View File

@ -0,0 +1,24 @@
import { startAppListening } from '..';
import { buildTextToImageGraph } from 'features/nodes/util/graphBuilders/buildTextToImageGraph';
import { sessionCreated } from 'services/thunks/session';
import { log } from 'app/logging/useLogger';
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
export const addUserInvokedTextToImageListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'txt2img',
effect: (action, { getState, dispatch }) => {
const state = getState();
const graph = buildTextToImageGraph(state);
dispatch(textToImageGraphBuilt(graph));
moduleLog({ data: graph }, 'Text to Image graph built');
dispatch(sessionCreated({ graph }));
},
});
};

View File

@ -1,4 +0,0 @@
import { store } from 'app/store/store';
import { persistStore } from 'redux-persist';
export const persistor = persistStore(store);

View File

@ -1,9 +1,12 @@
import { combineReducers, configureStore } from '@reduxjs/toolkit';
import {
AnyAction,
ThunkDispatch,
combineReducers,
configureStore,
} from '@reduxjs/toolkit';
import { persistReducer } from 'redux-persist';
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
import { rememberReducer, rememberEnhancer } from 'redux-remember';
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { getPersistConfig } from 'redux-deep-persist';
import canvasReducer from 'features/canvas/store/canvasSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
@ -19,33 +22,17 @@ import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import modelsReducer from 'features/system/store/modelSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import { canvasDenylist } from 'features/canvas/store/canvasPersistDenylist';
import { galleryDenylist } from 'features/gallery/store/galleryPersistDenylist';
import { generationDenylist } from 'features/parameters/store/generationPersistDenylist';
import { lightboxDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
import { modelsDenylist } from 'features/system/store/modelsPersistDenylist';
import { nodesDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { postprocessingDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
import { systemDenylist } from 'features/system/store/systemPersistDenylist';
import { uiDenylist } from 'features/ui/store/uiPersistDenylist';
import { resultsDenylist } from 'features/gallery/store/resultsPersistDenylist';
import { uploadsDenylist } from 'features/gallery/store/uploadsPersistDenylist';
import { listenerMiddleware } from './middleware/listenerMiddleware';
/**
* redux-persist provides an easy and reliable way to persist state across reloads.
*
* While we definitely want generation parameters to be persisted, there are a number
* of things we do *not* want to be persisted across reloads:
* - Gallery/selected image (user may add/delete images from disk between page loads)
* - Connection/processing status
* - Availability of external libraries like ESRGAN/GFPGAN
*
* These can be denylisted in redux-persist.
*
* The necesssary nested persistors with denylists are configured below.
*/
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
const rootReducer = combineReducers({
import { serialize } from './enhancers/reduxRemember/serialize';
import { unserialize } from './enhancers/reduxRemember/unserialize';
import { LOCALSTORAGE_PREFIX } from './constants';
const allReducers = {
canvas: canvasReducer,
gallery: galleryReducer,
generation: generationReducer,
@ -59,65 +46,54 @@ const rootReducer = combineReducers({
ui: uiReducer,
uploads: uploadsReducer,
hotkeys: hotkeysReducer,
});
};
const rootPersistConfig = getPersistConfig({
key: 'root',
storage,
rootReducer,
blacklist: [
...canvasDenylist,
...galleryDenylist,
...generationDenylist,
...lightboxDenylist,
...modelsDenylist,
...nodesDenylist,
...postprocessingDenylist,
// ...resultsDenylist,
'results',
...systemDenylist,
...uiDenylist,
// ...uploadsDenylist,
'uploads',
'hotkeys',
'config',
],
});
const rootReducer = combineReducers(allReducers);
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
const rememberedRootReducer = rememberReducer(rootReducer);
// TODO: rip the old middleware out when nodes is complete
// export function buildMiddleware() {
// if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
// return socketMiddleware();
// } else {
// return socketioMiddleware();
// }
// }
const rememberedKeys: (keyof typeof allReducers)[] = [
'canvas',
'gallery',
'generation',
'lightbox',
// 'models',
'nodes',
'postprocessing',
'system',
'ui',
// 'hotkeys',
// 'results',
// 'uploads',
// 'config',
];
export const store = configureStore({
reducer: persistedReducer,
reducer: rememberedRootReducer,
enhancers: [
rememberEnhancer(window.localStorage, rememberedKeys, {
persistDebounce: 300,
serialize,
unserialize,
prefix: LOCALSTORAGE_PREFIX,
}),
],
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
immutableCheck: false,
serializableCheck: false,
}).concat(dynamicMiddlewares),
})
.concat(dynamicMiddlewares)
.prepend(listenerMiddleware.middleware),
devTools: {
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
actionsDenylist: [
'canvas/setCursorPosition',
'canvas/setStageCoordinates',
'canvas/setStageScale',
'canvas/setIsDrawing',
'canvas/setBoundingBoxCoordinates',
'canvas/setBoundingBoxDimensions',
'canvas/setIsDrawing',
'canvas/addPointToCurrentLine',
'socket/generatorProgress',
],
actionsDenylist,
actionSanitizer,
stateSanitizer,
trace: true,
},
});
export type AppGetState = typeof store.getState;
export type RootState = ReturnType<typeof store.getState>;
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
export type AppDispatch = typeof store.dispatch;

View File

@ -1,6 +1,6 @@
import { TypedUseSelectorHook, useDispatch, useSelector } from 'react-redux';
import { AppDispatch, RootState } from 'app/store/store';
import { AppThunkDispatch, RootState } from 'app/store/store';
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch: () => AppDispatch = useDispatch;
export const useAppDispatch = () => useDispatch<AppThunkDispatch>();
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;

View File

@ -0,0 +1,7 @@
import { isEqual } from 'lodash-es';
export const defaultSelectorOptions = {
memoizeOptions: {
resultEqualityCheck: isEqual,
},
};

View File

@ -12,12 +12,10 @@
* 'gfpgan'.
*/
import { GalleryCategory } from 'features/gallery/store/gallerySlice';
import { FacetoolType } from 'features/parameters/store/postprocessingSlice';
import { SelectedImage } from 'features/parameters/store/actions';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { ImageResponseMetadata, ImageType } from 'services/api';
import { AnyInvocation } from 'services/events/types';
import { O } from 'ts-toolbelt';
/**
@ -126,6 +124,14 @@ export type Image = {
metadata: ImageResponseMetadata;
};
export const isInvokeAIImage = (obj: Image | SelectedImage): obj is Image => {
if ('url' in obj && 'thumbnail' in obj) {
return true;
}
return false;
};
/**
* Types related to the system status.
*/
@ -270,7 +276,7 @@ export type FoundModelResponse = {
// export type SystemConfigResponse = SystemConfig;
export type ImageResultResponse = Omit<_Image, 'uuid'> & {
export type ImageResultResponse = Omit<Image, 'uuid'> & {
boundingBox?: IRect;
generationMode: InvokeTabName;
};

View File

@ -0,0 +1,61 @@
import { ChevronUpIcon } from '@chakra-ui/icons';
import { Box, Collapse, Flex, Spacer, Switch } from '@chakra-ui/react';
import { PropsWithChildren, memo } from 'react';
export type IAIToggleCollapseProps = PropsWithChildren & {
label: string;
isOpen: boolean;
onToggle: () => void;
withSwitch?: boolean;
};
const IAICollapse = (props: IAIToggleCollapseProps) => {
const { label, isOpen, onToggle, children, withSwitch = false } = props;
return (
<Box>
<Flex
onClick={onToggle}
sx={{
alignItems: 'center',
p: 2,
px: 4,
borderTopRadius: 'base',
borderBottomRadius: isOpen ? 0 : 'base',
bg: isOpen ? 'base.750' : 'base.800',
color: 'base.100',
_hover: {
bg: isOpen ? 'base.700' : 'base.750',
},
fontSize: 'sm',
fontWeight: 600,
cursor: 'pointer',
transitionProperty: 'common',
transitionDuration: 'normal',
userSelect: 'none',
}}
>
{label}
<Spacer />
{withSwitch && <Switch isChecked={isOpen} pointerEvents="none" />}
{!withSwitch && (
<ChevronUpIcon
sx={{
w: '1rem',
h: '1rem',
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
transitionProperty: 'common',
transitionDuration: 'normal',
}}
/>
)}
</Flex>
<Collapse in={isOpen} animateOpacity>
<Box sx={{ p: 4, borderBottomRadius: 'base', bg: 'base.800' }}>
{children}
</Box>
</Collapse>
</Box>
);
};
export default memo(IAICollapse);

View File

@ -27,7 +27,7 @@ const IAIPopover = (props: IAIPopoverProps) => {
return (
<Popover isLazy={isLazy} {...rest}>
<PopoverTrigger>{triggerComponent}</PopoverTrigger>
<PopoverContent>
<PopoverContent shadow="dark-lg">
{hasArrow && <PopoverArrow />}
{children}
</PopoverContent>

View File

@ -7,7 +7,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
const ImageToImageSettingsHeader = () => {
const InitialImageButtons = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -18,24 +18,19 @@ const ImageToImageSettingsHeader = () => {
return (
<Flex w="full" alignItems="center">
<Text size="sm" fontWeight={500} color="base.300">
Image to Image
{t('parameters.initialImage')}
</Text>
<Spacer />
<ButtonGroup>
<IAIIconButton
size="sm"
icon={<FaUndo />}
aria-label={t('accessibility.reset')}
onClick={handleResetInitialImage}
/>
<IAIIconButton
size="sm"
icon={<FaUpload />}
aria-label={t('common.upload')}
/>
<IAIIconButton icon={<FaUpload />} aria-label={t('common.upload')} />
</ButtonGroup>
</Flex>
);
};
export default ImageToImageSettingsHeader;
export default InitialImageButtons;

View File

@ -14,6 +14,7 @@ const ImageToImageOverlay = ({ image }: ImageToImageOverlayProps) => {
w: 'full',
h: 'full',
position: 'absolute',
pointerEvents: 'none',
}}
>
<Flex

View File

@ -49,7 +49,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
const fileAcceptedCallback = useCallback(
async (file: File) => {
dispatch(imageUploaded({ formData: { file } }));
dispatch(imageUploaded({ imageType: 'uploads', formData: { file } }));
},
[dispatch]
);
@ -124,7 +124,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
return;
}
dispatch(imageUploaded({ formData: { file } }));
dispatch(imageUploaded({ imageType: 'uploads', formData: { file } }));
};
document.addEventListener('paste', pasteImageListener);
return () => {

View File

@ -7,7 +7,7 @@ const SelectImagePlaceholder = () => {
sx={{
w: 'full',
h: 'full',
bg: 'base.800',
// bg: 'base.800',
borderRadius: 'base',
alignItems: 'center',
justifyContent: 'center',

View File

@ -2,6 +2,13 @@ import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import {
setActiveTab,
toggleGalleryPanel,
toggleParametersPanel,
togglePinGalleryPanel,
togglePinParametersPanel,
} from 'features/ui/store/uiSlice';
import { isEqual } from 'lodash-es';
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
@ -36,4 +43,36 @@ export const useGlobalHotkeys = () => {
{ keyup: true, keydown: true },
[shift]
);
useHotkeys('o', () => {
dispatch(toggleParametersPanel());
});
useHotkeys(['shift+o'], () => {
dispatch(togglePinParametersPanel());
});
useHotkeys('g', () => {
dispatch(toggleGalleryPanel());
});
useHotkeys(['shift+g'], () => {
dispatch(togglePinGalleryPanel());
});
useHotkeys('1', () => {
dispatch(setActiveTab('txt2img'));
});
useHotkeys('2', () => {
dispatch(setActiveTab('img2img'));
});
useHotkeys('3', () => {
dispatch(setActiveTab('unifiedCanvas'));
});
useHotkeys('4', () => {
dispatch(setActiveTab('nodes'));
});
};

View File

@ -0,0 +1,33 @@
export const getImageDataTransparency = (pixels: Uint8ClampedArray) => {
let isFullyTransparent = true;
let isPartiallyTransparent = false;
const len = pixels.length;
let i = 3;
for (i; i < len; i += 4) {
if (pixels[i] === 255) {
isFullyTransparent = false;
} else {
isPartiallyTransparent = true;
}
if (!isFullyTransparent && isPartiallyTransparent) {
return { isFullyTransparent, isPartiallyTransparent };
}
}
return { isFullyTransparent, isPartiallyTransparent };
};
export const areAnyPixelsBlack = (pixels: Uint8ClampedArray) => {
const len = pixels.length;
let i = 0;
for (i; i < len; ) {
if (
pixels[i++] === 0 &&
pixels[i++] === 0 &&
pixels[i++] === 0 &&
pixels[i++] === 255
) {
return true;
}
}
return false;
};

View File

@ -19,6 +19,7 @@ import { InvokeTabName } from 'features/ui/store/tabMap';
import openBase64ImageInTab from './openBase64ImageInTab';
import randomInt from './randomInt';
import { stringToSeedWeightsArray } from './seedWeightPairs';
import { getIsImageDataTransparent, getIsImageDataWhite } from './arrayBuffer';
export type FrontendToBackendParametersConfig = {
generationMode: InvokeTabName;
@ -256,7 +257,7 @@ export const frontendToBackendParameters = (
...boundingBoxDimensions,
};
const maskDataURL = generateMask(
const { dataURL: maskDataURL, imageData: maskImageData } = generateMask(
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [],
boundingBox
);
@ -287,6 +288,17 @@ export const frontendToBackendParameters = (
height: boundingBox.height,
});
const ctx = canvasBaseLayer.getContext();
const imageData = ctx.getImageData(
boundingBox.x + absPos.x,
boundingBox.y + absPos.y,
boundingBox.width,
boundingBox.height
);
const doesBaseHaveTransparency = getIsImageDataTransparent(imageData);
const doesMaskHaveTransparency = getIsImageDataWhite(maskImageData);
if (enableImageDebugging) {
openBase64ImageInTab([
{ base64: maskDataURL, caption: 'mask sent as init_mask' },

View File

@ -34,6 +34,7 @@ import IAICanvasStagingAreaToolbar from './IAICanvasStagingAreaToolbar';
import IAICanvasStatusText from './IAICanvasStatusText';
import IAICanvasBoundingBox from './IAICanvasToolbar/IAICanvasBoundingBox';
import IAICanvasToolPreview from './IAICanvasToolPreview';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
const selector = createSelector(
[canvasSelector, isStagingSelector],
@ -52,6 +53,7 @@ const selector = createSelector(
shouldShowIntermediates,
shouldShowGrid,
shouldRestrictStrokesToBox,
shouldAntialias,
} = canvas;
let stageCursor: string | undefined = 'none';
@ -80,13 +82,10 @@ const selector = createSelector(
tool,
isStaging,
shouldShowIntermediates,
shouldAntialias,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
defaultSelectorOptions
);
const ChakraStage = chakra(Stage, {
@ -106,6 +105,7 @@ const IAICanvas = () => {
tool,
isStaging,
shouldShowIntermediates,
shouldAntialias,
} = useAppSelector(selector);
useCanvasHotkeys();
@ -190,7 +190,7 @@ const IAICanvas = () => {
id="base"
ref={canvasBaseLayerRefCallback}
listening={false}
imageSmoothingEnabled={false}
imageSmoothingEnabled={shouldAntialias}
>
<IAICanvasObjectRenderer />
</Layer>
@ -201,7 +201,7 @@ const IAICanvas = () => {
<Layer>
<IAICanvasBoundingBoxOverlay />
</Layer>
<Layer id="preview" imageSmoothingEnabled={false}>
<Layer id="preview" imageSmoothingEnabled={shouldAntialias}>
{!isStaging && (
<IAICanvasToolPreview
visible={tool !== 'move'}

View File

@ -12,18 +12,20 @@ const selector = createSelector(
[canvasSelector],
(canvas) => {
const {
layerState: {
stagingArea: { images, selectedImageIndex },
},
layerState,
shouldShowStagingImage,
shouldShowStagingOutline,
boundingBoxCoordinates: { x, y },
boundingBoxDimensions: { width, height },
} = canvas;
const { selectedImageIndex, images } = layerState.stagingArea;
return {
currentStagingAreaImage:
images.length > 0 ? images[selectedImageIndex] : undefined,
images.length > 0 && selectedImageIndex !== undefined
? images[selectedImageIndex]
: undefined,
isOnFirstImage: selectedImageIndex === 0,
isOnLastImage: selectedImageIndex === images.length - 1,
shouldShowStagingImage,

View File

@ -6,6 +6,7 @@ import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import {
setShouldAntialias,
setShouldAutoSave,
setShouldCropToBoundingBoxOnSave,
setShouldDarkenOutsideBoundingBox,
@ -36,6 +37,7 @@ export const canvasControlsSelector = createSelector(
shouldShowIntermediates,
shouldSnapToGrid,
shouldRestrictStrokesToBox,
shouldAntialias,
} = canvas;
return {
@ -47,6 +49,7 @@ export const canvasControlsSelector = createSelector(
shouldShowIntermediates,
shouldSnapToGrid,
shouldRestrictStrokesToBox,
shouldAntialias,
};
},
{
@ -69,6 +72,7 @@ const IAICanvasSettingsButtonPopover = () => {
shouldShowIntermediates,
shouldSnapToGrid,
shouldRestrictStrokesToBox,
shouldAntialias,
} = useAppSelector(canvasControlsSelector);
useHotkeys(
@ -148,6 +152,12 @@ const IAICanvasSettingsButtonPopover = () => {
dispatch(setShouldShowCanvasDebugInfo(e.target.checked))
}
/>
<IAICheckbox
label={t('unifiedCanvas.antialiasing')}
isChecked={shouldAntialias}
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
/>
<ClearCanvasHistoryButtonModal />
<EmptyTempFolderButtonModal />
</Flex>

View File

@ -9,6 +9,12 @@ const itemsToDenylist: (keyof CanvasState)[] = [
'doesCanvasNeedScaling',
];
export const canvasPersistDenylist: (keyof CanvasState)[] = [
'cursorPosition',
'isCanvasInitialized',
'doesCanvasNeedScaling',
];
export const canvasDenylist = itemsToDenylist.map(
(denylistItem) => `canvas.${denylistItem}`
);

View File

@ -38,7 +38,7 @@ export const initialLayerState: CanvasLayerState = {
},
};
const initialCanvasState: CanvasState = {
export const initialCanvasState: CanvasState = {
boundingBoxCoordinates: { x: 0, y: 0 },
boundingBoxDimensions: { width: 512, height: 512 },
boundingBoxPreviewFill: { r: 0, g: 0, b: 0, a: 0.5 },
@ -66,6 +66,7 @@ const initialCanvasState: CanvasState = {
minimumStageScale: 1,
pastLayerStates: [],
scaledBoundingBoxDimensions: { width: 512, height: 512 },
shouldAntialias: true,
shouldAutoSave: false,
shouldCropToBoundingBoxOnSave: false,
shouldDarkenOutsideBoundingBox: false,
@ -156,22 +157,20 @@ export const canvasSlice = createSlice({
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
state.cursorPosition = action.payload;
},
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI._Image>) => {
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI.Image>) => {
const image = action.payload;
const { width, height } = image.metadata;
const { stageDimensions } = state;
const newBoundingBoxDimensions = {
width: roundDownToMultiple(clamp(image.width, 64, 512), 64),
height: roundDownToMultiple(clamp(image.height, 64, 512), 64),
width: roundDownToMultiple(clamp(width, 64, 512), 64),
height: roundDownToMultiple(clamp(height, 64, 512), 64),
};
const newBoundingBoxCoordinates = {
x: roundToMultiple(
image.width / 2 - newBoundingBoxDimensions.width / 2,
64
),
x: roundToMultiple(width / 2 - newBoundingBoxDimensions.width / 2, 64),
y: roundToMultiple(
image.height / 2 - newBoundingBoxDimensions.height / 2,
height / 2 - newBoundingBoxDimensions.height / 2,
64
),
};
@ -196,8 +195,8 @@ export const canvasSlice = createSlice({
layer: 'base',
x: 0,
y: 0,
width: image.width,
height: image.height,
width: width,
height: height,
image: image,
},
],
@ -208,8 +207,8 @@ export const canvasSlice = createSlice({
const newScale = calculateScale(
stageDimensions.width,
stageDimensions.height,
image.width,
image.height,
width,
height,
STAGE_PADDING_PERCENTAGE
);
@ -218,8 +217,8 @@ export const canvasSlice = createSlice({
stageDimensions.height,
0,
0,
image.width,
image.height,
width,
height,
newScale
);
state.stageScale = newScale;
@ -287,16 +286,28 @@ export const canvasSlice = createSlice({
setIsMoveStageKeyHeld: (state, action: PayloadAction<boolean>) => {
state.isMoveStageKeyHeld = action.payload;
},
addImageToStagingArea: (
canvasSessionIdChanged: (state, action: PayloadAction<string>) => {
state.layerState.stagingArea.sessionId = action.payload;
},
stagingAreaInitialized: (
state,
action: PayloadAction<{
boundingBox: IRect;
image: InvokeAI._Image;
}>
action: PayloadAction<{ sessionId: string; boundingBox: IRect }>
) => {
const { boundingBox, image } = action.payload;
const { sessionId, boundingBox } = action.payload;
if (!boundingBox || !image) return;
state.layerState.stagingArea = {
boundingBox,
sessionId,
images: [],
selectedImageIndex: -1,
};
},
addImageToStagingArea: (state, action: PayloadAction<InvokeAI.Image>) => {
const image = action.payload;
if (!image || !state.layerState.stagingArea.boundingBox) {
return;
}
state.pastLayerStates.push(cloneDeep(state.layerState));
@ -307,7 +318,7 @@ export const canvasSlice = createSlice({
state.layerState.stagingArea.images.push({
kind: 'image',
layer: 'base',
...boundingBox,
...state.layerState.stagingArea.boundingBox,
image,
});
@ -323,9 +334,7 @@ export const canvasSlice = createSlice({
state.pastLayerStates.shift();
}
state.layerState.stagingArea = {
...initialLayerState.stagingArea,
};
state.layerState.stagingArea = { ...initialLayerState.stagingArea };
state.futureLayerStates = [];
state.shouldShowStagingOutline = true;
@ -663,6 +672,10 @@ export const canvasSlice = createSlice({
}
},
nextStagingAreaImage: (state) => {
if (!state.layerState.stagingArea.images.length) {
return;
}
const currentIndex = state.layerState.stagingArea.selectedImageIndex;
const length = state.layerState.stagingArea.images.length;
@ -672,6 +685,10 @@ export const canvasSlice = createSlice({
);
},
prevStagingAreaImage: (state) => {
if (!state.layerState.stagingArea.images.length) {
return;
}
const currentIndex = state.layerState.stagingArea.selectedImageIndex;
state.layerState.stagingArea.selectedImageIndex = Math.max(
@ -680,6 +697,10 @@ export const canvasSlice = createSlice({
);
},
commitStagingAreaImage: (state) => {
if (!state.layerState.stagingArea.images.length) {
return;
}
const { images, selectedImageIndex } = state.layerState.stagingArea;
state.pastLayerStates.push(cloneDeep(state.layerState));
@ -776,6 +797,9 @@ export const canvasSlice = createSlice({
setShouldRestrictStrokesToBox: (state, action: PayloadAction<boolean>) => {
state.shouldRestrictStrokesToBox = action.payload;
},
setShouldAntialias: (state, action: PayloadAction<boolean>) => {
state.shouldAntialias = action.payload;
},
setShouldCropToBoundingBoxOnSave: (
state,
action: PayloadAction<boolean>
@ -885,6 +909,9 @@ export const {
undo,
setScaledBoundingBoxDimensions,
setShouldRestrictStrokesToBox,
stagingAreaInitialized,
canvasSessionIdChanged,
setShouldAntialias,
} = canvasSlice.actions;
export default canvasSlice.reducer;

View File

@ -37,7 +37,7 @@ export type CanvasImage = {
y: number;
width: number;
height: number;
image: InvokeAI._Image;
image: InvokeAI.Image;
};
export type CanvasMaskLine = {
@ -90,9 +90,16 @@ export type CanvasLayerState = {
stagingArea: {
images: CanvasImage[];
selectedImageIndex: number;
sessionId?: string;
boundingBox?: IRect;
};
};
export type CanvasSession = {
sessionId: string;
boundingBox: IRect;
};
// type guards
export const isCanvasMaskLine = (obj: CanvasObject): obj is CanvasMaskLine =>
obj.kind === 'line' && obj.layer === 'mask';
@ -125,7 +132,7 @@ export interface CanvasState {
cursorPosition: Vector2d | null;
doesCanvasNeedScaling: boolean;
futureLayerStates: CanvasLayerState[];
intermediateImage?: InvokeAI._Image;
intermediateImage?: InvokeAI.Image;
isCanvasInitialized: boolean;
isDrawing: boolean;
isMaskEnabled: boolean;
@ -142,6 +149,7 @@ export interface CanvasState {
minimumStageScale: number;
pastLayerStates: CanvasLayerState[];
scaledBoundingBoxDimensions: Dimensions;
shouldAntialias: boolean;
shouldAutoSave: boolean;
shouldCropToBoundingBoxOnSave: boolean;
shouldDarkenOutsideBoundingBox: boolean;

View File

@ -0,0 +1,13 @@
/**
* Gets a Blob from a canvas.
*/
export const canvasToBlob = async (canvas: HTMLCanvasElement): Promise<Blob> =>
new Promise((resolve, reject) => {
canvas.toBlob((blob) => {
if (blob) {
resolve(blob);
return;
}
reject('Unable to create Blob');
});
});

View File

@ -0,0 +1,29 @@
/**
* Gets an ImageData object from an image dataURL by drawing it to a canvas.
*/
export const dataURLToImageData = async (
dataURL: string,
width: number,
height: number
): Promise<ImageData> =>
new Promise((resolve, reject) => {
const canvas = document.createElement('canvas');
canvas.width = width;
canvas.height = height;
const ctx = canvas.getContext('2d');
const image = new Image();
if (!ctx) {
canvas.remove();
reject('Unable to get context');
return;
}
image.onload = function () {
ctx.drawImage(image, 0, 0);
canvas.remove();
resolve(ctx.getImageData(0, 0, width, height));
};
image.src = dataURL;
});

View File

@ -1,6 +1,110 @@
// import { CanvasMaskLine } from 'features/canvas/store/canvasTypes';
// import Konva from 'konva';
// import { Stage } from 'konva/lib/Stage';
// import { IRect } from 'konva/lib/types';
// /**
// * Generating a mask image from InpaintingCanvas.tsx is not as simple
// * as calling toDataURL() on the canvas, because the mask may be represented
// * by colored lines or transparency, or the user may have inverted the mask
// * display.
// *
// * So we need to regenerate the mask image by creating an offscreen canvas,
// * drawing the mask and compositing everything correctly to output a valid
// * mask image.
// */
// export const getStageDataURL = (stage: Stage, boundingBox: IRect): string => {
// // create an offscreen canvas and add the mask to it
// // const { stage, offscreenContainer } = buildMaskStage(lines, boundingBox);
// const dataURL = stage.toDataURL({ ...boundingBox });
// // const imageData = stage
// // .toCanvas()
// // .getContext('2d')
// // ?.getImageData(
// // boundingBox.x,
// // boundingBox.y,
// // boundingBox.width,
// // boundingBox.height
// // );
// // offscreenContainer.remove();
// // return { dataURL, imageData };
// return dataURL;
// };
// export const getStageImageData = (
// stage: Stage,
// boundingBox: IRect
// ): ImageData | undefined => {
// const imageData = stage
// .toCanvas()
// .getContext('2d')
// ?.getImageData(
// boundingBox.x,
// boundingBox.y,
// boundingBox.width,
// boundingBox.height
// );
// return imageData;
// };
// export const buildMaskStage = (
// lines: CanvasMaskLine[],
// boundingBox: IRect
// ): { stage: Stage; offscreenContainer: HTMLDivElement } => {
// // create an offscreen canvas and add the mask to it
// const { width, height } = boundingBox;
// const offscreenContainer = document.createElement('div');
// const stage = new Konva.Stage({
// container: offscreenContainer,
// width: width,
// height: height,
// });
// const baseLayer = new Konva.Layer();
// const maskLayer = new Konva.Layer();
// // composite the image onto the mask layer
// baseLayer.add(
// new Konva.Rect({
// ...boundingBox,
// fill: 'white',
// })
// );
// lines.forEach((line) =>
// maskLayer.add(
// new Konva.Line({
// points: line.points,
// stroke: 'black',
// strokeWidth: line.strokeWidth * 2,
// tension: 0,
// lineCap: 'round',
// lineJoin: 'round',
// shadowForStrokeEnabled: false,
// globalCompositeOperation:
// line.tool === 'brush' ? 'source-over' : 'destination-out',
// })
// )
// );
// stage.add(baseLayer);
// stage.add(maskLayer);
// return { stage, offscreenContainer };
// };
import { CanvasMaskLine } from 'features/canvas/store/canvasTypes';
import Konva from 'konva';
import { IRect } from 'konva/lib/types';
import { canvasToBlob } from './canvasToBlob';
/**
* Generating a mask image from InpaintingCanvas.tsx is not as simple
@ -12,7 +116,7 @@ import { IRect } from 'konva/lib/types';
* drawing the mask and compositing everything correctly to output a valid
* mask image.
*/
const generateMask = (lines: CanvasMaskLine[], boundingBox: IRect): string => {
const generateMask = async (lines: CanvasMaskLine[], boundingBox: IRect) => {
// create an offscreen canvas and add the mask to it
const { width, height } = boundingBox;
@ -54,11 +158,13 @@ const generateMask = (lines: CanvasMaskLine[], boundingBox: IRect): string => {
stage.add(baseLayer);
stage.add(maskLayer);
const dataURL = stage.toDataURL({ ...boundingBox });
const maskDataURL = stage.toDataURL(boundingBox);
const maskBlob = await canvasToBlob(stage.toCanvas(boundingBox));
offscreenContainer.remove();
return dataURL;
return { maskDataURL, maskBlob };
};
export default generateMask;

View File

@ -0,0 +1,128 @@
import { RootState } from 'app/store/store';
import { getCanvasBaseLayer, getCanvasStage } from './konvaInstanceProvider';
import { isCanvasMaskLine } from '../store/canvasTypes';
import { log } from 'app/logging/useLogger';
import {
areAnyPixelsBlack,
getImageDataTransparency,
} from 'common/util/arrayBuffer';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import generateMask from './generateMask';
import { dataURLToImageData } from './dataURLToImageData';
import { canvasToBlob } from './canvasToBlob';
const moduleLog = log.child({ namespace: 'getCanvasDataURLs' });
export const getCanvasData = async (state: RootState) => {
const canvasBaseLayer = getCanvasBaseLayer();
const canvasStage = getCanvasStage();
if (!canvasBaseLayer || !canvasStage) {
moduleLog.error('Unable to find canvas / stage');
return;
}
const {
layerState: { objects },
boundingBoxCoordinates,
boundingBoxDimensions,
stageScale,
isMaskEnabled,
shouldPreserveMaskedArea,
boundingBoxScaleMethod: boundingBoxScale,
scaledBoundingBoxDimensions,
} = state.canvas;
const boundingBox = {
...boundingBoxCoordinates,
...boundingBoxDimensions,
};
// generationParameters.fit = false;
// generationParameters.strength = img2imgStrength;
// generationParameters.invert_mask = shouldPreserveMaskedArea;
// generationParameters.bounding_box = boundingBox;
const tempScale = canvasBaseLayer.scale();
canvasBaseLayer.scale({
x: 1 / stageScale,
y: 1 / stageScale,
});
const absPos = canvasBaseLayer.getAbsolutePosition();
const offsetBoundingBox = {
x: boundingBox.x + absPos.x,
y: boundingBox.y + absPos.y,
width: boundingBox.width,
height: boundingBox.height,
};
const baseDataURL = canvasBaseLayer.toDataURL(offsetBoundingBox);
const baseBlob = await canvasToBlob(
canvasBaseLayer.toCanvas(offsetBoundingBox)
);
canvasBaseLayer.scale(tempScale);
const { maskDataURL, maskBlob } = await generateMask(
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [],
boundingBox
);
const baseImageData = await dataURLToImageData(
baseDataURL,
boundingBox.width,
boundingBox.height
);
const maskImageData = await dataURLToImageData(
maskDataURL,
boundingBox.width,
boundingBox.height
);
const {
isPartiallyTransparent: baseIsPartiallyTransparent,
isFullyTransparent: baseIsFullyTransparent,
} = getImageDataTransparency(baseImageData.data);
const doesMaskHaveBlackPixels = areAnyPixelsBlack(maskImageData.data);
if (state.system.enableImageDebugging) {
openBase64ImageInTab([
{ base64: maskDataURL, caption: 'mask b64' },
{ base64: baseDataURL, caption: 'image b64' },
]);
}
// generationParameters.init_img = imageDataURL;
// generationParameters.progress_images = false;
// if (boundingBoxScale !== 'none') {
// generationParameters.inpaint_width = scaledBoundingBoxDimensions.width;
// generationParameters.inpaint_height = scaledBoundingBoxDimensions.height;
// }
// generationParameters.seam_size = seamSize;
// generationParameters.seam_blur = seamBlur;
// generationParameters.seam_strength = seamStrength;
// generationParameters.seam_steps = seamSteps;
// generationParameters.tile_size = tileSize;
// generationParameters.infill_method = infillMethod;
// generationParameters.force_outpaint = false;
return {
baseDataURL,
baseBlob,
maskDataURL,
maskBlob,
baseIsPartiallyTransparent,
baseIsFullyTransparent,
doesMaskHaveBlackPixels,
};
};

View File

@ -1,12 +1,17 @@
import { createSelector } from '@reduxjs/toolkit';
import { get, isEqual, isNumber, isString } from 'lodash-es';
import { isEqual, isString } from 'lodash-es';
import {
ButtonGroup,
Flex,
FlexProps,
FormControl,
IconButton,
Link,
Menu,
MenuButton,
MenuItemOption,
MenuList,
MenuOptionGroup,
useDisclosure,
useToast,
} from '@chakra-ui/react';
@ -15,21 +20,12 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { GalleryState } from 'features/gallery/store/gallerySlice';
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
import {
initialImageSelected,
setAllParameters,
// setInitialImage,
setSeed,
} from 'features/parameters/store/generationSlice';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import { SystemState } from 'features/system/store/systemSlice';
import {
activeTabNameSelector,
uiSelector,
@ -56,6 +52,7 @@ import {
FaShare,
FaShareAlt,
FaTrash,
FaWrench,
} from 'react-icons/fa';
import {
gallerySelector,
@ -66,8 +63,13 @@ import { useCallback } from 'react';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { useGetUrl } from 'common/util/getUrl';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { imageDeleted } from 'services/thunks/image';
import { useParameters } from 'features/parameters/hooks/useParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { requestedImageDeletion } from '../store/actions';
import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceRestore/FaceRestoreSettings';
import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings';
import { allParametersSet } from 'features/parameters/store/generationSlice';
import DeleteImageButton from './ImageActionButtons/DeleteImageButton';
const currentImageButtonsSelector = createSelector(
[
@ -164,40 +166,59 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const toast = useToast();
const { t } = useTranslation();
const { recallPrompt, recallSeed, sendToImageToImage } = useParameters();
const { recallPrompt, recallSeed, recallAllParameters } = useParameters();
const handleCopyImage = useCallback(async () => {
if (!image?.url) {
return;
}
// const handleCopyImage = useCallback(async () => {
// if (!image?.url) {
// return;
// }
const url = getUrl(image.url);
// const url = getUrl(image.url);
if (!url) {
return;
}
// if (!url) {
// return;
// }
const blob = await fetch(url).then((res) => res.blob());
const data = [new ClipboardItem({ [blob.type]: blob })];
// const blob = await fetch(url).then((res) => res.blob());
// const data = [new ClipboardItem({ [blob.type]: blob })];
await navigator.clipboard.write(data);
// await navigator.clipboard.write(data);
toast({
title: t('toast.imageCopied'),
status: 'success',
duration: 2500,
isClosable: true,
});
}, [getUrl, t, image?.url, toast]);
// toast({
// title: t('toast.imageCopied'),
// status: 'success',
// duration: 2500,
// isClosable: true,
// });
// }, [getUrl, t, image?.url, toast]);
const handleCopyImageLink = useCallback(() => {
const url = image
? shouldTransformUrls
? getUrl(image.url)
: window.location.toString() + image.url
: '';
const getImageUrl = () => {
if (!image) {
return;
}
if (shouldTransformUrls) {
return getUrl(image.url);
}
if (image.url.startsWith('http')) {
return image.url;
}
return window.location.toString() + image.url;
};
const url = getImageUrl();
if (!url) {
toast({
title: t('toast.problemCopyingImageLink'),
status: 'error',
duration: 2500,
isClosable: true,
});
return;
}
@ -216,39 +237,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}, [dispatch, shouldHidePreview]);
const handleClickUseAllParameters = useCallback(() => {
if (!image) return;
// selectedImage.metadata &&
// dispatch(setAllParameters(selectedImage.metadata));
// if (selectedImage.metadata?.image.type === 'img2img') {
// dispatch(setActiveTab('img2img'));
// } else if (selectedImage.metadata?.image.type === 'txt2img') {
// dispatch(setActiveTab('txt2img'));
// }
}, [image]);
recallAllParameters(image);
}, [image, recallAllParameters]);
useHotkeys(
'a',
() => {
const type = image?.metadata?.invokeai?.node?.types;
if (isString(type) && ['txt2img', 'img2img'].includes(type)) {
handleClickUseAllParameters();
toast({
title: t('toast.parametersSet'),
status: 'success',
duration: 2500,
isClosable: true,
});
} else {
toast({
title: t('toast.parametersNotSet'),
description: t('toast.parametersNotSetDesc'),
status: 'error',
duration: 2500,
isClosable: true,
});
}
handleClickUseAllParameters;
},
[image]
[image, recallAllParameters]
);
const handleUseSeed = useCallback(() => {
@ -264,8 +261,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys('p', handleUsePrompt, [image]);
const handleSendToImageToImage = useCallback(() => {
sendToImageToImage(image);
}, [image, sendToImageToImage]);
dispatch(initialImageSelected(image));
}, [dispatch, image]);
useHotkeys('shift+i', handleSendToImageToImage, [image]);
@ -375,7 +372,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const handleDelete = useCallback(() => {
if (canDeleteImage && image) {
dispatch(imageDeleted({ imageType: image.type, imageName: image.name }));
dispatch(requestedImageDeletion(image));
}
}, [image, canDeleteImage, dispatch]);
@ -440,13 +437,13 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
{t('parameters.sendToUnifiedCanvas')}
</IAIButton>
<IAIButton
{/* <IAIButton
size="sm"
onClick={handleCopyImage}
leftIcon={<FaCopy />}
>
{t('parameters.copyImage')}
</IAIButton>
</IAIButton> */}
<IAIButton
size="sm"
onClick={handleCopyImageLink}
@ -462,7 +459,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
</Link>
</Flex>
</IAIPopover>
<IAIIconButton
{/* <IAIIconButton
icon={shouldHidePreview ? <FaEyeSlash /> : <FaEye />}
tooltip={
!shouldHidePreview
@ -476,7 +473,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}
isChecked={shouldHidePreview}
onClick={handlePreviewVisibility}
/>
/> */}
{isLightboxEnabled && (
<IAIIconButton
icon={<FaExpand />}
@ -518,8 +515,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
tooltip={`${t('parameters.useAll')} (A)`}
aria-label={`${t('parameters.useAll')} (A)`}
isDisabled={
!['txt2img', 'img2img'].includes(
image?.metadata?.sd_metadata?.type
!['txt2img', 'img2img', 'inpaint'].includes(
String(image?.metadata?.invokeai?.node?.type)
)
}
onClick={handleClickUseAllParameters}
@ -602,22 +599,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
/>
</ButtonGroup>
<IAIIconButton
onClick={handleInitiateDelete}
icon={<FaTrash />}
tooltip={`${t('gallery.deleteImage')} (Del)`}
aria-label={`${t('gallery.deleteImage')} (Del)`}
isDisabled={!image || !isConnected}
colorScheme="error"
/>
<ButtonGroup isAttached={true}>
<DeleteImageButton image={image} />
</ButtonGroup>
</Flex>
{image && (
<DeleteImageModal
isOpen={isDeleteDialogOpen}
onClose={onDeleteDialogClose}
handleDelete={handleDelete}
/>
)}
</>
);
};

View File

@ -1,27 +1,35 @@
import { Box, Flex, Image } from '@chakra-ui/react';
import { Box, Flex, Image, Skeleton, useBoolean } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { systemSelector } from 'features/system/store/systemSelectors';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es';
import { selectedImageSelector } from '../store/gallerySelectors';
import CurrentImageFallback from './CurrentImageFallback';
import { gallerySelector } from '../store/gallerySelectors';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons';
import CurrentImageHidden from './CurrentImageHidden';
import { memo } from 'react';
import { DragEvent, memo, useCallback } from 'react';
import { systemSelector } from 'features/system/store/systemSelectors';
import CurrentImageFallback from './CurrentImageFallback';
export const imagesSelector = createSelector(
[uiSelector, selectedImageSelector, systemSelector],
(ui, selectedImage, system) => {
const { shouldShowImageDetails, shouldHidePreview } = ui;
[uiSelector, gallerySelector, systemSelector],
(ui, gallery, system) => {
const {
shouldShowImageDetails,
shouldHidePreview,
shouldShowProgressInViewer,
} = ui;
const { selectedImage } = gallery;
const { progressImage, shouldAntialiasProgressImage } = system;
return {
shouldShowImageDetails,
shouldHidePreview,
image: selectedImage,
progressImage,
shouldShowProgressInViewer,
shouldAntialiasProgressImage,
};
},
{
@ -32,10 +40,30 @@ export const imagesSelector = createSelector(
);
const CurrentImagePreview = () => {
const { shouldShowImageDetails, image, shouldHidePreview } =
useAppSelector(imagesSelector);
const {
shouldShowImageDetails,
image,
shouldHidePreview,
progressImage,
shouldShowProgressInViewer,
shouldAntialiasProgressImage,
} = useAppSelector(imagesSelector);
const { getUrl } = useGetUrl();
const [isLoaded, { on, off }] = useBoolean();
const handleDragStart = useCallback(
(e: DragEvent<HTMLDivElement>) => {
if (!image) {
return;
}
e.dataTransfer.setData('invokeai/imageName', image.name);
e.dataTransfer.setData('invokeai/imageType', image.type);
e.dataTransfer.effectAllowed = 'move';
},
[image]
);
return (
<Flex
sx={{
@ -46,12 +74,11 @@ const CurrentImagePreview = () => {
height: '100%',
}}
>
{image && (
{progressImage && shouldShowProgressInViewer ? (
<Image
src={shouldHidePreview ? undefined : getUrl(image.url)}
width={image.metadata.width}
height={image.metadata.height}
fallback={shouldHidePreview ? <CurrentImageHidden /> : undefined}
src={progressImage.dataURL}
width={progressImage.width}
height={progressImage.height}
sx={{
objectFit: 'contain',
maxWidth: '100%',
@ -59,8 +86,34 @@ const CurrentImagePreview = () => {
height: 'auto',
position: 'absolute',
borderRadius: 'base',
imageRendering: shouldAntialiasProgressImage ? 'auto' : 'pixelated',
}}
/>
) : (
image && (
<Image
onDragStart={handleDragStart}
fallbackStrategy="beforeLoadOrError"
src={shouldHidePreview ? undefined : getUrl(image.url)}
width={image.metadata.width || 'auto'}
height={image.metadata.height || 'auto'}
fallback={
shouldHidePreview ? (
<CurrentImageHidden />
) : (
<CurrentImageFallback />
)
}
sx={{
objectFit: 'contain',
maxWidth: '100%',
maxHeight: '100%',
height: 'auto',
position: 'absolute',
borderRadius: 'base',
}}
/>
)
)}
{shouldShowImageDetails && image && 'metadata' in image && (
<Box

View File

@ -0,0 +1,68 @@
import { Box, Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { systemSelector } from 'features/system/store/systemSelectors';
import { memo } from 'react';
import { gallerySelector } from '../store/gallerySelectors';
const selector = createSelector(
[systemSelector, gallerySelector],
(system, gallery) => {
const { shouldUseSingleGalleryColumn, galleryImageObjectFit } = gallery;
const { progressImage, shouldAntialiasProgressImage } = system;
return {
progressImage,
shouldUseSingleGalleryColumn,
galleryImageObjectFit,
shouldAntialiasProgressImage,
};
},
defaultSelectorOptions
);
const GalleryProgressImage = () => {
const {
progressImage,
shouldUseSingleGalleryColumn,
galleryImageObjectFit,
shouldAntialiasProgressImage,
} = useAppSelector(selector);
if (!progressImage) {
return null;
}
return (
<Flex
sx={{
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
aspectRatio: '1/1',
}}
>
<Image
draggable={false}
src={progressImage.dataURL}
width={progressImage.width}
height={progressImage.height}
sx={{
objectFit: shouldUseSingleGalleryColumn
? 'contain'
: galleryImageObjectFit,
width: '100%',
height: '100%',
maxWidth: '100%',
maxHeight: '100%',
borderRadius: 'base',
imageRendering: shouldAntialiasProgressImage ? 'auto' : 'pixelated',
}}
/>
</Flex>
);
};
export default memo(GalleryProgressImage);

View File

@ -5,19 +5,20 @@ import {
Image,
MenuItem,
MenuList,
Skeleton,
useDisclosure,
useTheme,
useToast,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { DragEvent, memo, useCallback, useState } from 'react';
import { DragEvent, MouseEvent, memo, useCallback, useState } from 'react';
import { FaCheck, FaExpand, FaImage, FaShare, FaTrash } from 'react-icons/fa';
import DeleteImageModal from './DeleteImageModal';
import { ContextMenu } from 'chakra-ui-contextmenu';
import * as InvokeAI from 'app/types/invokeai';
import { resizeAndScaleCanvas } from 'features/canvas/store/canvasSlice';
import {
resizeAndScaleCanvas,
setInitialCanvasImage,
} from 'features/canvas/store/canvasSlice';
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useTranslation } from 'react-i18next';
@ -25,7 +26,6 @@ import IAIIconButton from 'common/components/IAIIconButton';
import { useGetUrl } from 'common/util/getUrl';
import { ExternalLinkIcon } from '@chakra-ui/icons';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { imageDeleted } from 'services/thunks/image';
import { createSelector } from '@reduxjs/toolkit';
import { systemSelector } from 'features/system/store/systemSelectors';
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
@ -33,6 +33,8 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useParameters } from 'features/parameters/hooks/useParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { requestedImageDeletion } from '../store/actions';
export const selector = createSelector(
[gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector],
@ -94,16 +96,16 @@ const HoverableImage = memo((props: HoverableImageProps) => {
} = useDisclosure();
const { image, isSelected } = props;
const { url, thumbnail, name, metadata } = image;
const { url, thumbnail, name } = image;
const { getUrl } = useGetUrl();
const [isHovered, setIsHovered] = useState<boolean>(false);
const toast = useToast();
const { direction } = useTheme();
const { t } = useTranslation();
const { isFeatureEnabled: isLightboxEnabled } = useFeatureStatus('lightbox');
const { recallSeed, recallPrompt, sendToImageToImage, recallInitialImage } =
const { recallSeed, recallPrompt, recallInitialImage, recallAllParameters } =
useParameters();
const handleMouseOver = () => setIsHovered(true);
@ -112,18 +114,22 @@ const HoverableImage = memo((props: HoverableImageProps) => {
// Immediately deletes an image
const handleDelete = useCallback(() => {
if (canDeleteImage && image) {
dispatch(imageDeleted({ imageType: image.type, imageName: image.name }));
dispatch(requestedImageDeletion(image));
}
}, [dispatch, image, canDeleteImage]);
// Opens the alert dialog to check if user is sure they want to delete
const handleInitiateDelete = useCallback(() => {
if (shouldConfirmOnDelete) {
onDeleteDialogOpen();
} else {
handleDelete();
}
}, [handleDelete, onDeleteDialogOpen, shouldConfirmOnDelete]);
const handleInitiateDelete = useCallback(
(e: MouseEvent) => {
e.stopPropagation();
if (shouldConfirmOnDelete) {
onDeleteDialogOpen();
} else {
handleDelete();
}
},
[handleDelete, onDeleteDialogOpen, shouldConfirmOnDelete]
);
const handleSelectImage = useCallback(() => {
dispatch(imageSelected(image));
@ -148,8 +154,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
}, [image, recallSeed]);
const handleSendToImageToImage = useCallback(() => {
sendToImageToImage(image);
}, [image, sendToImageToImage]);
dispatch(initialImageSelected(image));
}, [dispatch, image]);
const handleRecallInitialImage = useCallback(() => {
recallInitialImage(image.metadata.invokeai?.node?.image);
@ -159,7 +165,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
* TODO: the rest of these
*/
const handleSendToCanvas = () => {
// dispatch(setInitialCanvasImage(image));
dispatch(setInitialCanvasImage(image));
dispatch(resizeAndScaleCanvas());
@ -175,16 +181,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
});
};
const handleUseAllParameters = () => {
// metadata.invokeai?.node &&
// dispatch(setAllParameters(metadata.invokeai?.node));
// toast({
// title: t('toast.parametersSet'),
// status: 'success',
// duration: 2500,
// isClosable: true,
// });
};
const handleUseAllParameters = useCallback(() => {
recallAllParameters(image);
}, [image, recallAllParameters]);
const handleLightBox = () => {
// dispatch(setCurrentImage(image));
@ -238,7 +237,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleUseAllParameters}
isDisabled={
!['txt2img', 'img2img'].includes(
!['txt2img', 'img2img', 'inpaint'].includes(
String(image?.metadata?.invokeai?.node?.type)
)
}
@ -315,6 +314,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
sx={{
width: '50%',
height: '50%',
maxWidth: '4rem',
maxHeight: '4rem',
fill: 'ok.500',
}}
/>

View File

@ -0,0 +1,92 @@
import { createSelector } from '@reduxjs/toolkit';
import { useDisclosure } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { systemSelector } from 'features/system/store/systemSelectors';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
import { memo, useCallback } from 'react';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import DeleteImageModal from '../DeleteImageModal';
import { requestedImageDeletion } from 'features/gallery/store/actions';
import { Image } from 'app/types/invokeai';
const selector = createSelector(
[systemSelector],
(system) => {
const { isProcessing, isConnected, shouldConfirmOnDelete } = system;
return {
canDeleteImage: isConnected && !isProcessing,
shouldConfirmOnDelete,
isProcessing,
isConnected,
};
},
defaultSelectorOptions
);
type DeleteImageButtonProps = {
image: Image | undefined;
};
const DeleteImageButton = (props: DeleteImageButtonProps) => {
const { image } = props;
const dispatch = useAppDispatch();
const { isProcessing, isConnected, canDeleteImage, shouldConfirmOnDelete } =
useAppSelector(selector);
const {
isOpen: isDeleteDialogOpen,
onOpen: onDeleteDialogOpen,
onClose: onDeleteDialogClose,
} = useDisclosure();
const { t } = useTranslation();
const handleDelete = useCallback(() => {
if (canDeleteImage && image) {
dispatch(requestedImageDeletion(image));
}
}, [image, canDeleteImage, dispatch]);
const handleInitiateDelete = useCallback(() => {
if (shouldConfirmOnDelete) {
onDeleteDialogOpen();
} else {
handleDelete();
}
}, [shouldConfirmOnDelete, onDeleteDialogOpen, handleDelete]);
useHotkeys('delete', handleInitiateDelete, [
image,
shouldConfirmOnDelete,
isConnected,
isProcessing,
]);
return (
<>
<IAIIconButton
onClick={handleInitiateDelete}
icon={<FaTrash />}
tooltip={`${t('gallery.deleteImage')} (Del)`}
aria-label={`${t('gallery.deleteImage')} (Del)`}
isDisabled={!image || !isConnected}
colorScheme="error"
/>
{image && (
<DeleteImageModal
isOpen={isDeleteDialogOpen}
onClose={onDeleteDialogClose}
handleDelete={handleDelete}
/>
)}
</>
);
};
export default memo(DeleteImageButton);

View File

@ -5,6 +5,7 @@ import {
FlexProps,
Grid,
Icon,
Image,
Text,
forwardRef,
} from '@chakra-ui/react';
@ -14,7 +15,10 @@ import IAICheckbox from 'common/components/IAICheckbox';
import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover';
import IAISlider from 'common/components/IAISlider';
import { imageGallerySelector } from 'features/gallery/store/gallerySelectors';
import {
gallerySelector,
imageGallerySelector,
} from 'features/gallery/store/gallerySelectors';
import {
setCurrentCategory,
setGalleryImageMinimumWidth,
@ -50,30 +54,48 @@ import { uploadsAdapter } from '../store/uploadsSlice';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso';
import ProgressImagePreview from 'features/parameters/components/_ProgressImagePreview';
import ProgressImage from 'features/parameters/components/ProgressImage';
import { systemSelector } from 'features/system/store/systemSelectors';
import { Image as ImageType } from 'app/types/invokeai';
import { ProgressImage as ProgressImageType } from 'services/events/types';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import GalleryProgressImage from './GalleryProgressImage';
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
const PROGRESS_IMAGE_PLACEHOLDER = 'PROGRESS_IMAGE_PLACEHOLDER';
const gallerySelector = createSelector(
[
(state: RootState) => state.uploads,
(state: RootState) => state.results,
(state: RootState) => state.gallery,
],
(uploads, results, gallery) => {
const selector = createSelector(
[(state: RootState) => state],
(state) => {
const { results, uploads, system, gallery } = state;
const { currentCategory } = gallery;
return currentCategory === 'results'
? {
images: resultsAdapter.getSelectors().selectAll(results),
isLoading: results.isLoading,
areMoreImagesAvailable: results.page < results.pages - 1,
}
: {
images: uploadsAdapter.getSelectors().selectAll(uploads),
isLoading: uploads.isLoading,
areMoreImagesAvailable: uploads.page < uploads.pages - 1,
};
}
const tempImages: (ImageType | typeof PROGRESS_IMAGE_PLACEHOLDER)[] = [];
if (system.progressImage) {
tempImages.push(PROGRESS_IMAGE_PLACEHOLDER);
}
if (currentCategory === 'results') {
return {
images: tempImages.concat(
resultsAdapter.getSelectors().selectAll(results)
),
isLoading: results.isLoading,
areMoreImagesAvailable: results.page < results.pages - 1,
};
}
return {
images: tempImages.concat(
uploadsAdapter.getSelectors().selectAll(uploads)
),
isLoading: uploads.isLoading,
areMoreImagesAvailable: uploads.page < uploads.pages - 1,
};
},
defaultSelectorOptions
);
const ImageGalleryContent = () => {
@ -108,7 +130,7 @@ const ImageGalleryContent = () => {
} = useAppSelector(imageGallerySelector);
const { images, areMoreImagesAvailable, isLoading } =
useAppSelector(gallerySelector);
useAppSelector(selector);
const handleClickLoadMore = () => {
if (currentCategory === 'results') {
@ -170,8 +192,24 @@ const ImageGalleryContent = () => {
}
}, []);
const handleEndReached = useCallback(() => {
if (currentCategory === 'results') {
dispatch(receivedResultImagesPage());
} else if (currentCategory === 'uploads') {
dispatch(receivedUploadImagesPage());
}
}, [dispatch, currentCategory]);
return (
<Flex flexDirection="column" w="full" h="full" gap={4}>
<Flex
sx={{
gap: 2,
flexDirection: 'column',
h: 'full',
w: 'full',
borderRadius: 'base',
}}
>
<Flex
ref={resizeObserverRef}
alignItems="center"
@ -290,18 +328,27 @@ const ImageGalleryContent = () => {
<Virtuoso
style={{ height: '100%' }}
data={images}
endReached={handleEndReached}
scrollerRef={(ref) => setScrollerRef(ref)}
itemContent={(index, image) => {
const { name } = image;
const isSelected = selectedImage?.name === name;
const isSelected =
image === PROGRESS_IMAGE_PLACEHOLDER
? false
: selectedImage?.name === image?.name;
return (
<Flex sx={{ pb: 2 }}>
<HoverableImage
key={`${name}-${image.thumbnail}`}
image={image}
isSelected={isSelected}
/>
{image === PROGRESS_IMAGE_PLACEHOLDER ? (
<GalleryProgressImage
key={PROGRESS_IMAGE_PLACEHOLDER}
/>
) : (
<HoverableImage
key={`${image.name}-${image.thumbnail}`}
image={image}
isSelected={isSelected}
/>
)}
</Flex>
);
}}
@ -310,18 +357,23 @@ const ImageGalleryContent = () => {
<VirtuosoGrid
style={{ height: '100%' }}
data={images}
endReached={handleEndReached}
components={{
Item: ItemContainer,
List: ListContainer,
}}
scrollerRef={setScroller}
itemContent={(index, image) => {
const { name } = image;
const isSelected = selectedImage?.name === name;
const isSelected =
image === PROGRESS_IMAGE_PLACEHOLDER
? false
: selectedImage?.name === image?.name;
return (
return image === PROGRESS_IMAGE_PLACEHOLDER ? (
<GalleryProgressImage key={PROGRESS_IMAGE_PLACEHOLDER} />
) : (
<HoverableImage
key={`${name}-${image.thumbnail}`}
key={`${image.name}-${image.thumbnail}`}
image={image}
isSelected={isSelected}
/>
@ -334,6 +386,7 @@ const ImageGalleryContent = () => {
onClick={handleClickLoadMore}
isDisabled={!areMoreImagesAvailable}
isLoading={isLoading}
loadingText="Loading"
flexShrink={0}
>
{areMoreImagesAvailable

View File

@ -5,7 +5,6 @@ import {
// selectPrevImage,
setGalleryImageMinimumWidth,
} from 'features/gallery/store/gallerySlice';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { clamp, isEqual } from 'lodash-es';
import { useHotkeys } from 'react-hotkeys-hook';
@ -13,11 +12,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
import './ImageGallery.css';
import ImageGalleryContent from './ImageGalleryContent';
import ResizableDrawer from 'features/ui/components/common/ResizableDrawer/ResizableDrawer';
import {
setShouldShowGallery,
toggleGalleryPanel,
togglePinGalleryPanel,
} from 'features/ui/store/uiSlice';
import { setShouldShowGallery } from 'features/ui/store/uiSlice';
import { createSelector } from '@reduxjs/toolkit';
import {
activeTabNameSelector,
@ -26,22 +21,20 @@ import {
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
import useResolution from 'common/hooks/useResolution';
import { Flex } from '@chakra-ui/react';
import { memo } from 'react';
const GALLERY_TAB_WIDTHS: Record<
InvokeTabName,
{ galleryMinWidth: number; galleryMaxWidth: number }
> = {
// txt2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// img2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
generate: { galleryMinWidth: 200, galleryMaxWidth: 500 },
unifiedCanvas: { galleryMinWidth: 200, galleryMaxWidth: 200 },
nodes: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// postprocessing: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// training: { galleryMinWidth: 200, galleryMaxWidth: 500 },
};
// const GALLERY_TAB_WIDTHS: Record<
// InvokeTabName,
// { galleryMinWidth: number; galleryMaxWidth: number }
// > = {
// txt2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// img2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// generate: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// unifiedCanvas: { galleryMinWidth: 200, galleryMaxWidth: 200 },
// nodes: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// postprocessing: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// training: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// };
const galleryPanelSelector = createSelector(
[
@ -73,50 +66,50 @@ const galleryPanelSelector = createSelector(
}
);
export const ImageGalleryPanel = () => {
const GalleryDrawer = () => {
const dispatch = useAppDispatch();
const {
shouldPinGallery,
shouldShowGallery,
galleryImageMinimumWidth,
activeTabName,
isStaging,
isResizable,
isLightboxOpen,
// activeTabName,
// isStaging,
// isResizable,
// isLightboxOpen,
} = useAppSelector(galleryPanelSelector);
const handleSetShouldPinGallery = () => {
dispatch(togglePinGalleryPanel());
dispatch(requestCanvasRescale());
};
// const handleSetShouldPinGallery = () => {
// dispatch(togglePinGalleryPanel());
// dispatch(requestCanvasRescale());
// };
const handleToggleGallery = () => {
dispatch(toggleGalleryPanel());
shouldPinGallery && dispatch(requestCanvasRescale());
};
// const handleToggleGallery = () => {
// dispatch(toggleGalleryPanel());
// shouldPinGallery && dispatch(requestCanvasRescale());
// };
const handleCloseGallery = () => {
dispatch(setShouldShowGallery(false));
shouldPinGallery && dispatch(requestCanvasRescale());
};
const resolution = useResolution();
// const resolution = useResolution();
useHotkeys(
'g',
() => {
handleToggleGallery();
},
[shouldPinGallery]
);
// useHotkeys(
// 'g',
// () => {
// handleToggleGallery();
// },
// [shouldPinGallery]
// );
useHotkeys(
'shift+g',
() => {
handleSetShouldPinGallery();
},
[shouldPinGallery]
);
// useHotkeys(
// 'shift+g',
// () => {
// handleSetShouldPinGallery();
// },
// [shouldPinGallery]
// );
useHotkeys(
'esc',
@ -162,55 +155,71 @@ export const ImageGalleryPanel = () => {
[galleryImageMinimumWidth]
);
const calcGalleryMinHeight = () => {
if (resolution === 'desktop') return;
return 300;
};
// const calcGalleryMinHeight = () => {
// if (resolution === 'desktop') return;
// return 300;
// };
const imageGalleryContent = () => {
return (
<Flex
w="100vw"
h={{ base: 300, xl: '100vh' }}
paddingRight={{ base: 8, xl: 0 }}
paddingBottom={{ base: 4, xl: 0 }}
>
<ImageGalleryContent />
</Flex>
);
};
// const imageGalleryContent = () => {
// return (
// <Flex
// w="100vw"
// h={{ base: 300, xl: '100vh' }}
// paddingRight={{ base: 8, xl: 0 }}
// paddingBottom={{ base: 4, xl: 0 }}
// >
// <ImageGalleryContent />
// </Flex>
// );
// };
const resizableImageGalleryContent = () => {
return (
<ResizableDrawer
direction="right"
isResizable={isResizable || !shouldPinGallery}
isOpen={shouldShowGallery}
onClose={handleCloseGallery}
isPinned={shouldPinGallery && !isLightboxOpen}
minWidth={
shouldPinGallery
? GALLERY_TAB_WIDTHS[activeTabName].galleryMinWidth
: 200
}
maxWidth={
shouldPinGallery
? GALLERY_TAB_WIDTHS[activeTabName].galleryMaxWidth
: undefined
}
minHeight={calcGalleryMinHeight()}
>
<ImageGalleryContent />
</ResizableDrawer>
);
};
// const resizableImageGalleryContent = () => {
// return (
// <ResizableDrawer
// direction="right"
// isResizable={isResizable || !shouldPinGallery}
// isOpen={shouldShowGallery}
// onClose={handleCloseGallery}
// isPinned={shouldPinGallery && !isLightboxOpen}
// minWidth={
// shouldPinGallery
// ? GALLERY_TAB_WIDTHS[activeTabName].galleryMinWidth
// : 200
// }
// maxWidth={
// shouldPinGallery
// ? GALLERY_TAB_WIDTHS[activeTabName].galleryMaxWidth
// : undefined
// }
// minHeight={calcGalleryMinHeight()}
// >
// <ImageGalleryContent />
// </ResizableDrawer>
// );
// };
const renderImageGallery = () => {
if (['mobile', 'tablet'].includes(resolution)) return imageGalleryContent();
return resizableImageGalleryContent();
};
// const renderImageGallery = () => {
// if (['mobile', 'tablet'].includes(resolution)) return imageGalleryContent();
// return resizableImageGalleryContent();
// };
return renderImageGallery();
if (shouldPinGallery) {
return null;
}
return (
<ResizableDrawer
direction="right"
isResizable={true}
isOpen={shouldShowGallery}
onClose={handleCloseGallery}
minWidth={200}
>
<ImageGalleryContent />
</ResizableDrawer>
);
// return renderImageGallery();
};
export default memo(ImageGalleryPanel);
export default memo(GalleryDrawer);

View File

@ -3,7 +3,6 @@ import {
Box,
Center,
Flex,
Heading,
IconButton,
Link,
Text,
@ -19,8 +18,6 @@ import {
setCfgScale,
setHeight,
setImg2imgStrength,
// setInitialImage,
setMaskPath,
setPerlin,
setSampler,
setSeamless,
@ -31,21 +28,14 @@ import {
setThreshold,
setWidth,
} from 'features/parameters/store/generationSlice';
import {
setCodeformerFidelity,
setFacetoolStrength,
setFacetoolType,
setHiresFix,
setUpscalingDenoising,
setUpscalingLevel,
setUpscalingStrength,
} from 'features/parameters/store/postprocessingSlice';
import { setHiresFix } from 'features/parameters/store/postprocessingSlice';
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
type MetadataItemProps = {
isLink?: boolean;
@ -300,7 +290,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
</Text>
</Center>
)}
<Flex gap={2} direction="column">
<Flex gap={2} direction="column" overflow="auto">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
@ -314,22 +304,19 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<Box
sx={{
mt: 0,
mr: 2,
mb: 4,
ml: 2,
padding: 4,
borderRadius: 'base',
overflowX: 'scroll',
wordBreak: 'break-all',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
}}
>
<pre>{metadataJSON}</pre>
</Box>
<OverlayScrollbarsComponent defer>
<Box
sx={{
padding: 4,
borderRadius: 'base',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
w: 'max-content',
}}
>
<pre>{metadataJSON}</pre>
</Box>
</OverlayScrollbarsComponent>
</Flex>
</Flex>
);

View File

@ -0,0 +1,7 @@
import { createAction } from '@reduxjs/toolkit';
import { Image } from 'app/types/invokeai';
import { SelectedImage } from 'features/parameters/store/actions';
export const requestedImageDeletion = createAction<
Image | SelectedImage | undefined
>('gallery/requestedImageDeletion');

View File

@ -4,12 +4,13 @@ import { GalleryState } from './gallerySlice';
* Gallery slice persist denylist
*/
const itemsToDenylist: (keyof GalleryState)[] = [
'categories',
'currentCategory',
'currentImage',
'currentImageUuid',
'shouldAutoSwitchToNewImages',
'intermediateImage',
];
export const galleryPersistDenylist: (keyof GalleryState)[] = [
'currentCategory',
'shouldAutoSwitchToNewImages',
];
export const galleryDenylist = itemsToDenylist.map(

View File

@ -1,23 +1,14 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
import { configSelector } from 'features/system/store/configSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import {
activeTabNameSelector,
uiSelector,
} from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es';
import {
selectResultsAll,
selectResultsById,
selectResultsEntities,
} from './resultsSlice';
import {
selectUploadsAll,
selectUploadsById,
selectUploadsEntities,
} from './uploadsSlice';
import { selectResultsById, selectResultsEntities } from './resultsSlice';
import { selectUploadsAll, selectUploadsById } from './uploadsSlice';
export const gallerySelector = (state: RootState) => state.gallery;
@ -44,6 +35,11 @@ export const imageGallerySelector = createSelector(
const { isLightboxOpen } = lightbox;
const images =
currentCategory === 'results'
? selectResultsEntities(state)
: selectUploadsAll(state);
return {
shouldPinGallery,
galleryImageMinimumWidth,
@ -53,7 +49,7 @@ export const imageGallerySelector = createSelector(
: `repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, auto))`,
shouldAutoSwitchToNewImages,
currentCategory,
images: state[currentCategory].entities,
images,
galleryWidth,
shouldEnableResize:
isLightboxOpen ||

View File

@ -1,10 +1,6 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { invocationComplete } from 'services/events/actions';
import { isImageOutput } from 'services/types/guards';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
import { imageUploaded } from 'services/thunks/image';
import { SelectedImage } from 'features/parameters/store/generationSlice';
import { Image } from 'app/types/invokeai';
type GalleryImageObjectFitType = 'contain' | 'cover';
@ -12,7 +8,7 @@ export interface GalleryState {
/**
* The selected image
*/
selectedImage?: SelectedImage;
selectedImage?: Image;
galleryImageMinimumWidth: number;
galleryImageObjectFit: GalleryImageObjectFitType;
shouldAutoSwitchToNewImages: boolean;
@ -21,8 +17,7 @@ export interface GalleryState {
currentCategory: 'results' | 'uploads';
}
const initialState: GalleryState = {
selectedImage: undefined,
export const initialGalleryState: GalleryState = {
galleryImageMinimumWidth: 64,
galleryImageObjectFit: 'cover',
shouldAutoSwitchToNewImages: true,
@ -33,12 +28,9 @@ const initialState: GalleryState = {
export const gallerySlice = createSlice({
name: 'gallery',
initialState,
initialState: initialGalleryState,
reducers: {
imageSelected: (
state,
action: PayloadAction<SelectedImage | undefined>
) => {
imageSelected: (state, action: PayloadAction<Image | undefined>) => {
state.selectedImage = action.payload;
// TODO: if the user selects an image, disable the auto switch?
// state.shouldAutoSwitchToNewImages = false;
@ -71,30 +63,6 @@ export const gallerySlice = createSlice({
state.shouldUseSingleGalleryColumn = action.payload;
},
},
extraReducers(builder) {
/**
* Invocation Complete
*/
builder.addCase(invocationComplete, (state, action) => {
const { data } = action.payload;
if (isImageOutput(data.result) && state.shouldAutoSwitchToNewImages) {
state.selectedImage = {
name: data.result.image.image_name,
type: 'results',
};
}
});
/**
* Upload Image - FULFILLED
*/
builder.addCase(imageUploaded.fulfilled, (state, action) => {
const { response } = action.payload;
const uploadedImage = deserializeImageResponse(response);
state.selectedImage = { name: uploadedImage.name, type: 'uploads' };
});
},
});
export const {

View File

@ -5,7 +5,9 @@ import { ResultsState } from './resultsSlice';
*
* Currently denylisting results slice entirely, see persist config in store.ts
*/
const itemsToDenylist: (keyof ResultsState)[] = ['isLoading'];
const itemsToDenylist: (keyof ResultsState)[] = [];
export const resultsPersistDenylist: (keyof ResultsState)[] = [];
export const resultsDenylist = itemsToDenylist.map(
(denylistItem) => `results.${denylistItem}`

View File

@ -1,17 +1,11 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/types/invokeai';
import { invocationComplete } from 'services/events/actions';
import { RootState } from 'app/store/store';
import {
receivedResultImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { isImageOutput } from 'services/types/guards';
import {
buildImageUrls,
extractTimestampFromImageName,
} from 'services/util/deserializeImageField';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
import {
imageDeleted,
@ -73,44 +67,6 @@ const resultsSlice = createSlice({
state.isLoading = false;
});
/**
* Invocation Complete
*/
builder.addCase(invocationComplete, (state, action) => {
const { data, shouldFetchImages } = action.payload;
const { result, node, graph_execution_state_id } = data;
if (isImageOutput(result)) {
const name = result.image.image_name;
const type = result.image.image_type;
// if we need to refetch, set URLs to placeholder for now
const { url, thumbnail } = shouldFetchImages
? { url: '', thumbnail: '' }
: buildImageUrls(type, name);
const timestamp = extractTimestampFromImageName(name);
const image: Image = {
name,
type,
url,
thumbnail,
metadata: {
created: timestamp,
width: result.width,
height: result.height,
invokeai: {
session_id: graph_execution_state_id,
...(node ? { node } : {}),
},
},
};
resultsAdapter.setOne(state, image);
}
});
/**
* Image Received - FULFILLED
*/
@ -142,9 +98,10 @@ const resultsSlice = createSlice({
});
/**
* Delete Image - FULFILLED
* Delete Image - PENDING
* Pre-emptively remove the image from the gallery
*/
builder.addCase(imageDeleted.fulfilled, (state, action) => {
builder.addCase(imageDeleted.pending, (state, action) => {
const { imageType, imageName } = action.meta.arg;
if (imageType === 'results') {

View File

@ -5,7 +5,8 @@ import { UploadsState } from './uploadsSlice';
*
* Currently denylisting uploads slice entirely, see persist config in store.ts
*/
const itemsToDenylist: (keyof UploadsState)[] = ['isLoading'];
const itemsToDenylist: (keyof UploadsState)[] = [];
export const uploadsPersistDenylist: (keyof UploadsState)[] = [];
export const uploadsDenylist = itemsToDenylist.map(
(denylistItem) => `uploads.${denylistItem}`

View File

@ -6,7 +6,7 @@ import {
receivedUploadImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { imageDeleted, imageUploaded } from 'services/thunks/image';
import { imageDeleted } from 'services/thunks/image';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
export const uploadsAdapter = createEntityAdapter<Image>({
@ -21,7 +21,7 @@ type AdditionalUploadsState = {
nextPage: number;
};
const initialUploadsState =
export const initialUploadsState =
uploadsAdapter.getInitialState<AdditionalUploadsState>({
page: 0,
pages: 0,
@ -35,7 +35,7 @@ const uploadsSlice = createSlice({
name: 'uploads',
initialState: initialUploadsState,
reducers: {
uploadAdded: uploadsAdapter.addOne,
uploadAdded: uploadsAdapter.upsertOne,
},
extraReducers: (builder) => {
/**
@ -62,20 +62,10 @@ const uploadsSlice = createSlice({
});
/**
* Upload Image - FULFILLED
* Delete Image - pending
* Pre-emptively remove the image from the gallery
*/
builder.addCase(imageUploaded.fulfilled, (state, action) => {
const { location, response } = action.payload;
const uploadedImage = deserializeImageResponse(response);
uploadsAdapter.setOne(state, uploadedImage);
});
/**
* Delete Image - FULFILLED
*/
builder.addCase(imageDeleted.fulfilled, (state, action) => {
builder.addCase(imageDeleted.pending, (state, action) => {
const { imageType, imageName } = action.meta.arg;
if (imageType === 'uploads') {

View File

@ -4,7 +4,7 @@ import * as InvokeAI from 'app/types/invokeai';
import { useGetUrl } from 'common/util/getUrl';
type ReactPanZoomProps = {
image: InvokeAI._Image;
image: InvokeAI.Image;
styleClass?: string;
alt?: string;
ref?: React.Ref<HTMLImageElement>;

View File

@ -4,6 +4,9 @@ import { LightboxState } from './lightboxSlice';
* Lightbox slice persist denylist
*/
const itemsToDenylist: (keyof LightboxState)[] = ['isLightboxOpen'];
export const lightboxPersistDenylist: (keyof LightboxState)[] = [
'isLightboxOpen',
];
export const lightboxDenylist = itemsToDenylist.map(
(denylistItem) => `lightbox.${denylistItem}`

View File

@ -5,7 +5,7 @@ export interface LightboxState {
isLightboxOpen: boolean;
}
const initialLightboxState: LightboxState = {
export const initialLightboxState: LightboxState = {
isLightboxOpen: false,
};

View File

@ -1,5 +1,3 @@
import { v4 as uuidv4 } from 'uuid';
import 'reactflow/dist/style.css';
import { memo, useCallback } from 'react';
import {
@ -8,12 +6,11 @@ import {
MenuButton,
MenuList,
MenuItem,
IconButton,
} from '@chakra-ui/react';
import { FaEllipsisV, FaPlus } from 'react-icons/fa';
import { FaEllipsisV } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { nodeAdded } from '../store/nodesSlice';
import { cloneDeep, map } from 'lodash-es';
import { map } from 'lodash-es';
import { RootState } from 'app/store/store';
import { useBuildInvocation } from '../hooks/useBuildInvocation';
import { addToast } from 'features/system/store/systemSlice';

View File

@ -1,12 +1,6 @@
import { Tooltip } from '@chakra-ui/react';
import { CSSProperties, memo, useMemo } from 'react';
import {
Handle,
Position,
Connection,
HandleType,
useReactFlow,
} from 'reactflow';
import { CSSProperties, memo } from 'react';
import { Handle, Position, Connection, HandleType } from 'reactflow';
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../types/constants';
// import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles';
import { InputFieldTemplate, OutputFieldTemplate } from '../types/types';
@ -26,9 +20,9 @@ const outputHandleStyles: CSSProperties = {
right: '-0.5rem',
};
const requiredConnectionStyles: CSSProperties = {
boxShadow: '0 0 0.5rem 0.5rem var(--invokeai-colors-error-400)',
};
// const requiredConnectionStyles: CSSProperties = {
// boxShadow: '0 0 0.5rem 0.5rem var(--invokeai-colors-error-400)',
// };
type FieldHandleProps = {
nodeId: string;
@ -39,8 +33,8 @@ type FieldHandleProps = {
};
const FieldHandle = (props: FieldHandleProps) => {
const { nodeId, field, isValidConnection, handleType, styles } = props;
const { name, title, type, description } = field;
const { field, isValidConnection, handleType, styles } = props;
const { name, type } = field;
return (
<Tooltip

View File

@ -23,7 +23,6 @@ import TopRightPanel from './panels/TopRightPanel';
import TopCenterPanel from './panels/TopCenterPanel';
import BottomLeftPanel from './panels/BottomLeftPanel.tsx';
import MinimapPanel from './panels/MinimapPanel';
import NodeSearch from './search/NodeSearch';
const nodeTypes = { invocation: InvocationComponent };
@ -78,8 +77,7 @@ export const Flow = () => {
style: { strokeWidth: 2 },
}}
>
<NodeSearch />
{/* <TopLeftPanel /> */}
<TopLeftPanel />
<TopCenterPanel />
<TopRightPanel />
<BottomLeftPanel />

View File

@ -1,6 +1,6 @@
import { Flex, Heading, Tooltip, Icon } from '@chakra-ui/react';
import { InvocationTemplate } from 'features/nodes/types/types';
import { memo, MutableRefObject } from 'react';
import { memo } from 'react';
import { FaInfoCircle } from 'react-icons/fa';
interface IAINodeHeaderProps {

View File

@ -10,6 +10,7 @@ import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComp
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
import StringInputFieldComponent from './fields/StringInputFieldComponent';
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
type InputFieldComponentProps = {
@ -21,7 +22,7 @@ type InputFieldComponentProps = {
// build an individual input element based on the schema
const InputFieldComponent = (props: InputFieldComponentProps) => {
const { nodeId, field, template } = props;
const { type, value } = field;
const { type } = field;
if (type === 'string' && template.type === 'string') {
return (
@ -126,6 +127,26 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
);
}
if (type === 'color' && template.type === 'color') {
return (
<ColorInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'item' && template.type === 'item') {
return (
<ItemInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
return <Box p={2}>Unknown field type: {type}</Box>;
};

View File

@ -2,7 +2,7 @@ import { Box } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { memo } from 'react';
import { buildNodesGraph } from '../util/nodesGraphBuilder/buildNodesGraph';
import { buildNodesGraph } from '../util/graphBuilders/buildNodesGraph';
const NodeGraphOverlay = () => {
const state = useAppSelector((state: RootState) => state);

View File

@ -0,0 +1,31 @@
import {
ColorInputFieldTemplate,
ColorInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FieldComponentProps } from './types';
import { RgbaColor, RgbaColorPicker } from 'react-colorful';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { useAppDispatch } from 'app/store/storeHooks';
const ColorInputFieldComponent = (
props: FieldComponentProps<ColorInputFieldValue, ColorInputFieldTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (value: RgbaColor) => {
dispatch(fieldValueChanged({ nodeId, fieldName: field.name, value }));
};
return (
<RgbaColorPicker
className="nodrag"
color={field.value}
onChange={handleValueChanged}
/>
);
};
export default memo(ColorInputFieldComponent);

View File

@ -1,16 +1,16 @@
import { Box, Image, Icon, Flex } from '@chakra-ui/react';
import { Box, Image } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import SelectImagePlaceholder from 'common/components/SelectImagePlaceholder';
import { useGetUrl } from 'common/util/getUrl';
import useGetImageByNameAndType from 'features/gallery/hooks/useGetImageByName';
import useGetImageByUuid from 'features/gallery/hooks/useGetImageByUuid';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
ImageInputFieldTemplate,
ImageInputFieldValue,
} from 'features/nodes/types/types';
import { DragEvent, memo, useCallback, useState } from 'react';
import { FaImage } from 'react-icons/fa';
import { ImageType } from 'services/api';
import { FieldComponentProps } from './types';
@ -18,7 +18,6 @@ const ImageInputFieldComponent = (
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
) => {
const { nodeId, field } = props;
const { value } = field;
const getImageByNameAndType = useGetImageByNameAndType();
const dispatch = useAppDispatch();

View File

@ -3,7 +3,7 @@ import {
ItemInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FaAddressCard, FaList } from 'react-icons/fa';
import { FaAddressCard } from 'react-icons/fa';
import { FieldComponentProps } from './types';
const ItemInputFieldComponent = (

View File

@ -1,17 +1,13 @@
import { Select } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
ModelInputFieldTemplate,
ModelInputFieldValue,
} from 'features/nodes/types/types';
import {
selectModelsById,
selectModelsIds,
} from 'features/system/store/modelSlice';
import { isEqual, map } from 'lodash-es';
import { selectModelsIds } from 'features/system/store/modelSlice';
import { isEqual } from 'lodash-es';
import { ChangeEvent, memo } from 'react';
import { FieldComponentProps } from './types';
@ -48,7 +44,10 @@ const ModelInputFieldComponent = (
};
return (
<Select onChange={handleValueChanged} value={field.value}>
<Select
onChange={handleValueChanged}
value={field.value || allModelNames[0]}
>
{allModelNames.map((option) => (
<option key={option}>{option}</option>
))}

View File

@ -1,16 +1,16 @@
import { HStack } from '@chakra-ui/react';
import { userInvoked } from 'app/store/actions';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import { memo, useCallback } from 'react';
import { Panel } from 'reactflow';
import { receivedOpenAPISchema } from 'services/thunks/schema';
import { nodesGraphBuilt } from 'services/thunks/session';
const TopCenterPanel = () => {
const dispatch = useAppDispatch();
const handleInvoke = useCallback(() => {
dispatch(nodesGraphBuilt());
dispatch(userInvoked('nodes'));
}, [dispatch]);
const handleReloadSchema = useCallback(() => {

View File

@ -1,10 +1,10 @@
import { memo } from 'react';
import { Panel } from 'reactflow';
import AddNodeMenu from '../AddNodeMenu';
import NodeSearch from '../search/NodeSearch';
const TopLeftPanel = () => (
<Panel position="top-left">
<AddNodeMenu />
<NodeSearch />
</Panel>
);

View File

@ -2,7 +2,6 @@ import { Box, Flex } from '@chakra-ui/layout';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIInput from 'common/components/IAIInput';
import { Panel } from 'reactflow';
import { map } from 'lodash-es';
import {
ChangeEvent,
@ -192,19 +191,17 @@ const NodeSearch = () => {
};
return (
<Panel position="top-left">
<Flex
flexDirection="column"
tabIndex={1}
onKeyDown={searchKeyHandler}
onFocus={() => setShowNodeList(true)}
onBlur={searchInputBlurHandler}
ref={nodeSearchRef}
>
<IAIInput value={searchText} onChange={findNode} />
{showNodeList && renderNodeList()}
</Flex>
</Panel>
<Flex
flexDirection="column"
tabIndex={1}
onKeyDown={searchKeyHandler}
onFocus={() => setShowNodeList(true)}
onBlur={searchInputBlurHandler}
ref={nodeSearchRef}
>
<IAIInput value={searchText} onChange={findNode} />
{showNodeList && renderNodeList()}
</Flex>
);
};

View File

@ -0,0 +1,18 @@
import { createAction, isAnyOf } from '@reduxjs/toolkit';
import { Graph } from 'services/api';
export const textToImageGraphBuilt = createAction<Graph>(
'nodes/textToImageGraphBuilt'
);
export const imageToImageGraphBuilt = createAction<Graph>(
'nodes/imageToImageGraphBuilt'
);
export const canvasGraphBuilt = createAction<Graph>('nodes/canvasGraphBuilt');
export const nodesGraphBuilt = createAction<Graph>('nodes/nodesGraphBuilt');
export const isAnyGraphBuilt = isAnyOf(
textToImageGraphBuilt,
imageToImageGraphBuilt,
canvasGraphBuilt,
nodesGraphBuilt
);

View File

@ -4,6 +4,10 @@ import { NodesState } from './nodesSlice';
* Nodes slice persist denylist
*/
const itemsToDenylist: (keyof NodesState)[] = ['schema', 'invocationTemplates'];
export const nodesPersistDenylist: (keyof NodesState)[] = [
'schema',
'invocationTemplates',
];
export const nodesDenylist = itemsToDenylist.map(
(denylistItem) => `nodes.${denylistItem}`

View File

@ -11,13 +11,14 @@ import {
NodeChange,
OnConnectStartParams,
} from 'reactflow';
import { Graph, ImageField } from 'services/api';
import { ImageField } from 'services/api';
import { receivedOpenAPISchema } from 'services/thunks/schema';
import { isFulfilledAnyGraphBuilt } from 'services/thunks/session';
import { InvocationTemplate, InvocationValue } from '../types/types';
import { parseSchema } from '../util/parseSchema';
import { log } from 'app/logging/useLogger';
import { size } from 'lodash-es';
import { isAnyGraphBuilt } from './actions';
import { RgbaColor } from 'react-colorful';
export type NodesState = {
nodes: Node<InvocationValue>[];
@ -25,7 +26,6 @@ export type NodesState = {
schema: OpenAPIV3.Document | null;
invocationTemplates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null;
lastGraph: Graph | null;
shouldShowGraphOverlay: boolean;
};
@ -35,7 +35,6 @@ export const initialNodesState: NodesState = {
schema: null,
invocationTemplates: {},
connectionStartParams: null,
lastGraph: null,
shouldShowGraphOverlay: false,
};
@ -71,6 +70,7 @@ const nodesSlice = createSlice({
| number
| boolean
| Pick<ImageField, 'image_name' | 'image_type'>
| RgbaColor
| undefined;
}>
) => {
@ -104,8 +104,9 @@ const nodesSlice = createSlice({
state.schema = action.payload;
});
builder.addMatcher(isFulfilledAnyGraphBuilt, (state, action) => {
state.lastGraph = action.payload;
builder.addMatcher(isAnyGraphBuilt, (state, action) => {
// TODO: Achtung! Side effect in a reducer!
log.info({ namespace: 'nodes', data: action.payload }, 'Graph built');
});
},
});

View File

@ -1,4 +1,3 @@
import { getCSSVar } from '@chakra-ui/utils';
import { FieldType, FieldUIConfig } from './types';
export const HANDLE_TOOLTIP_OPEN_DELAY = 500;
@ -15,6 +14,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
model: 'model',
array: 'array',
item: 'item',
ColorField: 'color',
};
const COLOR_TOKEN_VALUE = 500;
@ -89,4 +89,10 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'Collection Item',
description: 'TODO: Collection Item type description.',
},
color: {
color: 'gray',
colorCssVar: getColorTokenCssVariable('gray'),
title: 'Color',
description: 'A RGBA color.',
},
};

View File

@ -1,4 +1,5 @@
import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful';
import { ImageField } from 'services/api';
import { AnyInvocationType } from 'services/events/types';
@ -59,7 +60,8 @@ export type FieldType =
| 'conditioning'
| 'model'
| 'array'
| 'item';
| 'item'
| 'color';
/**
* An input field is persisted across reloads as part of the user's local state.
@ -80,7 +82,8 @@ export type InputFieldValue =
| EnumInputFieldValue
| ModelInputFieldValue
| ArrayInputFieldValue
| ItemInputFieldValue;
| ItemInputFieldValue
| ColorInputFieldValue;
/**
* An input field template is generated on each page load from the OpenAPI schema.
@ -99,7 +102,8 @@ export type InputFieldTemplate =
| EnumInputFieldTemplate
| ModelInputFieldTemplate
| ArrayInputFieldTemplate
| ItemInputFieldTemplate;
| ItemInputFieldTemplate
| ColorInputFieldTemplate;
/**
* An output field is persisted across as part of the user's local state.
@ -193,6 +197,11 @@ export type ItemInputFieldValue = FieldValueBase & {
value?: undefined;
};
export type ColorInputFieldValue = FieldValueBase & {
type: 'color';
value?: RgbaColor;
};
export type InputFieldTemplateBase = {
name: string;
title: string;
@ -241,7 +250,7 @@ export type ImageInputFieldTemplate = InputFieldTemplateBase & {
};
export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
default: string;
type: 'latents';
};
@ -272,6 +281,11 @@ export type ItemInputFieldTemplate = InputFieldTemplateBase & {
type: 'item';
};
export type ColorInputFieldTemplate = InputFieldTemplateBase & {
default: RgbaColor;
type: 'color';
};
/**
* JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES
*/

Some files were not shown because too many files have changed in this diff Show More