mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
merge with main
This commit is contained in:
commit
1103ab2844
13
.github/workflows/mkdocs-material.yml
vendored
13
.github/workflows/mkdocs-material.yml
vendored
@ -2,8 +2,7 @@ name: mkdocs-material
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
- 'development'
|
||||
- 'refs/heads/v2.3'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@ -12,6 +11,10 @@ jobs:
|
||||
mkdocs-material:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
|
||||
REPO_NAME: '${{ github.repository }}'
|
||||
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
|
||||
steps:
|
||||
- name: checkout sources
|
||||
uses: actions/checkout@v3
|
||||
@ -22,11 +25,15 @@ jobs:
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install requirements
|
||||
env:
|
||||
PIP_USE_PEP517: 1
|
||||
run: |
|
||||
python -m \
|
||||
pip install -r docs/requirements-mkdocs.txt
|
||||
pip install ".[docs]"
|
||||
|
||||
- name: confirm buildability
|
||||
run: |
|
||||
|
@ -247,8 +247,8 @@ class InvokeAiInstance:
|
||||
pip[
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torch~=2.0.0",
|
||||
"torchvision>=0.14.1",
|
||||
"--force-reinstall",
|
||||
"--find-links" if find_links is not None else None,
|
||||
find_links,
|
||||
|
@ -83,7 +83,7 @@ async def get_thumbnail(
|
||||
status_code=201,
|
||||
)
|
||||
async def upload_image(
|
||||
file: UploadFile, request: Request, response: Response
|
||||
file: UploadFile, image_type: ImageType, request: Request, response: Response
|
||||
) -> ImageResponse:
|
||||
if not file.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
@ -99,21 +99,21 @@ async def upload_image(
|
||||
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||
|
||||
saved_image = ApiDependencies.invoker.services.images.save(
|
||||
ImageType.UPLOAD, filename, img
|
||||
image_type, filename, img
|
||||
)
|
||||
|
||||
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
|
||||
|
||||
image_url = ApiDependencies.invoker.services.images.get_uri(
|
||||
ImageType.UPLOAD, saved_image.image_name
|
||||
image_type, saved_image.image_name
|
||||
)
|
||||
|
||||
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
|
||||
ImageType.UPLOAD, saved_image.image_name, True
|
||||
image_type, saved_image.image_name, True
|
||||
)
|
||||
|
||||
res = ImageResponse(
|
||||
image_type=ImageType.UPLOAD,
|
||||
image_type=image_type,
|
||||
image_name=saved_image.image_name,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
|
@ -122,7 +122,6 @@ app.openapi = custom_openapi
|
||||
# Override API doc favicons
|
||||
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
|
||||
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
def overridden_swagger():
|
||||
return get_swagger_ui_html(
|
||||
@ -140,6 +139,8 @@ def overridden_redoc():
|
||||
redoc_favicon_url="/static/favicon.ico",
|
||||
)
|
||||
|
||||
# Must mount *after* the other routes else it borks em
|
||||
app.mount("/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), name="ui")
|
||||
|
||||
def invoke_api():
|
||||
global web_config
|
||||
|
@ -3,12 +3,12 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import numpy.random
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationConfig,
|
||||
InvocationContext,
|
||||
BaseInvocationOutput,
|
||||
)
|
||||
@ -50,11 +50,11 @@ class RandomRangeInvocation(BaseInvocation):
|
||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||
)
|
||||
size: int = Field(default=1, description="The number of values to generate")
|
||||
seed: Optional[int] = Field(
|
||||
seed: int = Field(
|
||||
ge=0,
|
||||
le=np.iinfo(np.int32).max,
|
||||
description="The seed for the RNG",
|
||||
default_factory=lambda: numpy.random.randint(0, np.iinfo(np.int32).max),
|
||||
le=SEED_MAX,
|
||||
description="The seed for the RNG (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
|
@ -1,15 +1,17 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from functools import partial
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal, Optional, Union, get_args
|
||||
|
||||
import numpy as np
|
||||
from torch import Tensor
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.models.image import ColorField, ImageField, ImageType
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.generator.inpaint import infill_methods
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput, build_image_output
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
@ -17,7 +19,8 @@ from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = 'patchmatch' if 'patchmatch' in get_args(INFILL_METHODS) else 'tile'
|
||||
|
||||
class SDImageInvocation(BaseModel):
|
||||
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
||||
@ -44,15 +47,13 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
# fmt: on
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
@ -148,7 +149,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
)
|
||||
mask = None
|
||||
|
||||
if self.fit:
|
||||
image = image.resize((self.width, self.height))
|
||||
@ -165,7 +165,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
outputs = Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
init_mask=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
@ -197,7 +196,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
image=result_image,
|
||||
)
|
||||
|
||||
|
||||
class InpaintInvocation(ImageToImageInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
@ -205,6 +203,17 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
|
||||
# Inputs
|
||||
mask: Union[ImageField, None] = Field(description="The mask")
|
||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
|
||||
seam_strength: float = Field(
|
||||
default=0.75, gt=0, le=1, description="The seam inpaint strength"
|
||||
)
|
||||
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
|
||||
infill_method: INFILL_METHODS = Field(default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)")
|
||||
inpaint_width: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)")
|
||||
inpaint_height: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)")
|
||||
inpaint_fill: Optional[ColorField] = Field(default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color")
|
||||
inpaint_replace: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import io
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy
|
||||
@ -32,14 +33,12 @@ class ImageOutput(BaseInvocationOutput):
|
||||
# fmt: off
|
||||
type: Literal["image"] = "image"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: Optional[int] = Field(default=None, description="The width of the image in pixels")
|
||||
height: Optional[int] = Field(default=None, description="The height of the image in pixels")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": ["type", "image", "width", "height", "mode"]
|
||||
}
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
|
||||
|
||||
def build_image_output(
|
||||
@ -54,7 +53,6 @@ def build_image_output(
|
||||
image=image_field,
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
mode=image.mode,
|
||||
)
|
||||
|
||||
|
||||
|
233
invokeai/app/invocations/infill.py
Normal file
233
invokeai/app/invocations/infill.py
Normal file
@ -0,0 +1,233 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal, Optional, Union, get_args
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.invocations.image import ImageOutput, build_image_output
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from ..models.image import ColorField, ImageField, ImageType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationContext,
|
||||
)
|
||||
|
||||
|
||||
def infill_methods() -> list[str]:
|
||||
methods = [
|
||||
"tile",
|
||||
"solid",
|
||||
]
|
||||
if PatchMatch.patchmatch_available():
|
||||
methods.insert(0, "patchmatch")
|
||||
return methods
|
||||
|
||||
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = (
|
||||
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
)
|
||||
|
||||
|
||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
# Skip patchmatch if patchmatch isn't available
|
||||
if not PatchMatch.patchmatch_available():
|
||||
return im
|
||||
|
||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||
im_patched_np = PatchMatch.inpaint(
|
||||
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
|
||||
)
|
||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||
return im_patched
|
||||
|
||||
|
||||
def get_tile_images(image: np.ndarray, width=8, height=8):
|
||||
_nrows, _ncols, depth = image.shape
|
||||
_strides = image.strides
|
||||
|
||||
nrows, _m = divmod(_nrows, height)
|
||||
ncols, _n = divmod(_ncols, width)
|
||||
if _m != 0 or _n != 0:
|
||||
return None
|
||||
|
||||
return np.lib.stride_tricks.as_strided(
|
||||
np.ravel(image),
|
||||
shape=(nrows, ncols, height, width, depth),
|
||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||
writeable=False,
|
||||
)
|
||||
|
||||
|
||||
def tile_fill_missing(
|
||||
im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
|
||||
) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size_tuple = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = get_tile_images(a, *tile_size_tuple).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:, :, :, :, 3]
|
||||
|
||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||
tmask_shape = tiles_mask.shape
|
||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||
tiles_mask = tiles_mask > 0
|
||||
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
|
||||
|
||||
# Get RGB tiles in single array and filter by the mask
|
||||
tshape = tiles.shape
|
||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
|
||||
filtered_tiles = tiles_all[tiles_mask]
|
||||
|
||||
if len(filtered_tiles) == 0:
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum()
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
|
||||
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
|
||||
]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
tiles_all = tiles_all.swapaxes(1, 2)
|
||||
st = tiles_all.reshape(
|
||||
(
|
||||
math.prod(tiles_all.shape[0:2]),
|
||||
math.prod(tiles_all.shape[2:4]),
|
||||
tiles_all.shape[4],
|
||||
)
|
||||
)
|
||||
si = Image.fromarray(st, mode="RGBA")
|
||||
|
||||
return si
|
||||
|
||||
|
||||
class InfillColorInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
type: Literal["infill_rgba"] = "infill_rgba"
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
color: Optional[ColorField] = Field(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
infilled = Image.alpha_composite(solid_bg, image)
|
||||
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
class InfillTileInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
type: Literal["infill_tile"] = "infill_tile"
|
||||
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = Field(
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
infilled = tile_fill_missing(
|
||||
image.copy(), seed=self.seed, tile_size=self.tile_size
|
||||
)
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
class InfillPatchMatchInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||
|
||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
if PatchMatch.patchmatch_available():
|
||||
infilled = infill_patchmatch(image.copy())
|
||||
else:
|
||||
raise ValueError("PatchMatch is not available on this system")
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
|
||||
metadata = context.services.metadata.build_metadata(
|
||||
session_id=context.graph_execution_state_id, node=self
|
||||
)
|
||||
|
||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||
return build_image_output(
|
||||
image_type=image_type,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
)
|
@ -1,11 +1,13 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import random
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, Union
|
||||
import einops
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.util.choose_model import choose_model
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
|
||||
@ -13,7 +15,9 @@ from ...backend.model_management.model_manager import ModelManager
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
import numpy as np
|
||||
from ..services.image_storage import ImageType
|
||||
@ -37,41 +41,55 @@ class LatentsField(BaseModel):
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output latents"""
|
||||
#fmt: off
|
||||
type: Literal["latent_output"] = "latent_output"
|
||||
type: Literal["latents_output"] = "latents_output"
|
||||
|
||||
# Inputs
|
||||
latents: LatentsField = Field(default=None, description="The output latents")
|
||||
width: int = Field(description="The width of the latents in pixels")
|
||||
height: int = Field(description="The height of the latents in pixels")
|
||||
#fmt: on
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
class NoiseOutput(BaseInvocationOutput):
|
||||
"""Invocation noise output"""
|
||||
#fmt: off
|
||||
type: Literal["noise_output"] = "noise_output"
|
||||
|
||||
# Inputs
|
||||
noise: LatentsField = Field(default=None, description="The output noise")
|
||||
width: int = Field(description="The width of the noise in pixels")
|
||||
height: int = Field(description="The height of the noise in pixels")
|
||||
#fmt: on
|
||||
|
||||
|
||||
# TODO: this seems like a hack
|
||||
scheduler_map = dict(
|
||||
ddim=diffusers.DDIMScheduler,
|
||||
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_euler=diffusers.EulerDiscreteScheduler,
|
||||
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||
k_heun=diffusers.HeunDiscreteScheduler,
|
||||
k_lms=diffusers.LMSDiscreteScheduler,
|
||||
plms=diffusers.PNDMScheduler,
|
||||
def build_noise_output(latents_name: str, latents: torch.Tensor):
|
||||
return NoiseOutput(
|
||||
noise=LatentsField(latents_name=latents_name),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(list(scheduler_map.keys()))
|
||||
tuple(list(SCHEDULER_MAP.keys()))
|
||||
]
|
||||
|
||||
|
||||
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
|
||||
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
@ -102,17 +120,13 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
|
||||
return x
|
||||
|
||||
|
||||
def random_seed():
|
||||
return random.randint(0, np.iinfo(np.uint32).max)
|
||||
|
||||
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
type: Literal["noise"] = "noise"
|
||||
|
||||
# Inputs
|
||||
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed)
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
|
||||
|
||||
@ -131,9 +145,7 @@ class NoiseInvocation(BaseInvocation):
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, noise)
|
||||
return NoiseOutput(
|
||||
noise=LatentsField(latents_name=name)
|
||||
)
|
||||
return build_noise_output(latents_name=name, latents=noise)
|
||||
|
||||
|
||||
# Text to image
|
||||
@ -149,11 +161,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
@ -218,7 +229,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta)
|
||||
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
|
||||
return conditioning_data
|
||||
|
||||
|
||||
@ -250,9 +261,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=name)
|
||||
)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
@ -260,6 +269,10 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
type: Literal["l2l"] = "l2l"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
@ -271,10 +284,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
},
|
||||
}
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latent = context.services.latents.get(self.latents.latents_name)
|
||||
@ -287,7 +296,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
self.dispatch_progress(context, source_node_id, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(model)
|
||||
conditioning_data = self.get_conditioning_data(context, model)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
@ -295,11 +304,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
latent, device=model.device, dtype=latent.dtype
|
||||
)
|
||||
|
||||
timesteps, _ = model.get_img2img_timesteps(
|
||||
self.steps,
|
||||
self.strength,
|
||||
device=model.device,
|
||||
)
|
||||
timesteps, _ = model.get_img2img_timesteps(self.steps, self.strength)
|
||||
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
@ -315,9 +320,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=name)
|
||||
)
|
||||
return build_latents_output(latents_name=name, latents=result_latents)
|
||||
|
||||
|
||||
# Latent to image
|
||||
@ -384,8 +387,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
@ -402,7 +405,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.set(name, resized_latents)
|
||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
|
||||
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
@ -413,8 +416,8 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
@ -432,4 +435,48 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.set(name, resized_latents)
|
||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||
|
||||
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
type: Literal["i2l"] = "i2l"
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The image to encode")
|
||||
model: str = Field(default="", description="The model to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {"model": "model"},
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = choose_model(context.services.model_manager, self.model)
|
||||
model: StableDiffusionGeneratorPipeline = model_info["model"]
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
latents = model.non_noised_latents_from_image(
|
||||
image_tensor,
|
||||
device=model._model_group.device_for(model.unet),
|
||||
dtype=model.unet.dtype,
|
||||
)
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.set(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents)
|
||||
|
@ -3,6 +3,7 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
|
||||
@ -73,3 +74,12 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=int(self.a / self.b))
|
||||
|
||||
|
||||
class RandomIntInvocation(BaseInvocation):
|
||||
"""Outputs a single random integer."""
|
||||
#fmt: off
|
||||
type: Literal["rand_int"] = "rand_int"
|
||||
#fmt: on
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=np.random.randint(0, np.iinfo(np.int32).max))
|
||||
|
@ -4,10 +4,11 @@ from invokeai.backend.model_management.model_manager import ModelManager
|
||||
def choose_model(model_manager: ModelManager, model_name: str):
|
||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||
logger = model_manager.logger
|
||||
if model_manager.valid_model(model_name):
|
||||
model = model_manager.get_model(model_name)
|
||||
else:
|
||||
if model_name and not model_manager.valid_model(model_name):
|
||||
default_model_name = model_manager.default_model()
|
||||
logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{default_model_name}\' instead.")
|
||||
model = model_manager.get_model()
|
||||
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
|
||||
else:
|
||||
model = model_manager.get_model(model_name)
|
||||
|
||||
return model
|
||||
|
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@ -27,3 +27,13 @@ class ImageField(BaseModel):
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_type", "image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
@ -49,12 +49,13 @@ def create_text_to_image() -> LibraryGraph:
|
||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||
|
||||
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||
graphs: list[LibraryGraph] = list()
|
||||
|
||||
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
# text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
|
||||
# TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
#if text_to_image is None:
|
||||
# # TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
# #if text_to_image is None:
|
||||
text_to_image = create_text_to_image()
|
||||
graph_library.set(text_to_image)
|
||||
|
||||
|
@ -270,4 +270,5 @@ class DiskImageStorage(ImageStorageBase):
|
||||
) # TODO: this should refresh position for LRU cache
|
||||
if len(self.__cache) > self.__max_cache_size:
|
||||
cache_id = self.__cache_ids.get()
|
||||
if cache_id in self.__cache:
|
||||
del self.__cache[cache_id]
|
||||
|
@ -20,9 +20,18 @@ class MetadataLatentsField(TypedDict):
|
||||
latents_name: str
|
||||
|
||||
|
||||
class MetadataColorField(TypedDict):
|
||||
"""Pydantic-less ColorField, used for metadata parsing"""
|
||||
r: int
|
||||
g: int
|
||||
b: int
|
||||
a: int
|
||||
|
||||
|
||||
|
||||
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
|
||||
NodeMetadata = Dict[
|
||||
str, str | int | float | bool | MetadataImageField | MetadataLatentsField
|
||||
str, None | str | int | float | bool | MetadataImageField | MetadataLatentsField | MetadataColorField
|
||||
]
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import time
|
||||
import traceback
|
||||
from threading import Event, Thread, BoundedSemaphore
|
||||
|
||||
@ -6,6 +7,7 @@ from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
from ..models.exceptions import CanceledException
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
__stop_event: Event
|
||||
@ -34,8 +36,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
except Exception as e:
|
||||
logger.debug("Exception while getting from queue: %s" % e)
|
||||
|
||||
if not queue_item: # Probably stopping
|
||||
# do not hammer the queue
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
graph_execution_state = (
|
||||
@ -124,7 +132,16 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Queue any further commands if invoking all
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
if queue_item.invoke_all and not is_complete:
|
||||
try:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
||||
except Exception as e:
|
||||
logger.error("Error while invoking: %s" % e)
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
error=traceback.format_exc()
|
||||
)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(
|
||||
graph_execution_state.id
|
||||
|
@ -1,5 +1,13 @@
|
||||
import datetime
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||
|
||||
|
||||
SEED_MAX = np.iinfo(np.int32).max
|
||||
|
||||
|
||||
def get_random_seed():
|
||||
return np.random.randint(0, SEED_MAX)
|
||||
|
@ -31,6 +31,7 @@ from ..util.util import rand_perlin_2d
|
||||
from ..safety_checker import SafetyChecker
|
||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
|
||||
downsampling = 8
|
||||
|
||||
@ -71,19 +72,6 @@ class InvokeAIGeneratorOutput:
|
||||
# we are interposing a wrapper around the original Generator classes so that
|
||||
# old code that calls Generate will continue to work.
|
||||
class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
scheduler_map = dict(
|
||||
ddim=diffusers.DDIMScheduler,
|
||||
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_euler=diffusers.EulerDiscreteScheduler,
|
||||
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||
k_heun=diffusers.HeunDiscreteScheduler,
|
||||
k_lms=diffusers.LMSDiscreteScheduler,
|
||||
plms=diffusers.PNDMScheduler,
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
model_info: dict,
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
@ -175,14 +163,20 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
'''
|
||||
Return list of all the schedulers that we currently handle.
|
||||
'''
|
||||
return list(self.scheduler_map.keys())
|
||||
return list(SCHEDULER_MAP.keys())
|
||||
|
||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||
return generator_class(model, self.params.precision)
|
||||
|
||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
||||
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
@ -226,10 +220,10 @@ class Inpaint(Img2Img):
|
||||
def generate(self,
|
||||
mask_image: Image.Image | torch.FloatTensor,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_size: int = 96,
|
||||
seam_blur: int = 16,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
seam_steps: int = 30,
|
||||
tile_size: int = 32,
|
||||
inpaint_replace=False,
|
||||
infill_method=None,
|
||||
|
@ -4,6 +4,7 @@ invokeai.backend.generator.inpaint descends from .generator
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -59,7 +60,7 @@ class Inpaint(Img2Img):
|
||||
writeable=False,
|
||||
)
|
||||
|
||||
def infill_patchmatch(self, im: Image.Image) -> Image:
|
||||
def infill_patchmatch(self, im: Image.Image) -> Image.Image:
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
@ -75,18 +76,18 @@ class Inpaint(Img2Img):
|
||||
return im_patched
|
||||
|
||||
def tile_fill_missing(
|
||||
self, im: Image.Image, tile_size: int = 16, seed: int = None
|
||||
) -> Image:
|
||||
self, im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
|
||||
) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size = (tile_size, tile_size)
|
||||
tile_size_tuple = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = self.get_tile_images(a, *tile_size).copy()
|
||||
tiles = self.get_tile_images(a, *tile_size_tuple).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:, :, :, :, 3]
|
||||
@ -127,7 +128,9 @@ class Inpaint(Img2Img):
|
||||
|
||||
return si
|
||||
|
||||
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
|
||||
def mask_edge(
|
||||
self, mask: Image.Image, edge_size: int, edge_blur: int
|
||||
) -> Image.Image:
|
||||
npimg = np.asarray(mask, dtype=np.uint8)
|
||||
|
||||
# Detect any partially transparent regions
|
||||
@ -206,15 +209,15 @@ class Inpaint(Img2Img):
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_image: PIL.Image.Image | torch.FloatTensor,
|
||||
mask_image: PIL.Image.Image | torch.FloatTensor,
|
||||
init_image: Image.Image | torch.FloatTensor,
|
||||
mask_image: Image.Image | torch.FloatTensor,
|
||||
strength: float,
|
||||
mask_blur_radius: int = 8,
|
||||
# Seam settings - when 0, doesn't fill seam
|
||||
seam_size: int = 0,
|
||||
seam_blur: int = 0,
|
||||
seam_size: int = 96,
|
||||
seam_blur: int = 16,
|
||||
seam_strength: float = 0.7,
|
||||
seam_steps: int = 10,
|
||||
seam_steps: int = 30,
|
||||
tile_size: int = 32,
|
||||
step_callback=None,
|
||||
inpaint_replace=False,
|
||||
@ -222,7 +225,7 @@ class Inpaint(Img2Img):
|
||||
infill_method=None,
|
||||
inpaint_width=None,
|
||||
inpaint_height=None,
|
||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||
inpaint_fill: Tuple[int, int, int, int] = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -239,7 +242,7 @@ class Inpaint(Img2Img):
|
||||
self.inpaint_width = inpaint_width
|
||||
self.inpaint_height = inpaint_height
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
if isinstance(init_image, Image.Image):
|
||||
self.pil_image = init_image.copy()
|
||||
|
||||
# Do infill
|
||||
@ -250,8 +253,8 @@ class Inpaint(Img2Img):
|
||||
self.pil_image.copy(), seed=self.seed, tile_size=tile_size
|
||||
)
|
||||
elif infill_method == "solid":
|
||||
solid_bg = PIL.Image.new("RGBA", init_image.size, inpaint_fill)
|
||||
init_filled = PIL.Image.alpha_composite(solid_bg, init_image)
|
||||
solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
|
||||
init_filled = Image.alpha_composite(solid_bg, init_image)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Non-supported infill type {infill_method}", infill_method
|
||||
@ -269,7 +272,7 @@ class Inpaint(Img2Img):
|
||||
# Create init tensor
|
||||
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
|
||||
|
||||
if isinstance(mask_image, PIL.Image.Image):
|
||||
if isinstance(mask_image, Image.Image):
|
||||
self.pil_mask = mask_image.copy()
|
||||
debug_image(
|
||||
mask_image,
|
||||
|
@ -47,6 +47,7 @@ from diffusers import (
|
||||
LDMTextToImagePipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
@ -1208,6 +1209,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "dpm":
|
||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == 'unipc':
|
||||
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "ddim":
|
||||
scheduler = scheduler
|
||||
else:
|
||||
|
@ -1214,7 +1214,7 @@ class ModelManager(object):
|
||||
sha.update(chunk)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
||||
self.logger.debug(f"sha256 = {hash} ({count} files hashed in {toc - tic:4.2f}s)")
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
|
@ -509,10 +509,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
run_id=None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device('cpu')
|
||||
else:
|
||||
scheduler_device = self._model_group.device_for(self.unet)
|
||||
|
||||
if timesteps is None:
|
||||
self.scheduler.set_timesteps(
|
||||
num_inference_steps, device=self._model_group.device_for(self.unet)
|
||||
)
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
infer_latents_from_embeddings = GeneratorToCallbackinator(
|
||||
self.generate_latents_from_embeddings, PipelineIntermediateState
|
||||
@ -726,11 +729,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
run_id=None,
|
||||
callback=None,
|
||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||
timesteps, _ = self.get_img2img_timesteps(
|
||||
num_inference_steps,
|
||||
strength,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
)
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||
latents=initial_latents if strength < 1.0 else torch.zeros_like(
|
||||
initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
|
||||
@ -756,13 +755,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||
|
||||
def get_img2img_timesteps(
|
||||
self, num_inference_steps: int, strength: float, device
|
||||
self, num_inference_steps: int, strength: float, device=None
|
||||
) -> (torch.Tensor, int):
|
||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||
assert img2img_pipeline.scheduler is self.scheduler
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
if self.scheduler.config.get("cpu_only", False):
|
||||
scheduler_device = torch.device('cpu')
|
||||
else:
|
||||
scheduler_device = self._model_group.device_for(self.unet)
|
||||
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
|
||||
num_inference_steps, strength, device=device
|
||||
num_inference_steps, strength, device=scheduler_device
|
||||
)
|
||||
# Workaround for low strength resulting in zero timesteps.
|
||||
# TODO: submit upstream fix for zero-step img2img
|
||||
@ -796,9 +801,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if init_image.dim() == 3:
|
||||
init_image = init_image.unsqueeze(0)
|
||||
|
||||
timesteps, _ = self.get_img2img_timesteps(
|
||||
num_inference_steps, strength, device=device
|
||||
)
|
||||
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
||||
|
1
invokeai/backend/stable_diffusion/schedulers/__init__.py
Normal file
1
invokeai/backend/stable_diffusion/schedulers/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .schedulers import SCHEDULER_MAP
|
22
invokeai/backend/stable_diffusion/schedulers/schedulers.py
Normal file
22
invokeai/backend/stable_diffusion/schedulers/schedulers.py
Normal file
@ -0,0 +1,22 @@
|
||||
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
|
||||
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
|
||||
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \
|
||||
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler
|
||||
|
||||
SCHEDULER_MAP = dict(
|
||||
ddim=(DDIMScheduler, dict()),
|
||||
ddpm=(DDPMScheduler, dict()),
|
||||
deis=(DEISMultistepScheduler, dict()),
|
||||
lms=(LMSDiscreteScheduler, dict()),
|
||||
pndm=(PNDMScheduler, dict()),
|
||||
heun=(HeunDiscreteScheduler, dict()),
|
||||
euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||
euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||
euler_a=(EulerAncestralDiscreteScheduler, dict()),
|
||||
kdpm_2=(KDPM2DiscreteScheduler, dict()),
|
||||
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
|
||||
dpmpp_2s=(DPMSolverSinglestepScheduler, dict()),
|
||||
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
|
||||
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
|
||||
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
|
||||
)
|
@ -4,17 +4,20 @@ from .parse_seed_weights import parse_seed_weights
|
||||
|
||||
SAMPLER_CHOICES = [
|
||||
"ddim",
|
||||
"k_dpm_2_a",
|
||||
"k_dpm_2",
|
||||
"k_dpmpp_2_a",
|
||||
"k_dpmpp_2",
|
||||
"k_euler_a",
|
||||
"k_euler",
|
||||
"k_heun",
|
||||
"k_lms",
|
||||
"plms",
|
||||
# diffusers:
|
||||
"ddpm",
|
||||
"deis",
|
||||
"lms",
|
||||
"pndm",
|
||||
"heun",
|
||||
"euler",
|
||||
"euler_k",
|
||||
"euler_a",
|
||||
"kdpm_2",
|
||||
"kdpm_2_a",
|
||||
"dpmpp_2s",
|
||||
"dpmpp_2m",
|
||||
"dpmpp_2m_k",
|
||||
"unipc",
|
||||
]
|
||||
|
||||
|
||||
|
@ -1,13 +0,0 @@
|
||||
{
|
||||
"plugins": [
|
||||
[
|
||||
"transform-imports",
|
||||
{
|
||||
"lodash": {
|
||||
"transform": "lodash/${member}",
|
||||
"preventFullImport": true
|
||||
}
|
||||
}
|
||||
]
|
||||
]
|
||||
}
|
4
invokeai/frontend/web/.gitignore
vendored
4
invokeai/frontend/web/.gitignore
vendored
@ -35,3 +35,7 @@ stats.html
|
||||
!.yarn/releases
|
||||
!.yarn/sdks
|
||||
!.yarn/versions
|
||||
|
||||
# Yalc
|
||||
.yalc
|
||||
yalc.lock
|
@ -5,6 +5,7 @@ import { PluginOption, UserConfig } from 'vite';
|
||||
import dts from 'vite-plugin-dts';
|
||||
import eslint from 'vite-plugin-eslint';
|
||||
import tsconfigPaths from 'vite-tsconfig-paths';
|
||||
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
|
||||
|
||||
export const packageConfig: UserConfig = {
|
||||
base: './',
|
||||
@ -16,9 +17,10 @@ export const packageConfig: UserConfig = {
|
||||
dts({
|
||||
insertTypesEntry: true,
|
||||
}),
|
||||
cssInjectedByJsPlugin(),
|
||||
],
|
||||
build: {
|
||||
chunkSizeWarningLimit: 1500,
|
||||
cssCodeSplit: true,
|
||||
lib: {
|
||||
entry: path.resolve(__dirname, '../src/index.ts'),
|
||||
name: 'InvokeAIUI',
|
||||
@ -30,6 +32,7 @@ export const packageConfig: UserConfig = {
|
||||
globals: {
|
||||
react: 'React',
|
||||
'react-dom': 'ReactDOM',
|
||||
'@emotion/react': 'EmotionReact',
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -37,7 +37,7 @@ From `invokeai/frontend/web/` run `yarn install` to get everything set up.
|
||||
Start everything in dev mode:
|
||||
|
||||
1. Start the dev server: `yarn dev`
|
||||
2. Start the InvokeAI UI per usual: `invokeai --web`
|
||||
2. Start the InvokeAI Nodes backend: `python scripts/invokeai-new.py --web # run from the repo root`
|
||||
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
|
||||
|
||||
### Production builds
|
||||
|
@ -21,7 +21,6 @@
|
||||
"scripts": {
|
||||
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
|
||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
|
||||
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
|
||||
"build": "yarn run lint && vite build",
|
||||
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
|
||||
@ -90,6 +89,7 @@
|
||||
"react-konva": "^18.2.7",
|
||||
"react-konva-utils": "^1.0.4",
|
||||
"react-redux": "^8.0.5",
|
||||
"react-resizable-panels": "^0.0.42",
|
||||
"react-rnd": "^10.4.1",
|
||||
"react-transition-group": "^4.4.5",
|
||||
"react-use": "^17.4.0",
|
||||
@ -99,6 +99,7 @@
|
||||
"redux-deep-persist": "^1.0.7",
|
||||
"redux-dynamic-middlewares": "^2.2.0",
|
||||
"redux-persist": "^6.0.0",
|
||||
"redux-remember": "^3.3.1",
|
||||
"roarr": "^7.15.0",
|
||||
"serialize-error": "^11.0.0",
|
||||
"socket.io-client": "^4.6.0",
|
||||
@ -118,6 +119,7 @@
|
||||
"@types/node": "^18.16.2",
|
||||
"@types/react": "^18.2.0",
|
||||
"@types/react-dom": "^18.2.1",
|
||||
"@types/react-redux": "^7.1.25",
|
||||
"@types/react-transition-group": "^4.4.5",
|
||||
"@types/uuid": "^9.0.0",
|
||||
"@typescript-eslint/eslint-plugin": "^5.59.1",
|
||||
@ -143,6 +145,7 @@
|
||||
"terser": "^5.17.1",
|
||||
"ts-toolbelt": "^9.6.0",
|
||||
"vite": "^4.3.3",
|
||||
"vite-plugin-css-injected-by-js": "^3.1.1",
|
||||
"vite-plugin-dts": "^2.3.0",
|
||||
"vite-plugin-eslint": "^1.8.1",
|
||||
"vite-tsconfig-paths": "^4.2.0",
|
||||
|
@ -25,7 +25,7 @@
|
||||
"common": {
|
||||
"hotkeysLabel": "Hotkeys",
|
||||
"themeLabel": "Theme",
|
||||
"languagePickerLabel": "Language Picker",
|
||||
"languagePickerLabel": "Language",
|
||||
"reportBugLabel": "Report Bug",
|
||||
"githubLabel": "Github",
|
||||
"discordLabel": "Discord",
|
||||
@ -54,7 +54,7 @@
|
||||
"img2img": "Image To Image",
|
||||
"unifiedCanvas": "Unified Canvas",
|
||||
"linear": "Linear",
|
||||
"nodes": "Nodes",
|
||||
"nodes": "Node Editor",
|
||||
"postprocessing": "Post Processing",
|
||||
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
||||
"postProcessing": "Post Processing",
|
||||
@ -102,7 +102,8 @@
|
||||
"generate": "Generate",
|
||||
"openInNewTab": "Open in New Tab",
|
||||
"dontAskMeAgain": "Don't ask me again",
|
||||
"areYouSure": "Are you sure?"
|
||||
"areYouSure": "Are you sure?",
|
||||
"imagePrompt": "Image Prompt"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generations",
|
||||
@ -453,9 +454,10 @@
|
||||
"seed": "Seed",
|
||||
"imageToImage": "Image to Image",
|
||||
"randomizeSeed": "Randomize Seed",
|
||||
"shuffle": "Shuffle",
|
||||
"shuffle": "Shuffle Seed",
|
||||
"noiseThreshold": "Noise Threshold",
|
||||
"perlinNoise": "Perlin Noise",
|
||||
"noiseSettings": "Noise",
|
||||
"variations": "Variations",
|
||||
"variationAmount": "Variation Amount",
|
||||
"seedWeights": "Seed Weights",
|
||||
@ -470,6 +472,8 @@
|
||||
"scale": "Scale",
|
||||
"otherOptions": "Other Options",
|
||||
"seamlessTiling": "Seamless Tiling",
|
||||
"seamlessXAxis": "X Axis",
|
||||
"seamlessYAxis": "Y Axis",
|
||||
"hiresOptim": "High Res Optimization",
|
||||
"hiresStrength": "High Res Strength",
|
||||
"imageFit": "Fit Initial Image To Output Size",
|
||||
@ -527,7 +531,8 @@
|
||||
"useCanvasBeta": "Use Canvas Beta Layout",
|
||||
"enableImageDebugging": "Enable Image Debugging",
|
||||
"useSlidersForAll": "Use Sliders For All Options",
|
||||
"autoShowProgress": "Auto Show Progress Images",
|
||||
"showProgressInViewer": "Show Progress Images in Viewer",
|
||||
"antialiasProgressImages": "Antialias Progress Images",
|
||||
"resetWebUI": "Reset Web UI",
|
||||
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
|
||||
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
|
||||
@ -549,8 +554,9 @@
|
||||
"downloadImageStarted": "Image Download Started",
|
||||
"imageCopied": "Image Copied",
|
||||
"imageLinkCopied": "Image Link Copied",
|
||||
"problemCopyingImageLink": "Unable to Copy Image Link",
|
||||
"imageNotLoaded": "No Image Loaded",
|
||||
"imageNotLoadedDesc": "No image found to send to image to image module",
|
||||
"imageNotLoadedDesc": "Could not find image",
|
||||
"imageSavedToGallery": "Image Saved to Gallery",
|
||||
"canvasMerged": "Canvas Merged",
|
||||
"sentToImageToImage": "Sent To Image To Image",
|
||||
@ -645,7 +651,8 @@
|
||||
"betaClear": "Clear",
|
||||
"betaDarkenOutside": "Darken Outside",
|
||||
"betaLimitToBox": "Limit To Box",
|
||||
"betaPreserveMasked": "Preserve Masked"
|
||||
"betaPreserveMasked": "Preserve Masked",
|
||||
"antialiasing": "Antialiasing"
|
||||
},
|
||||
"ui": {
|
||||
"showProgressImages": "Show Progress Images",
|
||||
|
@ -1,24 +1,18 @@
|
||||
import ImageUploader from 'common/components/ImageUploader';
|
||||
import ProgressBar from 'features/system/components/ProgressBar';
|
||||
import SiteHeader from 'features/system/components/SiteHeader';
|
||||
import ProgressBar from 'features/system/components/ProgressBar';
|
||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||
|
||||
import useToastWatcher from 'features/system/hooks/useToastWatcher';
|
||||
|
||||
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
|
||||
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
|
||||
import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
|
||||
import { Box, Flex, Grid, Portal } from '@chakra-ui/react';
|
||||
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
||||
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
|
||||
import GalleryDrawer from 'features/gallery/components/ImageGalleryPanel';
|
||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
memo,
|
||||
PropsWithChildren,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useState,
|
||||
} from 'react';
|
||||
import { memo, ReactNode, useCallback, useEffect, useState } from 'react';
|
||||
import { motion, AnimatePresence } from 'framer-motion';
|
||||
import Loading from 'common/components/Loading/Loading';
|
||||
import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady';
|
||||
@ -27,20 +21,24 @@ import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { useLogger } from 'app/logging/useLogger';
|
||||
import ProgressImagePreview from 'features/parameters/components/ProgressImagePreview';
|
||||
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||
import i18n from 'i18n';
|
||||
|
||||
const DEFAULT_CONFIG = {};
|
||||
|
||||
interface Props extends PropsWithChildren {
|
||||
interface Props {
|
||||
config?: PartialAppConfig;
|
||||
headerComponent?: ReactNode;
|
||||
}
|
||||
|
||||
const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
|
||||
const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
||||
useToastWatcher();
|
||||
useGlobalHotkeys();
|
||||
const log = useLogger();
|
||||
|
||||
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
|
||||
const language = useAppSelector(languageSelector);
|
||||
|
||||
const log = useLogger();
|
||||
|
||||
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
||||
|
||||
@ -48,18 +46,17 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
|
||||
|
||||
const [loadingOverridden, setLoadingOverridden] = useState(false);
|
||||
|
||||
const { setColorMode } = useColorMode();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
useEffect(() => {
|
||||
i18n.changeLanguage(language);
|
||||
}, [language]);
|
||||
|
||||
useEffect(() => {
|
||||
log.info({ namespace: 'App', data: config }, 'Received config');
|
||||
dispatch(configChanged(config));
|
||||
}, [dispatch, config, log]);
|
||||
|
||||
useEffect(() => {
|
||||
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');
|
||||
}, [setColorMode, currentTheme]);
|
||||
|
||||
const handleOverrideClicked = useCallback(() => {
|
||||
setLoadingOverridden(true);
|
||||
}, []);
|
||||
@ -76,7 +73,7 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
|
||||
w={APP_WIDTH}
|
||||
h={APP_HEIGHT}
|
||||
>
|
||||
{children || <SiteHeader />}
|
||||
{headerComponent || <SiteHeader />}
|
||||
<Flex
|
||||
gap={4}
|
||||
w={{ base: '100vw', xl: 'full' }}
|
||||
@ -84,11 +81,13 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
|
||||
flexDir={{ base: 'column', xl: 'row' }}
|
||||
>
|
||||
<InvokeTabs />
|
||||
<ImageGalleryPanel />
|
||||
</Flex>
|
||||
</Grid>
|
||||
</ImageUploader>
|
||||
|
||||
<GalleryDrawer />
|
||||
<ParametersDrawer />
|
||||
|
||||
<AnimatePresence>
|
||||
{!isApplicationReady && !loadingOverridden && (
|
||||
<motion.div
|
||||
@ -121,7 +120,6 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
|
||||
<Portal>
|
||||
<FloatingGalleryButton />
|
||||
</Portal>
|
||||
<ProgressImagePreview />
|
||||
</Grid>
|
||||
);
|
||||
};
|
||||
|
@ -1,18 +1,13 @@
|
||||
import React, { lazy, memo, PropsWithChildren, useEffect } from 'react';
|
||||
import React, {
|
||||
lazy,
|
||||
memo,
|
||||
PropsWithChildren,
|
||||
ReactNode,
|
||||
useEffect,
|
||||
} from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
import { PersistGate } from 'redux-persist/integration/react';
|
||||
import { store } from 'app/store/store';
|
||||
import { persistor } from '../store/persistor';
|
||||
import { OpenAPI } from 'services/api';
|
||||
import '@fontsource/inter/100.css';
|
||||
import '@fontsource/inter/200.css';
|
||||
import '@fontsource/inter/300.css';
|
||||
import '@fontsource/inter/400.css';
|
||||
import '@fontsource/inter/500.css';
|
||||
import '@fontsource/inter/600.css';
|
||||
import '@fontsource/inter/700.css';
|
||||
import '@fontsource/inter/800.css';
|
||||
import '@fontsource/inter/900.css';
|
||||
|
||||
import Loading from '../../common/components/Loading/Loading';
|
||||
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
||||
@ -28,9 +23,10 @@ interface Props extends PropsWithChildren {
|
||||
apiUrl?: string;
|
||||
token?: string;
|
||||
config?: PartialAppConfig;
|
||||
headerComponent?: ReactNode;
|
||||
}
|
||||
|
||||
const InvokeAIUI = ({ apiUrl, token, config, children }: Props) => {
|
||||
const InvokeAIUI = ({ apiUrl, token, config, headerComponent }: Props) => {
|
||||
useEffect(() => {
|
||||
// configure API client token
|
||||
if (token) {
|
||||
@ -57,13 +53,11 @@ const InvokeAIUI = ({ apiUrl, token, config, children }: Props) => {
|
||||
return (
|
||||
<React.StrictMode>
|
||||
<Provider store={store}>
|
||||
<PersistGate loading={<Loading />} persistor={persistor}>
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<ThemeLocaleProvider>
|
||||
<App config={config}>{children}</App>
|
||||
<App config={config} headerComponent={headerComponent} />
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
</PersistGate>
|
||||
</Provider>
|
||||
</React.StrictMode>
|
||||
);
|
||||
|
@ -1,4 +1,8 @@
|
||||
import { ChakraProvider, extendTheme } from '@chakra-ui/react';
|
||||
import {
|
||||
ChakraProvider,
|
||||
createLocalStorageManager,
|
||||
extendTheme,
|
||||
} from '@chakra-ui/react';
|
||||
import { ReactNode, useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { theme as invokeAITheme } from 'theme/theme';
|
||||
@ -9,15 +13,8 @@ import { greenTeaThemeColors } from 'theme/colors/greenTea';
|
||||
import { invokeAIThemeColors } from 'theme/colors/invokeAI';
|
||||
import { lightThemeColors } from 'theme/colors/lightTheme';
|
||||
import { oceanBlueColors } from 'theme/colors/oceanBlue';
|
||||
import '@fontsource/inter/100.css';
|
||||
import '@fontsource/inter/200.css';
|
||||
import '@fontsource/inter/300.css';
|
||||
import '@fontsource/inter/400.css';
|
||||
import '@fontsource/inter/500.css';
|
||||
import '@fontsource/inter/600.css';
|
||||
import '@fontsource/inter/700.css';
|
||||
import '@fontsource/inter/800.css';
|
||||
import '@fontsource/inter/900.css';
|
||||
|
||||
import '@fontsource/inter/variable.css';
|
||||
import 'overlayscrollbars/overlayscrollbars.css';
|
||||
import 'theme/css/overlayscrollbars.css';
|
||||
|
||||
@ -32,6 +29,8 @@ const THEMES = {
|
||||
ocean: oceanBlueColors,
|
||||
};
|
||||
|
||||
const manager = createLocalStorageManager('@@invokeai-color-mode');
|
||||
|
||||
function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
||||
const { i18n } = useTranslation();
|
||||
|
||||
@ -51,7 +50,11 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
||||
document.body.dir = direction;
|
||||
}, [direction]);
|
||||
|
||||
return <ChakraProvider theme={theme}>{children}</ChakraProvider>;
|
||||
return (
|
||||
<ChakraProvider theme={theme} colorModeManager={manager}>
|
||||
{children}
|
||||
</ChakraProvider>
|
||||
);
|
||||
}
|
||||
|
||||
export default ThemeLocaleProvider;
|
||||
|
@ -2,17 +2,28 @@
|
||||
|
||||
export const DIFFUSERS_SCHEDULERS: Array<string> = [
|
||||
'ddim',
|
||||
'plms',
|
||||
'k_lms',
|
||||
'dpmpp_2',
|
||||
'k_dpm_2',
|
||||
'k_dpm_2_a',
|
||||
'k_dpmpp_2',
|
||||
'k_euler',
|
||||
'k_euler_a',
|
||||
'k_heun',
|
||||
'ddpm',
|
||||
'deis',
|
||||
'lms',
|
||||
'pndm',
|
||||
'heun',
|
||||
'euler',
|
||||
'euler_k',
|
||||
'euler_a',
|
||||
'kdpm_2',
|
||||
'kdpm_2_a',
|
||||
'dpmpp_2s',
|
||||
'dpmpp_2m',
|
||||
'dpmpp_2m_k',
|
||||
'unipc',
|
||||
];
|
||||
|
||||
export const IMG2IMG_DIFFUSERS_SCHEDULERS = DIFFUSERS_SCHEDULERS.filter(
|
||||
(scheduler) => {
|
||||
return scheduler !== 'dpmpp_2s';
|
||||
}
|
||||
);
|
||||
|
||||
// Valid image widths
|
||||
export const WIDTHS: Array<number> = Array.from(Array(64)).map(
|
||||
(_x, i) => (i + 1) * 64
|
||||
|
@ -1,26 +1,20 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
||||
import { initialCanvasImageSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
export const readinessSelector = createSelector(
|
||||
[
|
||||
generationSelector,
|
||||
systemSelector,
|
||||
initialCanvasImageSelector,
|
||||
activeTabNameSelector,
|
||||
],
|
||||
(generation, system, initialCanvasImage, activeTabName) => {
|
||||
[generationSelector, systemSelector, activeTabNameSelector],
|
||||
(generation, system, activeTabName) => {
|
||||
const {
|
||||
prompt,
|
||||
shouldGenerateVariations,
|
||||
seedWeights,
|
||||
initialImage,
|
||||
seed,
|
||||
isImageToImageEnabled,
|
||||
} = generation;
|
||||
|
||||
const { isProcessing, isConnected } = system;
|
||||
@ -34,7 +28,7 @@ export const readinessSelector = createSelector(
|
||||
reasonsWhyNotReady.push('Missing prompt');
|
||||
}
|
||||
|
||||
if (isImageToImageEnabled && !initialImage) {
|
||||
if (activeTabName === 'img2img' && !initialImage) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push('No initial image selected');
|
||||
}
|
||||
@ -64,10 +58,5 @@ export const readinessSelector = createSelector(
|
||||
// All good
|
||||
return { isReady, reasonsWhyNotReady };
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
equalityCheck: isEqual,
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
@ -1,209 +1,209 @@
|
||||
// import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
|
||||
// import * as InvokeAI from 'app/types/invokeai';
|
||||
// import type { RootState } from 'app/store/store';
|
||||
// import {
|
||||
// frontendToBackendParameters,
|
||||
// FrontendToBackendParametersConfig,
|
||||
// } from 'common/util/parameterTranslation';
|
||||
// import dateFormat from 'dateformat';
|
||||
// import {
|
||||
// GalleryCategory,
|
||||
// GalleryState,
|
||||
// removeImage,
|
||||
// } from 'features/gallery/store/gallerySlice';
|
||||
// import {
|
||||
// generationRequested,
|
||||
// modelChangeRequested,
|
||||
// modelConvertRequested,
|
||||
// modelMergingRequested,
|
||||
// setIsProcessing,
|
||||
// } from 'features/system/store/systemSlice';
|
||||
// import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
// import { Socket } from 'socket.io-client';
|
||||
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
|
||||
import * as InvokeAI from 'app/types/invokeai';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import {
|
||||
frontendToBackendParameters,
|
||||
FrontendToBackendParametersConfig,
|
||||
} from 'common/util/parameterTranslation';
|
||||
import dateFormat from 'dateformat';
|
||||
import {
|
||||
GalleryCategory,
|
||||
GalleryState,
|
||||
removeImage,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
generationRequested,
|
||||
modelChangeRequested,
|
||||
modelConvertRequested,
|
||||
modelMergingRequested,
|
||||
setIsProcessing,
|
||||
} from 'features/system/store/systemSlice';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { Socket } from 'socket.io-client';
|
||||
|
||||
// /**
|
||||
// * Returns an object containing all functions which use `socketio.emit()`.
|
||||
// * i.e. those which make server requests.
|
||||
// */
|
||||
// const makeSocketIOEmitters = (
|
||||
// store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
|
||||
// socketio: Socket
|
||||
// ) => {
|
||||
// // We need to dispatch actions to redux and get pieces of state from the store.
|
||||
// const { dispatch, getState } = store;
|
||||
/**
|
||||
* Returns an object containing all functions which use `socketio.emit()`.
|
||||
* i.e. those which make server requests.
|
||||
*/
|
||||
const makeSocketIOEmitters = (
|
||||
store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
|
||||
socketio: Socket
|
||||
) => {
|
||||
// We need to dispatch actions to redux and get pieces of state from the store.
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
// return {
|
||||
// emitGenerateImage: (generationMode: InvokeTabName) => {
|
||||
// dispatch(setIsProcessing(true));
|
||||
return {
|
||||
emitGenerateImage: (generationMode: InvokeTabName) => {
|
||||
dispatch(setIsProcessing(true));
|
||||
|
||||
// const state: RootState = getState();
|
||||
const state: RootState = getState();
|
||||
|
||||
// const {
|
||||
// generation: generationState,
|
||||
// postprocessing: postprocessingState,
|
||||
// system: systemState,
|
||||
// canvas: canvasState,
|
||||
// } = state;
|
||||
const {
|
||||
generation: generationState,
|
||||
postprocessing: postprocessingState,
|
||||
system: systemState,
|
||||
canvas: canvasState,
|
||||
} = state;
|
||||
|
||||
// const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
|
||||
// {
|
||||
// generationMode,
|
||||
// generationState,
|
||||
// postprocessingState,
|
||||
// canvasState,
|
||||
// systemState,
|
||||
// };
|
||||
const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
|
||||
{
|
||||
generationMode,
|
||||
generationState,
|
||||
postprocessingState,
|
||||
canvasState,
|
||||
systemState,
|
||||
};
|
||||
|
||||
// dispatch(generationRequested());
|
||||
dispatch(generationRequested());
|
||||
|
||||
// const { generationParameters, esrganParameters, facetoolParameters } =
|
||||
// frontendToBackendParameters(frontendToBackendParametersConfig);
|
||||
const { generationParameters, esrganParameters, facetoolParameters } =
|
||||
frontendToBackendParameters(frontendToBackendParametersConfig);
|
||||
|
||||
// socketio.emit(
|
||||
// 'generateImage',
|
||||
// generationParameters,
|
||||
// esrganParameters,
|
||||
// facetoolParameters
|
||||
// );
|
||||
socketio.emit(
|
||||
'generateImage',
|
||||
generationParameters,
|
||||
esrganParameters,
|
||||
facetoolParameters
|
||||
);
|
||||
|
||||
// // we need to truncate the init_mask base64 else it takes up the whole log
|
||||
// // TODO: handle maintaining masks for reproducibility in future
|
||||
// if (generationParameters.init_mask) {
|
||||
// generationParameters.init_mask = generationParameters.init_mask
|
||||
// .substr(0, 64)
|
||||
// .concat('...');
|
||||
// }
|
||||
// if (generationParameters.init_img) {
|
||||
// generationParameters.init_img = generationParameters.init_img
|
||||
// .substr(0, 64)
|
||||
// .concat('...');
|
||||
// }
|
||||
// we need to truncate the init_mask base64 else it takes up the whole log
|
||||
// TODO: handle maintaining masks for reproducibility in future
|
||||
if (generationParameters.init_mask) {
|
||||
generationParameters.init_mask = generationParameters.init_mask
|
||||
.substr(0, 64)
|
||||
.concat('...');
|
||||
}
|
||||
if (generationParameters.init_img) {
|
||||
generationParameters.init_img = generationParameters.init_img
|
||||
.substr(0, 64)
|
||||
.concat('...');
|
||||
}
|
||||
|
||||
// dispatch(
|
||||
// addLogEntry({
|
||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
// message: `Image generation requested: ${JSON.stringify({
|
||||
// ...generationParameters,
|
||||
// ...esrganParameters,
|
||||
// ...facetoolParameters,
|
||||
// })}`,
|
||||
// })
|
||||
// );
|
||||
// },
|
||||
// emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
|
||||
// dispatch(setIsProcessing(true));
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Image generation requested: ${JSON.stringify({
|
||||
...generationParameters,
|
||||
...esrganParameters,
|
||||
...facetoolParameters,
|
||||
})}`,
|
||||
})
|
||||
);
|
||||
},
|
||||
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
|
||||
dispatch(setIsProcessing(true));
|
||||
|
||||
// const {
|
||||
// postprocessing: {
|
||||
// upscalingLevel,
|
||||
// upscalingDenoising,
|
||||
// upscalingStrength,
|
||||
// },
|
||||
// } = getState();
|
||||
const {
|
||||
postprocessing: {
|
||||
upscalingLevel,
|
||||
upscalingDenoising,
|
||||
upscalingStrength,
|
||||
},
|
||||
} = getState();
|
||||
|
||||
// const esrganParameters = {
|
||||
// upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
|
||||
// };
|
||||
// socketio.emit('runPostprocessing', imageToProcess, {
|
||||
// type: 'esrgan',
|
||||
// ...esrganParameters,
|
||||
// });
|
||||
// dispatch(
|
||||
// addLogEntry({
|
||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
// message: `ESRGAN upscale requested: ${JSON.stringify({
|
||||
// file: imageToProcess.url,
|
||||
// ...esrganParameters,
|
||||
// })}`,
|
||||
// })
|
||||
// );
|
||||
// },
|
||||
// emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
|
||||
// dispatch(setIsProcessing(true));
|
||||
const esrganParameters = {
|
||||
upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
|
||||
};
|
||||
socketio.emit('runPostprocessing', imageToProcess, {
|
||||
type: 'esrgan',
|
||||
...esrganParameters,
|
||||
});
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `ESRGAN upscale requested: ${JSON.stringify({
|
||||
file: imageToProcess.url,
|
||||
...esrganParameters,
|
||||
})}`,
|
||||
})
|
||||
);
|
||||
},
|
||||
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
|
||||
dispatch(setIsProcessing(true));
|
||||
|
||||
// const {
|
||||
// postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
|
||||
// } = getState();
|
||||
const {
|
||||
postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
|
||||
} = getState();
|
||||
|
||||
// const facetoolParameters: Record<string, unknown> = {
|
||||
// facetool_strength: facetoolStrength,
|
||||
// };
|
||||
const facetoolParameters: Record<string, unknown> = {
|
||||
facetool_strength: facetoolStrength,
|
||||
};
|
||||
|
||||
// if (facetoolType === 'codeformer') {
|
||||
// facetoolParameters.codeformer_fidelity = codeformerFidelity;
|
||||
// }
|
||||
if (facetoolType === 'codeformer') {
|
||||
facetoolParameters.codeformer_fidelity = codeformerFidelity;
|
||||
}
|
||||
|
||||
// socketio.emit('runPostprocessing', imageToProcess, {
|
||||
// type: facetoolType,
|
||||
// ...facetoolParameters,
|
||||
// });
|
||||
// dispatch(
|
||||
// addLogEntry({
|
||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
// message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
|
||||
// {
|
||||
// file: imageToProcess.url,
|
||||
// ...facetoolParameters,
|
||||
// }
|
||||
// )}`,
|
||||
// })
|
||||
// );
|
||||
// },
|
||||
// emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
|
||||
// const { url, uuid, category, thumbnail } = imageToDelete;
|
||||
// dispatch(removeImage(imageToDelete));
|
||||
// socketio.emit('deleteImage', url, thumbnail, uuid, category);
|
||||
// },
|
||||
// emitRequestImages: (category: GalleryCategory) => {
|
||||
// const gallery: GalleryState = getState().gallery;
|
||||
// const { earliest_mtime } = gallery.categories[category];
|
||||
// socketio.emit('requestImages', category, earliest_mtime);
|
||||
// },
|
||||
// emitRequestNewImages: (category: GalleryCategory) => {
|
||||
// const gallery: GalleryState = getState().gallery;
|
||||
// const { latest_mtime } = gallery.categories[category];
|
||||
// socketio.emit('requestLatestImages', category, latest_mtime);
|
||||
// },
|
||||
// emitCancelProcessing: () => {
|
||||
// socketio.emit('cancel');
|
||||
// },
|
||||
// emitRequestSystemConfig: () => {
|
||||
// socketio.emit('requestSystemConfig');
|
||||
// },
|
||||
// emitSearchForModels: (modelFolder: string) => {
|
||||
// socketio.emit('searchForModels', modelFolder);
|
||||
// },
|
||||
// emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
|
||||
// socketio.emit('addNewModel', modelConfig);
|
||||
// },
|
||||
// emitDeleteModel: (modelName: string) => {
|
||||
// socketio.emit('deleteModel', modelName);
|
||||
// },
|
||||
// emitConvertToDiffusers: (
|
||||
// modelToConvert: InvokeAI.InvokeModelConversionProps
|
||||
// ) => {
|
||||
// dispatch(modelConvertRequested());
|
||||
// socketio.emit('convertToDiffusers', modelToConvert);
|
||||
// },
|
||||
// emitMergeDiffusersModels: (
|
||||
// modelMergeInfo: InvokeAI.InvokeModelMergingProps
|
||||
// ) => {
|
||||
// dispatch(modelMergingRequested());
|
||||
// socketio.emit('mergeDiffusersModels', modelMergeInfo);
|
||||
// },
|
||||
// emitRequestModelChange: (modelName: string) => {
|
||||
// dispatch(modelChangeRequested());
|
||||
// socketio.emit('requestModelChange', modelName);
|
||||
// },
|
||||
// emitSaveStagingAreaImageToGallery: (url: string) => {
|
||||
// socketio.emit('requestSaveStagingAreaImageToGallery', url);
|
||||
// },
|
||||
// emitRequestEmptyTempFolder: () => {
|
||||
// socketio.emit('requestEmptyTempFolder');
|
||||
// },
|
||||
// };
|
||||
// };
|
||||
socketio.emit('runPostprocessing', imageToProcess, {
|
||||
type: facetoolType,
|
||||
...facetoolParameters,
|
||||
});
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
|
||||
{
|
||||
file: imageToProcess.url,
|
||||
...facetoolParameters,
|
||||
}
|
||||
)}`,
|
||||
})
|
||||
);
|
||||
},
|
||||
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
|
||||
const { url, uuid, category, thumbnail } = imageToDelete;
|
||||
dispatch(removeImage(imageToDelete));
|
||||
socketio.emit('deleteImage', url, thumbnail, uuid, category);
|
||||
},
|
||||
emitRequestImages: (category: GalleryCategory) => {
|
||||
const gallery: GalleryState = getState().gallery;
|
||||
const { earliest_mtime } = gallery.categories[category];
|
||||
socketio.emit('requestImages', category, earliest_mtime);
|
||||
},
|
||||
emitRequestNewImages: (category: GalleryCategory) => {
|
||||
const gallery: GalleryState = getState().gallery;
|
||||
const { latest_mtime } = gallery.categories[category];
|
||||
socketio.emit('requestLatestImages', category, latest_mtime);
|
||||
},
|
||||
emitCancelProcessing: () => {
|
||||
socketio.emit('cancel');
|
||||
},
|
||||
emitRequestSystemConfig: () => {
|
||||
socketio.emit('requestSystemConfig');
|
||||
},
|
||||
emitSearchForModels: (modelFolder: string) => {
|
||||
socketio.emit('searchForModels', modelFolder);
|
||||
},
|
||||
emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
|
||||
socketio.emit('addNewModel', modelConfig);
|
||||
},
|
||||
emitDeleteModel: (modelName: string) => {
|
||||
socketio.emit('deleteModel', modelName);
|
||||
},
|
||||
emitConvertToDiffusers: (
|
||||
modelToConvert: InvokeAI.InvokeModelConversionProps
|
||||
) => {
|
||||
dispatch(modelConvertRequested());
|
||||
socketio.emit('convertToDiffusers', modelToConvert);
|
||||
},
|
||||
emitMergeDiffusersModels: (
|
||||
modelMergeInfo: InvokeAI.InvokeModelMergingProps
|
||||
) => {
|
||||
dispatch(modelMergingRequested());
|
||||
socketio.emit('mergeDiffusersModels', modelMergeInfo);
|
||||
},
|
||||
emitRequestModelChange: (modelName: string) => {
|
||||
dispatch(modelChangeRequested());
|
||||
socketio.emit('requestModelChange', modelName);
|
||||
},
|
||||
emitSaveStagingAreaImageToGallery: (url: string) => {
|
||||
socketio.emit('requestSaveStagingAreaImageToGallery', url);
|
||||
},
|
||||
emitRequestEmptyTempFolder: () => {
|
||||
socketio.emit('requestEmptyTempFolder');
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
// export default makeSocketIOEmitters;
|
||||
export default makeSocketIOEmitters;
|
||||
|
||||
export default {};
|
||||
|
4
invokeai/frontend/web/src/app/store/actions.ts
Normal file
4
invokeai/frontend/web/src/app/store/actions.ts
Normal file
@ -0,0 +1,4 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
|
||||
export const userInvoked = createAction<InvokeTabName>('app/userInvoked');
|
8
invokeai/frontend/web/src/app/store/constants.ts
Normal file
8
invokeai/frontend/web/src/app/store/constants.ts
Normal file
@ -0,0 +1,8 @@
|
||||
export const LOCALSTORAGE_KEYS = [
|
||||
'chakra-ui-color-mode',
|
||||
'i18nextLng',
|
||||
'ROARR_FILTER',
|
||||
'ROARR_LOG',
|
||||
];
|
||||
|
||||
export const LOCALSTORAGE_PREFIX = '@@invokeai-';
|
@ -0,0 +1,36 @@
|
||||
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
|
||||
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
|
||||
import { resultsPersistDenylist } from 'features/gallery/store/resultsPersistDenylist';
|
||||
import { uploadsPersistDenylist } from 'features/gallery/store/uploadsPersistDenylist';
|
||||
import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
|
||||
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
||||
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
||||
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
|
||||
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
|
||||
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
|
||||
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
|
||||
import { omit } from 'lodash-es';
|
||||
import { SerializeFunction } from 'redux-remember';
|
||||
|
||||
const serializationDenylist: {
|
||||
[key: string]: string[];
|
||||
} = {
|
||||
canvas: canvasPersistDenylist,
|
||||
gallery: galleryPersistDenylist,
|
||||
generation: generationPersistDenylist,
|
||||
lightbox: lightboxPersistDenylist,
|
||||
models: modelsPersistDenylist,
|
||||
nodes: nodesPersistDenylist,
|
||||
postprocessing: postprocessingPersistDenylist,
|
||||
results: resultsPersistDenylist,
|
||||
system: systemPersistDenylist,
|
||||
// config: configPersistDenyList,
|
||||
ui: uiPersistDenylist,
|
||||
uploads: uploadsPersistDenylist,
|
||||
// hotkeys: hotkeysPersistDenylist,
|
||||
};
|
||||
|
||||
export const serialize: SerializeFunction = (data, key) => {
|
||||
const result = omit(data, serializationDenylist[key]);
|
||||
return JSON.stringify(result);
|
||||
};
|
@ -0,0 +1,38 @@
|
||||
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
|
||||
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
|
||||
import { initialResultsState } from 'features/gallery/store/resultsSlice';
|
||||
import { initialUploadsState } from 'features/gallery/store/uploadsSlice';
|
||||
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
|
||||
import { initialNodesState } from 'features/nodes/store/nodesSlice';
|
||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
|
||||
import { initialConfigState } from 'features/system/store/configSlice';
|
||||
import { initialModelsState } from 'features/system/store/modelSlice';
|
||||
import { initialSystemState } from 'features/system/store/systemSlice';
|
||||
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
|
||||
import { initialUIState } from 'features/ui/store/uiSlice';
|
||||
import { defaultsDeep } from 'lodash-es';
|
||||
import { UnserializeFunction } from 'redux-remember';
|
||||
|
||||
const initialStates: {
|
||||
[key: string]: any;
|
||||
} = {
|
||||
canvas: initialCanvasState,
|
||||
gallery: initialGalleryState,
|
||||
generation: initialGenerationState,
|
||||
lightbox: initialLightboxState,
|
||||
models: initialModelsState,
|
||||
nodes: initialNodesState,
|
||||
postprocessing: initialPostprocessingState,
|
||||
results: initialResultsState,
|
||||
system: initialSystemState,
|
||||
config: initialConfigState,
|
||||
ui: initialUIState,
|
||||
uploads: initialUploadsState,
|
||||
hotkeys: initialHotkeysState,
|
||||
};
|
||||
|
||||
export const unserialize: UnserializeFunction = (data, key) => {
|
||||
const result = defaultsDeep(JSON.parse(data), initialStates[key]);
|
||||
return result;
|
||||
};
|
@ -0,0 +1,30 @@
|
||||
import { AnyAction } from '@reduxjs/toolkit';
|
||||
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { Graph } from 'services/api';
|
||||
|
||||
export const actionSanitizer = <A extends AnyAction>(action: A): A => {
|
||||
if (isAnyGraphBuilt(action)) {
|
||||
if (action.payload.nodes) {
|
||||
const sanitizedNodes: Graph['nodes'] = {};
|
||||
|
||||
// Sanitize nodes as needed
|
||||
forEach(action.payload.nodes, (node, key) => {
|
||||
// Don't log the whole freaking dataURL
|
||||
if (node.type === 'dataURL_image') {
|
||||
const { dataURL, ...rest } = node;
|
||||
sanitizedNodes[key] = { ...rest, dataURL: '<dataURL>' };
|
||||
} else {
|
||||
sanitizedNodes[key] = { ...node };
|
||||
}
|
||||
});
|
||||
|
||||
return {
|
||||
...action,
|
||||
payload: { ...action.payload, nodes: sanitizedNodes },
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return action;
|
||||
};
|
@ -0,0 +1,11 @@
|
||||
export const actionsDenylist = [
|
||||
'canvas/setCursorPosition',
|
||||
'canvas/setStageCoordinates',
|
||||
'canvas/setStageScale',
|
||||
'canvas/setIsDrawing',
|
||||
'canvas/setBoundingBoxCoordinates',
|
||||
'canvas/setBoundingBoxDimensions',
|
||||
'canvas/setIsDrawing',
|
||||
'canvas/addPointToCurrentLine',
|
||||
'socket/generatorProgress',
|
||||
];
|
@ -0,0 +1,3 @@
|
||||
export const stateSanitizer = <S>(state: S): S => {
|
||||
return state;
|
||||
};
|
@ -0,0 +1,45 @@
|
||||
import {
|
||||
createListenerMiddleware,
|
||||
addListener,
|
||||
ListenerEffect,
|
||||
AnyAction,
|
||||
} from '@reduxjs/toolkit';
|
||||
import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit';
|
||||
|
||||
import type { RootState, AppDispatch } from '../../store';
|
||||
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
|
||||
import { addImageResultReceivedListener } from './listeners/invocationComplete';
|
||||
import { addImageUploadedListener } from './listeners/imageUploaded';
|
||||
import { addRequestedImageDeletionListener } from './listeners/imageDeleted';
|
||||
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
|
||||
export const startAppListening =
|
||||
listenerMiddleware.startListening as AppStartListening;
|
||||
|
||||
export const addAppListener = addListener as TypedAddListener<
|
||||
RootState,
|
||||
AppDispatch
|
||||
>;
|
||||
|
||||
export type AppListenerEffect = ListenerEffect<
|
||||
AnyAction,
|
||||
RootState,
|
||||
AppDispatch
|
||||
>;
|
||||
|
||||
addImageUploadedListener();
|
||||
addInitialImageSelectedListener();
|
||||
addImageResultReceivedListener();
|
||||
addRequestedImageDeletionListener();
|
||||
|
||||
addUserInvokedCanvasListener();
|
||||
addUserInvokedNodesListener();
|
||||
addUserInvokedTextToImageListener();
|
||||
addUserInvokedImageToImageListener();
|
@ -0,0 +1,31 @@
|
||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import {
|
||||
canvasSessionIdChanged,
|
||||
stagingAreaInitialized,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { sessionInvoked } from 'services/thunks/session';
|
||||
|
||||
export const addCanvasGraphBuiltListener = () =>
|
||||
startAppListening({
|
||||
actionCreator: canvasGraphBuilt,
|
||||
effect: async (action, { dispatch, getState, take }) => {
|
||||
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
|
||||
const { sessionId } = meta.arg;
|
||||
const state = getState();
|
||||
|
||||
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
||||
dispatch(
|
||||
stagingAreaInitialized({
|
||||
sessionId,
|
||||
boundingBox: {
|
||||
...state.canvas.boundingBoxCoordinates,
|
||||
...state.canvas.boundingBoxDimensions,
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(canvasSessionIdChanged(sessionId));
|
||||
},
|
||||
});
|
@ -0,0 +1,59 @@
|
||||
import { requestedImageDeletion } from 'features/gallery/store/actions';
|
||||
import { startAppListening } from '..';
|
||||
import { imageDeleted } from 'services/thunks/image';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
||||
|
||||
export const addRequestedImageDeletionListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: requestedImageDeletion,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const image = action.payload;
|
||||
if (!image) {
|
||||
moduleLog.warn('No image provided');
|
||||
return;
|
||||
}
|
||||
|
||||
const { name, type } = image;
|
||||
|
||||
if (type !== 'uploads' && type !== 'results') {
|
||||
moduleLog.warn({ data: image }, `Invalid image type ${type}`);
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedImageName = getState().gallery.selectedImage?.name;
|
||||
|
||||
if (selectedImageName === name) {
|
||||
const allIds = getState()[type].ids;
|
||||
const allEntities = getState()[type].entities;
|
||||
|
||||
const deletedImageIndex = allIds.findIndex(
|
||||
(result) => result.toString() === name
|
||||
);
|
||||
|
||||
const filteredIds = allIds.filter((id) => id.toString() !== name);
|
||||
|
||||
const newSelectedImageIndex = clamp(
|
||||
deletedImageIndex,
|
||||
0,
|
||||
filteredIds.length - 1
|
||||
);
|
||||
|
||||
const newSelectedImageId = filteredIds[newSelectedImageIndex];
|
||||
|
||||
const newSelectedImage = allEntities[newSelectedImageId];
|
||||
|
||||
if (newSelectedImageId) {
|
||||
dispatch(imageSelected(newSelectedImage));
|
||||
} else {
|
||||
dispatch(imageSelected());
|
||||
}
|
||||
}
|
||||
|
||||
dispatch(imageDeleted({ imageName: name, imageType: type }));
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,25 @@
|
||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||
import { startAppListening } from '..';
|
||||
import { uploadAdded } from 'features/gallery/store/uploadsSlice';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
|
||||
export const addImageUploadedListener = () => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||
imageUploaded.fulfilled.match(action) &&
|
||||
action.payload.response.image_type !== 'intermediates',
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { response } = action.payload;
|
||||
|
||||
const state = getState();
|
||||
const image = deserializeImageResponse(response);
|
||||
|
||||
dispatch(uploadAdded(image));
|
||||
|
||||
if (state.gallery.shouldAutoSwitchToNewImages) {
|
||||
dispatch(imageSelected(image));
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,54 @@
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { Image, isInvokeAIImage } from 'app/types/invokeai';
|
||||
import { selectResultsById } from 'features/gallery/store/resultsSlice';
|
||||
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
|
||||
import { makeToast } from 'features/system/hooks/useToastWatcher';
|
||||
import { t } from 'i18next';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { startAppListening } from '..';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
|
||||
export const addInitialImageSelectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: initialImageSelected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
if (!action.payload) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({ title: t('toast.imageNotLoadedDesc'), status: 'error' })
|
||||
)
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (isInvokeAIImage(action.payload)) {
|
||||
dispatch(initialImageChanged(action.payload));
|
||||
dispatch(addToast(makeToast(t('toast.sentToImageToImage'))));
|
||||
return;
|
||||
}
|
||||
|
||||
const { name, type } = action.payload;
|
||||
|
||||
let image: Image | undefined;
|
||||
const state = getState();
|
||||
|
||||
if (type === 'results') {
|
||||
image = selectResultsById(state, name);
|
||||
} else if (type === 'uploads') {
|
||||
image = selectUploadsById(state, name);
|
||||
}
|
||||
|
||||
if (!image) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({ title: t('toast.imageNotLoadedDesc'), status: 'error' })
|
||||
)
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(initialImageChanged(image));
|
||||
dispatch(addToast(makeToast(t('toast.sentToImageToImage'))));
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,88 @@
|
||||
import { invocationComplete } from 'services/events/actions';
|
||||
import { isImageOutput } from 'services/types/guards';
|
||||
import {
|
||||
buildImageUrls,
|
||||
extractTimestampFromImageName,
|
||||
} from 'services/util/deserializeImageField';
|
||||
import { Image } from 'app/types/invokeai';
|
||||
import { resultAdded } from 'features/gallery/store/resultsSlice';
|
||||
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
|
||||
import { startAppListening } from '..';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||
|
||||
const nodeDenylist = ['dataURL_image'];
|
||||
|
||||
export const addImageResultReceivedListener = () => {
|
||||
startAppListening({
|
||||
predicate: (action) => {
|
||||
if (
|
||||
invocationComplete.match(action) &&
|
||||
isImageOutput(action.payload.data.result)
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
if (!invocationComplete.match(action)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { data, shouldFetchImages } = action.payload;
|
||||
const { result, node, graph_execution_state_id } = data;
|
||||
|
||||
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
|
||||
const name = result.image.image_name;
|
||||
const type = result.image.image_type;
|
||||
const state = getState();
|
||||
|
||||
// if we need to refetch, set URLs to placeholder for now
|
||||
const { url, thumbnail } = shouldFetchImages
|
||||
? { url: '', thumbnail: '' }
|
||||
: buildImageUrls(type, name);
|
||||
|
||||
const timestamp = extractTimestampFromImageName(name);
|
||||
|
||||
const image: Image = {
|
||||
name,
|
||||
type,
|
||||
url,
|
||||
thumbnail,
|
||||
metadata: {
|
||||
created: timestamp,
|
||||
width: result.width,
|
||||
height: result.height,
|
||||
invokeai: {
|
||||
session_id: graph_execution_state_id,
|
||||
...(node ? { node } : {}),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
dispatch(resultAdded(image));
|
||||
|
||||
if (state.gallery.shouldAutoSwitchToNewImages) {
|
||||
dispatch(imageSelected(image));
|
||||
}
|
||||
|
||||
if (state.config.shouldFetchImages) {
|
||||
dispatch(imageReceived({ imageName: name, imageType: type }));
|
||||
dispatch(
|
||||
thumbnailReceived({
|
||||
thumbnailName: name,
|
||||
thumbnailType: type,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
graph_execution_state_id ===
|
||||
state.canvas.layerState.stagingArea.sessionId
|
||||
) {
|
||||
dispatch(addImageToStagingArea(image));
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,126 @@
|
||||
import { startAppListening } from '..';
|
||||
import { sessionCreated, sessionInvoked } from 'services/thunks/session';
|
||||
import { buildCanvasGraphAndBlobs } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { Graph } from 'services/api';
|
||||
import {
|
||||
canvasSessionIdChanged,
|
||||
stagingAreaInitialized,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'invoke' });
|
||||
|
||||
export const addUserInvokedCanvasListener = () => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||
userInvoked.match(action) && action.payload === 'unifiedCanvas',
|
||||
effect: async (action, { getState, dispatch, take }) => {
|
||||
const state = getState();
|
||||
|
||||
const data = await buildCanvasGraphAndBlobs(state);
|
||||
|
||||
if (!data) {
|
||||
moduleLog.error('Problem building graph');
|
||||
return;
|
||||
}
|
||||
|
||||
const {
|
||||
rangeNode,
|
||||
iterateNode,
|
||||
baseNode,
|
||||
edges,
|
||||
baseBlob,
|
||||
maskBlob,
|
||||
generationMode,
|
||||
} = data;
|
||||
|
||||
const baseFilename = `${uuidv4()}.png`;
|
||||
const maskFilename = `${uuidv4()}.png`;
|
||||
|
||||
dispatch(
|
||||
imageUploaded({
|
||||
imageType: 'intermediates',
|
||||
formData: {
|
||||
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
|
||||
const [{ payload: basePayload }] = await take(
|
||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||
imageUploaded.fulfilled.match(action) &&
|
||||
action.meta.arg.formData.file.name === baseFilename
|
||||
);
|
||||
|
||||
const { image_name: baseName, image_type: baseType } =
|
||||
basePayload.response;
|
||||
|
||||
baseNode.image = {
|
||||
image_name: baseName,
|
||||
image_type: baseType,
|
||||
};
|
||||
}
|
||||
|
||||
if (baseNode.type === 'inpaint') {
|
||||
dispatch(
|
||||
imageUploaded({
|
||||
imageType: 'intermediates',
|
||||
formData: {
|
||||
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
const [{ payload: maskPayload }] = await take(
|
||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||
imageUploaded.fulfilled.match(action) &&
|
||||
action.meta.arg.formData.file.name === maskFilename
|
||||
);
|
||||
|
||||
const { image_name: maskName, image_type: maskType } =
|
||||
maskPayload.response;
|
||||
|
||||
baseNode.mask = {
|
||||
image_name: maskName,
|
||||
image_type: maskType,
|
||||
};
|
||||
}
|
||||
|
||||
// Assemble!
|
||||
const nodes: Graph['nodes'] = {
|
||||
[rangeNode.id]: rangeNode,
|
||||
[iterateNode.id]: iterateNode,
|
||||
[baseNode.id]: baseNode,
|
||||
};
|
||||
|
||||
const graph = { nodes, edges };
|
||||
|
||||
dispatch(canvasGraphBuilt(graph));
|
||||
moduleLog({ data: graph }, 'Canvas graph built');
|
||||
|
||||
dispatch(sessionCreated({ graph }));
|
||||
|
||||
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
|
||||
const { sessionId } = meta.arg;
|
||||
|
||||
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
||||
dispatch(
|
||||
stagingAreaInitialized({
|
||||
sessionId,
|
||||
boundingBox: {
|
||||
...state.canvas.boundingBoxCoordinates,
|
||||
...state.canvas.boundingBoxDimensions,
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(canvasSessionIdChanged(sessionId));
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,24 @@
|
||||
import { startAppListening } from '..';
|
||||
import { buildImageToImageGraph } from 'features/nodes/util/graphBuilders/buildImageToImageGraph';
|
||||
import { sessionCreated } from 'services/thunks/session';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'invoke' });
|
||||
|
||||
export const addUserInvokedImageToImageListener = () => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||
userInvoked.match(action) && action.payload === 'img2img',
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
|
||||
const graph = buildImageToImageGraph(state);
|
||||
dispatch(imageToImageGraphBuilt(graph));
|
||||
moduleLog({ data: graph }, 'Image to Image graph built');
|
||||
|
||||
dispatch(sessionCreated({ graph }));
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,24 @@
|
||||
import { startAppListening } from '..';
|
||||
import { sessionCreated } from 'services/thunks/session';
|
||||
import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { nodesGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'invoke' });
|
||||
|
||||
export const addUserInvokedNodesListener = () => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||
userInvoked.match(action) && action.payload === 'nodes',
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
|
||||
const graph = buildNodesGraph(state);
|
||||
dispatch(nodesGraphBuilt(graph));
|
||||
moduleLog({ data: graph }, 'Nodes graph built');
|
||||
|
||||
dispatch(sessionCreated({ graph }));
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,24 @@
|
||||
import { startAppListening } from '..';
|
||||
import { buildTextToImageGraph } from 'features/nodes/util/graphBuilders/buildTextToImageGraph';
|
||||
import { sessionCreated } from 'services/thunks/session';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'invoke' });
|
||||
|
||||
export const addUserInvokedTextToImageListener = () => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||
userInvoked.match(action) && action.payload === 'txt2img',
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
|
||||
const graph = buildTextToImageGraph(state);
|
||||
dispatch(textToImageGraphBuilt(graph));
|
||||
moduleLog({ data: graph }, 'Text to Image graph built');
|
||||
|
||||
dispatch(sessionCreated({ graph }));
|
||||
},
|
||||
});
|
||||
};
|
@ -1,4 +0,0 @@
|
||||
import { store } from 'app/store/store';
|
||||
import { persistStore } from 'redux-persist';
|
||||
|
||||
export const persistor = persistStore(store);
|
@ -1,9 +1,12 @@
|
||||
import { combineReducers, configureStore } from '@reduxjs/toolkit';
|
||||
import {
|
||||
AnyAction,
|
||||
ThunkDispatch,
|
||||
combineReducers,
|
||||
configureStore,
|
||||
} from '@reduxjs/toolkit';
|
||||
|
||||
import { persistReducer } from 'redux-persist';
|
||||
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
|
||||
import { rememberReducer, rememberEnhancer } from 'redux-remember';
|
||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||
import { getPersistConfig } from 'redux-deep-persist';
|
||||
|
||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||
@ -19,33 +22,17 @@ import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||
import modelsReducer from 'features/system/store/modelSlice';
|
||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||
|
||||
import { canvasDenylist } from 'features/canvas/store/canvasPersistDenylist';
|
||||
import { galleryDenylist } from 'features/gallery/store/galleryPersistDenylist';
|
||||
import { generationDenylist } from 'features/parameters/store/generationPersistDenylist';
|
||||
import { lightboxDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
|
||||
import { modelsDenylist } from 'features/system/store/modelsPersistDenylist';
|
||||
import { nodesDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
||||
import { postprocessingDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
|
||||
import { systemDenylist } from 'features/system/store/systemPersistDenylist';
|
||||
import { uiDenylist } from 'features/ui/store/uiPersistDenylist';
|
||||
import { resultsDenylist } from 'features/gallery/store/resultsPersistDenylist';
|
||||
import { uploadsDenylist } from 'features/gallery/store/uploadsPersistDenylist';
|
||||
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
||||
|
||||
/**
|
||||
* redux-persist provides an easy and reliable way to persist state across reloads.
|
||||
*
|
||||
* While we definitely want generation parameters to be persisted, there are a number
|
||||
* of things we do *not* want to be persisted across reloads:
|
||||
* - Gallery/selected image (user may add/delete images from disk between page loads)
|
||||
* - Connection/processing status
|
||||
* - Availability of external libraries like ESRGAN/GFPGAN
|
||||
*
|
||||
* These can be denylisted in redux-persist.
|
||||
*
|
||||
* The necesssary nested persistors with denylists are configured below.
|
||||
*/
|
||||
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||
|
||||
const rootReducer = combineReducers({
|
||||
import { serialize } from './enhancers/reduxRemember/serialize';
|
||||
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
||||
import { LOCALSTORAGE_PREFIX } from './constants';
|
||||
|
||||
const allReducers = {
|
||||
canvas: canvasReducer,
|
||||
gallery: galleryReducer,
|
||||
generation: generationReducer,
|
||||
@ -59,65 +46,54 @@ const rootReducer = combineReducers({
|
||||
ui: uiReducer,
|
||||
uploads: uploadsReducer,
|
||||
hotkeys: hotkeysReducer,
|
||||
});
|
||||
};
|
||||
|
||||
const rootPersistConfig = getPersistConfig({
|
||||
key: 'root',
|
||||
storage,
|
||||
rootReducer,
|
||||
blacklist: [
|
||||
...canvasDenylist,
|
||||
...galleryDenylist,
|
||||
...generationDenylist,
|
||||
...lightboxDenylist,
|
||||
...modelsDenylist,
|
||||
...nodesDenylist,
|
||||
...postprocessingDenylist,
|
||||
// ...resultsDenylist,
|
||||
'results',
|
||||
...systemDenylist,
|
||||
...uiDenylist,
|
||||
// ...uploadsDenylist,
|
||||
'uploads',
|
||||
'hotkeys',
|
||||
'config',
|
||||
],
|
||||
});
|
||||
const rootReducer = combineReducers(allReducers);
|
||||
|
||||
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
|
||||
const rememberedRootReducer = rememberReducer(rootReducer);
|
||||
|
||||
// TODO: rip the old middleware out when nodes is complete
|
||||
// export function buildMiddleware() {
|
||||
// if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
|
||||
// return socketMiddleware();
|
||||
// } else {
|
||||
// return socketioMiddleware();
|
||||
// }
|
||||
// }
|
||||
const rememberedKeys: (keyof typeof allReducers)[] = [
|
||||
'canvas',
|
||||
'gallery',
|
||||
'generation',
|
||||
'lightbox',
|
||||
// 'models',
|
||||
'nodes',
|
||||
'postprocessing',
|
||||
'system',
|
||||
'ui',
|
||||
// 'hotkeys',
|
||||
// 'results',
|
||||
// 'uploads',
|
||||
// 'config',
|
||||
];
|
||||
|
||||
export const store = configureStore({
|
||||
reducer: persistedReducer,
|
||||
reducer: rememberedRootReducer,
|
||||
enhancers: [
|
||||
rememberEnhancer(window.localStorage, rememberedKeys, {
|
||||
persistDebounce: 300,
|
||||
serialize,
|
||||
unserialize,
|
||||
prefix: LOCALSTORAGE_PREFIX,
|
||||
}),
|
||||
],
|
||||
middleware: (getDefaultMiddleware) =>
|
||||
getDefaultMiddleware({
|
||||
immutableCheck: false,
|
||||
serializableCheck: false,
|
||||
}).concat(dynamicMiddlewares),
|
||||
})
|
||||
.concat(dynamicMiddlewares)
|
||||
.prepend(listenerMiddleware.middleware),
|
||||
devTools: {
|
||||
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
|
||||
actionsDenylist: [
|
||||
'canvas/setCursorPosition',
|
||||
'canvas/setStageCoordinates',
|
||||
'canvas/setStageScale',
|
||||
'canvas/setIsDrawing',
|
||||
'canvas/setBoundingBoxCoordinates',
|
||||
'canvas/setBoundingBoxDimensions',
|
||||
'canvas/setIsDrawing',
|
||||
'canvas/addPointToCurrentLine',
|
||||
'socket/generatorProgress',
|
||||
],
|
||||
actionsDenylist,
|
||||
actionSanitizer,
|
||||
stateSanitizer,
|
||||
trace: true,
|
||||
},
|
||||
});
|
||||
|
||||
export type AppGetState = typeof store.getState;
|
||||
export type RootState = ReturnType<typeof store.getState>;
|
||||
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
|
||||
export type AppDispatch = typeof store.dispatch;
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { TypedUseSelectorHook, useDispatch, useSelector } from 'react-redux';
|
||||
import { AppDispatch, RootState } from 'app/store/store';
|
||||
import { AppThunkDispatch, RootState } from 'app/store/store';
|
||||
|
||||
// Use throughout your app instead of plain `useDispatch` and `useSelector`
|
||||
export const useAppDispatch: () => AppDispatch = useDispatch;
|
||||
export const useAppDispatch = () => useDispatch<AppThunkDispatch>();
|
||||
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
|
||||
|
@ -0,0 +1,7 @@
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
export const defaultSelectorOptions = {
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
};
|
@ -12,12 +12,10 @@
|
||||
* 'gfpgan'.
|
||||
*/
|
||||
|
||||
import { GalleryCategory } from 'features/gallery/store/gallerySlice';
|
||||
import { FacetoolType } from 'features/parameters/store/postprocessingSlice';
|
||||
import { SelectedImage } from 'features/parameters/store/actions';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { IRect } from 'konva/lib/types';
|
||||
import { ImageResponseMetadata, ImageType } from 'services/api';
|
||||
import { AnyInvocation } from 'services/events/types';
|
||||
import { O } from 'ts-toolbelt';
|
||||
|
||||
/**
|
||||
@ -49,15 +47,20 @@ export type CommonGeneratedImageMetadata = {
|
||||
postprocessing: null | Array<ESRGANMetadata | FacetoolMetadata>;
|
||||
sampler:
|
||||
| 'ddim'
|
||||
| 'k_dpm_2_a'
|
||||
| 'k_dpm_2'
|
||||
| 'k_dpmpp_2_a'
|
||||
| 'k_dpmpp_2'
|
||||
| 'k_euler_a'
|
||||
| 'k_euler'
|
||||
| 'k_heun'
|
||||
| 'k_lms'
|
||||
| 'plms';
|
||||
| 'ddpm'
|
||||
| 'deis'
|
||||
| 'lms'
|
||||
| 'pndm'
|
||||
| 'heun'
|
||||
| 'euler'
|
||||
| 'euler_k'
|
||||
| 'euler_a'
|
||||
| 'kdpm_2'
|
||||
| 'kdpm_2_a'
|
||||
| 'dpmpp_2s'
|
||||
| 'dpmpp_2m'
|
||||
| 'dpmpp_2m_k'
|
||||
| 'unipc';
|
||||
prompt: Prompt;
|
||||
seed: number;
|
||||
variations: SeedWeights;
|
||||
@ -126,6 +129,14 @@ export type Image = {
|
||||
metadata: ImageResponseMetadata;
|
||||
};
|
||||
|
||||
export const isInvokeAIImage = (obj: Image | SelectedImage): obj is Image => {
|
||||
if ('url' in obj && 'thumbnail' in obj) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
/**
|
||||
* Types related to the system status.
|
||||
*/
|
||||
@ -270,7 +281,7 @@ export type FoundModelResponse = {
|
||||
|
||||
// export type SystemConfigResponse = SystemConfig;
|
||||
|
||||
export type ImageResultResponse = Omit<_Image, 'uuid'> & {
|
||||
export type ImageResultResponse = Omit<Image, 'uuid'> & {
|
||||
boundingBox?: IRect;
|
||||
generationMode: InvokeTabName;
|
||||
};
|
||||
@ -315,11 +326,11 @@ export type AppFeature =
|
||||
/**
|
||||
* A disable-able Stable Diffusion feature
|
||||
*/
|
||||
export type StableDiffusionFeature =
|
||||
| 'noiseConfig'
|
||||
| 'variations'
|
||||
export type SDFeature =
|
||||
| 'noise'
|
||||
| 'variation'
|
||||
| 'symmetry'
|
||||
| 'tiling'
|
||||
| 'seamless'
|
||||
| 'hires';
|
||||
|
||||
/**
|
||||
@ -337,6 +348,7 @@ export type AppConfig = {
|
||||
shouldFetchImages: boolean;
|
||||
disabledTabs: InvokeTabName[];
|
||||
disabledFeatures: AppFeature[];
|
||||
disabledSDFeatures: SDFeature[];
|
||||
canRestoreDeletedImagesFromBin: boolean;
|
||||
sd: {
|
||||
iterations: {
|
||||
|
61
invokeai/frontend/web/src/common/components/IAICollapse.tsx
Normal file
61
invokeai/frontend/web/src/common/components/IAICollapse.tsx
Normal file
@ -0,0 +1,61 @@
|
||||
import { ChevronUpIcon } from '@chakra-ui/icons';
|
||||
import { Box, Collapse, Flex, Spacer, Switch } from '@chakra-ui/react';
|
||||
import { PropsWithChildren, memo } from 'react';
|
||||
|
||||
export type IAIToggleCollapseProps = PropsWithChildren & {
|
||||
label: string;
|
||||
isOpen: boolean;
|
||||
onToggle: () => void;
|
||||
withSwitch?: boolean;
|
||||
};
|
||||
|
||||
const IAICollapse = (props: IAIToggleCollapseProps) => {
|
||||
const { label, isOpen, onToggle, children, withSwitch = false } = props;
|
||||
return (
|
||||
<Box>
|
||||
<Flex
|
||||
onClick={onToggle}
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
p: 2,
|
||||
px: 4,
|
||||
borderTopRadius: 'base',
|
||||
borderBottomRadius: isOpen ? 0 : 'base',
|
||||
bg: isOpen ? 'base.750' : 'base.800',
|
||||
color: 'base.100',
|
||||
_hover: {
|
||||
bg: isOpen ? 'base.700' : 'base.750',
|
||||
},
|
||||
fontSize: 'sm',
|
||||
fontWeight: 600,
|
||||
cursor: 'pointer',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
userSelect: 'none',
|
||||
}}
|
||||
>
|
||||
{label}
|
||||
<Spacer />
|
||||
{withSwitch && <Switch isChecked={isOpen} pointerEvents="none" />}
|
||||
{!withSwitch && (
|
||||
<ChevronUpIcon
|
||||
sx={{
|
||||
w: '1rem',
|
||||
h: '1rem',
|
||||
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
<Collapse in={isOpen} animateOpacity>
|
||||
<Box sx={{ p: 4, borderBottomRadius: 'base', bg: 'base.800' }}>
|
||||
{children}
|
||||
</Box>
|
||||
</Collapse>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAICollapse);
|
@ -27,7 +27,7 @@ const IAIPopover = (props: IAIPopoverProps) => {
|
||||
return (
|
||||
<Popover isLazy={isLazy} {...rest}>
|
||||
<PopoverTrigger>{triggerComponent}</PopoverTrigger>
|
||||
<PopoverContent>
|
||||
<PopoverContent shadow="dark-lg">
|
||||
{hasArrow && <PopoverArrow />}
|
||||
{children}
|
||||
</PopoverContent>
|
||||
|
@ -0,0 +1,54 @@
|
||||
import { Badge, Flex } from '@chakra-ui/react';
|
||||
import { Image } from 'app/types/invokeai';
|
||||
import { isNumber, isString } from 'lodash-es';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
type ImageMetadataOverlayProps = {
|
||||
image: Image;
|
||||
};
|
||||
|
||||
const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => {
|
||||
const dimensions = useMemo(() => {
|
||||
if (!isNumber(image.metadata?.width) || isNumber(!image.metadata?.height)) {
|
||||
return;
|
||||
}
|
||||
|
||||
return `${image.metadata?.width} × ${image.metadata?.height}`;
|
||||
}, [image.metadata]);
|
||||
|
||||
const model = useMemo(() => {
|
||||
if (!isString(image.metadata?.invokeai?.node?.model)) {
|
||||
return;
|
||||
}
|
||||
|
||||
return image.metadata?.invokeai?.node?.model;
|
||||
}, [image.metadata]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
pointerEvents: 'none',
|
||||
flexDirection: 'column',
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
right: 0,
|
||||
p: 2,
|
||||
alignItems: 'flex-end',
|
||||
gap: 2,
|
||||
}}
|
||||
>
|
||||
{dimensions && (
|
||||
<Badge variant="solid" colorScheme="base">
|
||||
{dimensions}
|
||||
</Badge>
|
||||
)}
|
||||
{model && (
|
||||
<Badge variant="solid" colorScheme="base">
|
||||
{model}
|
||||
</Badge>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default ImageMetadataOverlay;
|
@ -7,7 +7,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useCallback } from 'react';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
|
||||
const ImageToImageSettingsHeader = () => {
|
||||
const InitialImageButtons = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
@ -18,24 +18,19 @@ const ImageToImageSettingsHeader = () => {
|
||||
return (
|
||||
<Flex w="full" alignItems="center">
|
||||
<Text size="sm" fontWeight={500} color="base.300">
|
||||
Image to Image
|
||||
{t('parameters.initialImage')}
|
||||
</Text>
|
||||
<Spacer />
|
||||
<ButtonGroup>
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
icon={<FaUndo />}
|
||||
aria-label={t('accessibility.reset')}
|
||||
onClick={handleResetInitialImage}
|
||||
/>
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
icon={<FaUpload />}
|
||||
aria-label={t('common.upload')}
|
||||
/>
|
||||
<IAIIconButton icon={<FaUpload />} aria-label={t('common.upload')} />
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default ImageToImageSettingsHeader;
|
||||
export default InitialImageButtons;
|
@ -1,36 +0,0 @@
|
||||
import { Badge, Box, Flex } from '@chakra-ui/react';
|
||||
import { Image } from 'app/types/invokeai';
|
||||
|
||||
type ImageToImageOverlayProps = {
|
||||
image: Image;
|
||||
};
|
||||
|
||||
const ImageToImageOverlay = ({ image }: ImageToImageOverlayProps) => {
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
top: 0,
|
||||
left: 0,
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
position: 'absolute',
|
||||
}}
|
||||
>
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
right: 0,
|
||||
p: 2,
|
||||
alignItems: 'flex-start',
|
||||
}}
|
||||
>
|
||||
<Badge variant="solid" colorScheme="base">
|
||||
{image.metadata?.width} × {image.metadata?.height}
|
||||
</Badge>
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default ImageToImageOverlay;
|
@ -49,7 +49,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
|
||||
const fileAcceptedCallback = useCallback(
|
||||
async (file: File) => {
|
||||
dispatch(imageUploaded({ formData: { file } }));
|
||||
dispatch(imageUploaded({ imageType: 'uploads', formData: { file } }));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
@ -124,7 +124,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(imageUploaded({ formData: { file } }));
|
||||
dispatch(imageUploaded({ imageType: 'uploads', formData: { file } }));
|
||||
};
|
||||
document.addEventListener('paste', pasteImageListener);
|
||||
return () => {
|
||||
|
@ -7,7 +7,7 @@ const SelectImagePlaceholder = () => {
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
bg: 'base.800',
|
||||
// bg: 'base.800',
|
||||
borderRadius: 'base',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
|
@ -2,6 +2,13 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||
import {
|
||||
setActiveTab,
|
||||
toggleGalleryPanel,
|
||||
toggleParametersPanel,
|
||||
togglePinGalleryPanel,
|
||||
togglePinParametersPanel,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
||||
|
||||
@ -36,4 +43,36 @@ export const useGlobalHotkeys = () => {
|
||||
{ keyup: true, keydown: true },
|
||||
[shift]
|
||||
);
|
||||
|
||||
useHotkeys('o', () => {
|
||||
dispatch(toggleParametersPanel());
|
||||
});
|
||||
|
||||
useHotkeys(['shift+o'], () => {
|
||||
dispatch(togglePinParametersPanel());
|
||||
});
|
||||
|
||||
useHotkeys('g', () => {
|
||||
dispatch(toggleGalleryPanel());
|
||||
});
|
||||
|
||||
useHotkeys(['shift+g'], () => {
|
||||
dispatch(togglePinGalleryPanel());
|
||||
});
|
||||
|
||||
useHotkeys('1', () => {
|
||||
dispatch(setActiveTab('txt2img'));
|
||||
});
|
||||
|
||||
useHotkeys('2', () => {
|
||||
dispatch(setActiveTab('img2img'));
|
||||
});
|
||||
|
||||
useHotkeys('3', () => {
|
||||
dispatch(setActiveTab('unifiedCanvas'));
|
||||
});
|
||||
|
||||
useHotkeys('4', () => {
|
||||
dispatch(setActiveTab('nodes'));
|
||||
});
|
||||
};
|
||||
|
33
invokeai/frontend/web/src/common/util/arrayBuffer.ts
Normal file
33
invokeai/frontend/web/src/common/util/arrayBuffer.ts
Normal file
@ -0,0 +1,33 @@
|
||||
export const getImageDataTransparency = (pixels: Uint8ClampedArray) => {
|
||||
let isFullyTransparent = true;
|
||||
let isPartiallyTransparent = false;
|
||||
const len = pixels.length;
|
||||
let i = 3;
|
||||
for (i; i < len; i += 4) {
|
||||
if (pixels[i] === 255) {
|
||||
isFullyTransparent = false;
|
||||
} else {
|
||||
isPartiallyTransparent = true;
|
||||
}
|
||||
if (!isFullyTransparent && isPartiallyTransparent) {
|
||||
return { isFullyTransparent, isPartiallyTransparent };
|
||||
}
|
||||
}
|
||||
return { isFullyTransparent, isPartiallyTransparent };
|
||||
};
|
||||
|
||||
export const areAnyPixelsBlack = (pixels: Uint8ClampedArray) => {
|
||||
const len = pixels.length;
|
||||
let i = 0;
|
||||
for (i; i < len; ) {
|
||||
if (
|
||||
pixels[i++] === 0 &&
|
||||
pixels[i++] === 0 &&
|
||||
pixels[i++] === 0 &&
|
||||
pixels[i++] === 255
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
@ -19,6 +19,7 @@ import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import openBase64ImageInTab from './openBase64ImageInTab';
|
||||
import randomInt from './randomInt';
|
||||
import { stringToSeedWeightsArray } from './seedWeightPairs';
|
||||
import { getIsImageDataTransparent, getIsImageDataWhite } from './arrayBuffer';
|
||||
|
||||
export type FrontendToBackendParametersConfig = {
|
||||
generationMode: InvokeTabName;
|
||||
@ -256,7 +257,7 @@ export const frontendToBackendParameters = (
|
||||
...boundingBoxDimensions,
|
||||
};
|
||||
|
||||
const maskDataURL = generateMask(
|
||||
const { dataURL: maskDataURL, imageData: maskImageData } = generateMask(
|
||||
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [],
|
||||
boundingBox
|
||||
);
|
||||
@ -287,6 +288,17 @@ export const frontendToBackendParameters = (
|
||||
height: boundingBox.height,
|
||||
});
|
||||
|
||||
const ctx = canvasBaseLayer.getContext();
|
||||
const imageData = ctx.getImageData(
|
||||
boundingBox.x + absPos.x,
|
||||
boundingBox.y + absPos.y,
|
||||
boundingBox.width,
|
||||
boundingBox.height
|
||||
);
|
||||
|
||||
const doesBaseHaveTransparency = getIsImageDataTransparent(imageData);
|
||||
const doesMaskHaveTransparency = getIsImageDataWhite(maskImageData);
|
||||
|
||||
if (enableImageDebugging) {
|
||||
openBase64ImageInTab([
|
||||
{ base64: maskDataURL, caption: 'mask sent as init_mask' },
|
||||
|
@ -34,6 +34,7 @@ import IAICanvasStagingAreaToolbar from './IAICanvasStagingAreaToolbar';
|
||||
import IAICanvasStatusText from './IAICanvasStatusText';
|
||||
import IAICanvasBoundingBox from './IAICanvasToolbar/IAICanvasBoundingBox';
|
||||
import IAICanvasToolPreview from './IAICanvasToolPreview';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
|
||||
const selector = createSelector(
|
||||
[canvasSelector, isStagingSelector],
|
||||
@ -52,6 +53,7 @@ const selector = createSelector(
|
||||
shouldShowIntermediates,
|
||||
shouldShowGrid,
|
||||
shouldRestrictStrokesToBox,
|
||||
shouldAntialias,
|
||||
} = canvas;
|
||||
|
||||
let stageCursor: string | undefined = 'none';
|
||||
@ -80,13 +82,10 @@ const selector = createSelector(
|
||||
tool,
|
||||
isStaging,
|
||||
shouldShowIntermediates,
|
||||
shouldAntialias,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ChakraStage = chakra(Stage, {
|
||||
@ -106,6 +105,7 @@ const IAICanvas = () => {
|
||||
tool,
|
||||
isStaging,
|
||||
shouldShowIntermediates,
|
||||
shouldAntialias,
|
||||
} = useAppSelector(selector);
|
||||
useCanvasHotkeys();
|
||||
|
||||
@ -190,7 +190,7 @@ const IAICanvas = () => {
|
||||
id="base"
|
||||
ref={canvasBaseLayerRefCallback}
|
||||
listening={false}
|
||||
imageSmoothingEnabled={false}
|
||||
imageSmoothingEnabled={shouldAntialias}
|
||||
>
|
||||
<IAICanvasObjectRenderer />
|
||||
</Layer>
|
||||
@ -201,7 +201,7 @@ const IAICanvas = () => {
|
||||
<Layer>
|
||||
<IAICanvasBoundingBoxOverlay />
|
||||
</Layer>
|
||||
<Layer id="preview" imageSmoothingEnabled={false}>
|
||||
<Layer id="preview" imageSmoothingEnabled={shouldAntialias}>
|
||||
{!isStaging && (
|
||||
<IAICanvasToolPreview
|
||||
visible={tool !== 'move'}
|
||||
|
@ -12,18 +12,20 @@ const selector = createSelector(
|
||||
[canvasSelector],
|
||||
(canvas) => {
|
||||
const {
|
||||
layerState: {
|
||||
stagingArea: { images, selectedImageIndex },
|
||||
},
|
||||
layerState,
|
||||
shouldShowStagingImage,
|
||||
shouldShowStagingOutline,
|
||||
boundingBoxCoordinates: { x, y },
|
||||
boundingBoxDimensions: { width, height },
|
||||
} = canvas;
|
||||
|
||||
const { selectedImageIndex, images } = layerState.stagingArea;
|
||||
|
||||
return {
|
||||
currentStagingAreaImage:
|
||||
images.length > 0 ? images[selectedImageIndex] : undefined,
|
||||
images.length > 0 && selectedImageIndex !== undefined
|
||||
? images[selectedImageIndex]
|
||||
: undefined,
|
||||
isOnFirstImage: selectedImageIndex === 0,
|
||||
isOnLastImage: selectedImageIndex === images.length - 1,
|
||||
shouldShowStagingImage,
|
||||
|
@ -6,6 +6,7 @@ import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAIPopover from 'common/components/IAIPopover';
|
||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import {
|
||||
setShouldAntialias,
|
||||
setShouldAutoSave,
|
||||
setShouldCropToBoundingBoxOnSave,
|
||||
setShouldDarkenOutsideBoundingBox,
|
||||
@ -36,6 +37,7 @@ export const canvasControlsSelector = createSelector(
|
||||
shouldShowIntermediates,
|
||||
shouldSnapToGrid,
|
||||
shouldRestrictStrokesToBox,
|
||||
shouldAntialias,
|
||||
} = canvas;
|
||||
|
||||
return {
|
||||
@ -47,6 +49,7 @@ export const canvasControlsSelector = createSelector(
|
||||
shouldShowIntermediates,
|
||||
shouldSnapToGrid,
|
||||
shouldRestrictStrokesToBox,
|
||||
shouldAntialias,
|
||||
};
|
||||
},
|
||||
{
|
||||
@ -69,6 +72,7 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
shouldShowIntermediates,
|
||||
shouldSnapToGrid,
|
||||
shouldRestrictStrokesToBox,
|
||||
shouldAntialias,
|
||||
} = useAppSelector(canvasControlsSelector);
|
||||
|
||||
useHotkeys(
|
||||
@ -148,6 +152,12 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
dispatch(setShouldShowCanvasDebugInfo(e.target.checked))
|
||||
}
|
||||
/>
|
||||
|
||||
<IAICheckbox
|
||||
label={t('unifiedCanvas.antialiasing')}
|
||||
isChecked={shouldAntialias}
|
||||
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
|
||||
/>
|
||||
<ClearCanvasHistoryButtonModal />
|
||||
<EmptyTempFolderButtonModal />
|
||||
</Flex>
|
||||
|
@ -9,6 +9,12 @@ const itemsToDenylist: (keyof CanvasState)[] = [
|
||||
'doesCanvasNeedScaling',
|
||||
];
|
||||
|
||||
export const canvasPersistDenylist: (keyof CanvasState)[] = [
|
||||
'cursorPosition',
|
||||
'isCanvasInitialized',
|
||||
'doesCanvasNeedScaling',
|
||||
];
|
||||
|
||||
export const canvasDenylist = itemsToDenylist.map(
|
||||
(denylistItem) => `canvas.${denylistItem}`
|
||||
);
|
||||
|
@ -38,7 +38,7 @@ export const initialLayerState: CanvasLayerState = {
|
||||
},
|
||||
};
|
||||
|
||||
const initialCanvasState: CanvasState = {
|
||||
export const initialCanvasState: CanvasState = {
|
||||
boundingBoxCoordinates: { x: 0, y: 0 },
|
||||
boundingBoxDimensions: { width: 512, height: 512 },
|
||||
boundingBoxPreviewFill: { r: 0, g: 0, b: 0, a: 0.5 },
|
||||
@ -66,6 +66,7 @@ const initialCanvasState: CanvasState = {
|
||||
minimumStageScale: 1,
|
||||
pastLayerStates: [],
|
||||
scaledBoundingBoxDimensions: { width: 512, height: 512 },
|
||||
shouldAntialias: true,
|
||||
shouldAutoSave: false,
|
||||
shouldCropToBoundingBoxOnSave: false,
|
||||
shouldDarkenOutsideBoundingBox: false,
|
||||
@ -156,22 +157,20 @@ export const canvasSlice = createSlice({
|
||||
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
|
||||
state.cursorPosition = action.payload;
|
||||
},
|
||||
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI._Image>) => {
|
||||
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI.Image>) => {
|
||||
const image = action.payload;
|
||||
const { width, height } = image.metadata;
|
||||
const { stageDimensions } = state;
|
||||
|
||||
const newBoundingBoxDimensions = {
|
||||
width: roundDownToMultiple(clamp(image.width, 64, 512), 64),
|
||||
height: roundDownToMultiple(clamp(image.height, 64, 512), 64),
|
||||
width: roundDownToMultiple(clamp(width, 64, 512), 64),
|
||||
height: roundDownToMultiple(clamp(height, 64, 512), 64),
|
||||
};
|
||||
|
||||
const newBoundingBoxCoordinates = {
|
||||
x: roundToMultiple(
|
||||
image.width / 2 - newBoundingBoxDimensions.width / 2,
|
||||
64
|
||||
),
|
||||
x: roundToMultiple(width / 2 - newBoundingBoxDimensions.width / 2, 64),
|
||||
y: roundToMultiple(
|
||||
image.height / 2 - newBoundingBoxDimensions.height / 2,
|
||||
height / 2 - newBoundingBoxDimensions.height / 2,
|
||||
64
|
||||
),
|
||||
};
|
||||
@ -196,8 +195,8 @@ export const canvasSlice = createSlice({
|
||||
layer: 'base',
|
||||
x: 0,
|
||||
y: 0,
|
||||
width: image.width,
|
||||
height: image.height,
|
||||
width: width,
|
||||
height: height,
|
||||
image: image,
|
||||
},
|
||||
],
|
||||
@ -208,8 +207,8 @@ export const canvasSlice = createSlice({
|
||||
const newScale = calculateScale(
|
||||
stageDimensions.width,
|
||||
stageDimensions.height,
|
||||
image.width,
|
||||
image.height,
|
||||
width,
|
||||
height,
|
||||
STAGE_PADDING_PERCENTAGE
|
||||
);
|
||||
|
||||
@ -218,8 +217,8 @@ export const canvasSlice = createSlice({
|
||||
stageDimensions.height,
|
||||
0,
|
||||
0,
|
||||
image.width,
|
||||
image.height,
|
||||
width,
|
||||
height,
|
||||
newScale
|
||||
);
|
||||
state.stageScale = newScale;
|
||||
@ -287,16 +286,28 @@ export const canvasSlice = createSlice({
|
||||
setIsMoveStageKeyHeld: (state, action: PayloadAction<boolean>) => {
|
||||
state.isMoveStageKeyHeld = action.payload;
|
||||
},
|
||||
addImageToStagingArea: (
|
||||
canvasSessionIdChanged: (state, action: PayloadAction<string>) => {
|
||||
state.layerState.stagingArea.sessionId = action.payload;
|
||||
},
|
||||
stagingAreaInitialized: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
boundingBox: IRect;
|
||||
image: InvokeAI._Image;
|
||||
}>
|
||||
action: PayloadAction<{ sessionId: string; boundingBox: IRect }>
|
||||
) => {
|
||||
const { boundingBox, image } = action.payload;
|
||||
const { sessionId, boundingBox } = action.payload;
|
||||
|
||||
if (!boundingBox || !image) return;
|
||||
state.layerState.stagingArea = {
|
||||
boundingBox,
|
||||
sessionId,
|
||||
images: [],
|
||||
selectedImageIndex: -1,
|
||||
};
|
||||
},
|
||||
addImageToStagingArea: (state, action: PayloadAction<InvokeAI.Image>) => {
|
||||
const image = action.payload;
|
||||
|
||||
if (!image || !state.layerState.stagingArea.boundingBox) {
|
||||
return;
|
||||
}
|
||||
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
|
||||
@ -307,7 +318,7 @@ export const canvasSlice = createSlice({
|
||||
state.layerState.stagingArea.images.push({
|
||||
kind: 'image',
|
||||
layer: 'base',
|
||||
...boundingBox,
|
||||
...state.layerState.stagingArea.boundingBox,
|
||||
image,
|
||||
});
|
||||
|
||||
@ -323,9 +334,7 @@ export const canvasSlice = createSlice({
|
||||
state.pastLayerStates.shift();
|
||||
}
|
||||
|
||||
state.layerState.stagingArea = {
|
||||
...initialLayerState.stagingArea,
|
||||
};
|
||||
state.layerState.stagingArea = { ...initialLayerState.stagingArea };
|
||||
|
||||
state.futureLayerStates = [];
|
||||
state.shouldShowStagingOutline = true;
|
||||
@ -663,6 +672,10 @@ export const canvasSlice = createSlice({
|
||||
}
|
||||
},
|
||||
nextStagingAreaImage: (state) => {
|
||||
if (!state.layerState.stagingArea.images.length) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentIndex = state.layerState.stagingArea.selectedImageIndex;
|
||||
const length = state.layerState.stagingArea.images.length;
|
||||
|
||||
@ -672,6 +685,10 @@ export const canvasSlice = createSlice({
|
||||
);
|
||||
},
|
||||
prevStagingAreaImage: (state) => {
|
||||
if (!state.layerState.stagingArea.images.length) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentIndex = state.layerState.stagingArea.selectedImageIndex;
|
||||
|
||||
state.layerState.stagingArea.selectedImageIndex = Math.max(
|
||||
@ -680,6 +697,10 @@ export const canvasSlice = createSlice({
|
||||
);
|
||||
},
|
||||
commitStagingAreaImage: (state) => {
|
||||
if (!state.layerState.stagingArea.images.length) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { images, selectedImageIndex } = state.layerState.stagingArea;
|
||||
|
||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||
@ -776,6 +797,9 @@ export const canvasSlice = createSlice({
|
||||
setShouldRestrictStrokesToBox: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldRestrictStrokesToBox = action.payload;
|
||||
},
|
||||
setShouldAntialias: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldAntialias = action.payload;
|
||||
},
|
||||
setShouldCropToBoundingBoxOnSave: (
|
||||
state,
|
||||
action: PayloadAction<boolean>
|
||||
@ -885,6 +909,9 @@ export const {
|
||||
undo,
|
||||
setScaledBoundingBoxDimensions,
|
||||
setShouldRestrictStrokesToBox,
|
||||
stagingAreaInitialized,
|
||||
canvasSessionIdChanged,
|
||||
setShouldAntialias,
|
||||
} = canvasSlice.actions;
|
||||
|
||||
export default canvasSlice.reducer;
|
||||
|
@ -37,7 +37,7 @@ export type CanvasImage = {
|
||||
y: number;
|
||||
width: number;
|
||||
height: number;
|
||||
image: InvokeAI._Image;
|
||||
image: InvokeAI.Image;
|
||||
};
|
||||
|
||||
export type CanvasMaskLine = {
|
||||
@ -90,9 +90,16 @@ export type CanvasLayerState = {
|
||||
stagingArea: {
|
||||
images: CanvasImage[];
|
||||
selectedImageIndex: number;
|
||||
sessionId?: string;
|
||||
boundingBox?: IRect;
|
||||
};
|
||||
};
|
||||
|
||||
export type CanvasSession = {
|
||||
sessionId: string;
|
||||
boundingBox: IRect;
|
||||
};
|
||||
|
||||
// type guards
|
||||
export const isCanvasMaskLine = (obj: CanvasObject): obj is CanvasMaskLine =>
|
||||
obj.kind === 'line' && obj.layer === 'mask';
|
||||
@ -125,7 +132,7 @@ export interface CanvasState {
|
||||
cursorPosition: Vector2d | null;
|
||||
doesCanvasNeedScaling: boolean;
|
||||
futureLayerStates: CanvasLayerState[];
|
||||
intermediateImage?: InvokeAI._Image;
|
||||
intermediateImage?: InvokeAI.Image;
|
||||
isCanvasInitialized: boolean;
|
||||
isDrawing: boolean;
|
||||
isMaskEnabled: boolean;
|
||||
@ -142,6 +149,7 @@ export interface CanvasState {
|
||||
minimumStageScale: number;
|
||||
pastLayerStates: CanvasLayerState[];
|
||||
scaledBoundingBoxDimensions: Dimensions;
|
||||
shouldAntialias: boolean;
|
||||
shouldAutoSave: boolean;
|
||||
shouldCropToBoundingBoxOnSave: boolean;
|
||||
shouldDarkenOutsideBoundingBox: boolean;
|
||||
|
@ -0,0 +1,13 @@
|
||||
/**
|
||||
* Gets a Blob from a canvas.
|
||||
*/
|
||||
export const canvasToBlob = async (canvas: HTMLCanvasElement): Promise<Blob> =>
|
||||
new Promise((resolve, reject) => {
|
||||
canvas.toBlob((blob) => {
|
||||
if (blob) {
|
||||
resolve(blob);
|
||||
return;
|
||||
}
|
||||
reject('Unable to create Blob');
|
||||
});
|
||||
});
|
@ -0,0 +1,29 @@
|
||||
/**
|
||||
* Gets an ImageData object from an image dataURL by drawing it to a canvas.
|
||||
*/
|
||||
export const dataURLToImageData = async (
|
||||
dataURL: string,
|
||||
width: number,
|
||||
height: number
|
||||
): Promise<ImageData> =>
|
||||
new Promise((resolve, reject) => {
|
||||
const canvas = document.createElement('canvas');
|
||||
canvas.width = width;
|
||||
canvas.height = height;
|
||||
const ctx = canvas.getContext('2d');
|
||||
const image = new Image();
|
||||
|
||||
if (!ctx) {
|
||||
canvas.remove();
|
||||
reject('Unable to get context');
|
||||
return;
|
||||
}
|
||||
|
||||
image.onload = function () {
|
||||
ctx.drawImage(image, 0, 0);
|
||||
canvas.remove();
|
||||
resolve(ctx.getImageData(0, 0, width, height));
|
||||
};
|
||||
|
||||
image.src = dataURL;
|
||||
});
|
@ -1,6 +1,110 @@
|
||||
// import { CanvasMaskLine } from 'features/canvas/store/canvasTypes';
|
||||
// import Konva from 'konva';
|
||||
// import { Stage } from 'konva/lib/Stage';
|
||||
// import { IRect } from 'konva/lib/types';
|
||||
|
||||
// /**
|
||||
// * Generating a mask image from InpaintingCanvas.tsx is not as simple
|
||||
// * as calling toDataURL() on the canvas, because the mask may be represented
|
||||
// * by colored lines or transparency, or the user may have inverted the mask
|
||||
// * display.
|
||||
// *
|
||||
// * So we need to regenerate the mask image by creating an offscreen canvas,
|
||||
// * drawing the mask and compositing everything correctly to output a valid
|
||||
// * mask image.
|
||||
// */
|
||||
// export const getStageDataURL = (stage: Stage, boundingBox: IRect): string => {
|
||||
// // create an offscreen canvas and add the mask to it
|
||||
// // const { stage, offscreenContainer } = buildMaskStage(lines, boundingBox);
|
||||
|
||||
// const dataURL = stage.toDataURL({ ...boundingBox });
|
||||
|
||||
// // const imageData = stage
|
||||
// // .toCanvas()
|
||||
// // .getContext('2d')
|
||||
// // ?.getImageData(
|
||||
// // boundingBox.x,
|
||||
// // boundingBox.y,
|
||||
// // boundingBox.width,
|
||||
// // boundingBox.height
|
||||
// // );
|
||||
|
||||
// // offscreenContainer.remove();
|
||||
|
||||
// // return { dataURL, imageData };
|
||||
|
||||
// return dataURL;
|
||||
// };
|
||||
|
||||
// export const getStageImageData = (
|
||||
// stage: Stage,
|
||||
// boundingBox: IRect
|
||||
// ): ImageData | undefined => {
|
||||
// const imageData = stage
|
||||
// .toCanvas()
|
||||
// .getContext('2d')
|
||||
// ?.getImageData(
|
||||
// boundingBox.x,
|
||||
// boundingBox.y,
|
||||
// boundingBox.width,
|
||||
// boundingBox.height
|
||||
// );
|
||||
|
||||
// return imageData;
|
||||
// };
|
||||
|
||||
// export const buildMaskStage = (
|
||||
// lines: CanvasMaskLine[],
|
||||
// boundingBox: IRect
|
||||
// ): { stage: Stage; offscreenContainer: HTMLDivElement } => {
|
||||
// // create an offscreen canvas and add the mask to it
|
||||
// const { width, height } = boundingBox;
|
||||
|
||||
// const offscreenContainer = document.createElement('div');
|
||||
|
||||
// const stage = new Konva.Stage({
|
||||
// container: offscreenContainer,
|
||||
// width: width,
|
||||
// height: height,
|
||||
// });
|
||||
|
||||
// const baseLayer = new Konva.Layer();
|
||||
// const maskLayer = new Konva.Layer();
|
||||
|
||||
// // composite the image onto the mask layer
|
||||
// baseLayer.add(
|
||||
// new Konva.Rect({
|
||||
// ...boundingBox,
|
||||
// fill: 'white',
|
||||
// })
|
||||
// );
|
||||
|
||||
// lines.forEach((line) =>
|
||||
// maskLayer.add(
|
||||
// new Konva.Line({
|
||||
// points: line.points,
|
||||
// stroke: 'black',
|
||||
// strokeWidth: line.strokeWidth * 2,
|
||||
// tension: 0,
|
||||
// lineCap: 'round',
|
||||
// lineJoin: 'round',
|
||||
// shadowForStrokeEnabled: false,
|
||||
// globalCompositeOperation:
|
||||
// line.tool === 'brush' ? 'source-over' : 'destination-out',
|
||||
// })
|
||||
// )
|
||||
// );
|
||||
|
||||
// stage.add(baseLayer);
|
||||
// stage.add(maskLayer);
|
||||
|
||||
// return { stage, offscreenContainer };
|
||||
// };
|
||||
|
||||
import { CanvasMaskLine } from 'features/canvas/store/canvasTypes';
|
||||
import Konva from 'konva';
|
||||
import { IRect } from 'konva/lib/types';
|
||||
import { canvasToBlob } from './canvasToBlob';
|
||||
|
||||
/**
|
||||
* Generating a mask image from InpaintingCanvas.tsx is not as simple
|
||||
@ -12,7 +116,7 @@ import { IRect } from 'konva/lib/types';
|
||||
* drawing the mask and compositing everything correctly to output a valid
|
||||
* mask image.
|
||||
*/
|
||||
const generateMask = (lines: CanvasMaskLine[], boundingBox: IRect): string => {
|
||||
const generateMask = async (lines: CanvasMaskLine[], boundingBox: IRect) => {
|
||||
// create an offscreen canvas and add the mask to it
|
||||
const { width, height } = boundingBox;
|
||||
|
||||
@ -54,11 +158,13 @@ const generateMask = (lines: CanvasMaskLine[], boundingBox: IRect): string => {
|
||||
stage.add(baseLayer);
|
||||
stage.add(maskLayer);
|
||||
|
||||
const dataURL = stage.toDataURL({ ...boundingBox });
|
||||
const maskDataURL = stage.toDataURL(boundingBox);
|
||||
|
||||
const maskBlob = await canvasToBlob(stage.toCanvas(boundingBox));
|
||||
|
||||
offscreenContainer.remove();
|
||||
|
||||
return dataURL;
|
||||
return { maskDataURL, maskBlob };
|
||||
};
|
||||
|
||||
export default generateMask;
|
||||
|
128
invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts
Normal file
128
invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts
Normal file
@ -0,0 +1,128 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { getCanvasBaseLayer, getCanvasStage } from './konvaInstanceProvider';
|
||||
import { isCanvasMaskLine } from '../store/canvasTypes';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import {
|
||||
areAnyPixelsBlack,
|
||||
getImageDataTransparency,
|
||||
} from 'common/util/arrayBuffer';
|
||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||
import generateMask from './generateMask';
|
||||
import { dataURLToImageData } from './dataURLToImageData';
|
||||
import { canvasToBlob } from './canvasToBlob';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'getCanvasDataURLs' });
|
||||
|
||||
export const getCanvasData = async (state: RootState) => {
|
||||
const canvasBaseLayer = getCanvasBaseLayer();
|
||||
const canvasStage = getCanvasStage();
|
||||
|
||||
if (!canvasBaseLayer || !canvasStage) {
|
||||
moduleLog.error('Unable to find canvas / stage');
|
||||
return;
|
||||
}
|
||||
|
||||
const {
|
||||
layerState: { objects },
|
||||
boundingBoxCoordinates,
|
||||
boundingBoxDimensions,
|
||||
stageScale,
|
||||
isMaskEnabled,
|
||||
shouldPreserveMaskedArea,
|
||||
boundingBoxScaleMethod: boundingBoxScale,
|
||||
scaledBoundingBoxDimensions,
|
||||
} = state.canvas;
|
||||
|
||||
const boundingBox = {
|
||||
...boundingBoxCoordinates,
|
||||
...boundingBoxDimensions,
|
||||
};
|
||||
|
||||
// generationParameters.fit = false;
|
||||
|
||||
// generationParameters.strength = img2imgStrength;
|
||||
|
||||
// generationParameters.invert_mask = shouldPreserveMaskedArea;
|
||||
|
||||
// generationParameters.bounding_box = boundingBox;
|
||||
|
||||
const tempScale = canvasBaseLayer.scale();
|
||||
|
||||
canvasBaseLayer.scale({
|
||||
x: 1 / stageScale,
|
||||
y: 1 / stageScale,
|
||||
});
|
||||
|
||||
const absPos = canvasBaseLayer.getAbsolutePosition();
|
||||
|
||||
const offsetBoundingBox = {
|
||||
x: boundingBox.x + absPos.x,
|
||||
y: boundingBox.y + absPos.y,
|
||||
width: boundingBox.width,
|
||||
height: boundingBox.height,
|
||||
};
|
||||
|
||||
const baseDataURL = canvasBaseLayer.toDataURL(offsetBoundingBox);
|
||||
const baseBlob = await canvasToBlob(
|
||||
canvasBaseLayer.toCanvas(offsetBoundingBox)
|
||||
);
|
||||
|
||||
canvasBaseLayer.scale(tempScale);
|
||||
|
||||
const { maskDataURL, maskBlob } = await generateMask(
|
||||
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [],
|
||||
boundingBox
|
||||
);
|
||||
|
||||
const baseImageData = await dataURLToImageData(
|
||||
baseDataURL,
|
||||
boundingBox.width,
|
||||
boundingBox.height
|
||||
);
|
||||
|
||||
const maskImageData = await dataURLToImageData(
|
||||
maskDataURL,
|
||||
boundingBox.width,
|
||||
boundingBox.height
|
||||
);
|
||||
|
||||
const {
|
||||
isPartiallyTransparent: baseIsPartiallyTransparent,
|
||||
isFullyTransparent: baseIsFullyTransparent,
|
||||
} = getImageDataTransparency(baseImageData.data);
|
||||
|
||||
const doesMaskHaveBlackPixels = areAnyPixelsBlack(maskImageData.data);
|
||||
|
||||
if (state.system.enableImageDebugging) {
|
||||
openBase64ImageInTab([
|
||||
{ base64: maskDataURL, caption: 'mask b64' },
|
||||
{ base64: baseDataURL, caption: 'image b64' },
|
||||
]);
|
||||
}
|
||||
|
||||
// generationParameters.init_img = imageDataURL;
|
||||
// generationParameters.progress_images = false;
|
||||
|
||||
// if (boundingBoxScale !== 'none') {
|
||||
// generationParameters.inpaint_width = scaledBoundingBoxDimensions.width;
|
||||
// generationParameters.inpaint_height = scaledBoundingBoxDimensions.height;
|
||||
// }
|
||||
|
||||
// generationParameters.seam_size = seamSize;
|
||||
// generationParameters.seam_blur = seamBlur;
|
||||
// generationParameters.seam_strength = seamStrength;
|
||||
// generationParameters.seam_steps = seamSteps;
|
||||
// generationParameters.tile_size = tileSize;
|
||||
// generationParameters.infill_method = infillMethod;
|
||||
// generationParameters.force_outpaint = false;
|
||||
|
||||
return {
|
||||
baseDataURL,
|
||||
baseBlob,
|
||||
maskDataURL,
|
||||
maskBlob,
|
||||
baseIsPartiallyTransparent,
|
||||
baseIsFullyTransparent,
|
||||
doesMaskHaveBlackPixels,
|
||||
};
|
||||
};
|
@ -1,12 +1,17 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { get, isEqual, isNumber, isString } from 'lodash-es';
|
||||
import { isEqual, isString } from 'lodash-es';
|
||||
|
||||
import {
|
||||
ButtonGroup,
|
||||
Flex,
|
||||
FlexProps,
|
||||
FormControl,
|
||||
IconButton,
|
||||
Link,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuItemOption,
|
||||
MenuList,
|
||||
MenuOptionGroup,
|
||||
useDisclosure,
|
||||
useToast,
|
||||
} from '@chakra-ui/react';
|
||||
@ -15,21 +20,12 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAIPopover from 'common/components/IAIPopover';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { GalleryState } from 'features/gallery/store/gallerySlice';
|
||||
|
||||
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
|
||||
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
||||
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
|
||||
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
|
||||
import {
|
||||
initialImageSelected,
|
||||
setAllParameters,
|
||||
// setInitialImage,
|
||||
setSeed,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { SystemState } from 'features/system/store/systemSlice';
|
||||
|
||||
import {
|
||||
activeTabNameSelector,
|
||||
uiSelector,
|
||||
@ -56,6 +52,7 @@ import {
|
||||
FaShare,
|
||||
FaShareAlt,
|
||||
FaTrash,
|
||||
FaWrench,
|
||||
} from 'react-icons/fa';
|
||||
import {
|
||||
gallerySelector,
|
||||
@ -66,8 +63,13 @@ import { useCallback } from 'react';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { imageDeleted } from 'services/thunks/image';
|
||||
import { useParameters } from 'features/parameters/hooks/useParameters';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { requestedImageDeletion } from '../store/actions';
|
||||
import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceRestore/FaceRestoreSettings';
|
||||
import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings';
|
||||
import { allParametersSet } from 'features/parameters/store/generationSlice';
|
||||
import DeleteImageButton from './ImageActionButtons/DeleteImageButton';
|
||||
|
||||
const currentImageButtonsSelector = createSelector(
|
||||
[
|
||||
@ -150,6 +152,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
} = useAppSelector(currentImageButtonsSelector);
|
||||
|
||||
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
|
||||
const isFaceRestoreEnabled = useFeatureStatus('faceRestore').isFeatureEnabled;
|
||||
|
||||
@ -164,40 +167,59 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
const toast = useToast();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { recallPrompt, recallSeed, sendToImageToImage } = useParameters();
|
||||
const { recallPrompt, recallSeed, recallAllParameters } = useParameters();
|
||||
|
||||
const handleCopyImage = useCallback(async () => {
|
||||
if (!image?.url) {
|
||||
// const handleCopyImage = useCallback(async () => {
|
||||
// if (!image?.url) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
// const url = getUrl(image.url);
|
||||
|
||||
// if (!url) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
// const blob = await fetch(url).then((res) => res.blob());
|
||||
// const data = [new ClipboardItem({ [blob.type]: blob })];
|
||||
|
||||
// await navigator.clipboard.write(data);
|
||||
|
||||
// toast({
|
||||
// title: t('toast.imageCopied'),
|
||||
// status: 'success',
|
||||
// duration: 2500,
|
||||
// isClosable: true,
|
||||
// });
|
||||
// }, [getUrl, t, image?.url, toast]);
|
||||
|
||||
const handleCopyImageLink = useCallback(() => {
|
||||
const getImageUrl = () => {
|
||||
if (!image) {
|
||||
return;
|
||||
}
|
||||
|
||||
const url = getUrl(image.url);
|
||||
if (shouldTransformUrls) {
|
||||
return getUrl(image.url);
|
||||
}
|
||||
|
||||
if (image.url.startsWith('http')) {
|
||||
return image.url;
|
||||
}
|
||||
|
||||
return window.location.toString() + image.url;
|
||||
};
|
||||
|
||||
const url = getImageUrl();
|
||||
|
||||
if (!url) {
|
||||
return;
|
||||
}
|
||||
|
||||
const blob = await fetch(url).then((res) => res.blob());
|
||||
const data = [new ClipboardItem({ [blob.type]: blob })];
|
||||
|
||||
await navigator.clipboard.write(data);
|
||||
|
||||
toast({
|
||||
title: t('toast.imageCopied'),
|
||||
status: 'success',
|
||||
title: t('toast.problemCopyingImageLink'),
|
||||
status: 'error',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
}, [getUrl, t, image?.url, toast]);
|
||||
|
||||
const handleCopyImageLink = useCallback(() => {
|
||||
const url = image
|
||||
? shouldTransformUrls
|
||||
? getUrl(image.url)
|
||||
: window.location.toString() + image.url
|
||||
: '';
|
||||
|
||||
if (!url) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -216,39 +238,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
}, [dispatch, shouldHidePreview]);
|
||||
|
||||
const handleClickUseAllParameters = useCallback(() => {
|
||||
if (!image) return;
|
||||
// selectedImage.metadata &&
|
||||
// dispatch(setAllParameters(selectedImage.metadata));
|
||||
// if (selectedImage.metadata?.image.type === 'img2img') {
|
||||
// dispatch(setActiveTab('img2img'));
|
||||
// } else if (selectedImage.metadata?.image.type === 'txt2img') {
|
||||
// dispatch(setActiveTab('txt2img'));
|
||||
// }
|
||||
}, [image]);
|
||||
recallAllParameters(image);
|
||||
}, [image, recallAllParameters]);
|
||||
|
||||
useHotkeys(
|
||||
'a',
|
||||
() => {
|
||||
const type = image?.metadata?.invokeai?.node?.types;
|
||||
if (isString(type) && ['txt2img', 'img2img'].includes(type)) {
|
||||
handleClickUseAllParameters();
|
||||
toast({
|
||||
title: t('toast.parametersSet'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
} else {
|
||||
toast({
|
||||
title: t('toast.parametersNotSet'),
|
||||
description: t('toast.parametersNotSetDesc'),
|
||||
status: 'error',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
}
|
||||
handleClickUseAllParameters;
|
||||
},
|
||||
[image]
|
||||
[image, recallAllParameters]
|
||||
);
|
||||
|
||||
const handleUseSeed = useCallback(() => {
|
||||
@ -264,8 +262,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
useHotkeys('p', handleUsePrompt, [image]);
|
||||
|
||||
const handleSendToImageToImage = useCallback(() => {
|
||||
sendToImageToImage(image);
|
||||
}, [image, sendToImageToImage]);
|
||||
dispatch(initialImageSelected(image));
|
||||
}, [dispatch, image]);
|
||||
|
||||
useHotkeys('shift+i', handleSendToImageToImage, [image]);
|
||||
|
||||
@ -375,7 +373,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
|
||||
const handleDelete = useCallback(() => {
|
||||
if (canDeleteImage && image) {
|
||||
dispatch(imageDeleted({ imageType: image.type, imageName: image.name }));
|
||||
dispatch(requestedImageDeletion(image));
|
||||
}
|
||||
}, [image, canDeleteImage, dispatch]);
|
||||
|
||||
@ -432,6 +430,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
>
|
||||
{t('parameters.sendToImg2Img')}
|
||||
</IAIButton>
|
||||
{isCanvasEnabled && (
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={handleSendToCanvas}
|
||||
@ -439,14 +438,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
>
|
||||
{t('parameters.sendToUnifiedCanvas')}
|
||||
</IAIButton>
|
||||
)}
|
||||
|
||||
<IAIButton
|
||||
{/* <IAIButton
|
||||
size="sm"
|
||||
onClick={handleCopyImage}
|
||||
leftIcon={<FaCopy />}
|
||||
>
|
||||
{t('parameters.copyImage')}
|
||||
</IAIButton>
|
||||
</IAIButton> */}
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={handleCopyImageLink}
|
||||
@ -462,7 +462,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
</Link>
|
||||
</Flex>
|
||||
</IAIPopover>
|
||||
<IAIIconButton
|
||||
{/* <IAIIconButton
|
||||
icon={shouldHidePreview ? <FaEyeSlash /> : <FaEye />}
|
||||
tooltip={
|
||||
!shouldHidePreview
|
||||
@ -476,7 +476,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
}
|
||||
isChecked={shouldHidePreview}
|
||||
onClick={handlePreviewVisibility}
|
||||
/>
|
||||
/> */}
|
||||
{isLightboxEnabled && (
|
||||
<IAIIconButton
|
||||
icon={<FaExpand />}
|
||||
@ -518,8 +518,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
tooltip={`${t('parameters.useAll')} (A)`}
|
||||
aria-label={`${t('parameters.useAll')} (A)`}
|
||||
isDisabled={
|
||||
!['txt2img', 'img2img'].includes(
|
||||
image?.metadata?.sd_metadata?.type
|
||||
!['txt2img', 'img2img', 'inpaint'].includes(
|
||||
String(image?.metadata?.invokeai?.node?.type)
|
||||
)
|
||||
}
|
||||
onClick={handleClickUseAllParameters}
|
||||
@ -602,22 +602,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
/>
|
||||
</ButtonGroup>
|
||||
|
||||
<IAIIconButton
|
||||
onClick={handleInitiateDelete}
|
||||
icon={<FaTrash />}
|
||||
tooltip={`${t('gallery.deleteImage')} (Del)`}
|
||||
aria-label={`${t('gallery.deleteImage')} (Del)`}
|
||||
isDisabled={!image || !isConnected}
|
||||
colorScheme="error"
|
||||
/>
|
||||
<ButtonGroup isAttached={true}>
|
||||
<DeleteImageButton image={image} />
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
{image && (
|
||||
<DeleteImageModal
|
||||
isOpen={isDeleteDialogOpen}
|
||||
onClose={onDeleteDialogClose}
|
||||
handleDelete={handleDelete}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
@ -2,26 +2,35 @@ import { Box, Flex, Image } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import { selectedImageSelector } from '../store/gallerySelectors';
|
||||
import CurrentImageFallback from './CurrentImageFallback';
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
||||
import NextPrevImageButtons from './NextPrevImageButtons';
|
||||
import CurrentImageHidden from './CurrentImageHidden';
|
||||
import { memo } from 'react';
|
||||
import { DragEvent, memo, useCallback } from 'react';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import ImageFallbackSpinner from './ImageFallbackSpinner';
|
||||
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
|
||||
|
||||
export const imagesSelector = createSelector(
|
||||
[uiSelector, selectedImageSelector, systemSelector],
|
||||
(ui, selectedImage, system) => {
|
||||
const { shouldShowImageDetails, shouldHidePreview } = ui;
|
||||
|
||||
[uiSelector, gallerySelector, systemSelector],
|
||||
(ui, gallery, system) => {
|
||||
const {
|
||||
shouldShowImageDetails,
|
||||
shouldHidePreview,
|
||||
shouldShowProgressInViewer,
|
||||
} = ui;
|
||||
const { selectedImage } = gallery;
|
||||
const { progressImage, shouldAntialiasProgressImage } = system;
|
||||
return {
|
||||
shouldShowImageDetails,
|
||||
shouldHidePreview,
|
||||
image: selectedImage,
|
||||
progressImage,
|
||||
shouldShowProgressInViewer,
|
||||
shouldAntialiasProgressImage,
|
||||
};
|
||||
},
|
||||
{
|
||||
@ -32,26 +41,61 @@ export const imagesSelector = createSelector(
|
||||
);
|
||||
|
||||
const CurrentImagePreview = () => {
|
||||
const { shouldShowImageDetails, image, shouldHidePreview } =
|
||||
useAppSelector(imagesSelector);
|
||||
const {
|
||||
shouldShowImageDetails,
|
||||
image,
|
||||
shouldHidePreview,
|
||||
progressImage,
|
||||
shouldShowProgressInViewer,
|
||||
shouldAntialiasProgressImage,
|
||||
} = useAppSelector(imagesSelector);
|
||||
const { getUrl } = useGetUrl();
|
||||
|
||||
const handleDragStart = useCallback(
|
||||
(e: DragEvent<HTMLDivElement>) => {
|
||||
if (!image) {
|
||||
return;
|
||||
}
|
||||
e.dataTransfer.setData('invokeai/imageName', image.name);
|
||||
e.dataTransfer.setData('invokeai/imageType', image.type);
|
||||
e.dataTransfer.effectAllowed = 'move';
|
||||
},
|
||||
[image]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'relative',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
position: 'relative',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
}}
|
||||
>
|
||||
{image && (
|
||||
{progressImage && shouldShowProgressInViewer ? (
|
||||
<Image
|
||||
src={shouldHidePreview ? undefined : getUrl(image.url)}
|
||||
width={image.metadata.width}
|
||||
height={image.metadata.height}
|
||||
fallback={shouldHidePreview ? <CurrentImageHidden /> : undefined}
|
||||
src={progressImage.dataURL}
|
||||
width={progressImage.width}
|
||||
height={progressImage.height}
|
||||
sx={{
|
||||
objectFit: 'contain',
|
||||
maxWidth: '100%',
|
||||
maxHeight: '100%',
|
||||
height: 'auto',
|
||||
position: 'absolute',
|
||||
borderRadius: 'base',
|
||||
imageRendering: shouldAntialiasProgressImage ? 'auto' : 'pixelated',
|
||||
}}
|
||||
/>
|
||||
) : (
|
||||
image && (
|
||||
<>
|
||||
<Image
|
||||
src={getUrl(image.url)}
|
||||
fallbackStrategy="beforeLoadOrError"
|
||||
fallback={<ImageFallbackSpinner />}
|
||||
onDragStart={handleDragStart}
|
||||
sx={{
|
||||
objectFit: 'contain',
|
||||
maxWidth: '100%',
|
||||
@ -61,6 +105,9 @@ const CurrentImagePreview = () => {
|
||||
borderRadius: 'base',
|
||||
}}
|
||||
/>
|
||||
<ImageMetadataOverlay image={image} />
|
||||
</>
|
||||
)
|
||||
)}
|
||||
{shouldShowImageDetails && image && 'metadata' in image && (
|
||||
<Box
|
||||
|
@ -0,0 +1,70 @@
|
||||
import { Flex, Image, Spinner } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { memo } from 'react';
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
|
||||
const selector = createSelector(
|
||||
[systemSelector, gallerySelector],
|
||||
(system, gallery) => {
|
||||
const { shouldUseSingleGalleryColumn, galleryImageObjectFit } = gallery;
|
||||
const { progressImage, shouldAntialiasProgressImage } = system;
|
||||
|
||||
return {
|
||||
progressImage,
|
||||
shouldUseSingleGalleryColumn,
|
||||
galleryImageObjectFit,
|
||||
shouldAntialiasProgressImage,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const GalleryProgressImage = () => {
|
||||
const {
|
||||
progressImage,
|
||||
shouldUseSingleGalleryColumn,
|
||||
galleryImageObjectFit,
|
||||
shouldAntialiasProgressImage,
|
||||
} = useAppSelector(selector);
|
||||
|
||||
if (!progressImage) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
aspectRatio: '1/1',
|
||||
position: 'relative',
|
||||
}}
|
||||
>
|
||||
<Image
|
||||
draggable={false}
|
||||
src={progressImage.dataURL}
|
||||
width={progressImage.width}
|
||||
height={progressImage.height}
|
||||
sx={{
|
||||
objectFit: shouldUseSingleGalleryColumn
|
||||
? 'contain'
|
||||
: galleryImageObjectFit,
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
maxWidth: '100%',
|
||||
maxHeight: '100%',
|
||||
borderRadius: 'base',
|
||||
imageRendering: shouldAntialiasProgressImage ? 'auto' : 'pixelated',
|
||||
}}
|
||||
/>
|
||||
<Spinner sx={{ position: 'absolute', top: 1, right: 1, opacity: 0.7 }} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(GalleryProgressImage);
|
@ -5,19 +5,20 @@ import {
|
||||
Image,
|
||||
MenuItem,
|
||||
MenuList,
|
||||
Skeleton,
|
||||
useDisclosure,
|
||||
useTheme,
|
||||
useToast,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { DragEvent, memo, useCallback, useState } from 'react';
|
||||
import { DragEvent, MouseEvent, memo, useCallback, useState } from 'react';
|
||||
import { FaCheck, FaExpand, FaImage, FaShare, FaTrash } from 'react-icons/fa';
|
||||
import DeleteImageModal from './DeleteImageModal';
|
||||
import { ContextMenu } from 'chakra-ui-contextmenu';
|
||||
import * as InvokeAI from 'app/types/invokeai';
|
||||
import { resizeAndScaleCanvas } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
resizeAndScaleCanvas,
|
||||
setInitialCanvasImage,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -25,7 +26,6 @@ import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import { ExternalLinkIcon } from '@chakra-ui/icons';
|
||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||
import { imageDeleted } from 'services/thunks/image';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
|
||||
@ -33,6 +33,8 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { useParameters } from 'features/parameters/hooks/useParameters';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { requestedImageDeletion } from '../store/actions';
|
||||
|
||||
export const selector = createSelector(
|
||||
[gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector],
|
||||
@ -94,16 +96,19 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
} = useDisclosure();
|
||||
|
||||
const { image, isSelected } = props;
|
||||
const { url, thumbnail, name, metadata } = image;
|
||||
const { url, thumbnail, name } = image;
|
||||
const { getUrl } = useGetUrl();
|
||||
|
||||
const [isHovered, setIsHovered] = useState<boolean>(false);
|
||||
|
||||
const toast = useToast();
|
||||
const { direction } = useTheme();
|
||||
|
||||
const { t } = useTranslation();
|
||||
const { isFeatureEnabled: isLightboxEnabled } = useFeatureStatus('lightbox');
|
||||
const { recallSeed, recallPrompt, sendToImageToImage, recallInitialImage } =
|
||||
|
||||
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
|
||||
const { recallSeed, recallPrompt, recallInitialImage, recallAllParameters } =
|
||||
useParameters();
|
||||
|
||||
const handleMouseOver = () => setIsHovered(true);
|
||||
@ -112,18 +117,22 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
// Immediately deletes an image
|
||||
const handleDelete = useCallback(() => {
|
||||
if (canDeleteImage && image) {
|
||||
dispatch(imageDeleted({ imageType: image.type, imageName: image.name }));
|
||||
dispatch(requestedImageDeletion(image));
|
||||
}
|
||||
}, [dispatch, image, canDeleteImage]);
|
||||
|
||||
// Opens the alert dialog to check if user is sure they want to delete
|
||||
const handleInitiateDelete = useCallback(() => {
|
||||
const handleInitiateDelete = useCallback(
|
||||
(e: MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
if (shouldConfirmOnDelete) {
|
||||
onDeleteDialogOpen();
|
||||
} else {
|
||||
handleDelete();
|
||||
}
|
||||
}, [handleDelete, onDeleteDialogOpen, shouldConfirmOnDelete]);
|
||||
},
|
||||
[handleDelete, onDeleteDialogOpen, shouldConfirmOnDelete]
|
||||
);
|
||||
|
||||
const handleSelectImage = useCallback(() => {
|
||||
dispatch(imageSelected(image));
|
||||
@ -148,8 +157,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
}, [image, recallSeed]);
|
||||
|
||||
const handleSendToImageToImage = useCallback(() => {
|
||||
sendToImageToImage(image);
|
||||
}, [image, sendToImageToImage]);
|
||||
dispatch(initialImageSelected(image));
|
||||
}, [dispatch, image]);
|
||||
|
||||
const handleRecallInitialImage = useCallback(() => {
|
||||
recallInitialImage(image.metadata.invokeai?.node?.image);
|
||||
@ -159,7 +168,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
* TODO: the rest of these
|
||||
*/
|
||||
const handleSendToCanvas = () => {
|
||||
// dispatch(setInitialCanvasImage(image));
|
||||
dispatch(setInitialCanvasImage(image));
|
||||
|
||||
dispatch(resizeAndScaleCanvas());
|
||||
|
||||
@ -175,16 +184,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
});
|
||||
};
|
||||
|
||||
const handleUseAllParameters = () => {
|
||||
// metadata.invokeai?.node &&
|
||||
// dispatch(setAllParameters(metadata.invokeai?.node));
|
||||
// toast({
|
||||
// title: t('toast.parametersSet'),
|
||||
// status: 'success',
|
||||
// duration: 2500,
|
||||
// isClosable: true,
|
||||
// });
|
||||
};
|
||||
const handleUseAllParameters = useCallback(() => {
|
||||
recallAllParameters(image);
|
||||
}, [image, recallAllParameters]);
|
||||
|
||||
const handleLightBox = () => {
|
||||
// dispatch(setCurrentImage(image));
|
||||
@ -238,7 +240,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
onClickCapture={handleUseAllParameters}
|
||||
isDisabled={
|
||||
!['txt2img', 'img2img'].includes(
|
||||
!['txt2img', 'img2img', 'inpaint'].includes(
|
||||
String(image?.metadata?.invokeai?.node?.type)
|
||||
)
|
||||
}
|
||||
@ -251,9 +253,11 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
>
|
||||
{t('parameters.sendToImg2Img')}
|
||||
</MenuItem>
|
||||
{isCanvasEnabled && (
|
||||
<MenuItem icon={<FaShare />} onClickCapture={handleSendToCanvas}>
|
||||
{t('parameters.sendToUnifiedCanvas')}
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem icon={<FaTrash />} onClickCapture={onDeleteDialogOpen}>
|
||||
{t('gallery.deleteImage')}
|
||||
</MenuItem>
|
||||
@ -279,6 +283,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
h: 'full',
|
||||
transition: 'transform 0.2s ease-out',
|
||||
aspectRatio: '1/1',
|
||||
cursor: 'pointer',
|
||||
}}
|
||||
>
|
||||
<Image
|
||||
@ -315,6 +320,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
sx={{
|
||||
width: '50%',
|
||||
height: '50%',
|
||||
maxWidth: '4rem',
|
||||
maxHeight: '4rem',
|
||||
fill: 'ok.500',
|
||||
}}
|
||||
/>
|
||||
|
@ -0,0 +1,92 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
|
||||
import { useDisclosure } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaTrash } from 'react-icons/fa';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import DeleteImageModal from '../DeleteImageModal';
|
||||
import { requestedImageDeletion } from 'features/gallery/store/actions';
|
||||
import { Image } from 'app/types/invokeai';
|
||||
|
||||
const selector = createSelector(
|
||||
[systemSelector],
|
||||
(system) => {
|
||||
const { isProcessing, isConnected, shouldConfirmOnDelete } = system;
|
||||
|
||||
return {
|
||||
canDeleteImage: isConnected && !isProcessing,
|
||||
shouldConfirmOnDelete,
|
||||
isProcessing,
|
||||
isConnected,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
type DeleteImageButtonProps = {
|
||||
image: Image | undefined;
|
||||
};
|
||||
|
||||
const DeleteImageButton = (props: DeleteImageButtonProps) => {
|
||||
const { image } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { isProcessing, isConnected, canDeleteImage, shouldConfirmOnDelete } =
|
||||
useAppSelector(selector);
|
||||
|
||||
const {
|
||||
isOpen: isDeleteDialogOpen,
|
||||
onOpen: onDeleteDialogOpen,
|
||||
onClose: onDeleteDialogClose,
|
||||
} = useDisclosure();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleDelete = useCallback(() => {
|
||||
if (canDeleteImage && image) {
|
||||
dispatch(requestedImageDeletion(image));
|
||||
}
|
||||
}, [image, canDeleteImage, dispatch]);
|
||||
|
||||
const handleInitiateDelete = useCallback(() => {
|
||||
if (shouldConfirmOnDelete) {
|
||||
onDeleteDialogOpen();
|
||||
} else {
|
||||
handleDelete();
|
||||
}
|
||||
}, [shouldConfirmOnDelete, onDeleteDialogOpen, handleDelete]);
|
||||
|
||||
useHotkeys('delete', handleInitiateDelete, [
|
||||
image,
|
||||
shouldConfirmOnDelete,
|
||||
isConnected,
|
||||
isProcessing,
|
||||
]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<IAIIconButton
|
||||
onClick={handleInitiateDelete}
|
||||
icon={<FaTrash />}
|
||||
tooltip={`${t('gallery.deleteImage')} (Del)`}
|
||||
aria-label={`${t('gallery.deleteImage')} (Del)`}
|
||||
isDisabled={!image || !isConnected}
|
||||
colorScheme="error"
|
||||
/>
|
||||
{image && (
|
||||
<DeleteImageModal
|
||||
isOpen={isDeleteDialogOpen}
|
||||
onClose={onDeleteDialogClose}
|
||||
handleDelete={handleDelete}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(DeleteImageButton);
|
@ -1,8 +1,8 @@
|
||||
import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react';
|
||||
|
||||
type CurrentImageFallbackProps = SpinnerProps;
|
||||
type ImageFallbackSpinnerProps = SpinnerProps;
|
||||
|
||||
const CurrentImageFallback = (props: CurrentImageFallbackProps) => {
|
||||
const ImageFallbackSpinner = (props: ImageFallbackSpinnerProps) => {
|
||||
const { size = 'xl', ...rest } = props;
|
||||
|
||||
return (
|
||||
@ -21,4 +21,4 @@ const CurrentImageFallback = (props: CurrentImageFallbackProps) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default CurrentImageFallback;
|
||||
export default ImageFallbackSpinner;
|
@ -5,6 +5,7 @@ import {
|
||||
FlexProps,
|
||||
Grid,
|
||||
Icon,
|
||||
Image,
|
||||
Text,
|
||||
forwardRef,
|
||||
} from '@chakra-ui/react';
|
||||
@ -14,7 +15,10 @@ import IAICheckbox from 'common/components/IAICheckbox';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAIPopover from 'common/components/IAIPopover';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { imageGallerySelector } from 'features/gallery/store/gallerySelectors';
|
||||
import {
|
||||
gallerySelector,
|
||||
imageGallerySelector,
|
||||
} from 'features/gallery/store/gallerySelectors';
|
||||
import {
|
||||
setCurrentCategory,
|
||||
setGalleryImageMinimumWidth,
|
||||
@ -50,30 +54,42 @@ import { uploadsAdapter } from '../store/uploadsSlice';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso';
|
||||
import { Image as ImageType } from 'app/types/invokeai';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import GalleryProgressImage from './GalleryProgressImage';
|
||||
|
||||
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
|
||||
const PROGRESS_IMAGE_PLACEHOLDER = 'PROGRESS_IMAGE_PLACEHOLDER';
|
||||
|
||||
const gallerySelector = createSelector(
|
||||
[
|
||||
(state: RootState) => state.uploads,
|
||||
(state: RootState) => state.results,
|
||||
(state: RootState) => state.gallery,
|
||||
],
|
||||
(uploads, results, gallery) => {
|
||||
const selector = createSelector(
|
||||
[(state: RootState) => state],
|
||||
(state) => {
|
||||
const { results, uploads, system, gallery } = state;
|
||||
const { currentCategory } = gallery;
|
||||
|
||||
return currentCategory === 'results'
|
||||
? {
|
||||
images: resultsAdapter.getSelectors().selectAll(results),
|
||||
if (currentCategory === 'results') {
|
||||
const tempImages: (ImageType | typeof PROGRESS_IMAGE_PLACEHOLDER)[] = [];
|
||||
|
||||
if (system.progressImage) {
|
||||
tempImages.push(PROGRESS_IMAGE_PLACEHOLDER);
|
||||
}
|
||||
|
||||
return {
|
||||
images: tempImages.concat(
|
||||
resultsAdapter.getSelectors().selectAll(results)
|
||||
),
|
||||
isLoading: results.isLoading,
|
||||
areMoreImagesAvailable: results.page < results.pages - 1,
|
||||
};
|
||||
}
|
||||
: {
|
||||
|
||||
return {
|
||||
images: uploadsAdapter.getSelectors().selectAll(uploads),
|
||||
isLoading: uploads.isLoading,
|
||||
areMoreImagesAvailable: uploads.page < uploads.pages - 1,
|
||||
};
|
||||
}
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ImageGalleryContent = () => {
|
||||
@ -108,7 +124,7 @@ const ImageGalleryContent = () => {
|
||||
} = useAppSelector(imageGallerySelector);
|
||||
|
||||
const { images, areMoreImagesAvailable, isLoading } =
|
||||
useAppSelector(gallerySelector);
|
||||
useAppSelector(selector);
|
||||
|
||||
const handleClickLoadMore = () => {
|
||||
if (currentCategory === 'results') {
|
||||
@ -170,8 +186,24 @@ const ImageGalleryContent = () => {
|
||||
}
|
||||
}, []);
|
||||
|
||||
const handleEndReached = useCallback(() => {
|
||||
if (currentCategory === 'results') {
|
||||
dispatch(receivedResultImagesPage());
|
||||
} else if (currentCategory === 'uploads') {
|
||||
dispatch(receivedUploadImagesPage());
|
||||
}
|
||||
}, [dispatch, currentCategory]);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" w="full" h="full" gap={4}>
|
||||
<Flex
|
||||
sx={{
|
||||
gap: 2,
|
||||
flexDirection: 'column',
|
||||
h: 'full',
|
||||
w: 'full',
|
||||
borderRadius: 'base',
|
||||
}}
|
||||
>
|
||||
<Flex
|
||||
ref={resizeObserverRef}
|
||||
alignItems="center"
|
||||
@ -290,18 +322,27 @@ const ImageGalleryContent = () => {
|
||||
<Virtuoso
|
||||
style={{ height: '100%' }}
|
||||
data={images}
|
||||
endReached={handleEndReached}
|
||||
scrollerRef={(ref) => setScrollerRef(ref)}
|
||||
itemContent={(index, image) => {
|
||||
const { name } = image;
|
||||
const isSelected = selectedImage?.name === name;
|
||||
const isSelected =
|
||||
image === PROGRESS_IMAGE_PLACEHOLDER
|
||||
? false
|
||||
: selectedImage?.name === image?.name;
|
||||
|
||||
return (
|
||||
<Flex sx={{ pb: 2 }}>
|
||||
{image === PROGRESS_IMAGE_PLACEHOLDER ? (
|
||||
<GalleryProgressImage
|
||||
key={PROGRESS_IMAGE_PLACEHOLDER}
|
||||
/>
|
||||
) : (
|
||||
<HoverableImage
|
||||
key={`${name}-${image.thumbnail}`}
|
||||
key={`${image.name}-${image.thumbnail}`}
|
||||
image={image}
|
||||
isSelected={isSelected}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}}
|
||||
@ -310,18 +351,23 @@ const ImageGalleryContent = () => {
|
||||
<VirtuosoGrid
|
||||
style={{ height: '100%' }}
|
||||
data={images}
|
||||
endReached={handleEndReached}
|
||||
components={{
|
||||
Item: ItemContainer,
|
||||
List: ListContainer,
|
||||
}}
|
||||
scrollerRef={setScroller}
|
||||
itemContent={(index, image) => {
|
||||
const { name } = image;
|
||||
const isSelected = selectedImage?.name === name;
|
||||
const isSelected =
|
||||
image === PROGRESS_IMAGE_PLACEHOLDER
|
||||
? false
|
||||
: selectedImage?.name === image?.name;
|
||||
|
||||
return (
|
||||
return image === PROGRESS_IMAGE_PLACEHOLDER ? (
|
||||
<GalleryProgressImage key={PROGRESS_IMAGE_PLACEHOLDER} />
|
||||
) : (
|
||||
<HoverableImage
|
||||
key={`${name}-${image.thumbnail}`}
|
||||
key={`${image.name}-${image.thumbnail}`}
|
||||
image={image}
|
||||
isSelected={isSelected}
|
||||
/>
|
||||
@ -334,6 +380,7 @@ const ImageGalleryContent = () => {
|
||||
onClick={handleClickLoadMore}
|
||||
isDisabled={!areMoreImagesAvailable}
|
||||
isLoading={isLoading}
|
||||
loadingText="Loading"
|
||||
flexShrink={0}
|
||||
>
|
||||
{areMoreImagesAvailable
|
||||
|
@ -5,7 +5,6 @@ import {
|
||||
// selectPrevImage,
|
||||
setGalleryImageMinimumWidth,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
|
||||
import { clamp, isEqual } from 'lodash-es';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
@ -13,11 +12,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import './ImageGallery.css';
|
||||
import ImageGalleryContent from './ImageGalleryContent';
|
||||
import ResizableDrawer from 'features/ui/components/common/ResizableDrawer/ResizableDrawer';
|
||||
import {
|
||||
setShouldShowGallery,
|
||||
toggleGalleryPanel,
|
||||
togglePinGalleryPanel,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { setShouldShowGallery } from 'features/ui/store/uiSlice';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import {
|
||||
activeTabNameSelector,
|
||||
@ -26,22 +21,20 @@ import {
|
||||
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
|
||||
import useResolution from 'common/hooks/useResolution';
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
|
||||
const GALLERY_TAB_WIDTHS: Record<
|
||||
InvokeTabName,
|
||||
{ galleryMinWidth: number; galleryMaxWidth: number }
|
||||
> = {
|
||||
// const GALLERY_TAB_WIDTHS: Record<
|
||||
// InvokeTabName,
|
||||
// { galleryMinWidth: number; galleryMaxWidth: number }
|
||||
// > = {
|
||||
// txt2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
// img2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
generate: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
unifiedCanvas: { galleryMinWidth: 200, galleryMaxWidth: 200 },
|
||||
nodes: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
// generate: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
// unifiedCanvas: { galleryMinWidth: 200, galleryMaxWidth: 200 },
|
||||
// nodes: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
// postprocessing: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
// training: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
};
|
||||
// };
|
||||
|
||||
const galleryPanelSelector = createSelector(
|
||||
[
|
||||
@ -73,50 +66,50 @@ const galleryPanelSelector = createSelector(
|
||||
}
|
||||
);
|
||||
|
||||
export const ImageGalleryPanel = () => {
|
||||
const GalleryDrawer = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const {
|
||||
shouldPinGallery,
|
||||
shouldShowGallery,
|
||||
galleryImageMinimumWidth,
|
||||
activeTabName,
|
||||
isStaging,
|
||||
isResizable,
|
||||
isLightboxOpen,
|
||||
// activeTabName,
|
||||
// isStaging,
|
||||
// isResizable,
|
||||
// isLightboxOpen,
|
||||
} = useAppSelector(galleryPanelSelector);
|
||||
|
||||
const handleSetShouldPinGallery = () => {
|
||||
dispatch(togglePinGalleryPanel());
|
||||
dispatch(requestCanvasRescale());
|
||||
};
|
||||
// const handleSetShouldPinGallery = () => {
|
||||
// dispatch(togglePinGalleryPanel());
|
||||
// dispatch(requestCanvasRescale());
|
||||
// };
|
||||
|
||||
const handleToggleGallery = () => {
|
||||
dispatch(toggleGalleryPanel());
|
||||
shouldPinGallery && dispatch(requestCanvasRescale());
|
||||
};
|
||||
// const handleToggleGallery = () => {
|
||||
// dispatch(toggleGalleryPanel());
|
||||
// shouldPinGallery && dispatch(requestCanvasRescale());
|
||||
// };
|
||||
|
||||
const handleCloseGallery = () => {
|
||||
dispatch(setShouldShowGallery(false));
|
||||
shouldPinGallery && dispatch(requestCanvasRescale());
|
||||
};
|
||||
|
||||
const resolution = useResolution();
|
||||
// const resolution = useResolution();
|
||||
|
||||
useHotkeys(
|
||||
'g',
|
||||
() => {
|
||||
handleToggleGallery();
|
||||
},
|
||||
[shouldPinGallery]
|
||||
);
|
||||
// useHotkeys(
|
||||
// 'g',
|
||||
// () => {
|
||||
// handleToggleGallery();
|
||||
// },
|
||||
// [shouldPinGallery]
|
||||
// );
|
||||
|
||||
useHotkeys(
|
||||
'shift+g',
|
||||
() => {
|
||||
handleSetShouldPinGallery();
|
||||
},
|
||||
[shouldPinGallery]
|
||||
);
|
||||
// useHotkeys(
|
||||
// 'shift+g',
|
||||
// () => {
|
||||
// handleSetShouldPinGallery();
|
||||
// },
|
||||
// [shouldPinGallery]
|
||||
// );
|
||||
|
||||
useHotkeys(
|
||||
'esc',
|
||||
@ -162,55 +155,71 @@ export const ImageGalleryPanel = () => {
|
||||
[galleryImageMinimumWidth]
|
||||
);
|
||||
|
||||
const calcGalleryMinHeight = () => {
|
||||
if (resolution === 'desktop') return;
|
||||
return 300;
|
||||
};
|
||||
// const calcGalleryMinHeight = () => {
|
||||
// if (resolution === 'desktop') return;
|
||||
// return 300;
|
||||
// };
|
||||
|
||||
const imageGalleryContent = () => {
|
||||
return (
|
||||
<Flex
|
||||
w="100vw"
|
||||
h={{ base: 300, xl: '100vh' }}
|
||||
paddingRight={{ base: 8, xl: 0 }}
|
||||
paddingBottom={{ base: 4, xl: 0 }}
|
||||
>
|
||||
<ImageGalleryContent />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
// const imageGalleryContent = () => {
|
||||
// return (
|
||||
// <Flex
|
||||
// w="100vw"
|
||||
// h={{ base: 300, xl: '100vh' }}
|
||||
// paddingRight={{ base: 8, xl: 0 }}
|
||||
// paddingBottom={{ base: 4, xl: 0 }}
|
||||
// >
|
||||
// <ImageGalleryContent />
|
||||
// </Flex>
|
||||
// );
|
||||
// };
|
||||
|
||||
// const resizableImageGalleryContent = () => {
|
||||
// return (
|
||||
// <ResizableDrawer
|
||||
// direction="right"
|
||||
// isResizable={isResizable || !shouldPinGallery}
|
||||
// isOpen={shouldShowGallery}
|
||||
// onClose={handleCloseGallery}
|
||||
// isPinned={shouldPinGallery && !isLightboxOpen}
|
||||
// minWidth={
|
||||
// shouldPinGallery
|
||||
// ? GALLERY_TAB_WIDTHS[activeTabName].galleryMinWidth
|
||||
// : 200
|
||||
// }
|
||||
// maxWidth={
|
||||
// shouldPinGallery
|
||||
// ? GALLERY_TAB_WIDTHS[activeTabName].galleryMaxWidth
|
||||
// : undefined
|
||||
// }
|
||||
// minHeight={calcGalleryMinHeight()}
|
||||
// >
|
||||
// <ImageGalleryContent />
|
||||
// </ResizableDrawer>
|
||||
// );
|
||||
// };
|
||||
|
||||
// const renderImageGallery = () => {
|
||||
// if (['mobile', 'tablet'].includes(resolution)) return imageGalleryContent();
|
||||
// return resizableImageGalleryContent();
|
||||
// };
|
||||
|
||||
if (shouldPinGallery) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const resizableImageGalleryContent = () => {
|
||||
return (
|
||||
<ResizableDrawer
|
||||
direction="right"
|
||||
isResizable={isResizable || !shouldPinGallery}
|
||||
isResizable={true}
|
||||
isOpen={shouldShowGallery}
|
||||
onClose={handleCloseGallery}
|
||||
isPinned={shouldPinGallery && !isLightboxOpen}
|
||||
minWidth={
|
||||
shouldPinGallery
|
||||
? GALLERY_TAB_WIDTHS[activeTabName].galleryMinWidth
|
||||
: 200
|
||||
}
|
||||
maxWidth={
|
||||
shouldPinGallery
|
||||
? GALLERY_TAB_WIDTHS[activeTabName].galleryMaxWidth
|
||||
: undefined
|
||||
}
|
||||
minHeight={calcGalleryMinHeight()}
|
||||
minWidth={200}
|
||||
>
|
||||
<ImageGalleryContent />
|
||||
</ResizableDrawer>
|
||||
);
|
||||
|
||||
// return renderImageGallery();
|
||||
};
|
||||
|
||||
const renderImageGallery = () => {
|
||||
if (['mobile', 'tablet'].includes(resolution)) return imageGalleryContent();
|
||||
return resizableImageGalleryContent();
|
||||
};
|
||||
|
||||
return renderImageGallery();
|
||||
};
|
||||
|
||||
export default memo(ImageGalleryPanel);
|
||||
export default memo(GalleryDrawer);
|
||||
|
@ -3,7 +3,6 @@ import {
|
||||
Box,
|
||||
Center,
|
||||
Flex,
|
||||
Heading,
|
||||
IconButton,
|
||||
Link,
|
||||
Text,
|
||||
@ -19,8 +18,6 @@ import {
|
||||
setCfgScale,
|
||||
setHeight,
|
||||
setImg2imgStrength,
|
||||
// setInitialImage,
|
||||
setMaskPath,
|
||||
setPerlin,
|
||||
setSampler,
|
||||
setSeamless,
|
||||
@ -31,21 +28,14 @@ import {
|
||||
setThreshold,
|
||||
setWidth,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
setCodeformerFidelity,
|
||||
setFacetoolStrength,
|
||||
setFacetoolType,
|
||||
setHiresFix,
|
||||
setUpscalingDenoising,
|
||||
setUpscalingLevel,
|
||||
setUpscalingStrength,
|
||||
} from 'features/parameters/store/postprocessingSlice';
|
||||
import { setHiresFix } from 'features/parameters/store/postprocessingSlice';
|
||||
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
|
||||
import { memo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaCopy } from 'react-icons/fa';
|
||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
|
||||
type MetadataItemProps = {
|
||||
isLink?: boolean;
|
||||
@ -300,7 +290,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
</Text>
|
||||
</Center>
|
||||
)}
|
||||
<Flex gap={2} direction="column">
|
||||
<Flex gap={2} direction="column" overflow="auto">
|
||||
<Flex gap={2}>
|
||||
<Tooltip label="Copy metadata JSON">
|
||||
<IconButton
|
||||
@ -314,22 +304,19 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
</Tooltip>
|
||||
<Text fontWeight="semibold">Metadata JSON:</Text>
|
||||
</Flex>
|
||||
<OverlayScrollbarsComponent defer>
|
||||
<Box
|
||||
sx={{
|
||||
mt: 0,
|
||||
mr: 2,
|
||||
mb: 4,
|
||||
ml: 2,
|
||||
padding: 4,
|
||||
borderRadius: 'base',
|
||||
overflowX: 'scroll',
|
||||
wordBreak: 'break-all',
|
||||
bg: 'whiteAlpha.500',
|
||||
_dark: { bg: 'blackAlpha.500' },
|
||||
w: 'max-content',
|
||||
}}
|
||||
>
|
||||
<pre>{metadataJSON}</pre>
|
||||
</Box>
|
||||
</OverlayScrollbarsComponent>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
|
@ -0,0 +1,7 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { Image } from 'app/types/invokeai';
|
||||
import { SelectedImage } from 'features/parameters/store/actions';
|
||||
|
||||
export const requestedImageDeletion = createAction<
|
||||
Image | SelectedImage | undefined
|
||||
>('gallery/requestedImageDeletion');
|
@ -4,12 +4,13 @@ import { GalleryState } from './gallerySlice';
|
||||
* Gallery slice persist denylist
|
||||
*/
|
||||
const itemsToDenylist: (keyof GalleryState)[] = [
|
||||
'categories',
|
||||
'currentCategory',
|
||||
'currentImage',
|
||||
'currentImageUuid',
|
||||
'shouldAutoSwitchToNewImages',
|
||||
'intermediateImage',
|
||||
];
|
||||
|
||||
export const galleryPersistDenylist: (keyof GalleryState)[] = [
|
||||
'currentCategory',
|
||||
'shouldAutoSwitchToNewImages',
|
||||
];
|
||||
|
||||
export const galleryDenylist = itemsToDenylist.map(
|
||||
|
@ -1,23 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
|
||||
import { configSelector } from 'features/system/store/configSelectors';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
|
||||
import {
|
||||
activeTabNameSelector,
|
||||
uiSelector,
|
||||
} from 'features/ui/store/uiSelectors';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import {
|
||||
selectResultsAll,
|
||||
selectResultsById,
|
||||
selectResultsEntities,
|
||||
} from './resultsSlice';
|
||||
import {
|
||||
selectUploadsAll,
|
||||
selectUploadsById,
|
||||
selectUploadsEntities,
|
||||
} from './uploadsSlice';
|
||||
import { selectResultsById, selectResultsEntities } from './resultsSlice';
|
||||
import { selectUploadsAll, selectUploadsById } from './uploadsSlice';
|
||||
|
||||
export const gallerySelector = (state: RootState) => state.gallery;
|
||||
|
||||
@ -44,6 +35,11 @@ export const imageGallerySelector = createSelector(
|
||||
|
||||
const { isLightboxOpen } = lightbox;
|
||||
|
||||
const images =
|
||||
currentCategory === 'results'
|
||||
? selectResultsEntities(state)
|
||||
: selectUploadsAll(state);
|
||||
|
||||
return {
|
||||
shouldPinGallery,
|
||||
galleryImageMinimumWidth,
|
||||
@ -53,7 +49,7 @@ export const imageGallerySelector = createSelector(
|
||||
: `repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, auto))`,
|
||||
shouldAutoSwitchToNewImages,
|
||||
currentCategory,
|
||||
images: state[currentCategory].entities,
|
||||
images,
|
||||
galleryWidth,
|
||||
shouldEnableResize:
|
||||
isLightboxOpen ||
|
||||
|
@ -1,10 +1,11 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { invocationComplete } from 'services/events/actions';
|
||||
import { isImageOutput } from 'services/types/guards';
|
||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { SelectedImage } from 'features/parameters/store/generationSlice';
|
||||
import { Image } from 'app/types/invokeai';
|
||||
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
receivedUploadImagesPage,
|
||||
} from '../../../services/thunks/gallery';
|
||||
|
||||
type GalleryImageObjectFitType = 'contain' | 'cover';
|
||||
|
||||
@ -12,7 +13,7 @@ export interface GalleryState {
|
||||
/**
|
||||
* The selected image
|
||||
*/
|
||||
selectedImage?: SelectedImage;
|
||||
selectedImage?: Image;
|
||||
galleryImageMinimumWidth: number;
|
||||
galleryImageObjectFit: GalleryImageObjectFitType;
|
||||
shouldAutoSwitchToNewImages: boolean;
|
||||
@ -21,8 +22,7 @@ export interface GalleryState {
|
||||
currentCategory: 'results' | 'uploads';
|
||||
}
|
||||
|
||||
const initialState: GalleryState = {
|
||||
selectedImage: undefined,
|
||||
export const initialGalleryState: GalleryState = {
|
||||
galleryImageMinimumWidth: 64,
|
||||
galleryImageObjectFit: 'cover',
|
||||
shouldAutoSwitchToNewImages: true,
|
||||
@ -33,12 +33,9 @@ const initialState: GalleryState = {
|
||||
|
||||
export const gallerySlice = createSlice({
|
||||
name: 'gallery',
|
||||
initialState,
|
||||
initialState: initialGalleryState,
|
||||
reducers: {
|
||||
imageSelected: (
|
||||
state,
|
||||
action: PayloadAction<SelectedImage | undefined>
|
||||
) => {
|
||||
imageSelected: (state, action: PayloadAction<Image | undefined>) => {
|
||||
state.selectedImage = action.payload;
|
||||
// TODO: if the user selects an image, disable the auto switch?
|
||||
// state.shouldAutoSwitchToNewImages = false;
|
||||
@ -72,27 +69,50 @@ export const gallerySlice = createSlice({
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
/**
|
||||
* Invocation Complete
|
||||
*/
|
||||
builder.addCase(invocationComplete, (state, action) => {
|
||||
const { data } = action.payload;
|
||||
if (isImageOutput(data.result) && state.shouldAutoSwitchToNewImages) {
|
||||
state.selectedImage = {
|
||||
name: data.result.image.image_name,
|
||||
type: 'results',
|
||||
};
|
||||
builder.addCase(imageReceived.fulfilled, (state, action) => {
|
||||
// When we get an updated URL for an image, we need to update the selectedImage in gallery,
|
||||
// which is currently its own object (instead of a reference to an image in results/uploads)
|
||||
const { imagePath } = action.payload;
|
||||
const { imageName } = action.meta.arg;
|
||||
|
||||
if (state.selectedImage?.name === imageName) {
|
||||
state.selectedImage.url = imagePath;
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Upload Image - FULFILLED
|
||||
*/
|
||||
builder.addCase(imageUploaded.fulfilled, (state, action) => {
|
||||
const { response } = action.payload;
|
||||
builder.addCase(thumbnailReceived.fulfilled, (state, action) => {
|
||||
// When we get an updated URL for an image, we need to update the selectedImage in gallery,
|
||||
// which is currently its own object (instead of a reference to an image in results/uploads)
|
||||
const { thumbnailPath } = action.payload;
|
||||
const { thumbnailName } = action.meta.arg;
|
||||
|
||||
const uploadedImage = deserializeImageResponse(response);
|
||||
state.selectedImage = { name: uploadedImage.name, type: 'uploads' };
|
||||
if (state.selectedImage?.name === thumbnailName) {
|
||||
state.selectedImage.thumbnail = thumbnailPath;
|
||||
}
|
||||
});
|
||||
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
|
||||
// rehydrate selectedImage URL when results list comes in
|
||||
// solves case when outdated URL is in local storage
|
||||
if (state.selectedImage) {
|
||||
const selectedImageInResults = action.payload.items.find(
|
||||
(image) => image.image_name === state.selectedImage!.name
|
||||
);
|
||||
if (selectedImageInResults) {
|
||||
state.selectedImage.url = selectedImageInResults.image_url;
|
||||
}
|
||||
}
|
||||
});
|
||||
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
|
||||
// rehydrate selectedImage URL when results list comes in
|
||||
// solves case when outdated URL is in local storage
|
||||
if (state.selectedImage) {
|
||||
const selectedImageInResults = action.payload.items.find(
|
||||
(image) => image.image_name === state.selectedImage!.name
|
||||
);
|
||||
if (selectedImageInResults) {
|
||||
state.selectedImage.url = selectedImageInResults.image_url;
|
||||
}
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
@ -5,7 +5,9 @@ import { ResultsState } from './resultsSlice';
|
||||
*
|
||||
* Currently denylisting results slice entirely, see persist config in store.ts
|
||||
*/
|
||||
const itemsToDenylist: (keyof ResultsState)[] = ['isLoading'];
|
||||
const itemsToDenylist: (keyof ResultsState)[] = [];
|
||||
|
||||
export const resultsPersistDenylist: (keyof ResultsState)[] = [];
|
||||
|
||||
export const resultsDenylist = itemsToDenylist.map(
|
||||
(denylistItem) => `results.${denylistItem}`
|
||||
|
@ -1,17 +1,11 @@
|
||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||
import { Image } from 'app/types/invokeai';
|
||||
import { invocationComplete } from 'services/events/actions';
|
||||
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
IMAGES_PER_PAGE,
|
||||
} from 'services/thunks/gallery';
|
||||
import { isImageOutput } from 'services/types/guards';
|
||||
import {
|
||||
buildImageUrls,
|
||||
extractTimestampFromImageName,
|
||||
} from 'services/util/deserializeImageField';
|
||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||
import {
|
||||
imageDeleted,
|
||||
@ -73,44 +67,6 @@ const resultsSlice = createSlice({
|
||||
state.isLoading = false;
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation Complete
|
||||
*/
|
||||
builder.addCase(invocationComplete, (state, action) => {
|
||||
const { data, shouldFetchImages } = action.payload;
|
||||
const { result, node, graph_execution_state_id } = data;
|
||||
|
||||
if (isImageOutput(result)) {
|
||||
const name = result.image.image_name;
|
||||
const type = result.image.image_type;
|
||||
|
||||
// if we need to refetch, set URLs to placeholder for now
|
||||
const { url, thumbnail } = shouldFetchImages
|
||||
? { url: '', thumbnail: '' }
|
||||
: buildImageUrls(type, name);
|
||||
|
||||
const timestamp = extractTimestampFromImageName(name);
|
||||
|
||||
const image: Image = {
|
||||
name,
|
||||
type,
|
||||
url,
|
||||
thumbnail,
|
||||
metadata: {
|
||||
created: timestamp,
|
||||
width: result.width,
|
||||
height: result.height,
|
||||
invokeai: {
|
||||
session_id: graph_execution_state_id,
|
||||
...(node ? { node } : {}),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
resultsAdapter.setOne(state, image);
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Image Received - FULFILLED
|
||||
*/
|
||||
@ -142,9 +98,10 @@ const resultsSlice = createSlice({
|
||||
});
|
||||
|
||||
/**
|
||||
* Delete Image - FULFILLED
|
||||
* Delete Image - PENDING
|
||||
* Pre-emptively remove the image from the gallery
|
||||
*/
|
||||
builder.addCase(imageDeleted.fulfilled, (state, action) => {
|
||||
builder.addCase(imageDeleted.pending, (state, action) => {
|
||||
const { imageType, imageName } = action.meta.arg;
|
||||
|
||||
if (imageType === 'results') {
|
||||
|
@ -5,7 +5,8 @@ import { UploadsState } from './uploadsSlice';
|
||||
*
|
||||
* Currently denylisting uploads slice entirely, see persist config in store.ts
|
||||
*/
|
||||
const itemsToDenylist: (keyof UploadsState)[] = ['isLoading'];
|
||||
const itemsToDenylist: (keyof UploadsState)[] = [];
|
||||
export const uploadsPersistDenylist: (keyof UploadsState)[] = [];
|
||||
|
||||
export const uploadsDenylist = itemsToDenylist.map(
|
||||
(denylistItem) => `uploads.${denylistItem}`
|
||||
|
@ -6,7 +6,7 @@ import {
|
||||
receivedUploadImagesPage,
|
||||
IMAGES_PER_PAGE,
|
||||
} from 'services/thunks/gallery';
|
||||
import { imageDeleted, imageUploaded } from 'services/thunks/image';
|
||||
import { imageDeleted } from 'services/thunks/image';
|
||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||
|
||||
export const uploadsAdapter = createEntityAdapter<Image>({
|
||||
@ -21,7 +21,7 @@ type AdditionalUploadsState = {
|
||||
nextPage: number;
|
||||
};
|
||||
|
||||
const initialUploadsState =
|
||||
export const initialUploadsState =
|
||||
uploadsAdapter.getInitialState<AdditionalUploadsState>({
|
||||
page: 0,
|
||||
pages: 0,
|
||||
@ -35,7 +35,7 @@ const uploadsSlice = createSlice({
|
||||
name: 'uploads',
|
||||
initialState: initialUploadsState,
|
||||
reducers: {
|
||||
uploadAdded: uploadsAdapter.addOne,
|
||||
uploadAdded: uploadsAdapter.upsertOne,
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
/**
|
||||
@ -62,20 +62,10 @@ const uploadsSlice = createSlice({
|
||||
});
|
||||
|
||||
/**
|
||||
* Upload Image - FULFILLED
|
||||
* Delete Image - pending
|
||||
* Pre-emptively remove the image from the gallery
|
||||
*/
|
||||
builder.addCase(imageUploaded.fulfilled, (state, action) => {
|
||||
const { location, response } = action.payload;
|
||||
|
||||
const uploadedImage = deserializeImageResponse(response);
|
||||
|
||||
uploadsAdapter.setOne(state, uploadedImage);
|
||||
});
|
||||
|
||||
/**
|
||||
* Delete Image - FULFILLED
|
||||
*/
|
||||
builder.addCase(imageDeleted.fulfilled, (state, action) => {
|
||||
builder.addCase(imageDeleted.pending, (state, action) => {
|
||||
const { imageType, imageName } = action.meta.arg;
|
||||
|
||||
if (imageType === 'uploads') {
|
||||
|
@ -4,7 +4,7 @@ import * as InvokeAI from 'app/types/invokeai';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
|
||||
type ReactPanZoomProps = {
|
||||
image: InvokeAI._Image;
|
||||
image: InvokeAI.Image;
|
||||
styleClass?: string;
|
||||
alt?: string;
|
||||
ref?: React.Ref<HTMLImageElement>;
|
||||
|
@ -4,6 +4,9 @@ import { LightboxState } from './lightboxSlice';
|
||||
* Lightbox slice persist denylist
|
||||
*/
|
||||
const itemsToDenylist: (keyof LightboxState)[] = ['isLightboxOpen'];
|
||||
export const lightboxPersistDenylist: (keyof LightboxState)[] = [
|
||||
'isLightboxOpen',
|
||||
];
|
||||
|
||||
export const lightboxDenylist = itemsToDenylist.map(
|
||||
(denylistItem) => `lightbox.${denylistItem}`
|
||||
|
@ -5,7 +5,7 @@ export interface LightboxState {
|
||||
isLightboxOpen: boolean;
|
||||
}
|
||||
|
||||
const initialLightboxState: LightboxState = {
|
||||
export const initialLightboxState: LightboxState = {
|
||||
isLightboxOpen: false,
|
||||
};
|
||||
|
||||
|
@ -1,5 +1,3 @@
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import 'reactflow/dist/style.css';
|
||||
import { memo, useCallback } from 'react';
|
||||
import {
|
||||
@ -8,12 +6,11 @@ import {
|
||||
MenuButton,
|
||||
MenuList,
|
||||
MenuItem,
|
||||
IconButton,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaEllipsisV, FaPlus } from 'react-icons/fa';
|
||||
import { FaEllipsisV } from 'react-icons/fa';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { nodeAdded } from '../store/nodesSlice';
|
||||
import { cloneDeep, map } from 'lodash-es';
|
||||
import { map } from 'lodash-es';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useBuildInvocation } from '../hooks/useBuildInvocation';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user