merge with main

This commit is contained in:
Lincoln Stein 2023-05-13 21:35:19 -04:00
commit 1103ab2844
326 changed files with 6753 additions and 3928 deletions

View File

@ -2,8 +2,7 @@ name: mkdocs-material
on: on:
push: push:
branches: branches:
- 'main' - 'refs/heads/v2.3'
- 'development'
permissions: permissions:
contents: write contents: write
@ -12,6 +11,10 @@ jobs:
mkdocs-material: mkdocs-material:
if: github.event.pull_request.draft == false if: github.event.pull_request.draft == false
runs-on: ubuntu-latest runs-on: ubuntu-latest
env:
REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
REPO_NAME: '${{ github.repository }}'
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
steps: steps:
- name: checkout sources - name: checkout sources
uses: actions/checkout@v3 uses: actions/checkout@v3
@ -22,11 +25,15 @@ jobs:
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: '3.10' python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install requirements - name: install requirements
env:
PIP_USE_PEP517: 1
run: | run: |
python -m \ python -m \
pip install -r docs/requirements-mkdocs.txt pip install ".[docs]"
- name: confirm buildability - name: confirm buildability
run: | run: |

View File

@ -247,8 +247,8 @@ class InvokeAiInstance:
pip[ pip[
"install", "install",
"--require-virtualenv", "--require-virtualenv",
"torch", "torch~=2.0.0",
"torchvision", "torchvision>=0.14.1",
"--force-reinstall", "--force-reinstall",
"--find-links" if find_links is not None else None, "--find-links" if find_links is not None else None,
find_links, find_links,

View File

@ -83,7 +83,7 @@ async def get_thumbnail(
status_code=201, status_code=201,
) )
async def upload_image( async def upload_image(
file: UploadFile, request: Request, response: Response file: UploadFile, image_type: ImageType, request: Request, response: Response
) -> ImageResponse: ) -> ImageResponse:
if not file.content_type.startswith("image"): if not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image") raise HTTPException(status_code=415, detail="Not an image")
@ -99,21 +99,21 @@ async def upload_image(
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png" filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
saved_image = ApiDependencies.invoker.services.images.save( saved_image = ApiDependencies.invoker.services.images.save(
ImageType.UPLOAD, filename, img image_type, filename, img
) )
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img) invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
image_url = ApiDependencies.invoker.services.images.get_uri( image_url = ApiDependencies.invoker.services.images.get_uri(
ImageType.UPLOAD, saved_image.image_name image_type, saved_image.image_name
) )
thumbnail_url = ApiDependencies.invoker.services.images.get_uri( thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
ImageType.UPLOAD, saved_image.image_name, True image_type, saved_image.image_name, True
) )
res = ImageResponse( res = ImageResponse(
image_type=ImageType.UPLOAD, image_type=image_type,
image_name=saved_image.image_name, image_name=saved_image.image_name,
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,

View File

@ -122,7 +122,6 @@ app.openapi = custom_openapi
# Override API doc favicons # Override API doc favicons
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static") app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
@app.get("/docs", include_in_schema=False) @app.get("/docs", include_in_schema=False)
def overridden_swagger(): def overridden_swagger():
return get_swagger_ui_html( return get_swagger_ui_html(
@ -140,6 +139,8 @@ def overridden_redoc():
redoc_favicon_url="/static/favicon.ico", redoc_favicon_url="/static/favicon.ico",
) )
# Must mount *after* the other routes else it borks em
app.mount("/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), name="ui")
def invoke_api(): def invoke_api():
global web_config global web_config

View File

@ -3,12 +3,12 @@
from typing import Literal, Optional from typing import Literal, Optional
import numpy as np import numpy as np
import numpy.random
from pydantic import Field from pydantic import Field
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
InvocationConfig,
InvocationContext, InvocationContext,
BaseInvocationOutput, BaseInvocationOutput,
) )
@ -50,11 +50,11 @@ class RandomRangeInvocation(BaseInvocation):
default=np.iinfo(np.int32).max, description="The exclusive high value" default=np.iinfo(np.int32).max, description="The exclusive high value"
) )
size: int = Field(default=1, description="The number of values to generate") size: int = Field(default=1, description="The number of values to generate")
seed: Optional[int] = Field( seed: int = Field(
ge=0, ge=0,
le=np.iinfo(np.int32).max, le=SEED_MAX,
description="The seed for the RNG", description="The seed for the RNG (omit for random)",
default_factory=lambda: numpy.random.randint(0, np.iinfo(np.int32).max), default_factory=get_random_seed,
) )
def invoke(self, context: InvocationContext) -> IntCollectionOutput: def invoke(self, context: InvocationContext) -> IntCollectionOutput:

View File

@ -1,15 +1,17 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from functools import partial from functools import partial
from typing import Literal, Optional, Union from typing import Literal, Optional, Union, get_args
import numpy as np import numpy as np
from torch import Tensor from torch import Tensor
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType from invokeai.app.models.image import ColorField, ImageField, ImageType
from invokeai.app.invocations.util.choose_model import choose_model from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output from .image import ImageOutput, build_image_output
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
@ -17,7 +19,8 @@ from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.step_callback import stable_diffusion_step_callback from ..util.step_callback import stable_diffusion_step_callback
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())] SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = 'patchmatch' if 'patchmatch' in get_args(INFILL_METHODS) else 'tile'
class SDImageInvocation(BaseModel): class SDImageInvocation(BaseModel):
"""Helper class to provide all Stable Diffusion raster image invocations with additional config""" """Helper class to provide all Stable Diffusion raster image invocations with additional config"""
@ -44,15 +47,13 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# TODO: consider making prompt optional to enable providing prompt through a link # TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off # fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from") prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", ) seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", ) width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", ) height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
model: str = Field(default="", description="The model to use (currently ignored)") model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
# fmt: on # fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching? # TODO: pass this an emitter method or something? or a session for dispatching?
@ -148,7 +149,6 @@ class ImageToImageInvocation(TextToImageInvocation):
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
) )
mask = None
if self.fit: if self.fit:
image = image.resize((self.width, self.height)) image = image.resize((self.width, self.height))
@ -165,7 +165,6 @@ class ImageToImageInvocation(TextToImageInvocation):
outputs = Img2Img(model).generate( outputs = Img2Img(model).generate(
prompt=self.prompt, prompt=self.prompt,
init_image=image, init_image=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id), step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict( **self.dict(
exclude={"prompt", "image", "mask"} exclude={"prompt", "image", "mask"}
@ -197,7 +196,6 @@ class ImageToImageInvocation(TextToImageInvocation):
image=result_image, image=result_image,
) )
class InpaintInvocation(ImageToImageInvocation): class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint.""" """Generates an image using inpaint."""
@ -205,6 +203,17 @@ class InpaintInvocation(ImageToImageInvocation):
# Inputs # Inputs
mask: Union[ImageField, None] = Field(description="The mask") mask: Union[ImageField, None] = Field(description="The mask")
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
seam_strength: float = Field(
default=0.75, gt=0, le=1, description="The seam inpaint strength"
)
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
infill_method: INFILL_METHODS = Field(default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)")
inpaint_width: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)")
inpaint_height: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)")
inpaint_fill: Optional[ColorField] = Field(default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color")
inpaint_replace: float = Field( inpaint_replace: float = Field(
default=0.0, default=0.0,
ge=0.0, ge=0.0,

View File

@ -1,5 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from typing import Literal, Optional from typing import Literal, Optional
import numpy import numpy
@ -32,14 +33,12 @@ class ImageOutput(BaseInvocationOutput):
# fmt: off # fmt: off
type: Literal["image"] = "image" type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image") image: ImageField = Field(default=None, description="The output image")
width: Optional[int] = Field(default=None, description="The width of the image in pixels") width: int = Field(description="The width of the image in pixels")
height: Optional[int] = Field(default=None, description="The height of the image in pixels") height: int = Field(description="The height of the image in pixels")
# fmt: on # fmt: on
class Config: class Config:
schema_extra = { schema_extra = {"required": ["type", "image", "width", "height"]}
"required": ["type", "image", "width", "height", "mode"]
}
def build_image_output( def build_image_output(
@ -54,7 +53,6 @@ def build_image_output(
image=image_field, image=image_field,
width=image.width, width=image.width,
height=image.height, height=image.height,
mode=image.mode,
) )

View File

@ -0,0 +1,233 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional, Union, get_args
import numpy as np
import math
from PIL import Image, ImageOps
from pydantic import Field
from invokeai.app.invocations.image import ImageOutput, build_image_output
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ColorField, ImageField, ImageType
from .baseinvocation import (
BaseInvocation,
InvocationContext,
)
def infill_methods() -> list[str]:
methods = [
"tile",
"solid",
]
if PatchMatch.patchmatch_available():
methods.insert(0, "patchmatch")
return methods
INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = (
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
)
def infill_patchmatch(im: Image.Image) -> Image.Image:
if im.mode != "RGBA":
return im
# Skip patchmatch if patchmatch isn't available
if not PatchMatch.patchmatch_available():
return im
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
im_patched_np = PatchMatch.inpaint(
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
)
im_patched = Image.fromarray(im_patched_np, mode="RGB")
return im_patched
def get_tile_images(image: np.ndarray, width=8, height=8):
_nrows, _ncols, depth = image.shape
_strides = image.strides
nrows, _m = divmod(_nrows, height)
ncols, _n = divmod(_ncols, width)
if _m != 0 or _n != 0:
return None
return np.lib.stride_tricks.as_strided(
np.ravel(image),
shape=(nrows, ncols, height, width, depth),
strides=(height * _strides[0], width * _strides[1], *_strides),
writeable=False,
)
def tile_fill_missing(
im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
) -> Image.Image:
# Only fill if there's an alpha layer
if im.mode != "RGBA":
return im
a = np.asarray(im, dtype=np.uint8)
tile_size_tuple = (tile_size, tile_size)
# Get the image as tiles of a specified size
tiles = get_tile_images(a, *tile_size_tuple).copy()
# Get the mask as tiles
tiles_mask = tiles[:, :, :, :, 3]
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
tmask_shape = tiles_mask.shape
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
tiles_mask = tiles_mask > 0
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
# Get RGB tiles in single array and filter by the mask
tshape = tiles.shape
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
filtered_tiles = tiles_all[tiles_mask]
if len(filtered_tiles) == 0:
return im
# Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask == False).sum()
rng = np.random.default_rng(seed=seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
]
# Convert back to an image
tiles_all = tiles_all.reshape(tshape)
tiles_all = tiles_all.swapaxes(1, 2)
st = tiles_all.reshape(
(
math.prod(tiles_all.shape[0:2]),
math.prod(tiles_all.shape[2:4]),
tiles_all.shape[4],
)
)
si = Image.fromarray(st, mode="RGBA")
return si
class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
color: Optional[ColorField] = Field(
default=ColorField(r=127, g=127, b=127, a=255),
description="The color to use to infill",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image)
infilled.paste(image, (0, 0), image.split()[-1])
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
)
class InfillTileInvocation(BaseInvocation):
"""Infills transparent areas of an image with tiles of the image"""
type: Literal["infill_tile"] = "infill_tile"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
seed: int = Field(
ge=0,
le=SEED_MAX,
description="The seed to use for tile generation (omit for random)",
default_factory=get_random_seed,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
infilled = tile_fill_missing(
image.copy(), seed=self.seed, tile_size=self.tile_size
)
infilled.paste(image, (0, 0), image.split()[-1])
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
)
class InfillPatchMatchInvocation(BaseInvocation):
"""Infills transparent areas of an image using the PatchMatch algorithm"""
type: Literal["infill_patchmatch"] = "infill_patchmatch"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
if PatchMatch.patchmatch_available():
infilled = infill_patchmatch(image.copy())
else:
raise ValueError("PatchMatch is not available on this system")
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
)

View File

@ -1,11 +1,13 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import random import random
from typing import Literal, Optional from typing import Literal, Optional, Union
import einops
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import torch import torch
from invokeai.app.invocations.util.choose_model import choose_model from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
@ -13,7 +15,9 @@ from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding from ...backend.image_util.seamless import configure_model_padding
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np import numpy as np
from ..services.image_storage import ImageType from ..services.image_storage import ImageType
@ -37,41 +41,55 @@ class LatentsField(BaseModel):
class LatentsOutput(BaseInvocationOutput): class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents""" """Base class for invocations that output latents"""
#fmt: off #fmt: off
type: Literal["latent_output"] = "latent_output" type: Literal["latents_output"] = "latents_output"
# Inputs
latents: LatentsField = Field(default=None, description="The output latents") latents: LatentsField = Field(default=None, description="The output latents")
width: int = Field(description="The width of the latents in pixels")
height: int = Field(description="The height of the latents in pixels")
#fmt: on #fmt: on
def build_latents_output(latents_name: str, latents: torch.Tensor):
return LatentsOutput(
latents=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
class NoiseOutput(BaseInvocationOutput): class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output""" """Invocation noise output"""
#fmt: off #fmt: off
type: Literal["noise_output"] = "noise_output" type: Literal["noise_output"] = "noise_output"
# Inputs
noise: LatentsField = Field(default=None, description="The output noise") noise: LatentsField = Field(default=None, description="The output noise")
width: int = Field(description="The width of the noise in pixels")
height: int = Field(description="The height of the noise in pixels")
#fmt: on #fmt: on
def build_noise_output(latents_name: str, latents: torch.Tensor):
# TODO: this seems like a hack return NoiseOutput(
scheduler_map = dict( noise=LatentsField(latents_name=latents_name),
ddim=diffusers.DDIMScheduler, width=latents.size()[3] * 8,
dpmpp_2=diffusers.DPMSolverMultistepScheduler, height=latents.size()[2] * 8,
k_dpm_2=diffusers.KDPM2DiscreteScheduler, )
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_euler=diffusers.EulerDiscreteScheduler,
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
k_heun=diffusers.HeunDiscreteScheduler,
k_lms=diffusers.LMSDiscreteScheduler,
plms=diffusers.PNDMScheduler,
)
SAMPLER_NAME_VALUES = Literal[ SAMPLER_NAME_VALUES = Literal[
tuple(list(scheduler_map.keys())) tuple(list(SCHEDULER_MAP.keys()))
] ]
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = scheduler_map.get(scheduler_name,'ddim') scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler = scheduler_class.from_config(model.scheduler.config)
scheduler_config = model.scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py # hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'): if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False scheduler.uses_inpainting_model = lambda: False
@ -102,17 +120,13 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
return x return x
def random_seed():
return random.randint(0, np.iinfo(np.uint32).max)
class NoiseInvocation(BaseInvocation): class NoiseInvocation(BaseInvocation):
"""Generates latent noise.""" """Generates latent noise."""
type: Literal["noise"] = "noise" type: Literal["noise"] = "noise"
# Inputs # Inputs
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed) seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed)
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", ) width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", ) height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
@ -131,9 +145,7 @@ class NoiseInvocation(BaseInvocation):
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, noise) context.services.latents.set(name, noise)
return NoiseOutput( return build_noise_output(latents_name=name, latents=noise)
noise=LatentsField(latents_name=name)
)
# Text to image # Text to image
@ -149,11 +161,10 @@ class TextToLatentsInvocation(BaseInvocation):
noise: Optional[LatentsField] = Field(description="The noise to use") noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
# fmt: on # fmt: on
# Schema customisation # Schema customisation
@ -218,7 +229,7 @@ class TextToLatentsInvocation(BaseInvocation):
h_symmetry_time_pct=None,#h_symmetry_time_pct, h_symmetry_time_pct=None,#h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct, v_symmetry_time_pct=None#v_symmetry_time_pct,
), ),
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta) ).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
return conditioning_data return conditioning_data
@ -250,9 +261,7 @@ class TextToLatentsInvocation(BaseInvocation):
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents) context.services.latents.set(name, result_latents)
return LatentsOutput( return build_latents_output(latents_name=name, latents=result_latents)
latents=LatentsField(latents_name=name)
)
class LatentsToLatentsInvocation(TextToLatentsInvocation): class LatentsToLatentsInvocation(TextToLatentsInvocation):
@ -260,6 +269,10 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
type: Literal["l2l"] = "l2l" type: Literal["l2l"] = "l2l"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
@ -271,10 +284,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
}, },
} }
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name) latent = context.services.latents.get(self.latents.latents_name)
@ -287,7 +296,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager) model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model) conditioning_data = self.get_conditioning_data(context, model)
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
@ -295,11 +304,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
latent, device=model.device, dtype=latent.dtype latent, device=model.device, dtype=latent.dtype
) )
timesteps, _ = model.get_img2img_timesteps( timesteps, _ = model.get_img2img_timesteps(self.steps, self.strength)
self.steps,
self.strength,
device=model.device,
)
result_latents, result_attention_map_saver = model.latents_from_embeddings( result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=initial_latents, latents=initial_latents,
@ -315,9 +320,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
name = f'{context.graph_execution_state_id}__{self.id}' name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents) context.services.latents.set(name, result_latents)
return LatentsOutput( return build_latents_output(latents_name=name, latents=result_latents)
latents=LatentsField(latents_name=name)
)
# Latent to image # Latent to image
@ -384,8 +387,8 @@ class ResizeLatentsInvocation(BaseInvocation):
latents: Optional[LatentsField] = Field(description="The latents to resize") latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode") mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
@ -402,7 +405,7 @@ class ResizeLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents) context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name)) return build_latents_output(latents_name=name, latents=resized_latents)
class ScaleLatentsInvocation(BaseInvocation): class ScaleLatentsInvocation(BaseInvocation):
@ -413,8 +416,8 @@ class ScaleLatentsInvocation(BaseInvocation):
# Inputs # Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale") latents: Optional[LatentsField] = Field(description="The latents to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode") mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
@ -432,4 +435,48 @@ class ScaleLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents) context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name)) return build_latents_output(latents_name=name, latents=resized_latents)
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
type: Literal["i2l"] = "i2l"
# Inputs
image: Union[ImageField, None] = Field(description="The image to encode")
model: str = Field(default="", description="The model to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
"type_hints": {"model": "model"},
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)
model: StableDiffusionGeneratorPipeline = model_info["model"]
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = model.non_noised_latents_from_image(
image_tensor,
device=model._model_group.device_for(model.unet),
dtype=model.unet.dtype,
)
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, latents)
return build_latents_output(latents_name=name, latents=latents)

View File

@ -3,6 +3,7 @@
from typing import Literal from typing import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import numpy as np
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
@ -73,3 +74,12 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig):
def invoke(self, context: InvocationContext) -> IntOutput: def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=int(self.a / self.b)) return IntOutput(a=int(self.a / self.b))
class RandomIntInvocation(BaseInvocation):
"""Outputs a single random integer."""
#fmt: off
type: Literal["rand_int"] = "rand_int"
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=np.random.randint(0, np.iinfo(np.int32).max))

View File

@ -4,10 +4,11 @@ from invokeai.backend.model_management.model_manager import ModelManager
def choose_model(model_manager: ModelManager, model_name: str): def choose_model(model_manager: ModelManager, model_name: str):
"""Returns the default model if the `model_name` not a valid model, else returns the selected model.""" """Returns the default model if the `model_name` not a valid model, else returns the selected model."""
logger = model_manager.logger logger = model_manager.logger
if model_manager.valid_model(model_name): if model_name and not model_manager.valid_model(model_name):
model = model_manager.get_model(model_name) default_model_name = model_manager.default_model()
else: logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{default_model_name}\' instead.")
model = model_manager.get_model() model = model_manager.get_model()
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.") else:
model = model_manager.get_model(model_name)
return model return model

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional, Tuple
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -27,3 +27,13 @@ class ImageField(BaseModel):
class Config: class Config:
schema_extra = {"required": ["image_type", "image_name"]} schema_extra = {"required": ["image_type", "image_name"]}
class ColorField(BaseModel):
r: int = Field(ge=0, le=255, description="The red component")
g: int = Field(ge=0, le=255, description="The green component")
b: int = Field(ge=0, le=255, description="The blue component")
a: int = Field(ge=0, le=255, description="The alpha component")
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)

View File

@ -49,12 +49,13 @@ def create_text_to_image() -> LibraryGraph:
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]: def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
"""Creates the default system graphs, or adds new versions if the old ones don't match""" """Creates the default system graphs, or adds new versions if the old ones don't match"""
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
graphs: list[LibraryGraph] = list() graphs: list[LibraryGraph] = list()
text_to_image = graph_library.get(default_text_to_image_graph_id) # text_to_image = graph_library.get(default_text_to_image_graph_id)
# TODO: Check if the graph is the same as the default one, and if not, update it # # TODO: Check if the graph is the same as the default one, and if not, update it
#if text_to_image is None: # #if text_to_image is None:
text_to_image = create_text_to_image() text_to_image = create_text_to_image()
graph_library.set(text_to_image) graph_library.set(text_to_image)

View File

@ -270,4 +270,5 @@ class DiskImageStorage(ImageStorageBase):
) # TODO: this should refresh position for LRU cache ) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size: if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get() cache_id = self.__cache_ids.get()
if cache_id in self.__cache:
del self.__cache[cache_id] del self.__cache[cache_id]

View File

@ -20,9 +20,18 @@ class MetadataLatentsField(TypedDict):
latents_name: str latents_name: str
class MetadataColorField(TypedDict):
"""Pydantic-less ColorField, used for metadata parsing"""
r: int
g: int
b: int
a: int
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports # TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
NodeMetadata = Dict[ NodeMetadata = Dict[
str, str | int | float | bool | MetadataImageField | MetadataLatentsField str, None | str | int | float | bool | MetadataImageField | MetadataLatentsField | MetadataColorField
] ]

View File

@ -1,3 +1,4 @@
import time
import traceback import traceback
from threading import Event, Thread, BoundedSemaphore from threading import Event, Thread, BoundedSemaphore
@ -6,6 +7,7 @@ from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker from .invoker import InvocationProcessorABC, Invoker
from ..models.exceptions import CanceledException from ..models.exceptions import CanceledException
import invokeai.backend.util.logging as logger
class DefaultInvocationProcessor(InvocationProcessorABC): class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread __invoker_thread: Thread
__stop_event: Event __stop_event: Event
@ -34,8 +36,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
try: try:
self.__threadLimit.acquire() self.__threadLimit.acquire()
while not stop_event.is_set(): while not stop_event.is_set():
try:
queue_item: InvocationQueueItem = self.__invoker.services.queue.get() queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
except Exception as e:
logger.debug("Exception while getting from queue: %s" % e)
if not queue_item: # Probably stopping if not queue_item: # Probably stopping
# do not hammer the queue
time.sleep(0.5)
continue continue
graph_execution_state = ( graph_execution_state = (
@ -124,7 +132,16 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Queue any further commands if invoking all # Queue any further commands if invoking all
is_complete = graph_execution_state.is_complete() is_complete = graph_execution_state.is_complete()
if queue_item.invoke_all and not is_complete: if queue_item.invoke_all and not is_complete:
try:
self.__invoker.invoke(graph_execution_state, invoke_all=True) self.__invoker.invoke(graph_execution_state, invoke_all=True)
except Exception as e:
logger.error("Error while invoking: %s" % e)
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
error=traceback.format_exc()
)
elif is_complete: elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete( self.__invoker.services.events.emit_graph_execution_complete(
graph_execution_state.id graph_execution_state.id

View File

@ -1,5 +1,13 @@
import datetime import datetime
import numpy as np
def get_timestamp(): def get_timestamp():
return int(datetime.datetime.now(datetime.timezone.utc).timestamp()) return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
SEED_MAX = np.iinfo(np.int32).max
def get_random_seed():
return np.random.randint(0, SEED_MAX)

View File

@ -31,6 +31,7 @@ from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker from ..safety_checker import SafetyChecker
from ..prompting.conditioning import get_uc_and_c_and_ec from ..prompting.conditioning import get_uc_and_c_and_ec
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..stable_diffusion.schedulers import SCHEDULER_MAP
downsampling = 8 downsampling = 8
@ -71,19 +72,6 @@ class InvokeAIGeneratorOutput:
# we are interposing a wrapper around the original Generator classes so that # we are interposing a wrapper around the original Generator classes so that
# old code that calls Generate will continue to work. # old code that calls Generate will continue to work.
class InvokeAIGenerator(metaclass=ABCMeta): class InvokeAIGenerator(metaclass=ABCMeta):
scheduler_map = dict(
ddim=diffusers.DDIMScheduler,
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_euler=diffusers.EulerDiscreteScheduler,
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
k_heun=diffusers.HeunDiscreteScheduler,
k_lms=diffusers.LMSDiscreteScheduler,
plms=diffusers.PNDMScheduler,
)
def __init__(self, def __init__(self,
model_info: dict, model_info: dict,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(), params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
@ -175,14 +163,20 @@ class InvokeAIGenerator(metaclass=ABCMeta):
''' '''
Return list of all the schedulers that we currently handle. Return list of all the schedulers that we currently handle.
''' '''
return list(self.scheduler_map.keys()) return list(SCHEDULER_MAP.keys())
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]): def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
return generator_class(model, self.params.precision) return generator_class(model, self.params.precision)
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim') scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler = scheduler_class.from_config(model.scheduler.config)
scheduler_config = model.scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py # hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'): if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False scheduler.uses_inpainting_model = lambda: False
@ -226,10 +220,10 @@ class Inpaint(Img2Img):
def generate(self, def generate(self,
mask_image: Image.Image | torch.FloatTensor, mask_image: Image.Image | torch.FloatTensor,
# Seam settings - when 0, doesn't fill seam # Seam settings - when 0, doesn't fill seam
seam_size: int = 0, seam_size: int = 96,
seam_blur: int = 0, seam_blur: int = 16,
seam_strength: float = 0.7, seam_strength: float = 0.7,
seam_steps: int = 10, seam_steps: int = 30,
tile_size: int = 32, tile_size: int = 32,
inpaint_replace=False, inpaint_replace=False,
infill_method=None, infill_method=None,

View File

@ -4,6 +4,7 @@ invokeai.backend.generator.inpaint descends from .generator
from __future__ import annotations from __future__ import annotations
import math import math
from typing import Tuple, Union
import cv2 import cv2
import numpy as np import numpy as np
@ -59,7 +60,7 @@ class Inpaint(Img2Img):
writeable=False, writeable=False,
) )
def infill_patchmatch(self, im: Image.Image) -> Image: def infill_patchmatch(self, im: Image.Image) -> Image.Image:
if im.mode != "RGBA": if im.mode != "RGBA":
return im return im
@ -75,18 +76,18 @@ class Inpaint(Img2Img):
return im_patched return im_patched
def tile_fill_missing( def tile_fill_missing(
self, im: Image.Image, tile_size: int = 16, seed: int = None self, im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
) -> Image: ) -> Image.Image:
# Only fill if there's an alpha layer # Only fill if there's an alpha layer
if im.mode != "RGBA": if im.mode != "RGBA":
return im return im
a = np.asarray(im, dtype=np.uint8) a = np.asarray(im, dtype=np.uint8)
tile_size = (tile_size, tile_size) tile_size_tuple = (tile_size, tile_size)
# Get the image as tiles of a specified size # Get the image as tiles of a specified size
tiles = self.get_tile_images(a, *tile_size).copy() tiles = self.get_tile_images(a, *tile_size_tuple).copy()
# Get the mask as tiles # Get the mask as tiles
tiles_mask = tiles[:, :, :, :, 3] tiles_mask = tiles[:, :, :, :, 3]
@ -127,7 +128,9 @@ class Inpaint(Img2Img):
return si return si
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image: def mask_edge(
self, mask: Image.Image, edge_size: int, edge_blur: int
) -> Image.Image:
npimg = np.asarray(mask, dtype=np.uint8) npimg = np.asarray(mask, dtype=np.uint8)
# Detect any partially transparent regions # Detect any partially transparent regions
@ -206,15 +209,15 @@ class Inpaint(Img2Img):
cfg_scale, cfg_scale,
ddim_eta, ddim_eta,
conditioning, conditioning,
init_image: PIL.Image.Image | torch.FloatTensor, init_image: Image.Image | torch.FloatTensor,
mask_image: PIL.Image.Image | torch.FloatTensor, mask_image: Image.Image | torch.FloatTensor,
strength: float, strength: float,
mask_blur_radius: int = 8, mask_blur_radius: int = 8,
# Seam settings - when 0, doesn't fill seam # Seam settings - when 0, doesn't fill seam
seam_size: int = 0, seam_size: int = 96,
seam_blur: int = 0, seam_blur: int = 16,
seam_strength: float = 0.7, seam_strength: float = 0.7,
seam_steps: int = 10, seam_steps: int = 30,
tile_size: int = 32, tile_size: int = 32,
step_callback=None, step_callback=None,
inpaint_replace=False, inpaint_replace=False,
@ -222,7 +225,7 @@ class Inpaint(Img2Img):
infill_method=None, infill_method=None,
inpaint_width=None, inpaint_width=None,
inpaint_height=None, inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF), inpaint_fill: Tuple[int, int, int, int] = (0x7F, 0x7F, 0x7F, 0xFF),
attention_maps_callback=None, attention_maps_callback=None,
**kwargs, **kwargs,
): ):
@ -239,7 +242,7 @@ class Inpaint(Img2Img):
self.inpaint_width = inpaint_width self.inpaint_width = inpaint_width
self.inpaint_height = inpaint_height self.inpaint_height = inpaint_height
if isinstance(init_image, PIL.Image.Image): if isinstance(init_image, Image.Image):
self.pil_image = init_image.copy() self.pil_image = init_image.copy()
# Do infill # Do infill
@ -250,8 +253,8 @@ class Inpaint(Img2Img):
self.pil_image.copy(), seed=self.seed, tile_size=tile_size self.pil_image.copy(), seed=self.seed, tile_size=tile_size
) )
elif infill_method == "solid": elif infill_method == "solid":
solid_bg = PIL.Image.new("RGBA", init_image.size, inpaint_fill) solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
init_filled = PIL.Image.alpha_composite(solid_bg, init_image) init_filled = Image.alpha_composite(solid_bg, init_image)
else: else:
raise ValueError( raise ValueError(
f"Non-supported infill type {infill_method}", infill_method f"Non-supported infill type {infill_method}", infill_method
@ -269,7 +272,7 @@ class Inpaint(Img2Img):
# Create init tensor # Create init tensor
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB")) init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
if isinstance(mask_image, PIL.Image.Image): if isinstance(mask_image, Image.Image):
self.pil_mask = mask_image.copy() self.pil_mask = mask_image.copy()
debug_image( debug_image(
mask_image, mask_image,

View File

@ -47,6 +47,7 @@ from diffusers import (
LDMTextToImagePipeline, LDMTextToImagePipeline,
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
UniPCMultistepScheduler,
StableDiffusionPipeline, StableDiffusionPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
@ -1208,6 +1209,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "dpm": elif scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == 'unipc':
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == "ddim": elif scheduler_type == "ddim":
scheduler = scheduler scheduler = scheduler
else: else:

View File

@ -1214,7 +1214,7 @@ class ModelManager(object):
sha.update(chunk) sha.update(chunk)
hash = sha.hexdigest() hash = sha.hexdigest()
toc = time.time() toc = time.time()
self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic)) self.logger.debug(f"sha256 = {hash} ({count} files hashed in {toc - tic:4.2f}s)")
with open(hashpath, "w") as f: with open(hashpath, "w") as f:
f.write(hash) f.write(hash)
return hash return hash

View File

@ -509,10 +509,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id=None, run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu')
else:
scheduler_device = self._model_group.device_for(self.unet)
if timesteps is None: if timesteps is None:
self.scheduler.set_timesteps( self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
num_inference_steps, device=self._model_group.device_for(self.unet)
)
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
infer_latents_from_embeddings = GeneratorToCallbackinator( infer_latents_from_embeddings = GeneratorToCallbackinator(
self.generate_latents_from_embeddings, PipelineIntermediateState self.generate_latents_from_embeddings, PipelineIntermediateState
@ -726,11 +729,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id=None, run_id=None,
callback=None, callback=None,
) -> InvokeAIStableDiffusionPipelineOutput: ) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps( timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
num_inference_steps,
strength,
device=self._model_group.device_for(self.unet),
)
result_latents, result_attention_maps = self.latents_from_embeddings( result_latents, result_attention_maps = self.latents_from_embeddings(
latents=initial_latents if strength < 1.0 else torch.zeros_like( latents=initial_latents if strength < 1.0 else torch.zeros_like(
initial_latents, device=initial_latents.device, dtype=initial_latents.dtype initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
@ -756,13 +755,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return self.check_for_safety(output, dtype=conditioning_data.dtype) return self.check_for_safety(output, dtype=conditioning_data.dtype)
def get_img2img_timesteps( def get_img2img_timesteps(
self, num_inference_steps: int, strength: float, device self, num_inference_steps: int, strength: float, device=None
) -> (torch.Tensor, int): ) -> (torch.Tensor, int):
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
assert img2img_pipeline.scheduler is self.scheduler assert img2img_pipeline.scheduler is self.scheduler
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu')
else:
scheduler_device = self._model_group.device_for(self.unet)
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
timesteps, adjusted_steps = img2img_pipeline.get_timesteps( timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
num_inference_steps, strength, device=device num_inference_steps, strength, device=scheduler_device
) )
# Workaround for low strength resulting in zero timesteps. # Workaround for low strength resulting in zero timesteps.
# TODO: submit upstream fix for zero-step img2img # TODO: submit upstream fix for zero-step img2img
@ -796,9 +801,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if init_image.dim() == 3: if init_image.dim() == 3:
init_image = init_image.unsqueeze(0) init_image = init_image.unsqueeze(0)
timesteps, _ = self.get_img2img_timesteps( timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
num_inference_steps, strength, device=device
)
# 6. Prepare latent variables # 6. Prepare latent variables
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents # can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents

View File

@ -0,0 +1 @@
from .schedulers import SCHEDULER_MAP

View File

@ -0,0 +1,22 @@
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler
SCHEDULER_MAP = dict(
ddim=(DDIMScheduler, dict()),
ddpm=(DDPMScheduler, dict()),
deis=(DEISMultistepScheduler, dict()),
lms=(LMSDiscreteScheduler, dict()),
pndm=(PNDMScheduler, dict()),
heun=(HeunDiscreteScheduler, dict()),
euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)),
euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)),
euler_a=(EulerAncestralDiscreteScheduler, dict()),
kdpm_2=(KDPM2DiscreteScheduler, dict()),
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
dpmpp_2s=(DPMSolverSinglestepScheduler, dict()),
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
)

View File

@ -4,17 +4,20 @@ from .parse_seed_weights import parse_seed_weights
SAMPLER_CHOICES = [ SAMPLER_CHOICES = [
"ddim", "ddim",
"k_dpm_2_a", "ddpm",
"k_dpm_2", "deis",
"k_dpmpp_2_a", "lms",
"k_dpmpp_2",
"k_euler_a",
"k_euler",
"k_heun",
"k_lms",
"plms",
# diffusers:
"pndm", "pndm",
"heun",
"euler",
"euler_k",
"euler_a",
"kdpm_2",
"kdpm_2_a",
"dpmpp_2s",
"dpmpp_2m",
"dpmpp_2m_k",
"unipc",
] ]

View File

@ -1,13 +0,0 @@
{
"plugins": [
[
"transform-imports",
{
"lodash": {
"transform": "lodash/${member}",
"preventFullImport": true
}
}
]
]
}

View File

@ -35,3 +35,7 @@ stats.html
!.yarn/releases !.yarn/releases
!.yarn/sdks !.yarn/sdks
!.yarn/versions !.yarn/versions
# Yalc
.yalc
yalc.lock

View File

@ -5,6 +5,7 @@ import { PluginOption, UserConfig } from 'vite';
import dts from 'vite-plugin-dts'; import dts from 'vite-plugin-dts';
import eslint from 'vite-plugin-eslint'; import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths'; import tsconfigPaths from 'vite-tsconfig-paths';
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
export const packageConfig: UserConfig = { export const packageConfig: UserConfig = {
base: './', base: './',
@ -16,9 +17,10 @@ export const packageConfig: UserConfig = {
dts({ dts({
insertTypesEntry: true, insertTypesEntry: true,
}), }),
cssInjectedByJsPlugin(),
], ],
build: { build: {
chunkSizeWarningLimit: 1500, cssCodeSplit: true,
lib: { lib: {
entry: path.resolve(__dirname, '../src/index.ts'), entry: path.resolve(__dirname, '../src/index.ts'),
name: 'InvokeAIUI', name: 'InvokeAIUI',
@ -30,6 +32,7 @@ export const packageConfig: UserConfig = {
globals: { globals: {
react: 'React', react: 'React',
'react-dom': 'ReactDOM', 'react-dom': 'ReactDOM',
'@emotion/react': 'EmotionReact',
}, },
}, },
}, },

View File

@ -37,7 +37,7 @@ From `invokeai/frontend/web/` run `yarn install` to get everything set up.
Start everything in dev mode: Start everything in dev mode:
1. Start the dev server: `yarn dev` 1. Start the dev server: `yarn dev`
2. Start the InvokeAI UI per usual: `invokeai --web` 2. Start the InvokeAI Nodes backend: `python scripts/invokeai-new.py --web # run from the repo root`
3. Point your browser to the dev server address e.g. <http://localhost:5173/> 3. Point your browser to the dev server address e.g. <http://localhost:5173/>
### Production builds ### Production builds

View File

@ -21,7 +21,6 @@
"scripts": { "scripts": {
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky", "prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"", "dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"", "dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build", "build": "yarn run lint && vite build",
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts", "api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
@ -90,6 +89,7 @@
"react-konva": "^18.2.7", "react-konva": "^18.2.7",
"react-konva-utils": "^1.0.4", "react-konva-utils": "^1.0.4",
"react-redux": "^8.0.5", "react-redux": "^8.0.5",
"react-resizable-panels": "^0.0.42",
"react-rnd": "^10.4.1", "react-rnd": "^10.4.1",
"react-transition-group": "^4.4.5", "react-transition-group": "^4.4.5",
"react-use": "^17.4.0", "react-use": "^17.4.0",
@ -99,6 +99,7 @@
"redux-deep-persist": "^1.0.7", "redux-deep-persist": "^1.0.7",
"redux-dynamic-middlewares": "^2.2.0", "redux-dynamic-middlewares": "^2.2.0",
"redux-persist": "^6.0.0", "redux-persist": "^6.0.0",
"redux-remember": "^3.3.1",
"roarr": "^7.15.0", "roarr": "^7.15.0",
"serialize-error": "^11.0.0", "serialize-error": "^11.0.0",
"socket.io-client": "^4.6.0", "socket.io-client": "^4.6.0",
@ -118,6 +119,7 @@
"@types/node": "^18.16.2", "@types/node": "^18.16.2",
"@types/react": "^18.2.0", "@types/react": "^18.2.0",
"@types/react-dom": "^18.2.1", "@types/react-dom": "^18.2.1",
"@types/react-redux": "^7.1.25",
"@types/react-transition-group": "^4.4.5", "@types/react-transition-group": "^4.4.5",
"@types/uuid": "^9.0.0", "@types/uuid": "^9.0.0",
"@typescript-eslint/eslint-plugin": "^5.59.1", "@typescript-eslint/eslint-plugin": "^5.59.1",
@ -143,6 +145,7 @@
"terser": "^5.17.1", "terser": "^5.17.1",
"ts-toolbelt": "^9.6.0", "ts-toolbelt": "^9.6.0",
"vite": "^4.3.3", "vite": "^4.3.3",
"vite-plugin-css-injected-by-js": "^3.1.1",
"vite-plugin-dts": "^2.3.0", "vite-plugin-dts": "^2.3.0",
"vite-plugin-eslint": "^1.8.1", "vite-plugin-eslint": "^1.8.1",
"vite-tsconfig-paths": "^4.2.0", "vite-tsconfig-paths": "^4.2.0",

View File

@ -25,7 +25,7 @@
"common": { "common": {
"hotkeysLabel": "Hotkeys", "hotkeysLabel": "Hotkeys",
"themeLabel": "Theme", "themeLabel": "Theme",
"languagePickerLabel": "Language Picker", "languagePickerLabel": "Language",
"reportBugLabel": "Report Bug", "reportBugLabel": "Report Bug",
"githubLabel": "Github", "githubLabel": "Github",
"discordLabel": "Discord", "discordLabel": "Discord",
@ -54,7 +54,7 @@
"img2img": "Image To Image", "img2img": "Image To Image",
"unifiedCanvas": "Unified Canvas", "unifiedCanvas": "Unified Canvas",
"linear": "Linear", "linear": "Linear",
"nodes": "Nodes", "nodes": "Node Editor",
"postprocessing": "Post Processing", "postprocessing": "Post Processing",
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.", "nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
"postProcessing": "Post Processing", "postProcessing": "Post Processing",
@ -102,7 +102,8 @@
"generate": "Generate", "generate": "Generate",
"openInNewTab": "Open in New Tab", "openInNewTab": "Open in New Tab",
"dontAskMeAgain": "Don't ask me again", "dontAskMeAgain": "Don't ask me again",
"areYouSure": "Are you sure?" "areYouSure": "Are you sure?",
"imagePrompt": "Image Prompt"
}, },
"gallery": { "gallery": {
"generations": "Generations", "generations": "Generations",
@ -453,9 +454,10 @@
"seed": "Seed", "seed": "Seed",
"imageToImage": "Image to Image", "imageToImage": "Image to Image",
"randomizeSeed": "Randomize Seed", "randomizeSeed": "Randomize Seed",
"shuffle": "Shuffle", "shuffle": "Shuffle Seed",
"noiseThreshold": "Noise Threshold", "noiseThreshold": "Noise Threshold",
"perlinNoise": "Perlin Noise", "perlinNoise": "Perlin Noise",
"noiseSettings": "Noise",
"variations": "Variations", "variations": "Variations",
"variationAmount": "Variation Amount", "variationAmount": "Variation Amount",
"seedWeights": "Seed Weights", "seedWeights": "Seed Weights",
@ -470,6 +472,8 @@
"scale": "Scale", "scale": "Scale",
"otherOptions": "Other Options", "otherOptions": "Other Options",
"seamlessTiling": "Seamless Tiling", "seamlessTiling": "Seamless Tiling",
"seamlessXAxis": "X Axis",
"seamlessYAxis": "Y Axis",
"hiresOptim": "High Res Optimization", "hiresOptim": "High Res Optimization",
"hiresStrength": "High Res Strength", "hiresStrength": "High Res Strength",
"imageFit": "Fit Initial Image To Output Size", "imageFit": "Fit Initial Image To Output Size",
@ -527,7 +531,8 @@
"useCanvasBeta": "Use Canvas Beta Layout", "useCanvasBeta": "Use Canvas Beta Layout",
"enableImageDebugging": "Enable Image Debugging", "enableImageDebugging": "Enable Image Debugging",
"useSlidersForAll": "Use Sliders For All Options", "useSlidersForAll": "Use Sliders For All Options",
"autoShowProgress": "Auto Show Progress Images", "showProgressInViewer": "Show Progress Images in Viewer",
"antialiasProgressImages": "Antialias Progress Images",
"resetWebUI": "Reset Web UI", "resetWebUI": "Reset Web UI",
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.", "resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.", "resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
@ -549,8 +554,9 @@
"downloadImageStarted": "Image Download Started", "downloadImageStarted": "Image Download Started",
"imageCopied": "Image Copied", "imageCopied": "Image Copied",
"imageLinkCopied": "Image Link Copied", "imageLinkCopied": "Image Link Copied",
"problemCopyingImageLink": "Unable to Copy Image Link",
"imageNotLoaded": "No Image Loaded", "imageNotLoaded": "No Image Loaded",
"imageNotLoadedDesc": "No image found to send to image to image module", "imageNotLoadedDesc": "Could not find image",
"imageSavedToGallery": "Image Saved to Gallery", "imageSavedToGallery": "Image Saved to Gallery",
"canvasMerged": "Canvas Merged", "canvasMerged": "Canvas Merged",
"sentToImageToImage": "Sent To Image To Image", "sentToImageToImage": "Sent To Image To Image",
@ -645,7 +651,8 @@
"betaClear": "Clear", "betaClear": "Clear",
"betaDarkenOutside": "Darken Outside", "betaDarkenOutside": "Darken Outside",
"betaLimitToBox": "Limit To Box", "betaLimitToBox": "Limit To Box",
"betaPreserveMasked": "Preserve Masked" "betaPreserveMasked": "Preserve Masked",
"antialiasing": "Antialiasing"
}, },
"ui": { "ui": {
"showProgressImages": "Show Progress Images", "showProgressImages": "Show Progress Images",

View File

@ -1,24 +1,18 @@
import ImageUploader from 'common/components/ImageUploader'; import ImageUploader from 'common/components/ImageUploader';
import ProgressBar from 'features/system/components/ProgressBar';
import SiteHeader from 'features/system/components/SiteHeader'; import SiteHeader from 'features/system/components/SiteHeader';
import ProgressBar from 'features/system/components/ProgressBar';
import InvokeTabs from 'features/ui/components/InvokeTabs'; import InvokeTabs from 'features/ui/components/InvokeTabs';
import useToastWatcher from 'features/system/hooks/useToastWatcher'; import useToastWatcher from 'features/system/hooks/useToastWatcher';
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton'; import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons'; import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react'; import { Box, Flex, Grid, Portal } from '@chakra-ui/react';
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants'; import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel'; import GalleryDrawer from 'features/gallery/components/ImageGalleryPanel';
import Lightbox from 'features/lightbox/components/Lightbox'; import Lightbox from 'features/lightbox/components/Lightbox';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { import { memo, ReactNode, useCallback, useEffect, useState } from 'react';
memo,
PropsWithChildren,
useCallback,
useEffect,
useState,
} from 'react';
import { motion, AnimatePresence } from 'framer-motion'; import { motion, AnimatePresence } from 'framer-motion';
import Loading from 'common/components/Loading/Loading'; import Loading from 'common/components/Loading/Loading';
import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady'; import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady';
@ -27,20 +21,24 @@ import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import { configChanged } from 'features/system/store/configSlice'; import { configChanged } from 'features/system/store/configSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useLogger } from 'app/logging/useLogger'; import { useLogger } from 'app/logging/useLogger';
import ProgressImagePreview from 'features/parameters/components/ProgressImagePreview'; import ParametersDrawer from 'features/ui/components/ParametersDrawer';
import { languageSelector } from 'features/system/store/systemSelectors';
import i18n from 'i18n';
const DEFAULT_CONFIG = {}; const DEFAULT_CONFIG = {};
interface Props extends PropsWithChildren { interface Props {
config?: PartialAppConfig; config?: PartialAppConfig;
headerComponent?: ReactNode;
} }
const App = ({ config = DEFAULT_CONFIG, children }: Props) => { const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
useToastWatcher(); useToastWatcher();
useGlobalHotkeys(); useGlobalHotkeys();
const log = useLogger();
const currentTheme = useAppSelector((state) => state.ui.currentTheme); const language = useAppSelector(languageSelector);
const log = useLogger();
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled; const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
@ -48,18 +46,17 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
const [loadingOverridden, setLoadingOverridden] = useState(false); const [loadingOverridden, setLoadingOverridden] = useState(false);
const { setColorMode } = useColorMode();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
useEffect(() => {
i18n.changeLanguage(language);
}, [language]);
useEffect(() => { useEffect(() => {
log.info({ namespace: 'App', data: config }, 'Received config'); log.info({ namespace: 'App', data: config }, 'Received config');
dispatch(configChanged(config)); dispatch(configChanged(config));
}, [dispatch, config, log]); }, [dispatch, config, log]);
useEffect(() => {
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');
}, [setColorMode, currentTheme]);
const handleOverrideClicked = useCallback(() => { const handleOverrideClicked = useCallback(() => {
setLoadingOverridden(true); setLoadingOverridden(true);
}, []); }, []);
@ -76,7 +73,7 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
w={APP_WIDTH} w={APP_WIDTH}
h={APP_HEIGHT} h={APP_HEIGHT}
> >
{children || <SiteHeader />} {headerComponent || <SiteHeader />}
<Flex <Flex
gap={4} gap={4}
w={{ base: '100vw', xl: 'full' }} w={{ base: '100vw', xl: 'full' }}
@ -84,11 +81,13 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
flexDir={{ base: 'column', xl: 'row' }} flexDir={{ base: 'column', xl: 'row' }}
> >
<InvokeTabs /> <InvokeTabs />
<ImageGalleryPanel />
</Flex> </Flex>
</Grid> </Grid>
</ImageUploader> </ImageUploader>
<GalleryDrawer />
<ParametersDrawer />
<AnimatePresence> <AnimatePresence>
{!isApplicationReady && !loadingOverridden && ( {!isApplicationReady && !loadingOverridden && (
<motion.div <motion.div
@ -121,7 +120,6 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
<Portal> <Portal>
<FloatingGalleryButton /> <FloatingGalleryButton />
</Portal> </Portal>
<ProgressImagePreview />
</Grid> </Grid>
); );
}; };

View File

@ -1,18 +1,13 @@
import React, { lazy, memo, PropsWithChildren, useEffect } from 'react'; import React, {
lazy,
memo,
PropsWithChildren,
ReactNode,
useEffect,
} from 'react';
import { Provider } from 'react-redux'; import { Provider } from 'react-redux';
import { PersistGate } from 'redux-persist/integration/react';
import { store } from 'app/store/store'; import { store } from 'app/store/store';
import { persistor } from '../store/persistor';
import { OpenAPI } from 'services/api'; import { OpenAPI } from 'services/api';
import '@fontsource/inter/100.css';
import '@fontsource/inter/200.css';
import '@fontsource/inter/300.css';
import '@fontsource/inter/400.css';
import '@fontsource/inter/500.css';
import '@fontsource/inter/600.css';
import '@fontsource/inter/700.css';
import '@fontsource/inter/800.css';
import '@fontsource/inter/900.css';
import Loading from '../../common/components/Loading/Loading'; import Loading from '../../common/components/Loading/Loading';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares'; import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
@ -28,9 +23,10 @@ interface Props extends PropsWithChildren {
apiUrl?: string; apiUrl?: string;
token?: string; token?: string;
config?: PartialAppConfig; config?: PartialAppConfig;
headerComponent?: ReactNode;
} }
const InvokeAIUI = ({ apiUrl, token, config, children }: Props) => { const InvokeAIUI = ({ apiUrl, token, config, headerComponent }: Props) => {
useEffect(() => { useEffect(() => {
// configure API client token // configure API client token
if (token) { if (token) {
@ -57,13 +53,11 @@ const InvokeAIUI = ({ apiUrl, token, config, children }: Props) => {
return ( return (
<React.StrictMode> <React.StrictMode>
<Provider store={store}> <Provider store={store}>
<PersistGate loading={<Loading />} persistor={persistor}>
<React.Suspense fallback={<Loading />}> <React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider> <ThemeLocaleProvider>
<App config={config}>{children}</App> <App config={config} headerComponent={headerComponent} />
</ThemeLocaleProvider> </ThemeLocaleProvider>
</React.Suspense> </React.Suspense>
</PersistGate>
</Provider> </Provider>
</React.StrictMode> </React.StrictMode>
); );

View File

@ -1,4 +1,8 @@
import { ChakraProvider, extendTheme } from '@chakra-ui/react'; import {
ChakraProvider,
createLocalStorageManager,
extendTheme,
} from '@chakra-ui/react';
import { ReactNode, useEffect } from 'react'; import { ReactNode, useEffect } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { theme as invokeAITheme } from 'theme/theme'; import { theme as invokeAITheme } from 'theme/theme';
@ -9,15 +13,8 @@ import { greenTeaThemeColors } from 'theme/colors/greenTea';
import { invokeAIThemeColors } from 'theme/colors/invokeAI'; import { invokeAIThemeColors } from 'theme/colors/invokeAI';
import { lightThemeColors } from 'theme/colors/lightTheme'; import { lightThemeColors } from 'theme/colors/lightTheme';
import { oceanBlueColors } from 'theme/colors/oceanBlue'; import { oceanBlueColors } from 'theme/colors/oceanBlue';
import '@fontsource/inter/100.css';
import '@fontsource/inter/200.css'; import '@fontsource/inter/variable.css';
import '@fontsource/inter/300.css';
import '@fontsource/inter/400.css';
import '@fontsource/inter/500.css';
import '@fontsource/inter/600.css';
import '@fontsource/inter/700.css';
import '@fontsource/inter/800.css';
import '@fontsource/inter/900.css';
import 'overlayscrollbars/overlayscrollbars.css'; import 'overlayscrollbars/overlayscrollbars.css';
import 'theme/css/overlayscrollbars.css'; import 'theme/css/overlayscrollbars.css';
@ -32,6 +29,8 @@ const THEMES = {
ocean: oceanBlueColors, ocean: oceanBlueColors,
}; };
const manager = createLocalStorageManager('@@invokeai-color-mode');
function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) { function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
const { i18n } = useTranslation(); const { i18n } = useTranslation();
@ -51,7 +50,11 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
document.body.dir = direction; document.body.dir = direction;
}, [direction]); }, [direction]);
return <ChakraProvider theme={theme}>{children}</ChakraProvider>; return (
<ChakraProvider theme={theme} colorModeManager={manager}>
{children}
</ChakraProvider>
);
} }
export default ThemeLocaleProvider; export default ThemeLocaleProvider;

View File

@ -2,17 +2,28 @@
export const DIFFUSERS_SCHEDULERS: Array<string> = [ export const DIFFUSERS_SCHEDULERS: Array<string> = [
'ddim', 'ddim',
'plms', 'ddpm',
'k_lms', 'deis',
'dpmpp_2', 'lms',
'k_dpm_2', 'pndm',
'k_dpm_2_a', 'heun',
'k_dpmpp_2', 'euler',
'k_euler', 'euler_k',
'k_euler_a', 'euler_a',
'k_heun', 'kdpm_2',
'kdpm_2_a',
'dpmpp_2s',
'dpmpp_2m',
'dpmpp_2m_k',
'unipc',
]; ];
export const IMG2IMG_DIFFUSERS_SCHEDULERS = DIFFUSERS_SCHEDULERS.filter(
(scheduler) => {
return scheduler !== 'dpmpp_2s';
}
);
// Valid image widths // Valid image widths
export const WIDTHS: Array<number> = Array.from(Array(64)).map( export const WIDTHS: Array<number> = Array.from(Array(64)).map(
(_x, i) => (i + 1) * 64 (_x, i) => (i + 1) * 64

View File

@ -1,26 +1,20 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { validateSeedWeights } from 'common/util/seedWeightPairs'; import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { initialCanvasImageSelector } from 'features/canvas/store/canvasSelectors';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { systemSelector } from 'features/system/store/systemSelectors'; import { systemSelector } from 'features/system/store/systemSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
export const readinessSelector = createSelector( export const readinessSelector = createSelector(
[ [generationSelector, systemSelector, activeTabNameSelector],
generationSelector, (generation, system, activeTabName) => {
systemSelector,
initialCanvasImageSelector,
activeTabNameSelector,
],
(generation, system, initialCanvasImage, activeTabName) => {
const { const {
prompt, prompt,
shouldGenerateVariations, shouldGenerateVariations,
seedWeights, seedWeights,
initialImage, initialImage,
seed, seed,
isImageToImageEnabled,
} = generation; } = generation;
const { isProcessing, isConnected } = system; const { isProcessing, isConnected } = system;
@ -34,7 +28,7 @@ export const readinessSelector = createSelector(
reasonsWhyNotReady.push('Missing prompt'); reasonsWhyNotReady.push('Missing prompt');
} }
if (isImageToImageEnabled && !initialImage) { if (activeTabName === 'img2img' && !initialImage) {
isReady = false; isReady = false;
reasonsWhyNotReady.push('No initial image selected'); reasonsWhyNotReady.push('No initial image selected');
} }
@ -64,10 +58,5 @@ export const readinessSelector = createSelector(
// All good // All good
return { isReady, reasonsWhyNotReady }; return { isReady, reasonsWhyNotReady };
}, },
{ defaultSelectorOptions
memoizeOptions: {
equalityCheck: isEqual,
resultEqualityCheck: isEqual,
},
}
); );

View File

@ -1,209 +1,209 @@
// import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit'; import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
// import * as InvokeAI from 'app/types/invokeai'; import * as InvokeAI from 'app/types/invokeai';
// import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
// import { import {
// frontendToBackendParameters, frontendToBackendParameters,
// FrontendToBackendParametersConfig, FrontendToBackendParametersConfig,
// } from 'common/util/parameterTranslation'; } from 'common/util/parameterTranslation';
// import dateFormat from 'dateformat'; import dateFormat from 'dateformat';
// import { import {
// GalleryCategory, GalleryCategory,
// GalleryState, GalleryState,
// removeImage, removeImage,
// } from 'features/gallery/store/gallerySlice'; } from 'features/gallery/store/gallerySlice';
// import { import {
// generationRequested, generationRequested,
// modelChangeRequested, modelChangeRequested,
// modelConvertRequested, modelConvertRequested,
// modelMergingRequested, modelMergingRequested,
// setIsProcessing, setIsProcessing,
// } from 'features/system/store/systemSlice'; } from 'features/system/store/systemSlice';
// import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName } from 'features/ui/store/tabMap';
// import { Socket } from 'socket.io-client'; import { Socket } from 'socket.io-client';
// /** /**
// * Returns an object containing all functions which use `socketio.emit()`. * Returns an object containing all functions which use `socketio.emit()`.
// * i.e. those which make server requests. * i.e. those which make server requests.
// */ */
// const makeSocketIOEmitters = ( const makeSocketIOEmitters = (
// store: MiddlewareAPI<Dispatch<AnyAction>, RootState>, store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
// socketio: Socket socketio: Socket
// ) => { ) => {
// // We need to dispatch actions to redux and get pieces of state from the store. // We need to dispatch actions to redux and get pieces of state from the store.
// const { dispatch, getState } = store; const { dispatch, getState } = store;
// return { return {
// emitGenerateImage: (generationMode: InvokeTabName) => { emitGenerateImage: (generationMode: InvokeTabName) => {
// dispatch(setIsProcessing(true)); dispatch(setIsProcessing(true));
// const state: RootState = getState(); const state: RootState = getState();
// const { const {
// generation: generationState, generation: generationState,
// postprocessing: postprocessingState, postprocessing: postprocessingState,
// system: systemState, system: systemState,
// canvas: canvasState, canvas: canvasState,
// } = state; } = state;
// const frontendToBackendParametersConfig: FrontendToBackendParametersConfig = const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
// { {
// generationMode, generationMode,
// generationState, generationState,
// postprocessingState, postprocessingState,
// canvasState, canvasState,
// systemState, systemState,
// }; };
// dispatch(generationRequested()); dispatch(generationRequested());
// const { generationParameters, esrganParameters, facetoolParameters } = const { generationParameters, esrganParameters, facetoolParameters } =
// frontendToBackendParameters(frontendToBackendParametersConfig); frontendToBackendParameters(frontendToBackendParametersConfig);
// socketio.emit( socketio.emit(
// 'generateImage', 'generateImage',
// generationParameters, generationParameters,
// esrganParameters, esrganParameters,
// facetoolParameters facetoolParameters
// ); );
// // we need to truncate the init_mask base64 else it takes up the whole log // we need to truncate the init_mask base64 else it takes up the whole log
// // TODO: handle maintaining masks for reproducibility in future // TODO: handle maintaining masks for reproducibility in future
// if (generationParameters.init_mask) { if (generationParameters.init_mask) {
// generationParameters.init_mask = generationParameters.init_mask generationParameters.init_mask = generationParameters.init_mask
// .substr(0, 64) .substr(0, 64)
// .concat('...'); .concat('...');
// } }
// if (generationParameters.init_img) { if (generationParameters.init_img) {
// generationParameters.init_img = generationParameters.init_img generationParameters.init_img = generationParameters.init_img
// .substr(0, 64) .substr(0, 64)
// .concat('...'); .concat('...');
// } }
// dispatch( dispatch(
// addLogEntry({ addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'), timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Image generation requested: ${JSON.stringify({ message: `Image generation requested: ${JSON.stringify({
// ...generationParameters, ...generationParameters,
// ...esrganParameters, ...esrganParameters,
// ...facetoolParameters, ...facetoolParameters,
// })}`, })}`,
// }) })
// ); );
// }, },
// emitRunESRGAN: (imageToProcess: InvokeAI._Image) => { emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
// dispatch(setIsProcessing(true)); dispatch(setIsProcessing(true));
// const { const {
// postprocessing: { postprocessing: {
// upscalingLevel, upscalingLevel,
// upscalingDenoising, upscalingDenoising,
// upscalingStrength, upscalingStrength,
// }, },
// } = getState(); } = getState();
// const esrganParameters = { const esrganParameters = {
// upscale: [upscalingLevel, upscalingDenoising, upscalingStrength], upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
// }; };
// socketio.emit('runPostprocessing', imageToProcess, { socketio.emit('runPostprocessing', imageToProcess, {
// type: 'esrgan', type: 'esrgan',
// ...esrganParameters, ...esrganParameters,
// }); });
// dispatch( dispatch(
// addLogEntry({ addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'), timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `ESRGAN upscale requested: ${JSON.stringify({ message: `ESRGAN upscale requested: ${JSON.stringify({
// file: imageToProcess.url, file: imageToProcess.url,
// ...esrganParameters, ...esrganParameters,
// })}`, })}`,
// }) })
// ); );
// }, },
// emitRunFacetool: (imageToProcess: InvokeAI._Image) => { emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
// dispatch(setIsProcessing(true)); dispatch(setIsProcessing(true));
// const { const {
// postprocessing: { facetoolType, facetoolStrength, codeformerFidelity }, postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
// } = getState(); } = getState();
// const facetoolParameters: Record<string, unknown> = { const facetoolParameters: Record<string, unknown> = {
// facetool_strength: facetoolStrength, facetool_strength: facetoolStrength,
// }; };
// if (facetoolType === 'codeformer') { if (facetoolType === 'codeformer') {
// facetoolParameters.codeformer_fidelity = codeformerFidelity; facetoolParameters.codeformer_fidelity = codeformerFidelity;
// } }
// socketio.emit('runPostprocessing', imageToProcess, { socketio.emit('runPostprocessing', imageToProcess, {
// type: facetoolType, type: facetoolType,
// ...facetoolParameters, ...facetoolParameters,
// }); });
// dispatch( dispatch(
// addLogEntry({ addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'), timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Face restoration (${facetoolType}) requested: ${JSON.stringify( message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
// { {
// file: imageToProcess.url, file: imageToProcess.url,
// ...facetoolParameters, ...facetoolParameters,
// } }
// )}`, )}`,
// }) })
// ); );
// }, },
// emitDeleteImage: (imageToDelete: InvokeAI._Image) => { emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
// const { url, uuid, category, thumbnail } = imageToDelete; const { url, uuid, category, thumbnail } = imageToDelete;
// dispatch(removeImage(imageToDelete)); dispatch(removeImage(imageToDelete));
// socketio.emit('deleteImage', url, thumbnail, uuid, category); socketio.emit('deleteImage', url, thumbnail, uuid, category);
// }, },
// emitRequestImages: (category: GalleryCategory) => { emitRequestImages: (category: GalleryCategory) => {
// const gallery: GalleryState = getState().gallery; const gallery: GalleryState = getState().gallery;
// const { earliest_mtime } = gallery.categories[category]; const { earliest_mtime } = gallery.categories[category];
// socketio.emit('requestImages', category, earliest_mtime); socketio.emit('requestImages', category, earliest_mtime);
// }, },
// emitRequestNewImages: (category: GalleryCategory) => { emitRequestNewImages: (category: GalleryCategory) => {
// const gallery: GalleryState = getState().gallery; const gallery: GalleryState = getState().gallery;
// const { latest_mtime } = gallery.categories[category]; const { latest_mtime } = gallery.categories[category];
// socketio.emit('requestLatestImages', category, latest_mtime); socketio.emit('requestLatestImages', category, latest_mtime);
// }, },
// emitCancelProcessing: () => { emitCancelProcessing: () => {
// socketio.emit('cancel'); socketio.emit('cancel');
// }, },
// emitRequestSystemConfig: () => { emitRequestSystemConfig: () => {
// socketio.emit('requestSystemConfig'); socketio.emit('requestSystemConfig');
// }, },
// emitSearchForModels: (modelFolder: string) => { emitSearchForModels: (modelFolder: string) => {
// socketio.emit('searchForModels', modelFolder); socketio.emit('searchForModels', modelFolder);
// }, },
// emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => { emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
// socketio.emit('addNewModel', modelConfig); socketio.emit('addNewModel', modelConfig);
// }, },
// emitDeleteModel: (modelName: string) => { emitDeleteModel: (modelName: string) => {
// socketio.emit('deleteModel', modelName); socketio.emit('deleteModel', modelName);
// }, },
// emitConvertToDiffusers: ( emitConvertToDiffusers: (
// modelToConvert: InvokeAI.InvokeModelConversionProps modelToConvert: InvokeAI.InvokeModelConversionProps
// ) => { ) => {
// dispatch(modelConvertRequested()); dispatch(modelConvertRequested());
// socketio.emit('convertToDiffusers', modelToConvert); socketio.emit('convertToDiffusers', modelToConvert);
// }, },
// emitMergeDiffusersModels: ( emitMergeDiffusersModels: (
// modelMergeInfo: InvokeAI.InvokeModelMergingProps modelMergeInfo: InvokeAI.InvokeModelMergingProps
// ) => { ) => {
// dispatch(modelMergingRequested()); dispatch(modelMergingRequested());
// socketio.emit('mergeDiffusersModels', modelMergeInfo); socketio.emit('mergeDiffusersModels', modelMergeInfo);
// }, },
// emitRequestModelChange: (modelName: string) => { emitRequestModelChange: (modelName: string) => {
// dispatch(modelChangeRequested()); dispatch(modelChangeRequested());
// socketio.emit('requestModelChange', modelName); socketio.emit('requestModelChange', modelName);
// }, },
// emitSaveStagingAreaImageToGallery: (url: string) => { emitSaveStagingAreaImageToGallery: (url: string) => {
// socketio.emit('requestSaveStagingAreaImageToGallery', url); socketio.emit('requestSaveStagingAreaImageToGallery', url);
// }, },
// emitRequestEmptyTempFolder: () => { emitRequestEmptyTempFolder: () => {
// socketio.emit('requestEmptyTempFolder'); socketio.emit('requestEmptyTempFolder');
// }, },
// }; };
// }; };
// export default makeSocketIOEmitters; export default makeSocketIOEmitters;
export default {}; export default {};

View File

@ -0,0 +1,4 @@
import { createAction } from '@reduxjs/toolkit';
import { InvokeTabName } from 'features/ui/store/tabMap';
export const userInvoked = createAction<InvokeTabName>('app/userInvoked');

View File

@ -0,0 +1,8 @@
export const LOCALSTORAGE_KEYS = [
'chakra-ui-color-mode',
'i18nextLng',
'ROARR_FILTER',
'ROARR_LOG',
];
export const LOCALSTORAGE_PREFIX = '@@invokeai-';

View File

@ -0,0 +1,36 @@
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
import { resultsPersistDenylist } from 'features/gallery/store/resultsPersistDenylist';
import { uploadsPersistDenylist } from 'features/gallery/store/uploadsPersistDenylist';
import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
import { omit } from 'lodash-es';
import { SerializeFunction } from 'redux-remember';
const serializationDenylist: {
[key: string]: string[];
} = {
canvas: canvasPersistDenylist,
gallery: galleryPersistDenylist,
generation: generationPersistDenylist,
lightbox: lightboxPersistDenylist,
models: modelsPersistDenylist,
nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist,
results: resultsPersistDenylist,
system: systemPersistDenylist,
// config: configPersistDenyList,
ui: uiPersistDenylist,
uploads: uploadsPersistDenylist,
// hotkeys: hotkeysPersistDenylist,
};
export const serialize: SerializeFunction = (data, key) => {
const result = omit(data, serializationDenylist[key]);
return JSON.stringify(result);
};

View File

@ -0,0 +1,38 @@
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
import { initialResultsState } from 'features/gallery/store/resultsSlice';
import { initialUploadsState } from 'features/gallery/store/uploadsSlice';
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
import { initialConfigState } from 'features/system/store/configSlice';
import { initialModelsState } from 'features/system/store/modelSlice';
import { initialSystemState } from 'features/system/store/systemSlice';
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
import { initialUIState } from 'features/ui/store/uiSlice';
import { defaultsDeep } from 'lodash-es';
import { UnserializeFunction } from 'redux-remember';
const initialStates: {
[key: string]: any;
} = {
canvas: initialCanvasState,
gallery: initialGalleryState,
generation: initialGenerationState,
lightbox: initialLightboxState,
models: initialModelsState,
nodes: initialNodesState,
postprocessing: initialPostprocessingState,
results: initialResultsState,
system: initialSystemState,
config: initialConfigState,
ui: initialUIState,
uploads: initialUploadsState,
hotkeys: initialHotkeysState,
};
export const unserialize: UnserializeFunction = (data, key) => {
const result = defaultsDeep(JSON.parse(data), initialStates[key]);
return result;
};

View File

@ -0,0 +1,30 @@
import { AnyAction } from '@reduxjs/toolkit';
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { forEach } from 'lodash-es';
import { Graph } from 'services/api';
export const actionSanitizer = <A extends AnyAction>(action: A): A => {
if (isAnyGraphBuilt(action)) {
if (action.payload.nodes) {
const sanitizedNodes: Graph['nodes'] = {};
// Sanitize nodes as needed
forEach(action.payload.nodes, (node, key) => {
// Don't log the whole freaking dataURL
if (node.type === 'dataURL_image') {
const { dataURL, ...rest } = node;
sanitizedNodes[key] = { ...rest, dataURL: '<dataURL>' };
} else {
sanitizedNodes[key] = { ...node };
}
});
return {
...action,
payload: { ...action.payload, nodes: sanitizedNodes },
};
}
}
return action;
};

View File

@ -0,0 +1,11 @@
export const actionsDenylist = [
'canvas/setCursorPosition',
'canvas/setStageCoordinates',
'canvas/setStageScale',
'canvas/setIsDrawing',
'canvas/setBoundingBoxCoordinates',
'canvas/setBoundingBoxDimensions',
'canvas/setIsDrawing',
'canvas/addPointToCurrentLine',
'socket/generatorProgress',
];

View File

@ -0,0 +1,3 @@
export const stateSanitizer = <S>(state: S): S => {
return state;
};

View File

@ -0,0 +1,45 @@
import {
createListenerMiddleware,
addListener,
ListenerEffect,
AnyAction,
} from '@reduxjs/toolkit';
import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit';
import type { RootState, AppDispatch } from '../../store';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
import { addImageResultReceivedListener } from './listeners/invocationComplete';
import { addImageUploadedListener } from './listeners/imageUploaded';
import { addRequestedImageDeletionListener } from './listeners/imageDeleted';
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
export const startAppListening =
listenerMiddleware.startListening as AppStartListening;
export const addAppListener = addListener as TypedAddListener<
RootState,
AppDispatch
>;
export type AppListenerEffect = ListenerEffect<
AnyAction,
RootState,
AppDispatch
>;
addImageUploadedListener();
addInitialImageSelectedListener();
addImageResultReceivedListener();
addRequestedImageDeletionListener();
addUserInvokedCanvasListener();
addUserInvokedNodesListener();
addUserInvokedTextToImageListener();
addUserInvokedImageToImageListener();

View File

@ -0,0 +1,31 @@
import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { startAppListening } from '..';
import {
canvasSessionIdChanged,
stagingAreaInitialized,
} from 'features/canvas/store/canvasSlice';
import { sessionInvoked } from 'services/thunks/session';
export const addCanvasGraphBuiltListener = () =>
startAppListening({
actionCreator: canvasGraphBuilt,
effect: async (action, { dispatch, getState, take }) => {
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
const { sessionId } = meta.arg;
const state = getState();
if (!state.canvas.layerState.stagingArea.boundingBox) {
dispatch(
stagingAreaInitialized({
sessionId,
boundingBox: {
...state.canvas.boundingBoxCoordinates,
...state.canvas.boundingBoxDimensions,
},
})
);
}
dispatch(canvasSessionIdChanged(sessionId));
},
});

View File

@ -0,0 +1,59 @@
import { requestedImageDeletion } from 'features/gallery/store/actions';
import { startAppListening } from '..';
import { imageDeleted } from 'services/thunks/image';
import { log } from 'app/logging/useLogger';
import { clamp } from 'lodash-es';
import { imageSelected } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
export const addRequestedImageDeletionListener = () => {
startAppListening({
actionCreator: requestedImageDeletion,
effect: (action, { dispatch, getState }) => {
const image = action.payload;
if (!image) {
moduleLog.warn('No image provided');
return;
}
const { name, type } = image;
if (type !== 'uploads' && type !== 'results') {
moduleLog.warn({ data: image }, `Invalid image type ${type}`);
return;
}
const selectedImageName = getState().gallery.selectedImage?.name;
if (selectedImageName === name) {
const allIds = getState()[type].ids;
const allEntities = getState()[type].entities;
const deletedImageIndex = allIds.findIndex(
(result) => result.toString() === name
);
const filteredIds = allIds.filter((id) => id.toString() !== name);
const newSelectedImageIndex = clamp(
deletedImageIndex,
0,
filteredIds.length - 1
);
const newSelectedImageId = filteredIds[newSelectedImageIndex];
const newSelectedImage = allEntities[newSelectedImageId];
if (newSelectedImageId) {
dispatch(imageSelected(newSelectedImage));
} else {
dispatch(imageSelected());
}
}
dispatch(imageDeleted({ imageName: name, imageType: type }));
},
});
};

View File

@ -0,0 +1,25 @@
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
import { startAppListening } from '..';
import { uploadAdded } from 'features/gallery/store/uploadsSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imageUploaded } from 'services/thunks/image';
export const addImageUploadedListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.payload.response.image_type !== 'intermediates',
effect: (action, { dispatch, getState }) => {
const { response } = action.payload;
const state = getState();
const image = deserializeImageResponse(response);
dispatch(uploadAdded(image));
if (state.gallery.shouldAutoSwitchToNewImages) {
dispatch(imageSelected(image));
}
},
});
};

View File

@ -0,0 +1,54 @@
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { Image, isInvokeAIImage } from 'app/types/invokeai';
import { selectResultsById } from 'features/gallery/store/resultsSlice';
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
import { makeToast } from 'features/system/hooks/useToastWatcher';
import { t } from 'i18next';
import { addToast } from 'features/system/store/systemSlice';
import { startAppListening } from '..';
import { initialImageSelected } from 'features/parameters/store/actions';
export const addInitialImageSelectedListener = () => {
startAppListening({
actionCreator: initialImageSelected,
effect: (action, { getState, dispatch }) => {
if (!action.payload) {
dispatch(
addToast(
makeToast({ title: t('toast.imageNotLoadedDesc'), status: 'error' })
)
);
return;
}
if (isInvokeAIImage(action.payload)) {
dispatch(initialImageChanged(action.payload));
dispatch(addToast(makeToast(t('toast.sentToImageToImage'))));
return;
}
const { name, type } = action.payload;
let image: Image | undefined;
const state = getState();
if (type === 'results') {
image = selectResultsById(state, name);
} else if (type === 'uploads') {
image = selectUploadsById(state, name);
}
if (!image) {
dispatch(
addToast(
makeToast({ title: t('toast.imageNotLoadedDesc'), status: 'error' })
)
);
return;
}
dispatch(initialImageChanged(image));
dispatch(addToast(makeToast(t('toast.sentToImageToImage'))));
},
});
};

View File

@ -0,0 +1,88 @@
import { invocationComplete } from 'services/events/actions';
import { isImageOutput } from 'services/types/guards';
import {
buildImageUrls,
extractTimestampFromImageName,
} from 'services/util/deserializeImageField';
import { Image } from 'app/types/invokeai';
import { resultAdded } from 'features/gallery/store/resultsSlice';
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
import { startAppListening } from '..';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
const nodeDenylist = ['dataURL_image'];
export const addImageResultReceivedListener = () => {
startAppListening({
predicate: (action) => {
if (
invocationComplete.match(action) &&
isImageOutput(action.payload.data.result)
) {
return true;
}
return false;
},
effect: (action, { getState, dispatch }) => {
if (!invocationComplete.match(action)) {
return;
}
const { data, shouldFetchImages } = action.payload;
const { result, node, graph_execution_state_id } = data;
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
const name = result.image.image_name;
const type = result.image.image_type;
const state = getState();
// if we need to refetch, set URLs to placeholder for now
const { url, thumbnail } = shouldFetchImages
? { url: '', thumbnail: '' }
: buildImageUrls(type, name);
const timestamp = extractTimestampFromImageName(name);
const image: Image = {
name,
type,
url,
thumbnail,
metadata: {
created: timestamp,
width: result.width,
height: result.height,
invokeai: {
session_id: graph_execution_state_id,
...(node ? { node } : {}),
},
},
};
dispatch(resultAdded(image));
if (state.gallery.shouldAutoSwitchToNewImages) {
dispatch(imageSelected(image));
}
if (state.config.shouldFetchImages) {
dispatch(imageReceived({ imageName: name, imageType: type }));
dispatch(
thumbnailReceived({
thumbnailName: name,
thumbnailType: type,
})
);
}
if (
graph_execution_state_id ===
state.canvas.layerState.stagingArea.sessionId
) {
dispatch(addImageToStagingArea(image));
}
}
},
});
};

View File

@ -0,0 +1,126 @@
import { startAppListening } from '..';
import { sessionCreated, sessionInvoked } from 'services/thunks/session';
import { buildCanvasGraphAndBlobs } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
import { log } from 'app/logging/useLogger';
import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { imageUploaded } from 'services/thunks/image';
import { v4 as uuidv4 } from 'uuid';
import { Graph } from 'services/api';
import {
canvasSessionIdChanged,
stagingAreaInitialized,
} from 'features/canvas/store/canvasSlice';
import { userInvoked } from 'app/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
export const addUserInvokedCanvasListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'unifiedCanvas',
effect: async (action, { getState, dispatch, take }) => {
const state = getState();
const data = await buildCanvasGraphAndBlobs(state);
if (!data) {
moduleLog.error('Problem building graph');
return;
}
const {
rangeNode,
iterateNode,
baseNode,
edges,
baseBlob,
maskBlob,
generationMode,
} = data;
const baseFilename = `${uuidv4()}.png`;
const maskFilename = `${uuidv4()}.png`;
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
},
})
);
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
const [{ payload: basePayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === baseFilename
);
const { image_name: baseName, image_type: baseType } =
basePayload.response;
baseNode.image = {
image_name: baseName,
image_type: baseType,
};
}
if (baseNode.type === 'inpaint') {
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
},
})
);
const [{ payload: maskPayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === maskFilename
);
const { image_name: maskName, image_type: maskType } =
maskPayload.response;
baseNode.mask = {
image_name: maskName,
image_type: maskType,
};
}
// Assemble!
const nodes: Graph['nodes'] = {
[rangeNode.id]: rangeNode,
[iterateNode.id]: iterateNode,
[baseNode.id]: baseNode,
};
const graph = { nodes, edges };
dispatch(canvasGraphBuilt(graph));
moduleLog({ data: graph }, 'Canvas graph built');
dispatch(sessionCreated({ graph }));
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
const { sessionId } = meta.arg;
if (!state.canvas.layerState.stagingArea.boundingBox) {
dispatch(
stagingAreaInitialized({
sessionId,
boundingBox: {
...state.canvas.boundingBoxCoordinates,
...state.canvas.boundingBoxDimensions,
},
})
);
}
dispatch(canvasSessionIdChanged(sessionId));
},
});
};

View File

@ -0,0 +1,24 @@
import { startAppListening } from '..';
import { buildImageToImageGraph } from 'features/nodes/util/graphBuilders/buildImageToImageGraph';
import { sessionCreated } from 'services/thunks/session';
import { log } from 'app/logging/useLogger';
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
export const addUserInvokedImageToImageListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'img2img',
effect: (action, { getState, dispatch }) => {
const state = getState();
const graph = buildImageToImageGraph(state);
dispatch(imageToImageGraphBuilt(graph));
moduleLog({ data: graph }, 'Image to Image graph built');
dispatch(sessionCreated({ graph }));
},
});
};

View File

@ -0,0 +1,24 @@
import { startAppListening } from '..';
import { sessionCreated } from 'services/thunks/session';
import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph';
import { log } from 'app/logging/useLogger';
import { nodesGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
export const addUserInvokedNodesListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'nodes',
effect: (action, { getState, dispatch }) => {
const state = getState();
const graph = buildNodesGraph(state);
dispatch(nodesGraphBuilt(graph));
moduleLog({ data: graph }, 'Nodes graph built');
dispatch(sessionCreated({ graph }));
},
});
};

View File

@ -0,0 +1,24 @@
import { startAppListening } from '..';
import { buildTextToImageGraph } from 'features/nodes/util/graphBuilders/buildTextToImageGraph';
import { sessionCreated } from 'services/thunks/session';
import { log } from 'app/logging/useLogger';
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
export const addUserInvokedTextToImageListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'txt2img',
effect: (action, { getState, dispatch }) => {
const state = getState();
const graph = buildTextToImageGraph(state);
dispatch(textToImageGraphBuilt(graph));
moduleLog({ data: graph }, 'Text to Image graph built');
dispatch(sessionCreated({ graph }));
},
});
};

View File

@ -1,4 +0,0 @@
import { store } from 'app/store/store';
import { persistStore } from 'redux-persist';
export const persistor = persistStore(store);

View File

@ -1,9 +1,12 @@
import { combineReducers, configureStore } from '@reduxjs/toolkit'; import {
AnyAction,
ThunkDispatch,
combineReducers,
configureStore,
} from '@reduxjs/toolkit';
import { persistReducer } from 'redux-persist'; import { rememberReducer, rememberEnhancer } from 'redux-remember';
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
import dynamicMiddlewares from 'redux-dynamic-middlewares'; import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { getPersistConfig } from 'redux-deep-persist';
import canvasReducer from 'features/canvas/store/canvasSlice'; import canvasReducer from 'features/canvas/store/canvasSlice';
import galleryReducer from 'features/gallery/store/gallerySlice'; import galleryReducer from 'features/gallery/store/gallerySlice';
@ -19,33 +22,17 @@ import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import modelsReducer from 'features/system/store/modelSlice'; import modelsReducer from 'features/system/store/modelSlice';
import nodesReducer from 'features/nodes/store/nodesSlice'; import nodesReducer from 'features/nodes/store/nodesSlice';
import { canvasDenylist } from 'features/canvas/store/canvasPersistDenylist'; import { listenerMiddleware } from './middleware/listenerMiddleware';
import { galleryDenylist } from 'features/gallery/store/galleryPersistDenylist';
import { generationDenylist } from 'features/parameters/store/generationPersistDenylist';
import { lightboxDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
import { modelsDenylist } from 'features/system/store/modelsPersistDenylist';
import { nodesDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { postprocessingDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
import { systemDenylist } from 'features/system/store/systemPersistDenylist';
import { uiDenylist } from 'features/ui/store/uiPersistDenylist';
import { resultsDenylist } from 'features/gallery/store/resultsPersistDenylist';
import { uploadsDenylist } from 'features/gallery/store/uploadsPersistDenylist';
/** import { actionSanitizer } from './middleware/devtools/actionSanitizer';
* redux-persist provides an easy and reliable way to persist state across reloads. import { stateSanitizer } from './middleware/devtools/stateSanitizer';
* import { actionsDenylist } from './middleware/devtools/actionsDenylist';
* While we definitely want generation parameters to be persisted, there are a number
* of things we do *not* want to be persisted across reloads:
* - Gallery/selected image (user may add/delete images from disk between page loads)
* - Connection/processing status
* - Availability of external libraries like ESRGAN/GFPGAN
*
* These can be denylisted in redux-persist.
*
* The necesssary nested persistors with denylists are configured below.
*/
const rootReducer = combineReducers({ import { serialize } from './enhancers/reduxRemember/serialize';
import { unserialize } from './enhancers/reduxRemember/unserialize';
import { LOCALSTORAGE_PREFIX } from './constants';
const allReducers = {
canvas: canvasReducer, canvas: canvasReducer,
gallery: galleryReducer, gallery: galleryReducer,
generation: generationReducer, generation: generationReducer,
@ -59,65 +46,54 @@ const rootReducer = combineReducers({
ui: uiReducer, ui: uiReducer,
uploads: uploadsReducer, uploads: uploadsReducer,
hotkeys: hotkeysReducer, hotkeys: hotkeysReducer,
}); };
const rootPersistConfig = getPersistConfig({ const rootReducer = combineReducers(allReducers);
key: 'root',
storage,
rootReducer,
blacklist: [
...canvasDenylist,
...galleryDenylist,
...generationDenylist,
...lightboxDenylist,
...modelsDenylist,
...nodesDenylist,
...postprocessingDenylist,
// ...resultsDenylist,
'results',
...systemDenylist,
...uiDenylist,
// ...uploadsDenylist,
'uploads',
'hotkeys',
'config',
],
});
const persistedReducer = persistReducer(rootPersistConfig, rootReducer); const rememberedRootReducer = rememberReducer(rootReducer);
// TODO: rip the old middleware out when nodes is complete const rememberedKeys: (keyof typeof allReducers)[] = [
// export function buildMiddleware() { 'canvas',
// if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') { 'gallery',
// return socketMiddleware(); 'generation',
// } else { 'lightbox',
// return socketioMiddleware(); // 'models',
// } 'nodes',
// } 'postprocessing',
'system',
'ui',
// 'hotkeys',
// 'results',
// 'uploads',
// 'config',
];
export const store = configureStore({ export const store = configureStore({
reducer: persistedReducer, reducer: rememberedRootReducer,
enhancers: [
rememberEnhancer(window.localStorage, rememberedKeys, {
persistDebounce: 300,
serialize,
unserialize,
prefix: LOCALSTORAGE_PREFIX,
}),
],
middleware: (getDefaultMiddleware) => middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({ getDefaultMiddleware({
immutableCheck: false, immutableCheck: false,
serializableCheck: false, serializableCheck: false,
}).concat(dynamicMiddlewares), })
.concat(dynamicMiddlewares)
.prepend(listenerMiddleware.middleware),
devTools: { devTools: {
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable actionsDenylist,
actionsDenylist: [ actionSanitizer,
'canvas/setCursorPosition', stateSanitizer,
'canvas/setStageCoordinates', trace: true,
'canvas/setStageScale',
'canvas/setIsDrawing',
'canvas/setBoundingBoxCoordinates',
'canvas/setBoundingBoxDimensions',
'canvas/setIsDrawing',
'canvas/addPointToCurrentLine',
'socket/generatorProgress',
],
}, },
}); });
export type AppGetState = typeof store.getState; export type AppGetState = typeof store.getState;
export type RootState = ReturnType<typeof store.getState>; export type RootState = ReturnType<typeof store.getState>;
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
export type AppDispatch = typeof store.dispatch; export type AppDispatch = typeof store.dispatch;

View File

@ -1,6 +1,6 @@
import { TypedUseSelectorHook, useDispatch, useSelector } from 'react-redux'; import { TypedUseSelectorHook, useDispatch, useSelector } from 'react-redux';
import { AppDispatch, RootState } from 'app/store/store'; import { AppThunkDispatch, RootState } from 'app/store/store';
// Use throughout your app instead of plain `useDispatch` and `useSelector` // Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch: () => AppDispatch = useDispatch; export const useAppDispatch = () => useDispatch<AppThunkDispatch>();
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector; export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;

View File

@ -0,0 +1,7 @@
import { isEqual } from 'lodash-es';
export const defaultSelectorOptions = {
memoizeOptions: {
resultEqualityCheck: isEqual,
},
};

View File

@ -12,12 +12,10 @@
* 'gfpgan'. * 'gfpgan'.
*/ */
import { GalleryCategory } from 'features/gallery/store/gallerySlice'; import { SelectedImage } from 'features/parameters/store/actions';
import { FacetoolType } from 'features/parameters/store/postprocessingSlice';
import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types'; import { IRect } from 'konva/lib/types';
import { ImageResponseMetadata, ImageType } from 'services/api'; import { ImageResponseMetadata, ImageType } from 'services/api';
import { AnyInvocation } from 'services/events/types';
import { O } from 'ts-toolbelt'; import { O } from 'ts-toolbelt';
/** /**
@ -49,15 +47,20 @@ export type CommonGeneratedImageMetadata = {
postprocessing: null | Array<ESRGANMetadata | FacetoolMetadata>; postprocessing: null | Array<ESRGANMetadata | FacetoolMetadata>;
sampler: sampler:
| 'ddim' | 'ddim'
| 'k_dpm_2_a' | 'ddpm'
| 'k_dpm_2' | 'deis'
| 'k_dpmpp_2_a' | 'lms'
| 'k_dpmpp_2' | 'pndm'
| 'k_euler_a' | 'heun'
| 'k_euler' | 'euler'
| 'k_heun' | 'euler_k'
| 'k_lms' | 'euler_a'
| 'plms'; | 'kdpm_2'
| 'kdpm_2_a'
| 'dpmpp_2s'
| 'dpmpp_2m'
| 'dpmpp_2m_k'
| 'unipc';
prompt: Prompt; prompt: Prompt;
seed: number; seed: number;
variations: SeedWeights; variations: SeedWeights;
@ -126,6 +129,14 @@ export type Image = {
metadata: ImageResponseMetadata; metadata: ImageResponseMetadata;
}; };
export const isInvokeAIImage = (obj: Image | SelectedImage): obj is Image => {
if ('url' in obj && 'thumbnail' in obj) {
return true;
}
return false;
};
/** /**
* Types related to the system status. * Types related to the system status.
*/ */
@ -270,7 +281,7 @@ export type FoundModelResponse = {
// export type SystemConfigResponse = SystemConfig; // export type SystemConfigResponse = SystemConfig;
export type ImageResultResponse = Omit<_Image, 'uuid'> & { export type ImageResultResponse = Omit<Image, 'uuid'> & {
boundingBox?: IRect; boundingBox?: IRect;
generationMode: InvokeTabName; generationMode: InvokeTabName;
}; };
@ -315,11 +326,11 @@ export type AppFeature =
/** /**
* A disable-able Stable Diffusion feature * A disable-able Stable Diffusion feature
*/ */
export type StableDiffusionFeature = export type SDFeature =
| 'noiseConfig' | 'noise'
| 'variations' | 'variation'
| 'symmetry' | 'symmetry'
| 'tiling' | 'seamless'
| 'hires'; | 'hires';
/** /**
@ -337,6 +348,7 @@ export type AppConfig = {
shouldFetchImages: boolean; shouldFetchImages: boolean;
disabledTabs: InvokeTabName[]; disabledTabs: InvokeTabName[];
disabledFeatures: AppFeature[]; disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];
canRestoreDeletedImagesFromBin: boolean; canRestoreDeletedImagesFromBin: boolean;
sd: { sd: {
iterations: { iterations: {

View File

@ -0,0 +1,61 @@
import { ChevronUpIcon } from '@chakra-ui/icons';
import { Box, Collapse, Flex, Spacer, Switch } from '@chakra-ui/react';
import { PropsWithChildren, memo } from 'react';
export type IAIToggleCollapseProps = PropsWithChildren & {
label: string;
isOpen: boolean;
onToggle: () => void;
withSwitch?: boolean;
};
const IAICollapse = (props: IAIToggleCollapseProps) => {
const { label, isOpen, onToggle, children, withSwitch = false } = props;
return (
<Box>
<Flex
onClick={onToggle}
sx={{
alignItems: 'center',
p: 2,
px: 4,
borderTopRadius: 'base',
borderBottomRadius: isOpen ? 0 : 'base',
bg: isOpen ? 'base.750' : 'base.800',
color: 'base.100',
_hover: {
bg: isOpen ? 'base.700' : 'base.750',
},
fontSize: 'sm',
fontWeight: 600,
cursor: 'pointer',
transitionProperty: 'common',
transitionDuration: 'normal',
userSelect: 'none',
}}
>
{label}
<Spacer />
{withSwitch && <Switch isChecked={isOpen} pointerEvents="none" />}
{!withSwitch && (
<ChevronUpIcon
sx={{
w: '1rem',
h: '1rem',
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
transitionProperty: 'common',
transitionDuration: 'normal',
}}
/>
)}
</Flex>
<Collapse in={isOpen} animateOpacity>
<Box sx={{ p: 4, borderBottomRadius: 'base', bg: 'base.800' }}>
{children}
</Box>
</Collapse>
</Box>
);
};
export default memo(IAICollapse);

View File

@ -27,7 +27,7 @@ const IAIPopover = (props: IAIPopoverProps) => {
return ( return (
<Popover isLazy={isLazy} {...rest}> <Popover isLazy={isLazy} {...rest}>
<PopoverTrigger>{triggerComponent}</PopoverTrigger> <PopoverTrigger>{triggerComponent}</PopoverTrigger>
<PopoverContent> <PopoverContent shadow="dark-lg">
{hasArrow && <PopoverArrow />} {hasArrow && <PopoverArrow />}
{children} {children}
</PopoverContent> </PopoverContent>

View File

@ -0,0 +1,54 @@
import { Badge, Flex } from '@chakra-ui/react';
import { Image } from 'app/types/invokeai';
import { isNumber, isString } from 'lodash-es';
import { useMemo } from 'react';
type ImageMetadataOverlayProps = {
image: Image;
};
const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => {
const dimensions = useMemo(() => {
if (!isNumber(image.metadata?.width) || isNumber(!image.metadata?.height)) {
return;
}
return `${image.metadata?.width} × ${image.metadata?.height}`;
}, [image.metadata]);
const model = useMemo(() => {
if (!isString(image.metadata?.invokeai?.node?.model)) {
return;
}
return image.metadata?.invokeai?.node?.model;
}, [image.metadata]);
return (
<Flex
sx={{
pointerEvents: 'none',
flexDirection: 'column',
position: 'absolute',
top: 0,
right: 0,
p: 2,
alignItems: 'flex-end',
gap: 2,
}}
>
{dimensions && (
<Badge variant="solid" colorScheme="base">
{dimensions}
</Badge>
)}
{model && (
<Badge variant="solid" colorScheme="base">
{model}
</Badge>
)}
</Flex>
);
};
export default ImageMetadataOverlay;

View File

@ -7,7 +7,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { clearInitialImage } from 'features/parameters/store/generationSlice';
const ImageToImageSettingsHeader = () => { const InitialImageButtons = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -18,24 +18,19 @@ const ImageToImageSettingsHeader = () => {
return ( return (
<Flex w="full" alignItems="center"> <Flex w="full" alignItems="center">
<Text size="sm" fontWeight={500} color="base.300"> <Text size="sm" fontWeight={500} color="base.300">
Image to Image {t('parameters.initialImage')}
</Text> </Text>
<Spacer /> <Spacer />
<ButtonGroup> <ButtonGroup>
<IAIIconButton <IAIIconButton
size="sm"
icon={<FaUndo />} icon={<FaUndo />}
aria-label={t('accessibility.reset')} aria-label={t('accessibility.reset')}
onClick={handleResetInitialImage} onClick={handleResetInitialImage}
/> />
<IAIIconButton <IAIIconButton icon={<FaUpload />} aria-label={t('common.upload')} />
size="sm"
icon={<FaUpload />}
aria-label={t('common.upload')}
/>
</ButtonGroup> </ButtonGroup>
</Flex> </Flex>
); );
}; };
export default ImageToImageSettingsHeader; export default InitialImageButtons;

View File

@ -1,36 +0,0 @@
import { Badge, Box, Flex } from '@chakra-ui/react';
import { Image } from 'app/types/invokeai';
type ImageToImageOverlayProps = {
image: Image;
};
const ImageToImageOverlay = ({ image }: ImageToImageOverlayProps) => {
return (
<Box
sx={{
top: 0,
left: 0,
w: 'full',
h: 'full',
position: 'absolute',
}}
>
<Flex
sx={{
position: 'absolute',
top: 0,
right: 0,
p: 2,
alignItems: 'flex-start',
}}
>
<Badge variant="solid" colorScheme="base">
{image.metadata?.width} × {image.metadata?.height}
</Badge>
</Flex>
</Box>
);
};
export default ImageToImageOverlay;

View File

@ -49,7 +49,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
const fileAcceptedCallback = useCallback( const fileAcceptedCallback = useCallback(
async (file: File) => { async (file: File) => {
dispatch(imageUploaded({ formData: { file } })); dispatch(imageUploaded({ imageType: 'uploads', formData: { file } }));
}, },
[dispatch] [dispatch]
); );
@ -124,7 +124,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
return; return;
} }
dispatch(imageUploaded({ formData: { file } })); dispatch(imageUploaded({ imageType: 'uploads', formData: { file } }));
}; };
document.addEventListener('paste', pasteImageListener); document.addEventListener('paste', pasteImageListener);
return () => { return () => {

View File

@ -7,7 +7,7 @@ const SelectImagePlaceholder = () => {
sx={{ sx={{
w: 'full', w: 'full',
h: 'full', h: 'full',
bg: 'base.800', // bg: 'base.800',
borderRadius: 'base', borderRadius: 'base',
alignItems: 'center', alignItems: 'center',
justifyContent: 'center', justifyContent: 'center',

View File

@ -2,6 +2,13 @@ import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import {
setActiveTab,
toggleGalleryPanel,
toggleParametersPanel,
togglePinGalleryPanel,
togglePinParametersPanel,
} from 'features/ui/store/uiSlice';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook'; import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
@ -36,4 +43,36 @@ export const useGlobalHotkeys = () => {
{ keyup: true, keydown: true }, { keyup: true, keydown: true },
[shift] [shift]
); );
useHotkeys('o', () => {
dispatch(toggleParametersPanel());
});
useHotkeys(['shift+o'], () => {
dispatch(togglePinParametersPanel());
});
useHotkeys('g', () => {
dispatch(toggleGalleryPanel());
});
useHotkeys(['shift+g'], () => {
dispatch(togglePinGalleryPanel());
});
useHotkeys('1', () => {
dispatch(setActiveTab('txt2img'));
});
useHotkeys('2', () => {
dispatch(setActiveTab('img2img'));
});
useHotkeys('3', () => {
dispatch(setActiveTab('unifiedCanvas'));
});
useHotkeys('4', () => {
dispatch(setActiveTab('nodes'));
});
}; };

View File

@ -0,0 +1,33 @@
export const getImageDataTransparency = (pixels: Uint8ClampedArray) => {
let isFullyTransparent = true;
let isPartiallyTransparent = false;
const len = pixels.length;
let i = 3;
for (i; i < len; i += 4) {
if (pixels[i] === 255) {
isFullyTransparent = false;
} else {
isPartiallyTransparent = true;
}
if (!isFullyTransparent && isPartiallyTransparent) {
return { isFullyTransparent, isPartiallyTransparent };
}
}
return { isFullyTransparent, isPartiallyTransparent };
};
export const areAnyPixelsBlack = (pixels: Uint8ClampedArray) => {
const len = pixels.length;
let i = 0;
for (i; i < len; ) {
if (
pixels[i++] === 0 &&
pixels[i++] === 0 &&
pixels[i++] === 0 &&
pixels[i++] === 255
) {
return true;
}
}
return false;
};

View File

@ -19,6 +19,7 @@ import { InvokeTabName } from 'features/ui/store/tabMap';
import openBase64ImageInTab from './openBase64ImageInTab'; import openBase64ImageInTab from './openBase64ImageInTab';
import randomInt from './randomInt'; import randomInt from './randomInt';
import { stringToSeedWeightsArray } from './seedWeightPairs'; import { stringToSeedWeightsArray } from './seedWeightPairs';
import { getIsImageDataTransparent, getIsImageDataWhite } from './arrayBuffer';
export type FrontendToBackendParametersConfig = { export type FrontendToBackendParametersConfig = {
generationMode: InvokeTabName; generationMode: InvokeTabName;
@ -256,7 +257,7 @@ export const frontendToBackendParameters = (
...boundingBoxDimensions, ...boundingBoxDimensions,
}; };
const maskDataURL = generateMask( const { dataURL: maskDataURL, imageData: maskImageData } = generateMask(
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [], isMaskEnabled ? objects.filter(isCanvasMaskLine) : [],
boundingBox boundingBox
); );
@ -287,6 +288,17 @@ export const frontendToBackendParameters = (
height: boundingBox.height, height: boundingBox.height,
}); });
const ctx = canvasBaseLayer.getContext();
const imageData = ctx.getImageData(
boundingBox.x + absPos.x,
boundingBox.y + absPos.y,
boundingBox.width,
boundingBox.height
);
const doesBaseHaveTransparency = getIsImageDataTransparent(imageData);
const doesMaskHaveTransparency = getIsImageDataWhite(maskImageData);
if (enableImageDebugging) { if (enableImageDebugging) {
openBase64ImageInTab([ openBase64ImageInTab([
{ base64: maskDataURL, caption: 'mask sent as init_mask' }, { base64: maskDataURL, caption: 'mask sent as init_mask' },

View File

@ -34,6 +34,7 @@ import IAICanvasStagingAreaToolbar from './IAICanvasStagingAreaToolbar';
import IAICanvasStatusText from './IAICanvasStatusText'; import IAICanvasStatusText from './IAICanvasStatusText';
import IAICanvasBoundingBox from './IAICanvasToolbar/IAICanvasBoundingBox'; import IAICanvasBoundingBox from './IAICanvasToolbar/IAICanvasBoundingBox';
import IAICanvasToolPreview from './IAICanvasToolPreview'; import IAICanvasToolPreview from './IAICanvasToolPreview';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
const selector = createSelector( const selector = createSelector(
[canvasSelector, isStagingSelector], [canvasSelector, isStagingSelector],
@ -52,6 +53,7 @@ const selector = createSelector(
shouldShowIntermediates, shouldShowIntermediates,
shouldShowGrid, shouldShowGrid,
shouldRestrictStrokesToBox, shouldRestrictStrokesToBox,
shouldAntialias,
} = canvas; } = canvas;
let stageCursor: string | undefined = 'none'; let stageCursor: string | undefined = 'none';
@ -80,13 +82,10 @@ const selector = createSelector(
tool, tool,
isStaging, isStaging,
shouldShowIntermediates, shouldShowIntermediates,
shouldAntialias,
}; };
}, },
{ defaultSelectorOptions
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
); );
const ChakraStage = chakra(Stage, { const ChakraStage = chakra(Stage, {
@ -106,6 +105,7 @@ const IAICanvas = () => {
tool, tool,
isStaging, isStaging,
shouldShowIntermediates, shouldShowIntermediates,
shouldAntialias,
} = useAppSelector(selector); } = useAppSelector(selector);
useCanvasHotkeys(); useCanvasHotkeys();
@ -190,7 +190,7 @@ const IAICanvas = () => {
id="base" id="base"
ref={canvasBaseLayerRefCallback} ref={canvasBaseLayerRefCallback}
listening={false} listening={false}
imageSmoothingEnabled={false} imageSmoothingEnabled={shouldAntialias}
> >
<IAICanvasObjectRenderer /> <IAICanvasObjectRenderer />
</Layer> </Layer>
@ -201,7 +201,7 @@ const IAICanvas = () => {
<Layer> <Layer>
<IAICanvasBoundingBoxOverlay /> <IAICanvasBoundingBoxOverlay />
</Layer> </Layer>
<Layer id="preview" imageSmoothingEnabled={false}> <Layer id="preview" imageSmoothingEnabled={shouldAntialias}>
{!isStaging && ( {!isStaging && (
<IAICanvasToolPreview <IAICanvasToolPreview
visible={tool !== 'move'} visible={tool !== 'move'}

View File

@ -12,18 +12,20 @@ const selector = createSelector(
[canvasSelector], [canvasSelector],
(canvas) => { (canvas) => {
const { const {
layerState: { layerState,
stagingArea: { images, selectedImageIndex },
},
shouldShowStagingImage, shouldShowStagingImage,
shouldShowStagingOutline, shouldShowStagingOutline,
boundingBoxCoordinates: { x, y }, boundingBoxCoordinates: { x, y },
boundingBoxDimensions: { width, height }, boundingBoxDimensions: { width, height },
} = canvas; } = canvas;
const { selectedImageIndex, images } = layerState.stagingArea;
return { return {
currentStagingAreaImage: currentStagingAreaImage:
images.length > 0 ? images[selectedImageIndex] : undefined, images.length > 0 && selectedImageIndex !== undefined
? images[selectedImageIndex]
: undefined,
isOnFirstImage: selectedImageIndex === 0, isOnFirstImage: selectedImageIndex === 0,
isOnLastImage: selectedImageIndex === images.length - 1, isOnLastImage: selectedImageIndex === images.length - 1,
shouldShowStagingImage, shouldShowStagingImage,

View File

@ -6,6 +6,7 @@ import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover'; import IAIPopover from 'common/components/IAIPopover';
import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { import {
setShouldAntialias,
setShouldAutoSave, setShouldAutoSave,
setShouldCropToBoundingBoxOnSave, setShouldCropToBoundingBoxOnSave,
setShouldDarkenOutsideBoundingBox, setShouldDarkenOutsideBoundingBox,
@ -36,6 +37,7 @@ export const canvasControlsSelector = createSelector(
shouldShowIntermediates, shouldShowIntermediates,
shouldSnapToGrid, shouldSnapToGrid,
shouldRestrictStrokesToBox, shouldRestrictStrokesToBox,
shouldAntialias,
} = canvas; } = canvas;
return { return {
@ -47,6 +49,7 @@ export const canvasControlsSelector = createSelector(
shouldShowIntermediates, shouldShowIntermediates,
shouldSnapToGrid, shouldSnapToGrid,
shouldRestrictStrokesToBox, shouldRestrictStrokesToBox,
shouldAntialias,
}; };
}, },
{ {
@ -69,6 +72,7 @@ const IAICanvasSettingsButtonPopover = () => {
shouldShowIntermediates, shouldShowIntermediates,
shouldSnapToGrid, shouldSnapToGrid,
shouldRestrictStrokesToBox, shouldRestrictStrokesToBox,
shouldAntialias,
} = useAppSelector(canvasControlsSelector); } = useAppSelector(canvasControlsSelector);
useHotkeys( useHotkeys(
@ -148,6 +152,12 @@ const IAICanvasSettingsButtonPopover = () => {
dispatch(setShouldShowCanvasDebugInfo(e.target.checked)) dispatch(setShouldShowCanvasDebugInfo(e.target.checked))
} }
/> />
<IAICheckbox
label={t('unifiedCanvas.antialiasing')}
isChecked={shouldAntialias}
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
/>
<ClearCanvasHistoryButtonModal /> <ClearCanvasHistoryButtonModal />
<EmptyTempFolderButtonModal /> <EmptyTempFolderButtonModal />
</Flex> </Flex>

View File

@ -9,6 +9,12 @@ const itemsToDenylist: (keyof CanvasState)[] = [
'doesCanvasNeedScaling', 'doesCanvasNeedScaling',
]; ];
export const canvasPersistDenylist: (keyof CanvasState)[] = [
'cursorPosition',
'isCanvasInitialized',
'doesCanvasNeedScaling',
];
export const canvasDenylist = itemsToDenylist.map( export const canvasDenylist = itemsToDenylist.map(
(denylistItem) => `canvas.${denylistItem}` (denylistItem) => `canvas.${denylistItem}`
); );

View File

@ -38,7 +38,7 @@ export const initialLayerState: CanvasLayerState = {
}, },
}; };
const initialCanvasState: CanvasState = { export const initialCanvasState: CanvasState = {
boundingBoxCoordinates: { x: 0, y: 0 }, boundingBoxCoordinates: { x: 0, y: 0 },
boundingBoxDimensions: { width: 512, height: 512 }, boundingBoxDimensions: { width: 512, height: 512 },
boundingBoxPreviewFill: { r: 0, g: 0, b: 0, a: 0.5 }, boundingBoxPreviewFill: { r: 0, g: 0, b: 0, a: 0.5 },
@ -66,6 +66,7 @@ const initialCanvasState: CanvasState = {
minimumStageScale: 1, minimumStageScale: 1,
pastLayerStates: [], pastLayerStates: [],
scaledBoundingBoxDimensions: { width: 512, height: 512 }, scaledBoundingBoxDimensions: { width: 512, height: 512 },
shouldAntialias: true,
shouldAutoSave: false, shouldAutoSave: false,
shouldCropToBoundingBoxOnSave: false, shouldCropToBoundingBoxOnSave: false,
shouldDarkenOutsideBoundingBox: false, shouldDarkenOutsideBoundingBox: false,
@ -156,22 +157,20 @@ export const canvasSlice = createSlice({
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => { setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
state.cursorPosition = action.payload; state.cursorPosition = action.payload;
}, },
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI._Image>) => { setInitialCanvasImage: (state, action: PayloadAction<InvokeAI.Image>) => {
const image = action.payload; const image = action.payload;
const { width, height } = image.metadata;
const { stageDimensions } = state; const { stageDimensions } = state;
const newBoundingBoxDimensions = { const newBoundingBoxDimensions = {
width: roundDownToMultiple(clamp(image.width, 64, 512), 64), width: roundDownToMultiple(clamp(width, 64, 512), 64),
height: roundDownToMultiple(clamp(image.height, 64, 512), 64), height: roundDownToMultiple(clamp(height, 64, 512), 64),
}; };
const newBoundingBoxCoordinates = { const newBoundingBoxCoordinates = {
x: roundToMultiple( x: roundToMultiple(width / 2 - newBoundingBoxDimensions.width / 2, 64),
image.width / 2 - newBoundingBoxDimensions.width / 2,
64
),
y: roundToMultiple( y: roundToMultiple(
image.height / 2 - newBoundingBoxDimensions.height / 2, height / 2 - newBoundingBoxDimensions.height / 2,
64 64
), ),
}; };
@ -196,8 +195,8 @@ export const canvasSlice = createSlice({
layer: 'base', layer: 'base',
x: 0, x: 0,
y: 0, y: 0,
width: image.width, width: width,
height: image.height, height: height,
image: image, image: image,
}, },
], ],
@ -208,8 +207,8 @@ export const canvasSlice = createSlice({
const newScale = calculateScale( const newScale = calculateScale(
stageDimensions.width, stageDimensions.width,
stageDimensions.height, stageDimensions.height,
image.width, width,
image.height, height,
STAGE_PADDING_PERCENTAGE STAGE_PADDING_PERCENTAGE
); );
@ -218,8 +217,8 @@ export const canvasSlice = createSlice({
stageDimensions.height, stageDimensions.height,
0, 0,
0, 0,
image.width, width,
image.height, height,
newScale newScale
); );
state.stageScale = newScale; state.stageScale = newScale;
@ -287,16 +286,28 @@ export const canvasSlice = createSlice({
setIsMoveStageKeyHeld: (state, action: PayloadAction<boolean>) => { setIsMoveStageKeyHeld: (state, action: PayloadAction<boolean>) => {
state.isMoveStageKeyHeld = action.payload; state.isMoveStageKeyHeld = action.payload;
}, },
addImageToStagingArea: ( canvasSessionIdChanged: (state, action: PayloadAction<string>) => {
state.layerState.stagingArea.sessionId = action.payload;
},
stagingAreaInitialized: (
state, state,
action: PayloadAction<{ action: PayloadAction<{ sessionId: string; boundingBox: IRect }>
boundingBox: IRect;
image: InvokeAI._Image;
}>
) => { ) => {
const { boundingBox, image } = action.payload; const { sessionId, boundingBox } = action.payload;
if (!boundingBox || !image) return; state.layerState.stagingArea = {
boundingBox,
sessionId,
images: [],
selectedImageIndex: -1,
};
},
addImageToStagingArea: (state, action: PayloadAction<InvokeAI.Image>) => {
const image = action.payload;
if (!image || !state.layerState.stagingArea.boundingBox) {
return;
}
state.pastLayerStates.push(cloneDeep(state.layerState)); state.pastLayerStates.push(cloneDeep(state.layerState));
@ -307,7 +318,7 @@ export const canvasSlice = createSlice({
state.layerState.stagingArea.images.push({ state.layerState.stagingArea.images.push({
kind: 'image', kind: 'image',
layer: 'base', layer: 'base',
...boundingBox, ...state.layerState.stagingArea.boundingBox,
image, image,
}); });
@ -323,9 +334,7 @@ export const canvasSlice = createSlice({
state.pastLayerStates.shift(); state.pastLayerStates.shift();
} }
state.layerState.stagingArea = { state.layerState.stagingArea = { ...initialLayerState.stagingArea };
...initialLayerState.stagingArea,
};
state.futureLayerStates = []; state.futureLayerStates = [];
state.shouldShowStagingOutline = true; state.shouldShowStagingOutline = true;
@ -663,6 +672,10 @@ export const canvasSlice = createSlice({
} }
}, },
nextStagingAreaImage: (state) => { nextStagingAreaImage: (state) => {
if (!state.layerState.stagingArea.images.length) {
return;
}
const currentIndex = state.layerState.stagingArea.selectedImageIndex; const currentIndex = state.layerState.stagingArea.selectedImageIndex;
const length = state.layerState.stagingArea.images.length; const length = state.layerState.stagingArea.images.length;
@ -672,6 +685,10 @@ export const canvasSlice = createSlice({
); );
}, },
prevStagingAreaImage: (state) => { prevStagingAreaImage: (state) => {
if (!state.layerState.stagingArea.images.length) {
return;
}
const currentIndex = state.layerState.stagingArea.selectedImageIndex; const currentIndex = state.layerState.stagingArea.selectedImageIndex;
state.layerState.stagingArea.selectedImageIndex = Math.max( state.layerState.stagingArea.selectedImageIndex = Math.max(
@ -680,6 +697,10 @@ export const canvasSlice = createSlice({
); );
}, },
commitStagingAreaImage: (state) => { commitStagingAreaImage: (state) => {
if (!state.layerState.stagingArea.images.length) {
return;
}
const { images, selectedImageIndex } = state.layerState.stagingArea; const { images, selectedImageIndex } = state.layerState.stagingArea;
state.pastLayerStates.push(cloneDeep(state.layerState)); state.pastLayerStates.push(cloneDeep(state.layerState));
@ -776,6 +797,9 @@ export const canvasSlice = createSlice({
setShouldRestrictStrokesToBox: (state, action: PayloadAction<boolean>) => { setShouldRestrictStrokesToBox: (state, action: PayloadAction<boolean>) => {
state.shouldRestrictStrokesToBox = action.payload; state.shouldRestrictStrokesToBox = action.payload;
}, },
setShouldAntialias: (state, action: PayloadAction<boolean>) => {
state.shouldAntialias = action.payload;
},
setShouldCropToBoundingBoxOnSave: ( setShouldCropToBoundingBoxOnSave: (
state, state,
action: PayloadAction<boolean> action: PayloadAction<boolean>
@ -885,6 +909,9 @@ export const {
undo, undo,
setScaledBoundingBoxDimensions, setScaledBoundingBoxDimensions,
setShouldRestrictStrokesToBox, setShouldRestrictStrokesToBox,
stagingAreaInitialized,
canvasSessionIdChanged,
setShouldAntialias,
} = canvasSlice.actions; } = canvasSlice.actions;
export default canvasSlice.reducer; export default canvasSlice.reducer;

View File

@ -37,7 +37,7 @@ export type CanvasImage = {
y: number; y: number;
width: number; width: number;
height: number; height: number;
image: InvokeAI._Image; image: InvokeAI.Image;
}; };
export type CanvasMaskLine = { export type CanvasMaskLine = {
@ -90,9 +90,16 @@ export type CanvasLayerState = {
stagingArea: { stagingArea: {
images: CanvasImage[]; images: CanvasImage[];
selectedImageIndex: number; selectedImageIndex: number;
sessionId?: string;
boundingBox?: IRect;
}; };
}; };
export type CanvasSession = {
sessionId: string;
boundingBox: IRect;
};
// type guards // type guards
export const isCanvasMaskLine = (obj: CanvasObject): obj is CanvasMaskLine => export const isCanvasMaskLine = (obj: CanvasObject): obj is CanvasMaskLine =>
obj.kind === 'line' && obj.layer === 'mask'; obj.kind === 'line' && obj.layer === 'mask';
@ -125,7 +132,7 @@ export interface CanvasState {
cursorPosition: Vector2d | null; cursorPosition: Vector2d | null;
doesCanvasNeedScaling: boolean; doesCanvasNeedScaling: boolean;
futureLayerStates: CanvasLayerState[]; futureLayerStates: CanvasLayerState[];
intermediateImage?: InvokeAI._Image; intermediateImage?: InvokeAI.Image;
isCanvasInitialized: boolean; isCanvasInitialized: boolean;
isDrawing: boolean; isDrawing: boolean;
isMaskEnabled: boolean; isMaskEnabled: boolean;
@ -142,6 +149,7 @@ export interface CanvasState {
minimumStageScale: number; minimumStageScale: number;
pastLayerStates: CanvasLayerState[]; pastLayerStates: CanvasLayerState[];
scaledBoundingBoxDimensions: Dimensions; scaledBoundingBoxDimensions: Dimensions;
shouldAntialias: boolean;
shouldAutoSave: boolean; shouldAutoSave: boolean;
shouldCropToBoundingBoxOnSave: boolean; shouldCropToBoundingBoxOnSave: boolean;
shouldDarkenOutsideBoundingBox: boolean; shouldDarkenOutsideBoundingBox: boolean;

View File

@ -0,0 +1,13 @@
/**
* Gets a Blob from a canvas.
*/
export const canvasToBlob = async (canvas: HTMLCanvasElement): Promise<Blob> =>
new Promise((resolve, reject) => {
canvas.toBlob((blob) => {
if (blob) {
resolve(blob);
return;
}
reject('Unable to create Blob');
});
});

View File

@ -0,0 +1,29 @@
/**
* Gets an ImageData object from an image dataURL by drawing it to a canvas.
*/
export const dataURLToImageData = async (
dataURL: string,
width: number,
height: number
): Promise<ImageData> =>
new Promise((resolve, reject) => {
const canvas = document.createElement('canvas');
canvas.width = width;
canvas.height = height;
const ctx = canvas.getContext('2d');
const image = new Image();
if (!ctx) {
canvas.remove();
reject('Unable to get context');
return;
}
image.onload = function () {
ctx.drawImage(image, 0, 0);
canvas.remove();
resolve(ctx.getImageData(0, 0, width, height));
};
image.src = dataURL;
});

View File

@ -1,6 +1,110 @@
// import { CanvasMaskLine } from 'features/canvas/store/canvasTypes';
// import Konva from 'konva';
// import { Stage } from 'konva/lib/Stage';
// import { IRect } from 'konva/lib/types';
// /**
// * Generating a mask image from InpaintingCanvas.tsx is not as simple
// * as calling toDataURL() on the canvas, because the mask may be represented
// * by colored lines or transparency, or the user may have inverted the mask
// * display.
// *
// * So we need to regenerate the mask image by creating an offscreen canvas,
// * drawing the mask and compositing everything correctly to output a valid
// * mask image.
// */
// export const getStageDataURL = (stage: Stage, boundingBox: IRect): string => {
// // create an offscreen canvas and add the mask to it
// // const { stage, offscreenContainer } = buildMaskStage(lines, boundingBox);
// const dataURL = stage.toDataURL({ ...boundingBox });
// // const imageData = stage
// // .toCanvas()
// // .getContext('2d')
// // ?.getImageData(
// // boundingBox.x,
// // boundingBox.y,
// // boundingBox.width,
// // boundingBox.height
// // );
// // offscreenContainer.remove();
// // return { dataURL, imageData };
// return dataURL;
// };
// export const getStageImageData = (
// stage: Stage,
// boundingBox: IRect
// ): ImageData | undefined => {
// const imageData = stage
// .toCanvas()
// .getContext('2d')
// ?.getImageData(
// boundingBox.x,
// boundingBox.y,
// boundingBox.width,
// boundingBox.height
// );
// return imageData;
// };
// export const buildMaskStage = (
// lines: CanvasMaskLine[],
// boundingBox: IRect
// ): { stage: Stage; offscreenContainer: HTMLDivElement } => {
// // create an offscreen canvas and add the mask to it
// const { width, height } = boundingBox;
// const offscreenContainer = document.createElement('div');
// const stage = new Konva.Stage({
// container: offscreenContainer,
// width: width,
// height: height,
// });
// const baseLayer = new Konva.Layer();
// const maskLayer = new Konva.Layer();
// // composite the image onto the mask layer
// baseLayer.add(
// new Konva.Rect({
// ...boundingBox,
// fill: 'white',
// })
// );
// lines.forEach((line) =>
// maskLayer.add(
// new Konva.Line({
// points: line.points,
// stroke: 'black',
// strokeWidth: line.strokeWidth * 2,
// tension: 0,
// lineCap: 'round',
// lineJoin: 'round',
// shadowForStrokeEnabled: false,
// globalCompositeOperation:
// line.tool === 'brush' ? 'source-over' : 'destination-out',
// })
// )
// );
// stage.add(baseLayer);
// stage.add(maskLayer);
// return { stage, offscreenContainer };
// };
import { CanvasMaskLine } from 'features/canvas/store/canvasTypes'; import { CanvasMaskLine } from 'features/canvas/store/canvasTypes';
import Konva from 'konva'; import Konva from 'konva';
import { IRect } from 'konva/lib/types'; import { IRect } from 'konva/lib/types';
import { canvasToBlob } from './canvasToBlob';
/** /**
* Generating a mask image from InpaintingCanvas.tsx is not as simple * Generating a mask image from InpaintingCanvas.tsx is not as simple
@ -12,7 +116,7 @@ import { IRect } from 'konva/lib/types';
* drawing the mask and compositing everything correctly to output a valid * drawing the mask and compositing everything correctly to output a valid
* mask image. * mask image.
*/ */
const generateMask = (lines: CanvasMaskLine[], boundingBox: IRect): string => { const generateMask = async (lines: CanvasMaskLine[], boundingBox: IRect) => {
// create an offscreen canvas and add the mask to it // create an offscreen canvas and add the mask to it
const { width, height } = boundingBox; const { width, height } = boundingBox;
@ -54,11 +158,13 @@ const generateMask = (lines: CanvasMaskLine[], boundingBox: IRect): string => {
stage.add(baseLayer); stage.add(baseLayer);
stage.add(maskLayer); stage.add(maskLayer);
const dataURL = stage.toDataURL({ ...boundingBox }); const maskDataURL = stage.toDataURL(boundingBox);
const maskBlob = await canvasToBlob(stage.toCanvas(boundingBox));
offscreenContainer.remove(); offscreenContainer.remove();
return dataURL; return { maskDataURL, maskBlob };
}; };
export default generateMask; export default generateMask;

View File

@ -0,0 +1,128 @@
import { RootState } from 'app/store/store';
import { getCanvasBaseLayer, getCanvasStage } from './konvaInstanceProvider';
import { isCanvasMaskLine } from '../store/canvasTypes';
import { log } from 'app/logging/useLogger';
import {
areAnyPixelsBlack,
getImageDataTransparency,
} from 'common/util/arrayBuffer';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
import generateMask from './generateMask';
import { dataURLToImageData } from './dataURLToImageData';
import { canvasToBlob } from './canvasToBlob';
const moduleLog = log.child({ namespace: 'getCanvasDataURLs' });
export const getCanvasData = async (state: RootState) => {
const canvasBaseLayer = getCanvasBaseLayer();
const canvasStage = getCanvasStage();
if (!canvasBaseLayer || !canvasStage) {
moduleLog.error('Unable to find canvas / stage');
return;
}
const {
layerState: { objects },
boundingBoxCoordinates,
boundingBoxDimensions,
stageScale,
isMaskEnabled,
shouldPreserveMaskedArea,
boundingBoxScaleMethod: boundingBoxScale,
scaledBoundingBoxDimensions,
} = state.canvas;
const boundingBox = {
...boundingBoxCoordinates,
...boundingBoxDimensions,
};
// generationParameters.fit = false;
// generationParameters.strength = img2imgStrength;
// generationParameters.invert_mask = shouldPreserveMaskedArea;
// generationParameters.bounding_box = boundingBox;
const tempScale = canvasBaseLayer.scale();
canvasBaseLayer.scale({
x: 1 / stageScale,
y: 1 / stageScale,
});
const absPos = canvasBaseLayer.getAbsolutePosition();
const offsetBoundingBox = {
x: boundingBox.x + absPos.x,
y: boundingBox.y + absPos.y,
width: boundingBox.width,
height: boundingBox.height,
};
const baseDataURL = canvasBaseLayer.toDataURL(offsetBoundingBox);
const baseBlob = await canvasToBlob(
canvasBaseLayer.toCanvas(offsetBoundingBox)
);
canvasBaseLayer.scale(tempScale);
const { maskDataURL, maskBlob } = await generateMask(
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [],
boundingBox
);
const baseImageData = await dataURLToImageData(
baseDataURL,
boundingBox.width,
boundingBox.height
);
const maskImageData = await dataURLToImageData(
maskDataURL,
boundingBox.width,
boundingBox.height
);
const {
isPartiallyTransparent: baseIsPartiallyTransparent,
isFullyTransparent: baseIsFullyTransparent,
} = getImageDataTransparency(baseImageData.data);
const doesMaskHaveBlackPixels = areAnyPixelsBlack(maskImageData.data);
if (state.system.enableImageDebugging) {
openBase64ImageInTab([
{ base64: maskDataURL, caption: 'mask b64' },
{ base64: baseDataURL, caption: 'image b64' },
]);
}
// generationParameters.init_img = imageDataURL;
// generationParameters.progress_images = false;
// if (boundingBoxScale !== 'none') {
// generationParameters.inpaint_width = scaledBoundingBoxDimensions.width;
// generationParameters.inpaint_height = scaledBoundingBoxDimensions.height;
// }
// generationParameters.seam_size = seamSize;
// generationParameters.seam_blur = seamBlur;
// generationParameters.seam_strength = seamStrength;
// generationParameters.seam_steps = seamSteps;
// generationParameters.tile_size = tileSize;
// generationParameters.infill_method = infillMethod;
// generationParameters.force_outpaint = false;
return {
baseDataURL,
baseBlob,
maskDataURL,
maskBlob,
baseIsPartiallyTransparent,
baseIsFullyTransparent,
doesMaskHaveBlackPixels,
};
};

View File

@ -1,12 +1,17 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { get, isEqual, isNumber, isString } from 'lodash-es'; import { isEqual, isString } from 'lodash-es';
import { import {
ButtonGroup, ButtonGroup,
Flex, Flex,
FlexProps, FlexProps,
FormControl, IconButton,
Link, Link,
Menu,
MenuButton,
MenuItemOption,
MenuList,
MenuOptionGroup,
useDisclosure, useDisclosure,
useToast, useToast,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
@ -15,21 +20,12 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover'; import IAIPopover from 'common/components/IAIPopover';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { GalleryState } from 'features/gallery/store/gallerySlice';
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors'; import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice'; import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
import {
initialImageSelected,
setAllParameters,
// setInitialImage,
setSeed,
} from 'features/parameters/store/generationSlice';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors'; import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
import { systemSelector } from 'features/system/store/systemSelectors'; import { systemSelector } from 'features/system/store/systemSelectors';
import { SystemState } from 'features/system/store/systemSlice';
import { import {
activeTabNameSelector, activeTabNameSelector,
uiSelector, uiSelector,
@ -56,6 +52,7 @@ import {
FaShare, FaShare,
FaShareAlt, FaShareAlt,
FaTrash, FaTrash,
FaWrench,
} from 'react-icons/fa'; } from 'react-icons/fa';
import { import {
gallerySelector, gallerySelector,
@ -66,8 +63,13 @@ import { useCallback } from 'react';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { useGetUrl } from 'common/util/getUrl'; import { useGetUrl } from 'common/util/getUrl';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { imageDeleted } from 'services/thunks/image';
import { useParameters } from 'features/parameters/hooks/useParameters'; import { useParameters } from 'features/parameters/hooks/useParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { requestedImageDeletion } from '../store/actions';
import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceRestore/FaceRestoreSettings';
import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings';
import { allParametersSet } from 'features/parameters/store/generationSlice';
import DeleteImageButton from './ImageActionButtons/DeleteImageButton';
const currentImageButtonsSelector = createSelector( const currentImageButtonsSelector = createSelector(
[ [
@ -150,6 +152,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
} = useAppSelector(currentImageButtonsSelector); } = useAppSelector(currentImageButtonsSelector);
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled; const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled; const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
const isFaceRestoreEnabled = useFeatureStatus('faceRestore').isFeatureEnabled; const isFaceRestoreEnabled = useFeatureStatus('faceRestore').isFeatureEnabled;
@ -164,40 +167,59 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const toast = useToast(); const toast = useToast();
const { t } = useTranslation(); const { t } = useTranslation();
const { recallPrompt, recallSeed, sendToImageToImage } = useParameters(); const { recallPrompt, recallSeed, recallAllParameters } = useParameters();
const handleCopyImage = useCallback(async () => { // const handleCopyImage = useCallback(async () => {
if (!image?.url) { // if (!image?.url) {
// return;
// }
// const url = getUrl(image.url);
// if (!url) {
// return;
// }
// const blob = await fetch(url).then((res) => res.blob());
// const data = [new ClipboardItem({ [blob.type]: blob })];
// await navigator.clipboard.write(data);
// toast({
// title: t('toast.imageCopied'),
// status: 'success',
// duration: 2500,
// isClosable: true,
// });
// }, [getUrl, t, image?.url, toast]);
const handleCopyImageLink = useCallback(() => {
const getImageUrl = () => {
if (!image) {
return; return;
} }
const url = getUrl(image.url); if (shouldTransformUrls) {
return getUrl(image.url);
}
if (image.url.startsWith('http')) {
return image.url;
}
return window.location.toString() + image.url;
};
const url = getImageUrl();
if (!url) { if (!url) {
return;
}
const blob = await fetch(url).then((res) => res.blob());
const data = [new ClipboardItem({ [blob.type]: blob })];
await navigator.clipboard.write(data);
toast({ toast({
title: t('toast.imageCopied'), title: t('toast.problemCopyingImageLink'),
status: 'success', status: 'error',
duration: 2500, duration: 2500,
isClosable: true, isClosable: true,
}); });
}, [getUrl, t, image?.url, toast]);
const handleCopyImageLink = useCallback(() => {
const url = image
? shouldTransformUrls
? getUrl(image.url)
: window.location.toString() + image.url
: '';
if (!url) {
return; return;
} }
@ -216,39 +238,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}, [dispatch, shouldHidePreview]); }, [dispatch, shouldHidePreview]);
const handleClickUseAllParameters = useCallback(() => { const handleClickUseAllParameters = useCallback(() => {
if (!image) return; recallAllParameters(image);
// selectedImage.metadata && }, [image, recallAllParameters]);
// dispatch(setAllParameters(selectedImage.metadata));
// if (selectedImage.metadata?.image.type === 'img2img') {
// dispatch(setActiveTab('img2img'));
// } else if (selectedImage.metadata?.image.type === 'txt2img') {
// dispatch(setActiveTab('txt2img'));
// }
}, [image]);
useHotkeys( useHotkeys(
'a', 'a',
() => { () => {
const type = image?.metadata?.invokeai?.node?.types; handleClickUseAllParameters;
if (isString(type) && ['txt2img', 'img2img'].includes(type)) {
handleClickUseAllParameters();
toast({
title: t('toast.parametersSet'),
status: 'success',
duration: 2500,
isClosable: true,
});
} else {
toast({
title: t('toast.parametersNotSet'),
description: t('toast.parametersNotSetDesc'),
status: 'error',
duration: 2500,
isClosable: true,
});
}
}, },
[image] [image, recallAllParameters]
); );
const handleUseSeed = useCallback(() => { const handleUseSeed = useCallback(() => {
@ -264,8 +262,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys('p', handleUsePrompt, [image]); useHotkeys('p', handleUsePrompt, [image]);
const handleSendToImageToImage = useCallback(() => { const handleSendToImageToImage = useCallback(() => {
sendToImageToImage(image); dispatch(initialImageSelected(image));
}, [image, sendToImageToImage]); }, [dispatch, image]);
useHotkeys('shift+i', handleSendToImageToImage, [image]); useHotkeys('shift+i', handleSendToImageToImage, [image]);
@ -375,7 +373,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {
if (canDeleteImage && image) { if (canDeleteImage && image) {
dispatch(imageDeleted({ imageType: image.type, imageName: image.name })); dispatch(requestedImageDeletion(image));
} }
}, [image, canDeleteImage, dispatch]); }, [image, canDeleteImage, dispatch]);
@ -432,6 +430,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
> >
{t('parameters.sendToImg2Img')} {t('parameters.sendToImg2Img')}
</IAIButton> </IAIButton>
{isCanvasEnabled && (
<IAIButton <IAIButton
size="sm" size="sm"
onClick={handleSendToCanvas} onClick={handleSendToCanvas}
@ -439,14 +438,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
> >
{t('parameters.sendToUnifiedCanvas')} {t('parameters.sendToUnifiedCanvas')}
</IAIButton> </IAIButton>
)}
<IAIButton {/* <IAIButton
size="sm" size="sm"
onClick={handleCopyImage} onClick={handleCopyImage}
leftIcon={<FaCopy />} leftIcon={<FaCopy />}
> >
{t('parameters.copyImage')} {t('parameters.copyImage')}
</IAIButton> </IAIButton> */}
<IAIButton <IAIButton
size="sm" size="sm"
onClick={handleCopyImageLink} onClick={handleCopyImageLink}
@ -462,7 +462,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
</Link> </Link>
</Flex> </Flex>
</IAIPopover> </IAIPopover>
<IAIIconButton {/* <IAIIconButton
icon={shouldHidePreview ? <FaEyeSlash /> : <FaEye />} icon={shouldHidePreview ? <FaEyeSlash /> : <FaEye />}
tooltip={ tooltip={
!shouldHidePreview !shouldHidePreview
@ -476,7 +476,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
} }
isChecked={shouldHidePreview} isChecked={shouldHidePreview}
onClick={handlePreviewVisibility} onClick={handlePreviewVisibility}
/> /> */}
{isLightboxEnabled && ( {isLightboxEnabled && (
<IAIIconButton <IAIIconButton
icon={<FaExpand />} icon={<FaExpand />}
@ -518,8 +518,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
tooltip={`${t('parameters.useAll')} (A)`} tooltip={`${t('parameters.useAll')} (A)`}
aria-label={`${t('parameters.useAll')} (A)`} aria-label={`${t('parameters.useAll')} (A)`}
isDisabled={ isDisabled={
!['txt2img', 'img2img'].includes( !['txt2img', 'img2img', 'inpaint'].includes(
image?.metadata?.sd_metadata?.type String(image?.metadata?.invokeai?.node?.type)
) )
} }
onClick={handleClickUseAllParameters} onClick={handleClickUseAllParameters}
@ -602,22 +602,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
/> />
</ButtonGroup> </ButtonGroup>
<IAIIconButton <ButtonGroup isAttached={true}>
onClick={handleInitiateDelete} <DeleteImageButton image={image} />
icon={<FaTrash />} </ButtonGroup>
tooltip={`${t('gallery.deleteImage')} (Del)`}
aria-label={`${t('gallery.deleteImage')} (Del)`}
isDisabled={!image || !isConnected}
colorScheme="error"
/>
</Flex> </Flex>
{image && (
<DeleteImageModal
isOpen={isDeleteDialogOpen}
onClose={onDeleteDialogClose}
handleDelete={handleDelete}
/>
)}
</> </>
); );
}; };

View File

@ -2,26 +2,35 @@ import { Box, Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { useGetUrl } from 'common/util/getUrl'; import { useGetUrl } from 'common/util/getUrl';
import { systemSelector } from 'features/system/store/systemSelectors';
import { uiSelector } from 'features/ui/store/uiSelectors'; import { uiSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { selectedImageSelector } from '../store/gallerySelectors'; import { gallerySelector } from '../store/gallerySelectors';
import CurrentImageFallback from './CurrentImageFallback';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer'; import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons'; import NextPrevImageButtons from './NextPrevImageButtons';
import CurrentImageHidden from './CurrentImageHidden'; import CurrentImageHidden from './CurrentImageHidden';
import { memo } from 'react'; import { DragEvent, memo, useCallback } from 'react';
import { systemSelector } from 'features/system/store/systemSelectors';
import ImageFallbackSpinner from './ImageFallbackSpinner';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
export const imagesSelector = createSelector( export const imagesSelector = createSelector(
[uiSelector, selectedImageSelector, systemSelector], [uiSelector, gallerySelector, systemSelector],
(ui, selectedImage, system) => { (ui, gallery, system) => {
const { shouldShowImageDetails, shouldHidePreview } = ui; const {
shouldShowImageDetails,
shouldHidePreview,
shouldShowProgressInViewer,
} = ui;
const { selectedImage } = gallery;
const { progressImage, shouldAntialiasProgressImage } = system;
return { return {
shouldShowImageDetails, shouldShowImageDetails,
shouldHidePreview, shouldHidePreview,
image: selectedImage, image: selectedImage,
progressImage,
shouldShowProgressInViewer,
shouldAntialiasProgressImage,
}; };
}, },
{ {
@ -32,26 +41,61 @@ export const imagesSelector = createSelector(
); );
const CurrentImagePreview = () => { const CurrentImagePreview = () => {
const { shouldShowImageDetails, image, shouldHidePreview } = const {
useAppSelector(imagesSelector); shouldShowImageDetails,
image,
shouldHidePreview,
progressImage,
shouldShowProgressInViewer,
shouldAntialiasProgressImage,
} = useAppSelector(imagesSelector);
const { getUrl } = useGetUrl(); const { getUrl } = useGetUrl();
const handleDragStart = useCallback(
(e: DragEvent<HTMLDivElement>) => {
if (!image) {
return;
}
e.dataTransfer.setData('invokeai/imageName', image.name);
e.dataTransfer.setData('invokeai/imageType', image.type);
e.dataTransfer.effectAllowed = 'move';
},
[image]
);
return ( return (
<Flex <Flex
sx={{ sx={{
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
width: '100%', width: '100%',
height: '100%', height: '100%',
position: 'relative',
alignItems: 'center',
justifyContent: 'center',
}} }}
> >
{image && ( {progressImage && shouldShowProgressInViewer ? (
<Image <Image
src={shouldHidePreview ? undefined : getUrl(image.url)} src={progressImage.dataURL}
width={image.metadata.width} width={progressImage.width}
height={image.metadata.height} height={progressImage.height}
fallback={shouldHidePreview ? <CurrentImageHidden /> : undefined} sx={{
objectFit: 'contain',
maxWidth: '100%',
maxHeight: '100%',
height: 'auto',
position: 'absolute',
borderRadius: 'base',
imageRendering: shouldAntialiasProgressImage ? 'auto' : 'pixelated',
}}
/>
) : (
image && (
<>
<Image
src={getUrl(image.url)}
fallbackStrategy="beforeLoadOrError"
fallback={<ImageFallbackSpinner />}
onDragStart={handleDragStart}
sx={{ sx={{
objectFit: 'contain', objectFit: 'contain',
maxWidth: '100%', maxWidth: '100%',
@ -61,6 +105,9 @@ const CurrentImagePreview = () => {
borderRadius: 'base', borderRadius: 'base',
}} }}
/> />
<ImageMetadataOverlay image={image} />
</>
)
)} )}
{shouldShowImageDetails && image && 'metadata' in image && ( {shouldShowImageDetails && image && 'metadata' in image && (
<Box <Box

View File

@ -0,0 +1,70 @@
import { Flex, Image, Spinner } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { systemSelector } from 'features/system/store/systemSelectors';
import { memo } from 'react';
import { gallerySelector } from '../store/gallerySelectors';
const selector = createSelector(
[systemSelector, gallerySelector],
(system, gallery) => {
const { shouldUseSingleGalleryColumn, galleryImageObjectFit } = gallery;
const { progressImage, shouldAntialiasProgressImage } = system;
return {
progressImage,
shouldUseSingleGalleryColumn,
galleryImageObjectFit,
shouldAntialiasProgressImage,
};
},
defaultSelectorOptions
);
const GalleryProgressImage = () => {
const {
progressImage,
shouldUseSingleGalleryColumn,
galleryImageObjectFit,
shouldAntialiasProgressImage,
} = useAppSelector(selector);
if (!progressImage) {
return null;
}
return (
<Flex
sx={{
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
aspectRatio: '1/1',
position: 'relative',
}}
>
<Image
draggable={false}
src={progressImage.dataURL}
width={progressImage.width}
height={progressImage.height}
sx={{
objectFit: shouldUseSingleGalleryColumn
? 'contain'
: galleryImageObjectFit,
width: '100%',
height: '100%',
maxWidth: '100%',
maxHeight: '100%',
borderRadius: 'base',
imageRendering: shouldAntialiasProgressImage ? 'auto' : 'pixelated',
}}
/>
<Spinner sx={{ position: 'absolute', top: 1, right: 1, opacity: 0.7 }} />
</Flex>
);
};
export default memo(GalleryProgressImage);

View File

@ -5,19 +5,20 @@ import {
Image, Image,
MenuItem, MenuItem,
MenuList, MenuList,
Skeleton,
useDisclosure, useDisclosure,
useTheme,
useToast, useToast,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageSelected } from 'features/gallery/store/gallerySlice';
import { DragEvent, memo, useCallback, useState } from 'react'; import { DragEvent, MouseEvent, memo, useCallback, useState } from 'react';
import { FaCheck, FaExpand, FaImage, FaShare, FaTrash } from 'react-icons/fa'; import { FaCheck, FaExpand, FaImage, FaShare, FaTrash } from 'react-icons/fa';
import DeleteImageModal from './DeleteImageModal'; import DeleteImageModal from './DeleteImageModal';
import { ContextMenu } from 'chakra-ui-contextmenu'; import { ContextMenu } from 'chakra-ui-contextmenu';
import * as InvokeAI from 'app/types/invokeai'; import * as InvokeAI from 'app/types/invokeai';
import { resizeAndScaleCanvas } from 'features/canvas/store/canvasSlice'; import {
resizeAndScaleCanvas,
setInitialCanvasImage,
} from 'features/canvas/store/canvasSlice';
import { gallerySelector } from 'features/gallery/store/gallerySelectors'; import { gallerySelector } from 'features/gallery/store/gallerySelectors';
import { setActiveTab } from 'features/ui/store/uiSlice'; import { setActiveTab } from 'features/ui/store/uiSlice';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -25,7 +26,6 @@ import IAIIconButton from 'common/components/IAIIconButton';
import { useGetUrl } from 'common/util/getUrl'; import { useGetUrl } from 'common/util/getUrl';
import { ExternalLinkIcon } from '@chakra-ui/icons'; import { ExternalLinkIcon } from '@chakra-ui/icons';
import { IoArrowUndoCircleOutline } from 'react-icons/io5'; import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { imageDeleted } from 'services/thunks/image';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { systemSelector } from 'features/system/store/systemSelectors'; import { systemSelector } from 'features/system/store/systemSelectors';
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors'; import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
@ -33,6 +33,8 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useParameters } from 'features/parameters/hooks/useParameters'; import { useParameters } from 'features/parameters/hooks/useParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { requestedImageDeletion } from '../store/actions';
export const selector = createSelector( export const selector = createSelector(
[gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector], [gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector],
@ -94,16 +96,19 @@ const HoverableImage = memo((props: HoverableImageProps) => {
} = useDisclosure(); } = useDisclosure();
const { image, isSelected } = props; const { image, isSelected } = props;
const { url, thumbnail, name, metadata } = image; const { url, thumbnail, name } = image;
const { getUrl } = useGetUrl(); const { getUrl } = useGetUrl();
const [isHovered, setIsHovered] = useState<boolean>(false); const [isHovered, setIsHovered] = useState<boolean>(false);
const toast = useToast(); const toast = useToast();
const { direction } = useTheme();
const { t } = useTranslation(); const { t } = useTranslation();
const { isFeatureEnabled: isLightboxEnabled } = useFeatureStatus('lightbox');
const { recallSeed, recallPrompt, sendToImageToImage, recallInitialImage } = const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { recallSeed, recallPrompt, recallInitialImage, recallAllParameters } =
useParameters(); useParameters();
const handleMouseOver = () => setIsHovered(true); const handleMouseOver = () => setIsHovered(true);
@ -112,18 +117,22 @@ const HoverableImage = memo((props: HoverableImageProps) => {
// Immediately deletes an image // Immediately deletes an image
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {
if (canDeleteImage && image) { if (canDeleteImage && image) {
dispatch(imageDeleted({ imageType: image.type, imageName: image.name })); dispatch(requestedImageDeletion(image));
} }
}, [dispatch, image, canDeleteImage]); }, [dispatch, image, canDeleteImage]);
// Opens the alert dialog to check if user is sure they want to delete // Opens the alert dialog to check if user is sure they want to delete
const handleInitiateDelete = useCallback(() => { const handleInitiateDelete = useCallback(
(e: MouseEvent) => {
e.stopPropagation();
if (shouldConfirmOnDelete) { if (shouldConfirmOnDelete) {
onDeleteDialogOpen(); onDeleteDialogOpen();
} else { } else {
handleDelete(); handleDelete();
} }
}, [handleDelete, onDeleteDialogOpen, shouldConfirmOnDelete]); },
[handleDelete, onDeleteDialogOpen, shouldConfirmOnDelete]
);
const handleSelectImage = useCallback(() => { const handleSelectImage = useCallback(() => {
dispatch(imageSelected(image)); dispatch(imageSelected(image));
@ -148,8 +157,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
}, [image, recallSeed]); }, [image, recallSeed]);
const handleSendToImageToImage = useCallback(() => { const handleSendToImageToImage = useCallback(() => {
sendToImageToImage(image); dispatch(initialImageSelected(image));
}, [image, sendToImageToImage]); }, [dispatch, image]);
const handleRecallInitialImage = useCallback(() => { const handleRecallInitialImage = useCallback(() => {
recallInitialImage(image.metadata.invokeai?.node?.image); recallInitialImage(image.metadata.invokeai?.node?.image);
@ -159,7 +168,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
* TODO: the rest of these * TODO: the rest of these
*/ */
const handleSendToCanvas = () => { const handleSendToCanvas = () => {
// dispatch(setInitialCanvasImage(image)); dispatch(setInitialCanvasImage(image));
dispatch(resizeAndScaleCanvas()); dispatch(resizeAndScaleCanvas());
@ -175,16 +184,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
}); });
}; };
const handleUseAllParameters = () => { const handleUseAllParameters = useCallback(() => {
// metadata.invokeai?.node && recallAllParameters(image);
// dispatch(setAllParameters(metadata.invokeai?.node)); }, [image, recallAllParameters]);
// toast({
// title: t('toast.parametersSet'),
// status: 'success',
// duration: 2500,
// isClosable: true,
// });
};
const handleLightBox = () => { const handleLightBox = () => {
// dispatch(setCurrentImage(image)); // dispatch(setCurrentImage(image));
@ -238,7 +240,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
icon={<IoArrowUndoCircleOutline />} icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleUseAllParameters} onClickCapture={handleUseAllParameters}
isDisabled={ isDisabled={
!['txt2img', 'img2img'].includes( !['txt2img', 'img2img', 'inpaint'].includes(
String(image?.metadata?.invokeai?.node?.type) String(image?.metadata?.invokeai?.node?.type)
) )
} }
@ -251,9 +253,11 @@ const HoverableImage = memo((props: HoverableImageProps) => {
> >
{t('parameters.sendToImg2Img')} {t('parameters.sendToImg2Img')}
</MenuItem> </MenuItem>
{isCanvasEnabled && (
<MenuItem icon={<FaShare />} onClickCapture={handleSendToCanvas}> <MenuItem icon={<FaShare />} onClickCapture={handleSendToCanvas}>
{t('parameters.sendToUnifiedCanvas')} {t('parameters.sendToUnifiedCanvas')}
</MenuItem> </MenuItem>
)}
<MenuItem icon={<FaTrash />} onClickCapture={onDeleteDialogOpen}> <MenuItem icon={<FaTrash />} onClickCapture={onDeleteDialogOpen}>
{t('gallery.deleteImage')} {t('gallery.deleteImage')}
</MenuItem> </MenuItem>
@ -279,6 +283,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
h: 'full', h: 'full',
transition: 'transform 0.2s ease-out', transition: 'transform 0.2s ease-out',
aspectRatio: '1/1', aspectRatio: '1/1',
cursor: 'pointer',
}} }}
> >
<Image <Image
@ -315,6 +320,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
sx={{ sx={{
width: '50%', width: '50%',
height: '50%', height: '50%',
maxWidth: '4rem',
maxHeight: '4rem',
fill: 'ok.500', fill: 'ok.500',
}} }}
/> />

View File

@ -0,0 +1,92 @@
import { createSelector } from '@reduxjs/toolkit';
import { useDisclosure } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { systemSelector } from 'features/system/store/systemSelectors';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
import { memo, useCallback } from 'react';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import DeleteImageModal from '../DeleteImageModal';
import { requestedImageDeletion } from 'features/gallery/store/actions';
import { Image } from 'app/types/invokeai';
const selector = createSelector(
[systemSelector],
(system) => {
const { isProcessing, isConnected, shouldConfirmOnDelete } = system;
return {
canDeleteImage: isConnected && !isProcessing,
shouldConfirmOnDelete,
isProcessing,
isConnected,
};
},
defaultSelectorOptions
);
type DeleteImageButtonProps = {
image: Image | undefined;
};
const DeleteImageButton = (props: DeleteImageButtonProps) => {
const { image } = props;
const dispatch = useAppDispatch();
const { isProcessing, isConnected, canDeleteImage, shouldConfirmOnDelete } =
useAppSelector(selector);
const {
isOpen: isDeleteDialogOpen,
onOpen: onDeleteDialogOpen,
onClose: onDeleteDialogClose,
} = useDisclosure();
const { t } = useTranslation();
const handleDelete = useCallback(() => {
if (canDeleteImage && image) {
dispatch(requestedImageDeletion(image));
}
}, [image, canDeleteImage, dispatch]);
const handleInitiateDelete = useCallback(() => {
if (shouldConfirmOnDelete) {
onDeleteDialogOpen();
} else {
handleDelete();
}
}, [shouldConfirmOnDelete, onDeleteDialogOpen, handleDelete]);
useHotkeys('delete', handleInitiateDelete, [
image,
shouldConfirmOnDelete,
isConnected,
isProcessing,
]);
return (
<>
<IAIIconButton
onClick={handleInitiateDelete}
icon={<FaTrash />}
tooltip={`${t('gallery.deleteImage')} (Del)`}
aria-label={`${t('gallery.deleteImage')} (Del)`}
isDisabled={!image || !isConnected}
colorScheme="error"
/>
{image && (
<DeleteImageModal
isOpen={isDeleteDialogOpen}
onClose={onDeleteDialogClose}
handleDelete={handleDelete}
/>
)}
</>
);
};
export default memo(DeleteImageButton);

View File

@ -1,8 +1,8 @@
import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react'; import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react';
type CurrentImageFallbackProps = SpinnerProps; type ImageFallbackSpinnerProps = SpinnerProps;
const CurrentImageFallback = (props: CurrentImageFallbackProps) => { const ImageFallbackSpinner = (props: ImageFallbackSpinnerProps) => {
const { size = 'xl', ...rest } = props; const { size = 'xl', ...rest } = props;
return ( return (
@ -21,4 +21,4 @@ const CurrentImageFallback = (props: CurrentImageFallbackProps) => {
); );
}; };
export default CurrentImageFallback; export default ImageFallbackSpinner;

View File

@ -5,6 +5,7 @@ import {
FlexProps, FlexProps,
Grid, Grid,
Icon, Icon,
Image,
Text, Text,
forwardRef, forwardRef,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
@ -14,7 +15,10 @@ import IAICheckbox from 'common/components/IAICheckbox';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import IAIPopover from 'common/components/IAIPopover'; import IAIPopover from 'common/components/IAIPopover';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { imageGallerySelector } from 'features/gallery/store/gallerySelectors'; import {
gallerySelector,
imageGallerySelector,
} from 'features/gallery/store/gallerySelectors';
import { import {
setCurrentCategory, setCurrentCategory,
setGalleryImageMinimumWidth, setGalleryImageMinimumWidth,
@ -50,30 +54,42 @@ import { uploadsAdapter } from '../store/uploadsSlice';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso'; import { Virtuoso, VirtuosoGrid } from 'react-virtuoso';
import { Image as ImageType } from 'app/types/invokeai';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import GalleryProgressImage from './GalleryProgressImage';
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290; const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
const PROGRESS_IMAGE_PLACEHOLDER = 'PROGRESS_IMAGE_PLACEHOLDER';
const gallerySelector = createSelector( const selector = createSelector(
[ [(state: RootState) => state],
(state: RootState) => state.uploads, (state) => {
(state: RootState) => state.results, const { results, uploads, system, gallery } = state;
(state: RootState) => state.gallery,
],
(uploads, results, gallery) => {
const { currentCategory } = gallery; const { currentCategory } = gallery;
return currentCategory === 'results' if (currentCategory === 'results') {
? { const tempImages: (ImageType | typeof PROGRESS_IMAGE_PLACEHOLDER)[] = [];
images: resultsAdapter.getSelectors().selectAll(results),
if (system.progressImage) {
tempImages.push(PROGRESS_IMAGE_PLACEHOLDER);
}
return {
images: tempImages.concat(
resultsAdapter.getSelectors().selectAll(results)
),
isLoading: results.isLoading, isLoading: results.isLoading,
areMoreImagesAvailable: results.page < results.pages - 1, areMoreImagesAvailable: results.page < results.pages - 1,
};
} }
: {
return {
images: uploadsAdapter.getSelectors().selectAll(uploads), images: uploadsAdapter.getSelectors().selectAll(uploads),
isLoading: uploads.isLoading, isLoading: uploads.isLoading,
areMoreImagesAvailable: uploads.page < uploads.pages - 1, areMoreImagesAvailable: uploads.page < uploads.pages - 1,
}; };
} },
defaultSelectorOptions
); );
const ImageGalleryContent = () => { const ImageGalleryContent = () => {
@ -108,7 +124,7 @@ const ImageGalleryContent = () => {
} = useAppSelector(imageGallerySelector); } = useAppSelector(imageGallerySelector);
const { images, areMoreImagesAvailable, isLoading } = const { images, areMoreImagesAvailable, isLoading } =
useAppSelector(gallerySelector); useAppSelector(selector);
const handleClickLoadMore = () => { const handleClickLoadMore = () => {
if (currentCategory === 'results') { if (currentCategory === 'results') {
@ -170,8 +186,24 @@ const ImageGalleryContent = () => {
} }
}, []); }, []);
const handleEndReached = useCallback(() => {
if (currentCategory === 'results') {
dispatch(receivedResultImagesPage());
} else if (currentCategory === 'uploads') {
dispatch(receivedUploadImagesPage());
}
}, [dispatch, currentCategory]);
return ( return (
<Flex flexDirection="column" w="full" h="full" gap={4}> <Flex
sx={{
gap: 2,
flexDirection: 'column',
h: 'full',
w: 'full',
borderRadius: 'base',
}}
>
<Flex <Flex
ref={resizeObserverRef} ref={resizeObserverRef}
alignItems="center" alignItems="center"
@ -290,18 +322,27 @@ const ImageGalleryContent = () => {
<Virtuoso <Virtuoso
style={{ height: '100%' }} style={{ height: '100%' }}
data={images} data={images}
endReached={handleEndReached}
scrollerRef={(ref) => setScrollerRef(ref)} scrollerRef={(ref) => setScrollerRef(ref)}
itemContent={(index, image) => { itemContent={(index, image) => {
const { name } = image; const isSelected =
const isSelected = selectedImage?.name === name; image === PROGRESS_IMAGE_PLACEHOLDER
? false
: selectedImage?.name === image?.name;
return ( return (
<Flex sx={{ pb: 2 }}> <Flex sx={{ pb: 2 }}>
{image === PROGRESS_IMAGE_PLACEHOLDER ? (
<GalleryProgressImage
key={PROGRESS_IMAGE_PLACEHOLDER}
/>
) : (
<HoverableImage <HoverableImage
key={`${name}-${image.thumbnail}`} key={`${image.name}-${image.thumbnail}`}
image={image} image={image}
isSelected={isSelected} isSelected={isSelected}
/> />
)}
</Flex> </Flex>
); );
}} }}
@ -310,18 +351,23 @@ const ImageGalleryContent = () => {
<VirtuosoGrid <VirtuosoGrid
style={{ height: '100%' }} style={{ height: '100%' }}
data={images} data={images}
endReached={handleEndReached}
components={{ components={{
Item: ItemContainer, Item: ItemContainer,
List: ListContainer, List: ListContainer,
}} }}
scrollerRef={setScroller} scrollerRef={setScroller}
itemContent={(index, image) => { itemContent={(index, image) => {
const { name } = image; const isSelected =
const isSelected = selectedImage?.name === name; image === PROGRESS_IMAGE_PLACEHOLDER
? false
: selectedImage?.name === image?.name;
return ( return image === PROGRESS_IMAGE_PLACEHOLDER ? (
<GalleryProgressImage key={PROGRESS_IMAGE_PLACEHOLDER} />
) : (
<HoverableImage <HoverableImage
key={`${name}-${image.thumbnail}`} key={`${image.name}-${image.thumbnail}`}
image={image} image={image}
isSelected={isSelected} isSelected={isSelected}
/> />
@ -334,6 +380,7 @@ const ImageGalleryContent = () => {
onClick={handleClickLoadMore} onClick={handleClickLoadMore}
isDisabled={!areMoreImagesAvailable} isDisabled={!areMoreImagesAvailable}
isLoading={isLoading} isLoading={isLoading}
loadingText="Loading"
flexShrink={0} flexShrink={0}
> >
{areMoreImagesAvailable {areMoreImagesAvailable

View File

@ -5,7 +5,6 @@ import {
// selectPrevImage, // selectPrevImage,
setGalleryImageMinimumWidth, setGalleryImageMinimumWidth,
} from 'features/gallery/store/gallerySlice'; } from 'features/gallery/store/gallerySlice';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { clamp, isEqual } from 'lodash-es'; import { clamp, isEqual } from 'lodash-es';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
@ -13,11 +12,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
import './ImageGallery.css'; import './ImageGallery.css';
import ImageGalleryContent from './ImageGalleryContent'; import ImageGalleryContent from './ImageGalleryContent';
import ResizableDrawer from 'features/ui/components/common/ResizableDrawer/ResizableDrawer'; import ResizableDrawer from 'features/ui/components/common/ResizableDrawer/ResizableDrawer';
import { import { setShouldShowGallery } from 'features/ui/store/uiSlice';
setShouldShowGallery,
toggleGalleryPanel,
togglePinGalleryPanel,
} from 'features/ui/store/uiSlice';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { import {
activeTabNameSelector, activeTabNameSelector,
@ -26,22 +21,20 @@ import {
import { isStagingSelector } from 'features/canvas/store/canvasSelectors'; import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors'; import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
import useResolution from 'common/hooks/useResolution';
import { Flex } from '@chakra-ui/react';
import { memo } from 'react'; import { memo } from 'react';
const GALLERY_TAB_WIDTHS: Record< // const GALLERY_TAB_WIDTHS: Record<
InvokeTabName, // InvokeTabName,
{ galleryMinWidth: number; galleryMaxWidth: number } // { galleryMinWidth: number; galleryMaxWidth: number }
> = { // > = {
// txt2img: { galleryMinWidth: 200, galleryMaxWidth: 500 }, // txt2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// img2img: { galleryMinWidth: 200, galleryMaxWidth: 500 }, // img2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
generate: { galleryMinWidth: 200, galleryMaxWidth: 500 }, // generate: { galleryMinWidth: 200, galleryMaxWidth: 500 },
unifiedCanvas: { galleryMinWidth: 200, galleryMaxWidth: 200 }, // unifiedCanvas: { galleryMinWidth: 200, galleryMaxWidth: 200 },
nodes: { galleryMinWidth: 200, galleryMaxWidth: 500 }, // nodes: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// postprocessing: { galleryMinWidth: 200, galleryMaxWidth: 500 }, // postprocessing: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// training: { galleryMinWidth: 200, galleryMaxWidth: 500 }, // training: { galleryMinWidth: 200, galleryMaxWidth: 500 },
}; // };
const galleryPanelSelector = createSelector( const galleryPanelSelector = createSelector(
[ [
@ -73,50 +66,50 @@ const galleryPanelSelector = createSelector(
} }
); );
export const ImageGalleryPanel = () => { const GalleryDrawer = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { const {
shouldPinGallery, shouldPinGallery,
shouldShowGallery, shouldShowGallery,
galleryImageMinimumWidth, galleryImageMinimumWidth,
activeTabName, // activeTabName,
isStaging, // isStaging,
isResizable, // isResizable,
isLightboxOpen, // isLightboxOpen,
} = useAppSelector(galleryPanelSelector); } = useAppSelector(galleryPanelSelector);
const handleSetShouldPinGallery = () => { // const handleSetShouldPinGallery = () => {
dispatch(togglePinGalleryPanel()); // dispatch(togglePinGalleryPanel());
dispatch(requestCanvasRescale()); // dispatch(requestCanvasRescale());
}; // };
const handleToggleGallery = () => { // const handleToggleGallery = () => {
dispatch(toggleGalleryPanel()); // dispatch(toggleGalleryPanel());
shouldPinGallery && dispatch(requestCanvasRescale()); // shouldPinGallery && dispatch(requestCanvasRescale());
}; // };
const handleCloseGallery = () => { const handleCloseGallery = () => {
dispatch(setShouldShowGallery(false)); dispatch(setShouldShowGallery(false));
shouldPinGallery && dispatch(requestCanvasRescale()); shouldPinGallery && dispatch(requestCanvasRescale());
}; };
const resolution = useResolution(); // const resolution = useResolution();
useHotkeys( // useHotkeys(
'g', // 'g',
() => { // () => {
handleToggleGallery(); // handleToggleGallery();
}, // },
[shouldPinGallery] // [shouldPinGallery]
); // );
useHotkeys( // useHotkeys(
'shift+g', // 'shift+g',
() => { // () => {
handleSetShouldPinGallery(); // handleSetShouldPinGallery();
}, // },
[shouldPinGallery] // [shouldPinGallery]
); // );
useHotkeys( useHotkeys(
'esc', 'esc',
@ -162,55 +155,71 @@ export const ImageGalleryPanel = () => {
[galleryImageMinimumWidth] [galleryImageMinimumWidth]
); );
const calcGalleryMinHeight = () => { // const calcGalleryMinHeight = () => {
if (resolution === 'desktop') return; // if (resolution === 'desktop') return;
return 300; // return 300;
}; // };
const imageGalleryContent = () => { // const imageGalleryContent = () => {
return ( // return (
<Flex // <Flex
w="100vw" // w="100vw"
h={{ base: 300, xl: '100vh' }} // h={{ base: 300, xl: '100vh' }}
paddingRight={{ base: 8, xl: 0 }} // paddingRight={{ base: 8, xl: 0 }}
paddingBottom={{ base: 4, xl: 0 }} // paddingBottom={{ base: 4, xl: 0 }}
> // >
<ImageGalleryContent /> // <ImageGalleryContent />
</Flex> // </Flex>
); // );
}; // };
// const resizableImageGalleryContent = () => {
// return (
// <ResizableDrawer
// direction="right"
// isResizable={isResizable || !shouldPinGallery}
// isOpen={shouldShowGallery}
// onClose={handleCloseGallery}
// isPinned={shouldPinGallery && !isLightboxOpen}
// minWidth={
// shouldPinGallery
// ? GALLERY_TAB_WIDTHS[activeTabName].galleryMinWidth
// : 200
// }
// maxWidth={
// shouldPinGallery
// ? GALLERY_TAB_WIDTHS[activeTabName].galleryMaxWidth
// : undefined
// }
// minHeight={calcGalleryMinHeight()}
// >
// <ImageGalleryContent />
// </ResizableDrawer>
// );
// };
// const renderImageGallery = () => {
// if (['mobile', 'tablet'].includes(resolution)) return imageGalleryContent();
// return resizableImageGalleryContent();
// };
if (shouldPinGallery) {
return null;
}
const resizableImageGalleryContent = () => {
return ( return (
<ResizableDrawer <ResizableDrawer
direction="right" direction="right"
isResizable={isResizable || !shouldPinGallery} isResizable={true}
isOpen={shouldShowGallery} isOpen={shouldShowGallery}
onClose={handleCloseGallery} onClose={handleCloseGallery}
isPinned={shouldPinGallery && !isLightboxOpen} minWidth={200}
minWidth={
shouldPinGallery
? GALLERY_TAB_WIDTHS[activeTabName].galleryMinWidth
: 200
}
maxWidth={
shouldPinGallery
? GALLERY_TAB_WIDTHS[activeTabName].galleryMaxWidth
: undefined
}
minHeight={calcGalleryMinHeight()}
> >
<ImageGalleryContent /> <ImageGalleryContent />
</ResizableDrawer> </ResizableDrawer>
); );
};
const renderImageGallery = () => { // return renderImageGallery();
if (['mobile', 'tablet'].includes(resolution)) return imageGalleryContent();
return resizableImageGalleryContent();
};
return renderImageGallery();
}; };
export default memo(ImageGalleryPanel); export default memo(GalleryDrawer);

View File

@ -3,7 +3,6 @@ import {
Box, Box,
Center, Center,
Flex, Flex,
Heading,
IconButton, IconButton,
Link, Link,
Text, Text,
@ -19,8 +18,6 @@ import {
setCfgScale, setCfgScale,
setHeight, setHeight,
setImg2imgStrength, setImg2imgStrength,
// setInitialImage,
setMaskPath,
setPerlin, setPerlin,
setSampler, setSampler,
setSeamless, setSeamless,
@ -31,21 +28,14 @@ import {
setThreshold, setThreshold,
setWidth, setWidth,
} from 'features/parameters/store/generationSlice'; } from 'features/parameters/store/generationSlice';
import { import { setHiresFix } from 'features/parameters/store/postprocessingSlice';
setCodeformerFidelity,
setFacetoolStrength,
setFacetoolType,
setHiresFix,
setUpscalingDenoising,
setUpscalingLevel,
setUpscalingStrength,
} from 'features/parameters/store/postprocessingSlice';
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice'; import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
import { memo } from 'react'; import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa'; import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5'; import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
type MetadataItemProps = { type MetadataItemProps = {
isLink?: boolean; isLink?: boolean;
@ -300,7 +290,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
</Text> </Text>
</Center> </Center>
)} )}
<Flex gap={2} direction="column"> <Flex gap={2} direction="column" overflow="auto">
<Flex gap={2}> <Flex gap={2}>
<Tooltip label="Copy metadata JSON"> <Tooltip label="Copy metadata JSON">
<IconButton <IconButton
@ -314,22 +304,19 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
</Tooltip> </Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text> <Text fontWeight="semibold">Metadata JSON:</Text>
</Flex> </Flex>
<OverlayScrollbarsComponent defer>
<Box <Box
sx={{ sx={{
mt: 0,
mr: 2,
mb: 4,
ml: 2,
padding: 4, padding: 4,
borderRadius: 'base', borderRadius: 'base',
overflowX: 'scroll',
wordBreak: 'break-all',
bg: 'whiteAlpha.500', bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' }, _dark: { bg: 'blackAlpha.500' },
w: 'max-content',
}} }}
> >
<pre>{metadataJSON}</pre> <pre>{metadataJSON}</pre>
</Box> </Box>
</OverlayScrollbarsComponent>
</Flex> </Flex>
</Flex> </Flex>
); );

View File

@ -0,0 +1,7 @@
import { createAction } from '@reduxjs/toolkit';
import { Image } from 'app/types/invokeai';
import { SelectedImage } from 'features/parameters/store/actions';
export const requestedImageDeletion = createAction<
Image | SelectedImage | undefined
>('gallery/requestedImageDeletion');

View File

@ -4,12 +4,13 @@ import { GalleryState } from './gallerySlice';
* Gallery slice persist denylist * Gallery slice persist denylist
*/ */
const itemsToDenylist: (keyof GalleryState)[] = [ const itemsToDenylist: (keyof GalleryState)[] = [
'categories',
'currentCategory', 'currentCategory',
'currentImage',
'currentImageUuid',
'shouldAutoSwitchToNewImages', 'shouldAutoSwitchToNewImages',
'intermediateImage', ];
export const galleryPersistDenylist: (keyof GalleryState)[] = [
'currentCategory',
'shouldAutoSwitchToNewImages',
]; ];
export const galleryDenylist = itemsToDenylist.map( export const galleryDenylist = itemsToDenylist.map(

View File

@ -1,23 +1,14 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors'; import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
import { configSelector } from 'features/system/store/configSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import { import {
activeTabNameSelector, activeTabNameSelector,
uiSelector, uiSelector,
} from 'features/ui/store/uiSelectors'; } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { import { selectResultsById, selectResultsEntities } from './resultsSlice';
selectResultsAll, import { selectUploadsAll, selectUploadsById } from './uploadsSlice';
selectResultsById,
selectResultsEntities,
} from './resultsSlice';
import {
selectUploadsAll,
selectUploadsById,
selectUploadsEntities,
} from './uploadsSlice';
export const gallerySelector = (state: RootState) => state.gallery; export const gallerySelector = (state: RootState) => state.gallery;
@ -44,6 +35,11 @@ export const imageGallerySelector = createSelector(
const { isLightboxOpen } = lightbox; const { isLightboxOpen } = lightbox;
const images =
currentCategory === 'results'
? selectResultsEntities(state)
: selectUploadsAll(state);
return { return {
shouldPinGallery, shouldPinGallery,
galleryImageMinimumWidth, galleryImageMinimumWidth,
@ -53,7 +49,7 @@ export const imageGallerySelector = createSelector(
: `repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, auto))`, : `repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, auto))`,
shouldAutoSwitchToNewImages, shouldAutoSwitchToNewImages,
currentCategory, currentCategory,
images: state[currentCategory].entities, images,
galleryWidth, galleryWidth,
shouldEnableResize: shouldEnableResize:
isLightboxOpen || isLightboxOpen ||

View File

@ -1,10 +1,11 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import { invocationComplete } from 'services/events/actions'; import { Image } from 'app/types/invokeai';
import { isImageOutput } from 'services/types/guards'; import { imageReceived, thumbnailReceived } from 'services/thunks/image';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse'; import {
import { imageUploaded } from 'services/thunks/image'; receivedResultImagesPage,
import { SelectedImage } from 'features/parameters/store/generationSlice'; receivedUploadImagesPage,
} from '../../../services/thunks/gallery';
type GalleryImageObjectFitType = 'contain' | 'cover'; type GalleryImageObjectFitType = 'contain' | 'cover';
@ -12,7 +13,7 @@ export interface GalleryState {
/** /**
* The selected image * The selected image
*/ */
selectedImage?: SelectedImage; selectedImage?: Image;
galleryImageMinimumWidth: number; galleryImageMinimumWidth: number;
galleryImageObjectFit: GalleryImageObjectFitType; galleryImageObjectFit: GalleryImageObjectFitType;
shouldAutoSwitchToNewImages: boolean; shouldAutoSwitchToNewImages: boolean;
@ -21,8 +22,7 @@ export interface GalleryState {
currentCategory: 'results' | 'uploads'; currentCategory: 'results' | 'uploads';
} }
const initialState: GalleryState = { export const initialGalleryState: GalleryState = {
selectedImage: undefined,
galleryImageMinimumWidth: 64, galleryImageMinimumWidth: 64,
galleryImageObjectFit: 'cover', galleryImageObjectFit: 'cover',
shouldAutoSwitchToNewImages: true, shouldAutoSwitchToNewImages: true,
@ -33,12 +33,9 @@ const initialState: GalleryState = {
export const gallerySlice = createSlice({ export const gallerySlice = createSlice({
name: 'gallery', name: 'gallery',
initialState, initialState: initialGalleryState,
reducers: { reducers: {
imageSelected: ( imageSelected: (state, action: PayloadAction<Image | undefined>) => {
state,
action: PayloadAction<SelectedImage | undefined>
) => {
state.selectedImage = action.payload; state.selectedImage = action.payload;
// TODO: if the user selects an image, disable the auto switch? // TODO: if the user selects an image, disable the auto switch?
// state.shouldAutoSwitchToNewImages = false; // state.shouldAutoSwitchToNewImages = false;
@ -72,27 +69,50 @@ export const gallerySlice = createSlice({
}, },
}, },
extraReducers(builder) { extraReducers(builder) {
/** builder.addCase(imageReceived.fulfilled, (state, action) => {
* Invocation Complete // When we get an updated URL for an image, we need to update the selectedImage in gallery,
*/ // which is currently its own object (instead of a reference to an image in results/uploads)
builder.addCase(invocationComplete, (state, action) => { const { imagePath } = action.payload;
const { data } = action.payload; const { imageName } = action.meta.arg;
if (isImageOutput(data.result) && state.shouldAutoSwitchToNewImages) {
state.selectedImage = { if (state.selectedImage?.name === imageName) {
name: data.result.image.image_name, state.selectedImage.url = imagePath;
type: 'results',
};
} }
}); });
/** builder.addCase(thumbnailReceived.fulfilled, (state, action) => {
* Upload Image - FULFILLED // When we get an updated URL for an image, we need to update the selectedImage in gallery,
*/ // which is currently its own object (instead of a reference to an image in results/uploads)
builder.addCase(imageUploaded.fulfilled, (state, action) => { const { thumbnailPath } = action.payload;
const { response } = action.payload; const { thumbnailName } = action.meta.arg;
const uploadedImage = deserializeImageResponse(response); if (state.selectedImage?.name === thumbnailName) {
state.selectedImage = { name: uploadedImage.name, type: 'uploads' }; state.selectedImage.thumbnail = thumbnailPath;
}
});
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
// rehydrate selectedImage URL when results list comes in
// solves case when outdated URL is in local storage
if (state.selectedImage) {
const selectedImageInResults = action.payload.items.find(
(image) => image.image_name === state.selectedImage!.name
);
if (selectedImageInResults) {
state.selectedImage.url = selectedImageInResults.image_url;
}
}
});
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
// rehydrate selectedImage URL when results list comes in
// solves case when outdated URL is in local storage
if (state.selectedImage) {
const selectedImageInResults = action.payload.items.find(
(image) => image.image_name === state.selectedImage!.name
);
if (selectedImageInResults) {
state.selectedImage.url = selectedImageInResults.image_url;
}
}
}); });
}, },
}); });

View File

@ -5,7 +5,9 @@ import { ResultsState } from './resultsSlice';
* *
* Currently denylisting results slice entirely, see persist config in store.ts * Currently denylisting results slice entirely, see persist config in store.ts
*/ */
const itemsToDenylist: (keyof ResultsState)[] = ['isLoading']; const itemsToDenylist: (keyof ResultsState)[] = [];
export const resultsPersistDenylist: (keyof ResultsState)[] = [];
export const resultsDenylist = itemsToDenylist.map( export const resultsDenylist = itemsToDenylist.map(
(denylistItem) => `results.${denylistItem}` (denylistItem) => `results.${denylistItem}`

View File

@ -1,17 +1,11 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/types/invokeai'; import { Image } from 'app/types/invokeai';
import { invocationComplete } from 'services/events/actions';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { import {
receivedResultImagesPage, receivedResultImagesPage,
IMAGES_PER_PAGE, IMAGES_PER_PAGE,
} from 'services/thunks/gallery'; } from 'services/thunks/gallery';
import { isImageOutput } from 'services/types/guards';
import {
buildImageUrls,
extractTimestampFromImageName,
} from 'services/util/deserializeImageField';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse'; import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
import { import {
imageDeleted, imageDeleted,
@ -73,44 +67,6 @@ const resultsSlice = createSlice({
state.isLoading = false; state.isLoading = false;
}); });
/**
* Invocation Complete
*/
builder.addCase(invocationComplete, (state, action) => {
const { data, shouldFetchImages } = action.payload;
const { result, node, graph_execution_state_id } = data;
if (isImageOutput(result)) {
const name = result.image.image_name;
const type = result.image.image_type;
// if we need to refetch, set URLs to placeholder for now
const { url, thumbnail } = shouldFetchImages
? { url: '', thumbnail: '' }
: buildImageUrls(type, name);
const timestamp = extractTimestampFromImageName(name);
const image: Image = {
name,
type,
url,
thumbnail,
metadata: {
created: timestamp,
width: result.width,
height: result.height,
invokeai: {
session_id: graph_execution_state_id,
...(node ? { node } : {}),
},
},
};
resultsAdapter.setOne(state, image);
}
});
/** /**
* Image Received - FULFILLED * Image Received - FULFILLED
*/ */
@ -142,9 +98,10 @@ const resultsSlice = createSlice({
}); });
/** /**
* Delete Image - FULFILLED * Delete Image - PENDING
* Pre-emptively remove the image from the gallery
*/ */
builder.addCase(imageDeleted.fulfilled, (state, action) => { builder.addCase(imageDeleted.pending, (state, action) => {
const { imageType, imageName } = action.meta.arg; const { imageType, imageName } = action.meta.arg;
if (imageType === 'results') { if (imageType === 'results') {

View File

@ -5,7 +5,8 @@ import { UploadsState } from './uploadsSlice';
* *
* Currently denylisting uploads slice entirely, see persist config in store.ts * Currently denylisting uploads slice entirely, see persist config in store.ts
*/ */
const itemsToDenylist: (keyof UploadsState)[] = ['isLoading']; const itemsToDenylist: (keyof UploadsState)[] = [];
export const uploadsPersistDenylist: (keyof UploadsState)[] = [];
export const uploadsDenylist = itemsToDenylist.map( export const uploadsDenylist = itemsToDenylist.map(
(denylistItem) => `uploads.${denylistItem}` (denylistItem) => `uploads.${denylistItem}`

View File

@ -6,7 +6,7 @@ import {
receivedUploadImagesPage, receivedUploadImagesPage,
IMAGES_PER_PAGE, IMAGES_PER_PAGE,
} from 'services/thunks/gallery'; } from 'services/thunks/gallery';
import { imageDeleted, imageUploaded } from 'services/thunks/image'; import { imageDeleted } from 'services/thunks/image';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse'; import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
export const uploadsAdapter = createEntityAdapter<Image>({ export const uploadsAdapter = createEntityAdapter<Image>({
@ -21,7 +21,7 @@ type AdditionalUploadsState = {
nextPage: number; nextPage: number;
}; };
const initialUploadsState = export const initialUploadsState =
uploadsAdapter.getInitialState<AdditionalUploadsState>({ uploadsAdapter.getInitialState<AdditionalUploadsState>({
page: 0, page: 0,
pages: 0, pages: 0,
@ -35,7 +35,7 @@ const uploadsSlice = createSlice({
name: 'uploads', name: 'uploads',
initialState: initialUploadsState, initialState: initialUploadsState,
reducers: { reducers: {
uploadAdded: uploadsAdapter.addOne, uploadAdded: uploadsAdapter.upsertOne,
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
/** /**
@ -62,20 +62,10 @@ const uploadsSlice = createSlice({
}); });
/** /**
* Upload Image - FULFILLED * Delete Image - pending
* Pre-emptively remove the image from the gallery
*/ */
builder.addCase(imageUploaded.fulfilled, (state, action) => { builder.addCase(imageDeleted.pending, (state, action) => {
const { location, response } = action.payload;
const uploadedImage = deserializeImageResponse(response);
uploadsAdapter.setOne(state, uploadedImage);
});
/**
* Delete Image - FULFILLED
*/
builder.addCase(imageDeleted.fulfilled, (state, action) => {
const { imageType, imageName } = action.meta.arg; const { imageType, imageName } = action.meta.arg;
if (imageType === 'uploads') { if (imageType === 'uploads') {

View File

@ -4,7 +4,7 @@ import * as InvokeAI from 'app/types/invokeai';
import { useGetUrl } from 'common/util/getUrl'; import { useGetUrl } from 'common/util/getUrl';
type ReactPanZoomProps = { type ReactPanZoomProps = {
image: InvokeAI._Image; image: InvokeAI.Image;
styleClass?: string; styleClass?: string;
alt?: string; alt?: string;
ref?: React.Ref<HTMLImageElement>; ref?: React.Ref<HTMLImageElement>;

View File

@ -4,6 +4,9 @@ import { LightboxState } from './lightboxSlice';
* Lightbox slice persist denylist * Lightbox slice persist denylist
*/ */
const itemsToDenylist: (keyof LightboxState)[] = ['isLightboxOpen']; const itemsToDenylist: (keyof LightboxState)[] = ['isLightboxOpen'];
export const lightboxPersistDenylist: (keyof LightboxState)[] = [
'isLightboxOpen',
];
export const lightboxDenylist = itemsToDenylist.map( export const lightboxDenylist = itemsToDenylist.map(
(denylistItem) => `lightbox.${denylistItem}` (denylistItem) => `lightbox.${denylistItem}`

View File

@ -5,7 +5,7 @@ export interface LightboxState {
isLightboxOpen: boolean; isLightboxOpen: boolean;
} }
const initialLightboxState: LightboxState = { export const initialLightboxState: LightboxState = {
isLightboxOpen: false, isLightboxOpen: false,
}; };

View File

@ -1,5 +1,3 @@
import { v4 as uuidv4 } from 'uuid';
import 'reactflow/dist/style.css'; import 'reactflow/dist/style.css';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { import {
@ -8,12 +6,11 @@ import {
MenuButton, MenuButton,
MenuList, MenuList,
MenuItem, MenuItem,
IconButton,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { FaEllipsisV, FaPlus } from 'react-icons/fa'; import { FaEllipsisV } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { nodeAdded } from '../store/nodesSlice'; import { nodeAdded } from '../store/nodesSlice';
import { cloneDeep, map } from 'lodash-es'; import { map } from 'lodash-es';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useBuildInvocation } from '../hooks/useBuildInvocation'; import { useBuildInvocation } from '../hooks/useBuildInvocation';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';

Some files were not shown because too many files have changed in this diff Show More