mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
merge with main
This commit is contained in:
commit
1103ab2844
13
.github/workflows/mkdocs-material.yml
vendored
13
.github/workflows/mkdocs-material.yml
vendored
@ -2,8 +2,7 @@ name: mkdocs-material
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- 'main'
|
- 'refs/heads/v2.3'
|
||||||
- 'development'
|
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
@ -12,6 +11,10 @@ jobs:
|
|||||||
mkdocs-material:
|
mkdocs-material:
|
||||||
if: github.event.pull_request.draft == false
|
if: github.event.pull_request.draft == false
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
|
||||||
|
REPO_NAME: '${{ github.repository }}'
|
||||||
|
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
|
||||||
steps:
|
steps:
|
||||||
- name: checkout sources
|
- name: checkout sources
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
@ -22,11 +25,15 @@ jobs:
|
|||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
cache: pip
|
||||||
|
cache-dependency-path: pyproject.toml
|
||||||
|
|
||||||
- name: install requirements
|
- name: install requirements
|
||||||
|
env:
|
||||||
|
PIP_USE_PEP517: 1
|
||||||
run: |
|
run: |
|
||||||
python -m \
|
python -m \
|
||||||
pip install -r docs/requirements-mkdocs.txt
|
pip install ".[docs]"
|
||||||
|
|
||||||
- name: confirm buildability
|
- name: confirm buildability
|
||||||
run: |
|
run: |
|
||||||
|
@ -247,8 +247,8 @@ class InvokeAiInstance:
|
|||||||
pip[
|
pip[
|
||||||
"install",
|
"install",
|
||||||
"--require-virtualenv",
|
"--require-virtualenv",
|
||||||
"torch",
|
"torch~=2.0.0",
|
||||||
"torchvision",
|
"torchvision>=0.14.1",
|
||||||
"--force-reinstall",
|
"--force-reinstall",
|
||||||
"--find-links" if find_links is not None else None,
|
"--find-links" if find_links is not None else None,
|
||||||
find_links,
|
find_links,
|
||||||
|
@ -83,7 +83,7 @@ async def get_thumbnail(
|
|||||||
status_code=201,
|
status_code=201,
|
||||||
)
|
)
|
||||||
async def upload_image(
|
async def upload_image(
|
||||||
file: UploadFile, request: Request, response: Response
|
file: UploadFile, image_type: ImageType, request: Request, response: Response
|
||||||
) -> ImageResponse:
|
) -> ImageResponse:
|
||||||
if not file.content_type.startswith("image"):
|
if not file.content_type.startswith("image"):
|
||||||
raise HTTPException(status_code=415, detail="Not an image")
|
raise HTTPException(status_code=415, detail="Not an image")
|
||||||
@ -99,21 +99,21 @@ async def upload_image(
|
|||||||
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||||
|
|
||||||
saved_image = ApiDependencies.invoker.services.images.save(
|
saved_image = ApiDependencies.invoker.services.images.save(
|
||||||
ImageType.UPLOAD, filename, img
|
image_type, filename, img
|
||||||
)
|
)
|
||||||
|
|
||||||
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
|
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
|
||||||
|
|
||||||
image_url = ApiDependencies.invoker.services.images.get_uri(
|
image_url = ApiDependencies.invoker.services.images.get_uri(
|
||||||
ImageType.UPLOAD, saved_image.image_name
|
image_type, saved_image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
|
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
|
||||||
ImageType.UPLOAD, saved_image.image_name, True
|
image_type, saved_image.image_name, True
|
||||||
)
|
)
|
||||||
|
|
||||||
res = ImageResponse(
|
res = ImageResponse(
|
||||||
image_type=ImageType.UPLOAD,
|
image_type=image_type,
|
||||||
image_name=saved_image.image_name,
|
image_name=saved_image.image_name,
|
||||||
image_url=image_url,
|
image_url=image_url,
|
||||||
thumbnail_url=thumbnail_url,
|
thumbnail_url=thumbnail_url,
|
||||||
|
@ -122,7 +122,6 @@ app.openapi = custom_openapi
|
|||||||
# Override API doc favicons
|
# Override API doc favicons
|
||||||
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
|
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/docs", include_in_schema=False)
|
@app.get("/docs", include_in_schema=False)
|
||||||
def overridden_swagger():
|
def overridden_swagger():
|
||||||
return get_swagger_ui_html(
|
return get_swagger_ui_html(
|
||||||
@ -140,6 +139,8 @@ def overridden_redoc():
|
|||||||
redoc_favicon_url="/static/favicon.ico",
|
redoc_favicon_url="/static/favicon.ico",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Must mount *after* the other routes else it borks em
|
||||||
|
app.mount("/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), name="ui")
|
||||||
|
|
||||||
def invoke_api():
|
def invoke_api():
|
||||||
global web_config
|
global web_config
|
||||||
|
@ -3,12 +3,12 @@
|
|||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.random
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
InvocationConfig,
|
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
)
|
)
|
||||||
@ -50,11 +50,11 @@ class RandomRangeInvocation(BaseInvocation):
|
|||||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||||
)
|
)
|
||||||
size: int = Field(default=1, description="The number of values to generate")
|
size: int = Field(default=1, description="The number of values to generate")
|
||||||
seed: Optional[int] = Field(
|
seed: int = Field(
|
||||||
ge=0,
|
ge=0,
|
||||||
le=np.iinfo(np.int32).max,
|
le=SEED_MAX,
|
||||||
description="The seed for the RNG",
|
description="The seed for the RNG (omit for random)",
|
||||||
default_factory=lambda: numpy.random.randint(0, np.iinfo(np.int32).max),
|
default_factory=get_random_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union, get_args
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageField, ImageType
|
from invokeai.app.models.image import ColorField, ImageField, ImageType
|
||||||
from invokeai.app.invocations.util.choose_model import choose_model
|
from invokeai.app.invocations.util.choose_model import choose_model
|
||||||
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
from invokeai.backend.generator.inpaint import infill_methods
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
from .image import ImageOutput, build_image_output
|
from .image import ImageOutput, build_image_output
|
||||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||||
@ -17,7 +19,8 @@ from ...backend.stable_diffusion import PipelineIntermediateState
|
|||||||
from ..util.step_callback import stable_diffusion_step_callback
|
from ..util.step_callback import stable_diffusion_step_callback
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||||
|
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||||
|
DEFAULT_INFILL_METHOD = 'patchmatch' if 'patchmatch' in get_args(INFILL_METHODS) else 'tile'
|
||||||
|
|
||||||
class SDImageInvocation(BaseModel):
|
class SDImageInvocation(BaseModel):
|
||||||
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
||||||
@ -44,15 +47,13 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||||
# fmt: off
|
# fmt: off
|
||||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
|
||||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
|
||||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
@ -148,7 +149,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
mask = None
|
|
||||||
|
|
||||||
if self.fit:
|
if self.fit:
|
||||||
image = image.resize((self.width, self.height))
|
image = image.resize((self.width, self.height))
|
||||||
@ -165,7 +165,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
outputs = Img2Img(model).generate(
|
outputs = Img2Img(model).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_image=image,
|
init_image=image,
|
||||||
init_mask=mask,
|
|
||||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt", "image", "mask"}
|
exclude={"prompt", "image", "mask"}
|
||||||
@ -197,7 +196,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
image=result_image,
|
image=result_image,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InpaintInvocation(ImageToImageInvocation):
|
class InpaintInvocation(ImageToImageInvocation):
|
||||||
"""Generates an image using inpaint."""
|
"""Generates an image using inpaint."""
|
||||||
|
|
||||||
@ -205,6 +203,17 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
mask: Union[ImageField, None] = Field(description="The mask")
|
mask: Union[ImageField, None] = Field(description="The mask")
|
||||||
|
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||||
|
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
|
||||||
|
seam_strength: float = Field(
|
||||||
|
default=0.75, gt=0, le=1, description="The seam inpaint strength"
|
||||||
|
)
|
||||||
|
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
|
||||||
|
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
|
||||||
|
infill_method: INFILL_METHODS = Field(default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)")
|
||||||
|
inpaint_width: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)")
|
||||||
|
inpaint_height: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)")
|
||||||
|
inpaint_fill: Optional[ColorField] = Field(default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color")
|
||||||
inpaint_replace: float = Field(
|
inpaint_replace: float = Field(
|
||||||
default=0.0,
|
default=0.0,
|
||||||
ge=0.0,
|
ge=0.0,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
import io
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
@ -32,14 +33,12 @@ class ImageOutput(BaseInvocationOutput):
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["image"] = "image"
|
type: Literal["image"] = "image"
|
||||||
image: ImageField = Field(default=None, description="The output image")
|
image: ImageField = Field(default=None, description="The output image")
|
||||||
width: Optional[int] = Field(default=None, description="The width of the image in pixels")
|
width: int = Field(description="The width of the image in pixels")
|
||||||
height: Optional[int] = Field(default=None, description="The height of the image in pixels")
|
height: int = Field(description="The height of the image in pixels")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||||
"required": ["type", "image", "width", "height", "mode"]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def build_image_output(
|
def build_image_output(
|
||||||
@ -54,7 +53,6 @@ def build_image_output(
|
|||||||
image=image_field,
|
image=image_field,
|
||||||
width=image.width,
|
width=image.width,
|
||||||
height=image.height,
|
height=image.height,
|
||||||
mode=image.mode,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
233
invokeai/app/invocations/infill.py
Normal file
233
invokeai/app/invocations/infill.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from typing import Literal, Optional, Union, get_args
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from invokeai.app.invocations.image import ImageOutput, build_image_output
|
||||||
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
|
|
||||||
|
from ..models.image import ColorField, ImageField, ImageType
|
||||||
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
InvocationContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def infill_methods() -> list[str]:
|
||||||
|
methods = [
|
||||||
|
"tile",
|
||||||
|
"solid",
|
||||||
|
]
|
||||||
|
if PatchMatch.patchmatch_available():
|
||||||
|
methods.insert(0, "patchmatch")
|
||||||
|
return methods
|
||||||
|
|
||||||
|
|
||||||
|
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||||
|
DEFAULT_INFILL_METHOD = (
|
||||||
|
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||||
|
if im.mode != "RGBA":
|
||||||
|
return im
|
||||||
|
|
||||||
|
# Skip patchmatch if patchmatch isn't available
|
||||||
|
if not PatchMatch.patchmatch_available():
|
||||||
|
return im
|
||||||
|
|
||||||
|
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||||
|
im_patched_np = PatchMatch.inpaint(
|
||||||
|
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
|
||||||
|
)
|
||||||
|
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||||
|
return im_patched
|
||||||
|
|
||||||
|
|
||||||
|
def get_tile_images(image: np.ndarray, width=8, height=8):
|
||||||
|
_nrows, _ncols, depth = image.shape
|
||||||
|
_strides = image.strides
|
||||||
|
|
||||||
|
nrows, _m = divmod(_nrows, height)
|
||||||
|
ncols, _n = divmod(_ncols, width)
|
||||||
|
if _m != 0 or _n != 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return np.lib.stride_tricks.as_strided(
|
||||||
|
np.ravel(image),
|
||||||
|
shape=(nrows, ncols, height, width, depth),
|
||||||
|
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||||
|
writeable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def tile_fill_missing(
|
||||||
|
im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
|
||||||
|
) -> Image.Image:
|
||||||
|
# Only fill if there's an alpha layer
|
||||||
|
if im.mode != "RGBA":
|
||||||
|
return im
|
||||||
|
|
||||||
|
a = np.asarray(im, dtype=np.uint8)
|
||||||
|
|
||||||
|
tile_size_tuple = (tile_size, tile_size)
|
||||||
|
|
||||||
|
# Get the image as tiles of a specified size
|
||||||
|
tiles = get_tile_images(a, *tile_size_tuple).copy()
|
||||||
|
|
||||||
|
# Get the mask as tiles
|
||||||
|
tiles_mask = tiles[:, :, :, :, 3]
|
||||||
|
|
||||||
|
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||||
|
tmask_shape = tiles_mask.shape
|
||||||
|
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||||
|
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||||
|
tiles_mask = tiles_mask > 0
|
||||||
|
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
|
||||||
|
|
||||||
|
# Get RGB tiles in single array and filter by the mask
|
||||||
|
tshape = tiles.shape
|
||||||
|
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
|
||||||
|
filtered_tiles = tiles_all[tiles_mask]
|
||||||
|
|
||||||
|
if len(filtered_tiles) == 0:
|
||||||
|
return im
|
||||||
|
|
||||||
|
# Find all invalid tiles and replace with a random valid tile
|
||||||
|
replace_count = (tiles_mask == False).sum()
|
||||||
|
rng = np.random.default_rng(seed=seed)
|
||||||
|
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
|
||||||
|
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
|
||||||
|
]
|
||||||
|
|
||||||
|
# Convert back to an image
|
||||||
|
tiles_all = tiles_all.reshape(tshape)
|
||||||
|
tiles_all = tiles_all.swapaxes(1, 2)
|
||||||
|
st = tiles_all.reshape(
|
||||||
|
(
|
||||||
|
math.prod(tiles_all.shape[0:2]),
|
||||||
|
math.prod(tiles_all.shape[2:4]),
|
||||||
|
tiles_all.shape[4],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
si = Image.fromarray(st, mode="RGBA")
|
||||||
|
|
||||||
|
return si
|
||||||
|
|
||||||
|
|
||||||
|
class InfillColorInvocation(BaseInvocation):
|
||||||
|
"""Infills transparent areas of an image with a solid color"""
|
||||||
|
|
||||||
|
type: Literal["infill_rgba"] = "infill_rgba"
|
||||||
|
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||||
|
color: Optional[ColorField] = Field(
|
||||||
|
default=ColorField(r=127, g=127, b=127, a=255),
|
||||||
|
description="The color to use to infill",
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image = context.services.images.get(
|
||||||
|
self.image.image_type, self.image.image_name
|
||||||
|
)
|
||||||
|
|
||||||
|
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||||
|
infilled = Image.alpha_composite(solid_bg, image)
|
||||||
|
|
||||||
|
infilled.paste(image, (0, 0), image.split()[-1])
|
||||||
|
|
||||||
|
image_type = ImageType.RESULT
|
||||||
|
image_name = context.services.images.create_name(
|
||||||
|
context.graph_execution_state_id, self.id
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = context.services.metadata.build_metadata(
|
||||||
|
session_id=context.graph_execution_state_id, node=self
|
||||||
|
)
|
||||||
|
|
||||||
|
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||||
|
return build_image_output(
|
||||||
|
image_type=image_type,
|
||||||
|
image_name=image_name,
|
||||||
|
image=image,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InfillTileInvocation(BaseInvocation):
|
||||||
|
"""Infills transparent areas of an image with tiles of the image"""
|
||||||
|
|
||||||
|
type: Literal["infill_tile"] = "infill_tile"
|
||||||
|
|
||||||
|
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||||
|
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
||||||
|
seed: int = Field(
|
||||||
|
ge=0,
|
||||||
|
le=SEED_MAX,
|
||||||
|
description="The seed to use for tile generation (omit for random)",
|
||||||
|
default_factory=get_random_seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image = context.services.images.get(
|
||||||
|
self.image.image_type, self.image.image_name
|
||||||
|
)
|
||||||
|
|
||||||
|
infilled = tile_fill_missing(
|
||||||
|
image.copy(), seed=self.seed, tile_size=self.tile_size
|
||||||
|
)
|
||||||
|
infilled.paste(image, (0, 0), image.split()[-1])
|
||||||
|
|
||||||
|
image_type = ImageType.RESULT
|
||||||
|
image_name = context.services.images.create_name(
|
||||||
|
context.graph_execution_state_id, self.id
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = context.services.metadata.build_metadata(
|
||||||
|
session_id=context.graph_execution_state_id, node=self
|
||||||
|
)
|
||||||
|
|
||||||
|
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||||
|
return build_image_output(
|
||||||
|
image_type=image_type,
|
||||||
|
image_name=image_name,
|
||||||
|
image=image,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InfillPatchMatchInvocation(BaseInvocation):
|
||||||
|
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||||
|
|
||||||
|
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||||
|
|
||||||
|
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image = context.services.images.get(
|
||||||
|
self.image.image_type, self.image.image_name
|
||||||
|
)
|
||||||
|
|
||||||
|
if PatchMatch.patchmatch_available():
|
||||||
|
infilled = infill_patchmatch(image.copy())
|
||||||
|
else:
|
||||||
|
raise ValueError("PatchMatch is not available on this system")
|
||||||
|
|
||||||
|
image_type = ImageType.RESULT
|
||||||
|
image_name = context.services.images.create_name(
|
||||||
|
context.graph_execution_state_id, self.id
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = context.services.metadata.build_metadata(
|
||||||
|
session_id=context.graph_execution_state_id, node=self
|
||||||
|
)
|
||||||
|
|
||||||
|
context.services.images.save(image_type, image_name, infilled, metadata)
|
||||||
|
return build_image_output(
|
||||||
|
image_type=image_type,
|
||||||
|
image_name=image_name,
|
||||||
|
image=image,
|
||||||
|
)
|
@ -1,11 +1,13 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional, Union
|
||||||
|
import einops
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.app.invocations.util.choose_model import choose_model
|
from invokeai.app.invocations.util.choose_model import choose_model
|
||||||
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
|
|
||||||
@ -13,7 +15,9 @@ from ...backend.model_management.model_manager import ModelManager
|
|||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.image_util.seamless import configure_model_padding
|
from ...backend.image_util.seamless import configure_model_padding
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||||
|
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||||
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ..services.image_storage import ImageType
|
from ..services.image_storage import ImageType
|
||||||
@ -37,41 +41,55 @@ class LatentsField(BaseModel):
|
|||||||
class LatentsOutput(BaseInvocationOutput):
|
class LatentsOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output latents"""
|
"""Base class for invocations that output latents"""
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["latent_output"] = "latent_output"
|
type: Literal["latents_output"] = "latents_output"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
latents: LatentsField = Field(default=None, description="The output latents")
|
latents: LatentsField = Field(default=None, description="The output latents")
|
||||||
|
width: int = Field(description="The width of the latents in pixels")
|
||||||
|
height: int = Field(description="The height of the latents in pixels")
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
||||||
|
return LatentsOutput(
|
||||||
|
latents=LatentsField(latents_name=latents_name),
|
||||||
|
width=latents.size()[3] * 8,
|
||||||
|
height=latents.size()[2] * 8,
|
||||||
|
)
|
||||||
|
|
||||||
class NoiseOutput(BaseInvocationOutput):
|
class NoiseOutput(BaseInvocationOutput):
|
||||||
"""Invocation noise output"""
|
"""Invocation noise output"""
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["noise_output"] = "noise_output"
|
type: Literal["noise_output"] = "noise_output"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
noise: LatentsField = Field(default=None, description="The output noise")
|
noise: LatentsField = Field(default=None, description="The output noise")
|
||||||
|
width: int = Field(description="The width of the noise in pixels")
|
||||||
|
height: int = Field(description="The height of the noise in pixels")
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
def build_noise_output(latents_name: str, latents: torch.Tensor):
|
||||||
# TODO: this seems like a hack
|
return NoiseOutput(
|
||||||
scheduler_map = dict(
|
noise=LatentsField(latents_name=latents_name),
|
||||||
ddim=diffusers.DDIMScheduler,
|
width=latents.size()[3] * 8,
|
||||||
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
height=latents.size()[2] * 8,
|
||||||
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
|
||||||
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
|
||||||
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
|
||||||
k_euler=diffusers.EulerDiscreteScheduler,
|
|
||||||
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
|
||||||
k_heun=diffusers.HeunDiscreteScheduler,
|
|
||||||
k_lms=diffusers.LMSDiscreteScheduler,
|
|
||||||
plms=diffusers.PNDMScheduler,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[
|
SAMPLER_NAME_VALUES = Literal[
|
||||||
tuple(list(scheduler_map.keys()))
|
tuple(list(SCHEDULER_MAP.keys()))
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||||
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||||
scheduler = scheduler_class.from_config(model.scheduler.config)
|
|
||||||
|
scheduler_config = model.scheduler.config
|
||||||
|
if "_backup" in scheduler_config:
|
||||||
|
scheduler_config = scheduler_config["_backup"]
|
||||||
|
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||||
|
scheduler = scheduler_class.from_config(scheduler_config)
|
||||||
|
|
||||||
# hack copied over from generate.py
|
# hack copied over from generate.py
|
||||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||||
scheduler.uses_inpainting_model = lambda: False
|
scheduler.uses_inpainting_model = lambda: False
|
||||||
@ -102,17 +120,13 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def random_seed():
|
|
||||||
return random.randint(0, np.iinfo(np.uint32).max)
|
|
||||||
|
|
||||||
|
|
||||||
class NoiseInvocation(BaseInvocation):
|
class NoiseInvocation(BaseInvocation):
|
||||||
"""Generates latent noise."""
|
"""Generates latent noise."""
|
||||||
|
|
||||||
type: Literal["noise"] = "noise"
|
type: Literal["noise"] = "noise"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
|
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed)
|
||||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
|
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
|
||||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
|
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
|
||||||
|
|
||||||
@ -131,9 +145,7 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.set(name, noise)
|
context.services.latents.set(name, noise)
|
||||||
return NoiseOutput(
|
return build_noise_output(latents_name=name, latents=noise)
|
||||||
noise=LatentsField(latents_name=name)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
@ -149,11 +161,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
|
||||||
|
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
|
||||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
@ -218,7 +229,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||||
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||||
),
|
),
|
||||||
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta)
|
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
|
||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
|
|
||||||
@ -250,9 +261,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.set(name, result_latents)
|
context.services.latents.set(name, result_latents)
|
||||||
return LatentsOutput(
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
latents=LatentsField(latents_name=name)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||||
@ -260,6 +269,10 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
|
|
||||||
type: Literal["l2l"] = "l2l"
|
type: Literal["l2l"] = "l2l"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||||
|
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
@ -271,10 +284,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Inputs
|
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
|
||||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
latent = context.services.latents.get(self.latents.latents_name)
|
latent = context.services.latents.get(self.latents.latents_name)
|
||||||
@ -287,7 +296,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
self.dispatch_progress(context, source_node_id, state)
|
self.dispatch_progress(context, source_node_id, state)
|
||||||
|
|
||||||
model = self.get_model(context.services.model_manager)
|
model = self.get_model(context.services.model_manager)
|
||||||
conditioning_data = self.get_conditioning_data(model)
|
conditioning_data = self.get_conditioning_data(context, model)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
|
|
||||||
@ -295,11 +304,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
latent, device=model.device, dtype=latent.dtype
|
latent, device=model.device, dtype=latent.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
timesteps, _ = model.get_img2img_timesteps(
|
timesteps, _ = model.get_img2img_timesteps(self.steps, self.strength)
|
||||||
self.steps,
|
|
||||||
self.strength,
|
|
||||||
device=model.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||||
latents=initial_latents,
|
latents=initial_latents,
|
||||||
@ -315,9 +320,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.set(name, result_latents)
|
context.services.latents.set(name, result_latents)
|
||||||
return LatentsOutput(
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
latents=LatentsField(latents_name=name)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Latent to image
|
# Latent to image
|
||||||
@ -384,8 +387,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
@ -402,7 +405,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.set(name, resized_latents)
|
context.services.latents.set(name, resized_latents)
|
||||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||||
|
|
||||||
|
|
||||||
class ScaleLatentsInvocation(BaseInvocation):
|
class ScaleLatentsInvocation(BaseInvocation):
|
||||||
@ -413,8 +416,8 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
||||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
||||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
@ -432,4 +435,48 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.set(name, resized_latents)
|
context.services.latents.set(name, resized_latents)
|
||||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
|
"""Encodes an image into latents."""
|
||||||
|
|
||||||
|
type: Literal["i2l"] = "i2l"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: Union[ImageField, None] = Field(description="The image to encode")
|
||||||
|
model: str = Field(default="", description="The model to use")
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"tags": ["latents", "image"],
|
||||||
|
"type_hints": {"model": "model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
image = context.services.images.get(
|
||||||
|
self.image.image_type, self.image.image_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: this only really needs the vae
|
||||||
|
model_info = choose_model(context.services.model_manager, self.model)
|
||||||
|
model: StableDiffusionGeneratorPipeline = model_info["model"]
|
||||||
|
|
||||||
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
|
|
||||||
|
if image_tensor.dim() == 3:
|
||||||
|
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||||
|
|
||||||
|
latents = model.non_noised_latents_from_image(
|
||||||
|
image_tensor,
|
||||||
|
device=model._model_group.device_for(model.unet),
|
||||||
|
dtype=model.unet.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
context.services.latents.set(name, latents)
|
||||||
|
return build_latents_output(latents_name=name, latents=latents)
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
|
|
||||||
@ -73,3 +74,12 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
return IntOutput(a=int(self.a / self.b))
|
return IntOutput(a=int(self.a / self.b))
|
||||||
|
|
||||||
|
|
||||||
|
class RandomIntInvocation(BaseInvocation):
|
||||||
|
"""Outputs a single random integer."""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["rand_int"] = "rand_int"
|
||||||
|
#fmt: on
|
||||||
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
|
return IntOutput(a=np.random.randint(0, np.iinfo(np.int32).max))
|
||||||
|
@ -4,10 +4,11 @@ from invokeai.backend.model_management.model_manager import ModelManager
|
|||||||
def choose_model(model_manager: ModelManager, model_name: str):
|
def choose_model(model_manager: ModelManager, model_name: str):
|
||||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||||
logger = model_manager.logger
|
logger = model_manager.logger
|
||||||
if model_manager.valid_model(model_name):
|
if model_name and not model_manager.valid_model(model_name):
|
||||||
model = model_manager.get_model(model_name)
|
default_model_name = model_manager.default_model()
|
||||||
else:
|
logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{default_model_name}\' instead.")
|
||||||
model = model_manager.get_model()
|
model = model_manager.get_model()
|
||||||
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
|
else:
|
||||||
|
model = model_manager.get_model(model_name)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
@ -27,3 +27,13 @@ class ImageField(BaseModel):
|
|||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["image_type", "image_name"]}
|
schema_extra = {"required": ["image_type", "image_name"]}
|
||||||
|
|
||||||
|
|
||||||
|
class ColorField(BaseModel):
|
||||||
|
r: int = Field(ge=0, le=255, description="The red component")
|
||||||
|
g: int = Field(ge=0, le=255, description="The green component")
|
||||||
|
b: int = Field(ge=0, le=255, description="The blue component")
|
||||||
|
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||||
|
|
||||||
|
def tuple(self) -> Tuple[int, int, int, int]:
|
||||||
|
return (self.r, self.g, self.b, self.a)
|
||||||
|
@ -49,12 +49,13 @@ def create_text_to_image() -> LibraryGraph:
|
|||||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||||
|
|
||||||
|
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||||
graphs: list[LibraryGraph] = list()
|
graphs: list[LibraryGraph] = list()
|
||||||
|
|
||||||
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
# text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||||
|
|
||||||
# TODO: Check if the graph is the same as the default one, and if not, update it
|
# # TODO: Check if the graph is the same as the default one, and if not, update it
|
||||||
#if text_to_image is None:
|
# #if text_to_image is None:
|
||||||
text_to_image = create_text_to_image()
|
text_to_image = create_text_to_image()
|
||||||
graph_library.set(text_to_image)
|
graph_library.set(text_to_image)
|
||||||
|
|
||||||
|
@ -270,4 +270,5 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
) # TODO: this should refresh position for LRU cache
|
) # TODO: this should refresh position for LRU cache
|
||||||
if len(self.__cache) > self.__max_cache_size:
|
if len(self.__cache) > self.__max_cache_size:
|
||||||
cache_id = self.__cache_ids.get()
|
cache_id = self.__cache_ids.get()
|
||||||
|
if cache_id in self.__cache:
|
||||||
del self.__cache[cache_id]
|
del self.__cache[cache_id]
|
||||||
|
@ -20,9 +20,18 @@ class MetadataLatentsField(TypedDict):
|
|||||||
latents_name: str
|
latents_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataColorField(TypedDict):
|
||||||
|
"""Pydantic-less ColorField, used for metadata parsing"""
|
||||||
|
r: int
|
||||||
|
g: int
|
||||||
|
b: int
|
||||||
|
a: int
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
|
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
|
||||||
NodeMetadata = Dict[
|
NodeMetadata = Dict[
|
||||||
str, str | int | float | bool | MetadataImageField | MetadataLatentsField
|
str, None | str | int | float | bool | MetadataImageField | MetadataLatentsField | MetadataColorField
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from threading import Event, Thread, BoundedSemaphore
|
from threading import Event, Thread, BoundedSemaphore
|
||||||
|
|
||||||
@ -6,6 +7,7 @@ from .invocation_queue import InvocationQueueItem
|
|||||||
from .invoker import InvocationProcessorABC, Invoker
|
from .invoker import InvocationProcessorABC, Invoker
|
||||||
from ..models.exceptions import CanceledException
|
from ..models.exceptions import CanceledException
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||||
__invoker_thread: Thread
|
__invoker_thread: Thread
|
||||||
__stop_event: Event
|
__stop_event: Event
|
||||||
@ -34,8 +36,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
try:
|
try:
|
||||||
self.__threadLimit.acquire()
|
self.__threadLimit.acquire()
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
|
try:
|
||||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Exception while getting from queue: %s" % e)
|
||||||
|
|
||||||
if not queue_item: # Probably stopping
|
if not queue_item: # Probably stopping
|
||||||
|
# do not hammer the queue
|
||||||
|
time.sleep(0.5)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
graph_execution_state = (
|
graph_execution_state = (
|
||||||
@ -124,7 +132,16 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
# Queue any further commands if invoking all
|
# Queue any further commands if invoking all
|
||||||
is_complete = graph_execution_state.is_complete()
|
is_complete = graph_execution_state.is_complete()
|
||||||
if queue_item.invoke_all and not is_complete:
|
if queue_item.invoke_all and not is_complete:
|
||||||
|
try:
|
||||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error while invoking: %s" % e)
|
||||||
|
self.__invoker.services.events.emit_invocation_error(
|
||||||
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
|
node=invocation.dict(),
|
||||||
|
source_node_id=source_node_id,
|
||||||
|
error=traceback.format_exc()
|
||||||
|
)
|
||||||
elif is_complete:
|
elif is_complete:
|
||||||
self.__invoker.services.events.emit_graph_execution_complete(
|
self.__invoker.services.events.emit_graph_execution_complete(
|
||||||
graph_execution_state.id
|
graph_execution_state.id
|
||||||
|
@ -1,5 +1,13 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def get_timestamp():
|
def get_timestamp():
|
||||||
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||||
|
|
||||||
|
|
||||||
|
SEED_MAX = np.iinfo(np.int32).max
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_seed():
|
||||||
|
return np.random.randint(0, SEED_MAX)
|
||||||
|
@ -31,6 +31,7 @@ from ..util.util import rand_perlin_2d
|
|||||||
from ..safety_checker import SafetyChecker
|
from ..safety_checker import SafetyChecker
|
||||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
|
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
|
|
||||||
downsampling = 8
|
downsampling = 8
|
||||||
|
|
||||||
@ -71,19 +72,6 @@ class InvokeAIGeneratorOutput:
|
|||||||
# we are interposing a wrapper around the original Generator classes so that
|
# we are interposing a wrapper around the original Generator classes so that
|
||||||
# old code that calls Generate will continue to work.
|
# old code that calls Generate will continue to work.
|
||||||
class InvokeAIGenerator(metaclass=ABCMeta):
|
class InvokeAIGenerator(metaclass=ABCMeta):
|
||||||
scheduler_map = dict(
|
|
||||||
ddim=diffusers.DDIMScheduler,
|
|
||||||
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
|
||||||
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
|
||||||
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
|
||||||
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
|
||||||
k_euler=diffusers.EulerDiscreteScheduler,
|
|
||||||
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
|
||||||
k_heun=diffusers.HeunDiscreteScheduler,
|
|
||||||
k_lms=diffusers.LMSDiscreteScheduler,
|
|
||||||
plms=diffusers.PNDMScheduler,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_info: dict,
|
model_info: dict,
|
||||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||||
@ -175,14 +163,20 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
'''
|
'''
|
||||||
Return list of all the schedulers that we currently handle.
|
Return list of all the schedulers that we currently handle.
|
||||||
'''
|
'''
|
||||||
return list(self.scheduler_map.keys())
|
return list(SCHEDULER_MAP.keys())
|
||||||
|
|
||||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||||
return generator_class(model, self.params.precision)
|
return generator_class(model, self.params.precision)
|
||||||
|
|
||||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||||
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||||
scheduler = scheduler_class.from_config(model.scheduler.config)
|
|
||||||
|
scheduler_config = model.scheduler.config
|
||||||
|
if "_backup" in scheduler_config:
|
||||||
|
scheduler_config = scheduler_config["_backup"]
|
||||||
|
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||||
|
scheduler = scheduler_class.from_config(scheduler_config)
|
||||||
|
|
||||||
# hack copied over from generate.py
|
# hack copied over from generate.py
|
||||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||||
scheduler.uses_inpainting_model = lambda: False
|
scheduler.uses_inpainting_model = lambda: False
|
||||||
@ -226,10 +220,10 @@ class Inpaint(Img2Img):
|
|||||||
def generate(self,
|
def generate(self,
|
||||||
mask_image: Image.Image | torch.FloatTensor,
|
mask_image: Image.Image | torch.FloatTensor,
|
||||||
# Seam settings - when 0, doesn't fill seam
|
# Seam settings - when 0, doesn't fill seam
|
||||||
seam_size: int = 0,
|
seam_size: int = 96,
|
||||||
seam_blur: int = 0,
|
seam_blur: int = 16,
|
||||||
seam_strength: float = 0.7,
|
seam_strength: float = 0.7,
|
||||||
seam_steps: int = 10,
|
seam_steps: int = 30,
|
||||||
tile_size: int = 32,
|
tile_size: int = 32,
|
||||||
inpaint_replace=False,
|
inpaint_replace=False,
|
||||||
infill_method=None,
|
infill_method=None,
|
||||||
|
@ -4,6 +4,7 @@ invokeai.backend.generator.inpaint descends from .generator
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -59,7 +60,7 @@ class Inpaint(Img2Img):
|
|||||||
writeable=False,
|
writeable=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def infill_patchmatch(self, im: Image.Image) -> Image:
|
def infill_patchmatch(self, im: Image.Image) -> Image.Image:
|
||||||
if im.mode != "RGBA":
|
if im.mode != "RGBA":
|
||||||
return im
|
return im
|
||||||
|
|
||||||
@ -75,18 +76,18 @@ class Inpaint(Img2Img):
|
|||||||
return im_patched
|
return im_patched
|
||||||
|
|
||||||
def tile_fill_missing(
|
def tile_fill_missing(
|
||||||
self, im: Image.Image, tile_size: int = 16, seed: int = None
|
self, im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
|
||||||
) -> Image:
|
) -> Image.Image:
|
||||||
# Only fill if there's an alpha layer
|
# Only fill if there's an alpha layer
|
||||||
if im.mode != "RGBA":
|
if im.mode != "RGBA":
|
||||||
return im
|
return im
|
||||||
|
|
||||||
a = np.asarray(im, dtype=np.uint8)
|
a = np.asarray(im, dtype=np.uint8)
|
||||||
|
|
||||||
tile_size = (tile_size, tile_size)
|
tile_size_tuple = (tile_size, tile_size)
|
||||||
|
|
||||||
# Get the image as tiles of a specified size
|
# Get the image as tiles of a specified size
|
||||||
tiles = self.get_tile_images(a, *tile_size).copy()
|
tiles = self.get_tile_images(a, *tile_size_tuple).copy()
|
||||||
|
|
||||||
# Get the mask as tiles
|
# Get the mask as tiles
|
||||||
tiles_mask = tiles[:, :, :, :, 3]
|
tiles_mask = tiles[:, :, :, :, 3]
|
||||||
@ -127,7 +128,9 @@ class Inpaint(Img2Img):
|
|||||||
|
|
||||||
return si
|
return si
|
||||||
|
|
||||||
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
|
def mask_edge(
|
||||||
|
self, mask: Image.Image, edge_size: int, edge_blur: int
|
||||||
|
) -> Image.Image:
|
||||||
npimg = np.asarray(mask, dtype=np.uint8)
|
npimg = np.asarray(mask, dtype=np.uint8)
|
||||||
|
|
||||||
# Detect any partially transparent regions
|
# Detect any partially transparent regions
|
||||||
@ -206,15 +209,15 @@ class Inpaint(Img2Img):
|
|||||||
cfg_scale,
|
cfg_scale,
|
||||||
ddim_eta,
|
ddim_eta,
|
||||||
conditioning,
|
conditioning,
|
||||||
init_image: PIL.Image.Image | torch.FloatTensor,
|
init_image: Image.Image | torch.FloatTensor,
|
||||||
mask_image: PIL.Image.Image | torch.FloatTensor,
|
mask_image: Image.Image | torch.FloatTensor,
|
||||||
strength: float,
|
strength: float,
|
||||||
mask_blur_radius: int = 8,
|
mask_blur_radius: int = 8,
|
||||||
# Seam settings - when 0, doesn't fill seam
|
# Seam settings - when 0, doesn't fill seam
|
||||||
seam_size: int = 0,
|
seam_size: int = 96,
|
||||||
seam_blur: int = 0,
|
seam_blur: int = 16,
|
||||||
seam_strength: float = 0.7,
|
seam_strength: float = 0.7,
|
||||||
seam_steps: int = 10,
|
seam_steps: int = 30,
|
||||||
tile_size: int = 32,
|
tile_size: int = 32,
|
||||||
step_callback=None,
|
step_callback=None,
|
||||||
inpaint_replace=False,
|
inpaint_replace=False,
|
||||||
@ -222,7 +225,7 @@ class Inpaint(Img2Img):
|
|||||||
infill_method=None,
|
infill_method=None,
|
||||||
inpaint_width=None,
|
inpaint_width=None,
|
||||||
inpaint_height=None,
|
inpaint_height=None,
|
||||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
inpaint_fill: Tuple[int, int, int, int] = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||||
attention_maps_callback=None,
|
attention_maps_callback=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -239,7 +242,7 @@ class Inpaint(Img2Img):
|
|||||||
self.inpaint_width = inpaint_width
|
self.inpaint_width = inpaint_width
|
||||||
self.inpaint_height = inpaint_height
|
self.inpaint_height = inpaint_height
|
||||||
|
|
||||||
if isinstance(init_image, PIL.Image.Image):
|
if isinstance(init_image, Image.Image):
|
||||||
self.pil_image = init_image.copy()
|
self.pil_image = init_image.copy()
|
||||||
|
|
||||||
# Do infill
|
# Do infill
|
||||||
@ -250,8 +253,8 @@ class Inpaint(Img2Img):
|
|||||||
self.pil_image.copy(), seed=self.seed, tile_size=tile_size
|
self.pil_image.copy(), seed=self.seed, tile_size=tile_size
|
||||||
)
|
)
|
||||||
elif infill_method == "solid":
|
elif infill_method == "solid":
|
||||||
solid_bg = PIL.Image.new("RGBA", init_image.size, inpaint_fill)
|
solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
|
||||||
init_filled = PIL.Image.alpha_composite(solid_bg, init_image)
|
init_filled = Image.alpha_composite(solid_bg, init_image)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Non-supported infill type {infill_method}", infill_method
|
f"Non-supported infill type {infill_method}", infill_method
|
||||||
@ -269,7 +272,7 @@ class Inpaint(Img2Img):
|
|||||||
# Create init tensor
|
# Create init tensor
|
||||||
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
|
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
|
||||||
|
|
||||||
if isinstance(mask_image, PIL.Image.Image):
|
if isinstance(mask_image, Image.Image):
|
||||||
self.pil_mask = mask_image.copy()
|
self.pil_mask = mask_image.copy()
|
||||||
debug_image(
|
debug_image(
|
||||||
mask_image,
|
mask_image,
|
||||||
|
@ -47,6 +47,7 @@ from diffusers import (
|
|||||||
LDMTextToImagePipeline,
|
LDMTextToImagePipeline,
|
||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
|
UniPCMultistepScheduler,
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
)
|
)
|
||||||
@ -1208,6 +1209,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
||||||
elif scheduler_type == "dpm":
|
elif scheduler_type == "dpm":
|
||||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||||
|
elif scheduler_type == 'unipc':
|
||||||
|
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
|
||||||
elif scheduler_type == "ddim":
|
elif scheduler_type == "ddim":
|
||||||
scheduler = scheduler
|
scheduler = scheduler
|
||||||
else:
|
else:
|
||||||
|
@ -1214,7 +1214,7 @@ class ModelManager(object):
|
|||||||
sha.update(chunk)
|
sha.update(chunk)
|
||||||
hash = sha.hexdigest()
|
hash = sha.hexdigest()
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
self.logger.debug(f"sha256 = {hash} ({count} files hashed in {toc - tic:4.2f}s)")
|
||||||
with open(hashpath, "w") as f:
|
with open(hashpath, "w") as f:
|
||||||
f.write(hash)
|
f.write(hash)
|
||||||
return hash
|
return hash
|
||||||
|
@ -509,10 +509,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
run_id=None,
|
run_id=None,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||||
|
if self.scheduler.config.get("cpu_only", False):
|
||||||
|
scheduler_device = torch.device('cpu')
|
||||||
|
else:
|
||||||
|
scheduler_device = self._model_group.device_for(self.unet)
|
||||||
|
|
||||||
if timesteps is None:
|
if timesteps is None:
|
||||||
self.scheduler.set_timesteps(
|
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||||
num_inference_steps, device=self._model_group.device_for(self.unet)
|
|
||||||
)
|
|
||||||
timesteps = self.scheduler.timesteps
|
timesteps = self.scheduler.timesteps
|
||||||
infer_latents_from_embeddings = GeneratorToCallbackinator(
|
infer_latents_from_embeddings = GeneratorToCallbackinator(
|
||||||
self.generate_latents_from_embeddings, PipelineIntermediateState
|
self.generate_latents_from_embeddings, PipelineIntermediateState
|
||||||
@ -726,11 +729,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
run_id=None,
|
run_id=None,
|
||||||
callback=None,
|
callback=None,
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
timesteps, _ = self.get_img2img_timesteps(
|
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
||||||
num_inference_steps,
|
|
||||||
strength,
|
|
||||||
device=self._model_group.device_for(self.unet),
|
|
||||||
)
|
|
||||||
result_latents, result_attention_maps = self.latents_from_embeddings(
|
result_latents, result_attention_maps = self.latents_from_embeddings(
|
||||||
latents=initial_latents if strength < 1.0 else torch.zeros_like(
|
latents=initial_latents if strength < 1.0 else torch.zeros_like(
|
||||||
initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
|
initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
|
||||||
@ -756,13 +755,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
return self.check_for_safety(output, dtype=conditioning_data.dtype)
|
||||||
|
|
||||||
def get_img2img_timesteps(
|
def get_img2img_timesteps(
|
||||||
self, num_inference_steps: int, strength: float, device
|
self, num_inference_steps: int, strength: float, device=None
|
||||||
) -> (torch.Tensor, int):
|
) -> (torch.Tensor, int):
|
||||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||||
assert img2img_pipeline.scheduler is self.scheduler
|
assert img2img_pipeline.scheduler is self.scheduler
|
||||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
|
||||||
|
if self.scheduler.config.get("cpu_only", False):
|
||||||
|
scheduler_device = torch.device('cpu')
|
||||||
|
else:
|
||||||
|
scheduler_device = self._model_group.device_for(self.unet)
|
||||||
|
|
||||||
|
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
|
||||||
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
|
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
|
||||||
num_inference_steps, strength, device=device
|
num_inference_steps, strength, device=scheduler_device
|
||||||
)
|
)
|
||||||
# Workaround for low strength resulting in zero timesteps.
|
# Workaround for low strength resulting in zero timesteps.
|
||||||
# TODO: submit upstream fix for zero-step img2img
|
# TODO: submit upstream fix for zero-step img2img
|
||||||
@ -796,9 +801,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if init_image.dim() == 3:
|
if init_image.dim() == 3:
|
||||||
init_image = init_image.unsqueeze(0)
|
init_image = init_image.unsqueeze(0)
|
||||||
|
|
||||||
timesteps, _ = self.get_img2img_timesteps(
|
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
|
||||||
num_inference_steps, strength, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
# 6. Prepare latent variables
|
# 6. Prepare latent variables
|
||||||
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
|
||||||
|
1
invokeai/backend/stable_diffusion/schedulers/__init__.py
Normal file
1
invokeai/backend/stable_diffusion/schedulers/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .schedulers import SCHEDULER_MAP
|
22
invokeai/backend/stable_diffusion/schedulers/schedulers.py
Normal file
22
invokeai/backend/stable_diffusion/schedulers/schedulers.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
|
||||||
|
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
|
||||||
|
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \
|
||||||
|
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler
|
||||||
|
|
||||||
|
SCHEDULER_MAP = dict(
|
||||||
|
ddim=(DDIMScheduler, dict()),
|
||||||
|
ddpm=(DDPMScheduler, dict()),
|
||||||
|
deis=(DEISMultistepScheduler, dict()),
|
||||||
|
lms=(LMSDiscreteScheduler, dict()),
|
||||||
|
pndm=(PNDMScheduler, dict()),
|
||||||
|
heun=(HeunDiscreteScheduler, dict()),
|
||||||
|
euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||||
|
euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||||
|
euler_a=(EulerAncestralDiscreteScheduler, dict()),
|
||||||
|
kdpm_2=(KDPM2DiscreteScheduler, dict()),
|
||||||
|
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
|
||||||
|
dpmpp_2s=(DPMSolverSinglestepScheduler, dict()),
|
||||||
|
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
|
||||||
|
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
|
||||||
|
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
|
||||||
|
)
|
@ -4,17 +4,20 @@ from .parse_seed_weights import parse_seed_weights
|
|||||||
|
|
||||||
SAMPLER_CHOICES = [
|
SAMPLER_CHOICES = [
|
||||||
"ddim",
|
"ddim",
|
||||||
"k_dpm_2_a",
|
"ddpm",
|
||||||
"k_dpm_2",
|
"deis",
|
||||||
"k_dpmpp_2_a",
|
"lms",
|
||||||
"k_dpmpp_2",
|
|
||||||
"k_euler_a",
|
|
||||||
"k_euler",
|
|
||||||
"k_heun",
|
|
||||||
"k_lms",
|
|
||||||
"plms",
|
|
||||||
# diffusers:
|
|
||||||
"pndm",
|
"pndm",
|
||||||
|
"heun",
|
||||||
|
"euler",
|
||||||
|
"euler_k",
|
||||||
|
"euler_a",
|
||||||
|
"kdpm_2",
|
||||||
|
"kdpm_2_a",
|
||||||
|
"dpmpp_2s",
|
||||||
|
"dpmpp_2m",
|
||||||
|
"dpmpp_2m_k",
|
||||||
|
"unipc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,13 +0,0 @@
|
|||||||
{
|
|
||||||
"plugins": [
|
|
||||||
[
|
|
||||||
"transform-imports",
|
|
||||||
{
|
|
||||||
"lodash": {
|
|
||||||
"transform": "lodash/${member}",
|
|
||||||
"preventFullImport": true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
]
|
|
||||||
}
|
|
4
invokeai/frontend/web/.gitignore
vendored
4
invokeai/frontend/web/.gitignore
vendored
@ -35,3 +35,7 @@ stats.html
|
|||||||
!.yarn/releases
|
!.yarn/releases
|
||||||
!.yarn/sdks
|
!.yarn/sdks
|
||||||
!.yarn/versions
|
!.yarn/versions
|
||||||
|
|
||||||
|
# Yalc
|
||||||
|
.yalc
|
||||||
|
yalc.lock
|
@ -5,6 +5,7 @@ import { PluginOption, UserConfig } from 'vite';
|
|||||||
import dts from 'vite-plugin-dts';
|
import dts from 'vite-plugin-dts';
|
||||||
import eslint from 'vite-plugin-eslint';
|
import eslint from 'vite-plugin-eslint';
|
||||||
import tsconfigPaths from 'vite-tsconfig-paths';
|
import tsconfigPaths from 'vite-tsconfig-paths';
|
||||||
|
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
|
||||||
|
|
||||||
export const packageConfig: UserConfig = {
|
export const packageConfig: UserConfig = {
|
||||||
base: './',
|
base: './',
|
||||||
@ -16,9 +17,10 @@ export const packageConfig: UserConfig = {
|
|||||||
dts({
|
dts({
|
||||||
insertTypesEntry: true,
|
insertTypesEntry: true,
|
||||||
}),
|
}),
|
||||||
|
cssInjectedByJsPlugin(),
|
||||||
],
|
],
|
||||||
build: {
|
build: {
|
||||||
chunkSizeWarningLimit: 1500,
|
cssCodeSplit: true,
|
||||||
lib: {
|
lib: {
|
||||||
entry: path.resolve(__dirname, '../src/index.ts'),
|
entry: path.resolve(__dirname, '../src/index.ts'),
|
||||||
name: 'InvokeAIUI',
|
name: 'InvokeAIUI',
|
||||||
@ -30,6 +32,7 @@ export const packageConfig: UserConfig = {
|
|||||||
globals: {
|
globals: {
|
||||||
react: 'React',
|
react: 'React',
|
||||||
'react-dom': 'ReactDOM',
|
'react-dom': 'ReactDOM',
|
||||||
|
'@emotion/react': 'EmotionReact',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -37,7 +37,7 @@ From `invokeai/frontend/web/` run `yarn install` to get everything set up.
|
|||||||
Start everything in dev mode:
|
Start everything in dev mode:
|
||||||
|
|
||||||
1. Start the dev server: `yarn dev`
|
1. Start the dev server: `yarn dev`
|
||||||
2. Start the InvokeAI UI per usual: `invokeai --web`
|
2. Start the InvokeAI Nodes backend: `python scripts/invokeai-new.py --web # run from the repo root`
|
||||||
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
|
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
|
||||||
|
|
||||||
### Production builds
|
### Production builds
|
||||||
|
@ -21,7 +21,6 @@
|
|||||||
"scripts": {
|
"scripts": {
|
||||||
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
|
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
|
||||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||||
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
|
|
||||||
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
|
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
|
||||||
"build": "yarn run lint && vite build",
|
"build": "yarn run lint && vite build",
|
||||||
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
|
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
|
||||||
@ -90,6 +89,7 @@
|
|||||||
"react-konva": "^18.2.7",
|
"react-konva": "^18.2.7",
|
||||||
"react-konva-utils": "^1.0.4",
|
"react-konva-utils": "^1.0.4",
|
||||||
"react-redux": "^8.0.5",
|
"react-redux": "^8.0.5",
|
||||||
|
"react-resizable-panels": "^0.0.42",
|
||||||
"react-rnd": "^10.4.1",
|
"react-rnd": "^10.4.1",
|
||||||
"react-transition-group": "^4.4.5",
|
"react-transition-group": "^4.4.5",
|
||||||
"react-use": "^17.4.0",
|
"react-use": "^17.4.0",
|
||||||
@ -99,6 +99,7 @@
|
|||||||
"redux-deep-persist": "^1.0.7",
|
"redux-deep-persist": "^1.0.7",
|
||||||
"redux-dynamic-middlewares": "^2.2.0",
|
"redux-dynamic-middlewares": "^2.2.0",
|
||||||
"redux-persist": "^6.0.0",
|
"redux-persist": "^6.0.0",
|
||||||
|
"redux-remember": "^3.3.1",
|
||||||
"roarr": "^7.15.0",
|
"roarr": "^7.15.0",
|
||||||
"serialize-error": "^11.0.0",
|
"serialize-error": "^11.0.0",
|
||||||
"socket.io-client": "^4.6.0",
|
"socket.io-client": "^4.6.0",
|
||||||
@ -118,6 +119,7 @@
|
|||||||
"@types/node": "^18.16.2",
|
"@types/node": "^18.16.2",
|
||||||
"@types/react": "^18.2.0",
|
"@types/react": "^18.2.0",
|
||||||
"@types/react-dom": "^18.2.1",
|
"@types/react-dom": "^18.2.1",
|
||||||
|
"@types/react-redux": "^7.1.25",
|
||||||
"@types/react-transition-group": "^4.4.5",
|
"@types/react-transition-group": "^4.4.5",
|
||||||
"@types/uuid": "^9.0.0",
|
"@types/uuid": "^9.0.0",
|
||||||
"@typescript-eslint/eslint-plugin": "^5.59.1",
|
"@typescript-eslint/eslint-plugin": "^5.59.1",
|
||||||
@ -143,6 +145,7 @@
|
|||||||
"terser": "^5.17.1",
|
"terser": "^5.17.1",
|
||||||
"ts-toolbelt": "^9.6.0",
|
"ts-toolbelt": "^9.6.0",
|
||||||
"vite": "^4.3.3",
|
"vite": "^4.3.3",
|
||||||
|
"vite-plugin-css-injected-by-js": "^3.1.1",
|
||||||
"vite-plugin-dts": "^2.3.0",
|
"vite-plugin-dts": "^2.3.0",
|
||||||
"vite-plugin-eslint": "^1.8.1",
|
"vite-plugin-eslint": "^1.8.1",
|
||||||
"vite-tsconfig-paths": "^4.2.0",
|
"vite-tsconfig-paths": "^4.2.0",
|
||||||
|
@ -25,7 +25,7 @@
|
|||||||
"common": {
|
"common": {
|
||||||
"hotkeysLabel": "Hotkeys",
|
"hotkeysLabel": "Hotkeys",
|
||||||
"themeLabel": "Theme",
|
"themeLabel": "Theme",
|
||||||
"languagePickerLabel": "Language Picker",
|
"languagePickerLabel": "Language",
|
||||||
"reportBugLabel": "Report Bug",
|
"reportBugLabel": "Report Bug",
|
||||||
"githubLabel": "Github",
|
"githubLabel": "Github",
|
||||||
"discordLabel": "Discord",
|
"discordLabel": "Discord",
|
||||||
@ -54,7 +54,7 @@
|
|||||||
"img2img": "Image To Image",
|
"img2img": "Image To Image",
|
||||||
"unifiedCanvas": "Unified Canvas",
|
"unifiedCanvas": "Unified Canvas",
|
||||||
"linear": "Linear",
|
"linear": "Linear",
|
||||||
"nodes": "Nodes",
|
"nodes": "Node Editor",
|
||||||
"postprocessing": "Post Processing",
|
"postprocessing": "Post Processing",
|
||||||
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
||||||
"postProcessing": "Post Processing",
|
"postProcessing": "Post Processing",
|
||||||
@ -102,7 +102,8 @@
|
|||||||
"generate": "Generate",
|
"generate": "Generate",
|
||||||
"openInNewTab": "Open in New Tab",
|
"openInNewTab": "Open in New Tab",
|
||||||
"dontAskMeAgain": "Don't ask me again",
|
"dontAskMeAgain": "Don't ask me again",
|
||||||
"areYouSure": "Are you sure?"
|
"areYouSure": "Are you sure?",
|
||||||
|
"imagePrompt": "Image Prompt"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"generations": "Generations",
|
"generations": "Generations",
|
||||||
@ -453,9 +454,10 @@
|
|||||||
"seed": "Seed",
|
"seed": "Seed",
|
||||||
"imageToImage": "Image to Image",
|
"imageToImage": "Image to Image",
|
||||||
"randomizeSeed": "Randomize Seed",
|
"randomizeSeed": "Randomize Seed",
|
||||||
"shuffle": "Shuffle",
|
"shuffle": "Shuffle Seed",
|
||||||
"noiseThreshold": "Noise Threshold",
|
"noiseThreshold": "Noise Threshold",
|
||||||
"perlinNoise": "Perlin Noise",
|
"perlinNoise": "Perlin Noise",
|
||||||
|
"noiseSettings": "Noise",
|
||||||
"variations": "Variations",
|
"variations": "Variations",
|
||||||
"variationAmount": "Variation Amount",
|
"variationAmount": "Variation Amount",
|
||||||
"seedWeights": "Seed Weights",
|
"seedWeights": "Seed Weights",
|
||||||
@ -470,6 +472,8 @@
|
|||||||
"scale": "Scale",
|
"scale": "Scale",
|
||||||
"otherOptions": "Other Options",
|
"otherOptions": "Other Options",
|
||||||
"seamlessTiling": "Seamless Tiling",
|
"seamlessTiling": "Seamless Tiling",
|
||||||
|
"seamlessXAxis": "X Axis",
|
||||||
|
"seamlessYAxis": "Y Axis",
|
||||||
"hiresOptim": "High Res Optimization",
|
"hiresOptim": "High Res Optimization",
|
||||||
"hiresStrength": "High Res Strength",
|
"hiresStrength": "High Res Strength",
|
||||||
"imageFit": "Fit Initial Image To Output Size",
|
"imageFit": "Fit Initial Image To Output Size",
|
||||||
@ -527,7 +531,8 @@
|
|||||||
"useCanvasBeta": "Use Canvas Beta Layout",
|
"useCanvasBeta": "Use Canvas Beta Layout",
|
||||||
"enableImageDebugging": "Enable Image Debugging",
|
"enableImageDebugging": "Enable Image Debugging",
|
||||||
"useSlidersForAll": "Use Sliders For All Options",
|
"useSlidersForAll": "Use Sliders For All Options",
|
||||||
"autoShowProgress": "Auto Show Progress Images",
|
"showProgressInViewer": "Show Progress Images in Viewer",
|
||||||
|
"antialiasProgressImages": "Antialias Progress Images",
|
||||||
"resetWebUI": "Reset Web UI",
|
"resetWebUI": "Reset Web UI",
|
||||||
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
|
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
|
||||||
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
|
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
|
||||||
@ -549,8 +554,9 @@
|
|||||||
"downloadImageStarted": "Image Download Started",
|
"downloadImageStarted": "Image Download Started",
|
||||||
"imageCopied": "Image Copied",
|
"imageCopied": "Image Copied",
|
||||||
"imageLinkCopied": "Image Link Copied",
|
"imageLinkCopied": "Image Link Copied",
|
||||||
|
"problemCopyingImageLink": "Unable to Copy Image Link",
|
||||||
"imageNotLoaded": "No Image Loaded",
|
"imageNotLoaded": "No Image Loaded",
|
||||||
"imageNotLoadedDesc": "No image found to send to image to image module",
|
"imageNotLoadedDesc": "Could not find image",
|
||||||
"imageSavedToGallery": "Image Saved to Gallery",
|
"imageSavedToGallery": "Image Saved to Gallery",
|
||||||
"canvasMerged": "Canvas Merged",
|
"canvasMerged": "Canvas Merged",
|
||||||
"sentToImageToImage": "Sent To Image To Image",
|
"sentToImageToImage": "Sent To Image To Image",
|
||||||
@ -645,7 +651,8 @@
|
|||||||
"betaClear": "Clear",
|
"betaClear": "Clear",
|
||||||
"betaDarkenOutside": "Darken Outside",
|
"betaDarkenOutside": "Darken Outside",
|
||||||
"betaLimitToBox": "Limit To Box",
|
"betaLimitToBox": "Limit To Box",
|
||||||
"betaPreserveMasked": "Preserve Masked"
|
"betaPreserveMasked": "Preserve Masked",
|
||||||
|
"antialiasing": "Antialiasing"
|
||||||
},
|
},
|
||||||
"ui": {
|
"ui": {
|
||||||
"showProgressImages": "Show Progress Images",
|
"showProgressImages": "Show Progress Images",
|
||||||
|
@ -1,24 +1,18 @@
|
|||||||
import ImageUploader from 'common/components/ImageUploader';
|
import ImageUploader from 'common/components/ImageUploader';
|
||||||
import ProgressBar from 'features/system/components/ProgressBar';
|
|
||||||
import SiteHeader from 'features/system/components/SiteHeader';
|
import SiteHeader from 'features/system/components/SiteHeader';
|
||||||
|
import ProgressBar from 'features/system/components/ProgressBar';
|
||||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||||
|
|
||||||
import useToastWatcher from 'features/system/hooks/useToastWatcher';
|
import useToastWatcher from 'features/system/hooks/useToastWatcher';
|
||||||
|
|
||||||
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
|
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
|
||||||
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
|
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
|
||||||
import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
|
import { Box, Flex, Grid, Portal } from '@chakra-ui/react';
|
||||||
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
||||||
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
|
import GalleryDrawer from 'features/gallery/components/ImageGalleryPanel';
|
||||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import {
|
import { memo, ReactNode, useCallback, useEffect, useState } from 'react';
|
||||||
memo,
|
|
||||||
PropsWithChildren,
|
|
||||||
useCallback,
|
|
||||||
useEffect,
|
|
||||||
useState,
|
|
||||||
} from 'react';
|
|
||||||
import { motion, AnimatePresence } from 'framer-motion';
|
import { motion, AnimatePresence } from 'framer-motion';
|
||||||
import Loading from 'common/components/Loading/Loading';
|
import Loading from 'common/components/Loading/Loading';
|
||||||
import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady';
|
import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady';
|
||||||
@ -27,20 +21,24 @@ import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
|
|||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
import { useLogger } from 'app/logging/useLogger';
|
import { useLogger } from 'app/logging/useLogger';
|
||||||
import ProgressImagePreview from 'features/parameters/components/ProgressImagePreview';
|
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
||||||
|
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||||
|
import i18n from 'i18n';
|
||||||
|
|
||||||
const DEFAULT_CONFIG = {};
|
const DEFAULT_CONFIG = {};
|
||||||
|
|
||||||
interface Props extends PropsWithChildren {
|
interface Props {
|
||||||
config?: PartialAppConfig;
|
config?: PartialAppConfig;
|
||||||
|
headerComponent?: ReactNode;
|
||||||
}
|
}
|
||||||
|
|
||||||
const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
|
const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
||||||
useToastWatcher();
|
useToastWatcher();
|
||||||
useGlobalHotkeys();
|
useGlobalHotkeys();
|
||||||
const log = useLogger();
|
|
||||||
|
|
||||||
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
|
const language = useAppSelector(languageSelector);
|
||||||
|
|
||||||
|
const log = useLogger();
|
||||||
|
|
||||||
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
||||||
|
|
||||||
@ -48,18 +46,17 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
|
|||||||
|
|
||||||
const [loadingOverridden, setLoadingOverridden] = useState(false);
|
const [loadingOverridden, setLoadingOverridden] = useState(false);
|
||||||
|
|
||||||
const { setColorMode } = useColorMode();
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
i18n.changeLanguage(language);
|
||||||
|
}, [language]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
log.info({ namespace: 'App', data: config }, 'Received config');
|
log.info({ namespace: 'App', data: config }, 'Received config');
|
||||||
dispatch(configChanged(config));
|
dispatch(configChanged(config));
|
||||||
}, [dispatch, config, log]);
|
}, [dispatch, config, log]);
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');
|
|
||||||
}, [setColorMode, currentTheme]);
|
|
||||||
|
|
||||||
const handleOverrideClicked = useCallback(() => {
|
const handleOverrideClicked = useCallback(() => {
|
||||||
setLoadingOverridden(true);
|
setLoadingOverridden(true);
|
||||||
}, []);
|
}, []);
|
||||||
@ -76,7 +73,7 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
|
|||||||
w={APP_WIDTH}
|
w={APP_WIDTH}
|
||||||
h={APP_HEIGHT}
|
h={APP_HEIGHT}
|
||||||
>
|
>
|
||||||
{children || <SiteHeader />}
|
{headerComponent || <SiteHeader />}
|
||||||
<Flex
|
<Flex
|
||||||
gap={4}
|
gap={4}
|
||||||
w={{ base: '100vw', xl: 'full' }}
|
w={{ base: '100vw', xl: 'full' }}
|
||||||
@ -84,11 +81,13 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
|
|||||||
flexDir={{ base: 'column', xl: 'row' }}
|
flexDir={{ base: 'column', xl: 'row' }}
|
||||||
>
|
>
|
||||||
<InvokeTabs />
|
<InvokeTabs />
|
||||||
<ImageGalleryPanel />
|
|
||||||
</Flex>
|
</Flex>
|
||||||
</Grid>
|
</Grid>
|
||||||
</ImageUploader>
|
</ImageUploader>
|
||||||
|
|
||||||
|
<GalleryDrawer />
|
||||||
|
<ParametersDrawer />
|
||||||
|
|
||||||
<AnimatePresence>
|
<AnimatePresence>
|
||||||
{!isApplicationReady && !loadingOverridden && (
|
{!isApplicationReady && !loadingOverridden && (
|
||||||
<motion.div
|
<motion.div
|
||||||
@ -121,7 +120,6 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
|
|||||||
<Portal>
|
<Portal>
|
||||||
<FloatingGalleryButton />
|
<FloatingGalleryButton />
|
||||||
</Portal>
|
</Portal>
|
||||||
<ProgressImagePreview />
|
|
||||||
</Grid>
|
</Grid>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -1,18 +1,13 @@
|
|||||||
import React, { lazy, memo, PropsWithChildren, useEffect } from 'react';
|
import React, {
|
||||||
|
lazy,
|
||||||
|
memo,
|
||||||
|
PropsWithChildren,
|
||||||
|
ReactNode,
|
||||||
|
useEffect,
|
||||||
|
} from 'react';
|
||||||
import { Provider } from 'react-redux';
|
import { Provider } from 'react-redux';
|
||||||
import { PersistGate } from 'redux-persist/integration/react';
|
|
||||||
import { store } from 'app/store/store';
|
import { store } from 'app/store/store';
|
||||||
import { persistor } from '../store/persistor';
|
|
||||||
import { OpenAPI } from 'services/api';
|
import { OpenAPI } from 'services/api';
|
||||||
import '@fontsource/inter/100.css';
|
|
||||||
import '@fontsource/inter/200.css';
|
|
||||||
import '@fontsource/inter/300.css';
|
|
||||||
import '@fontsource/inter/400.css';
|
|
||||||
import '@fontsource/inter/500.css';
|
|
||||||
import '@fontsource/inter/600.css';
|
|
||||||
import '@fontsource/inter/700.css';
|
|
||||||
import '@fontsource/inter/800.css';
|
|
||||||
import '@fontsource/inter/900.css';
|
|
||||||
|
|
||||||
import Loading from '../../common/components/Loading/Loading';
|
import Loading from '../../common/components/Loading/Loading';
|
||||||
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
||||||
@ -28,9 +23,10 @@ interface Props extends PropsWithChildren {
|
|||||||
apiUrl?: string;
|
apiUrl?: string;
|
||||||
token?: string;
|
token?: string;
|
||||||
config?: PartialAppConfig;
|
config?: PartialAppConfig;
|
||||||
|
headerComponent?: ReactNode;
|
||||||
}
|
}
|
||||||
|
|
||||||
const InvokeAIUI = ({ apiUrl, token, config, children }: Props) => {
|
const InvokeAIUI = ({ apiUrl, token, config, headerComponent }: Props) => {
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// configure API client token
|
// configure API client token
|
||||||
if (token) {
|
if (token) {
|
||||||
@ -57,13 +53,11 @@ const InvokeAIUI = ({ apiUrl, token, config, children }: Props) => {
|
|||||||
return (
|
return (
|
||||||
<React.StrictMode>
|
<React.StrictMode>
|
||||||
<Provider store={store}>
|
<Provider store={store}>
|
||||||
<PersistGate loading={<Loading />} persistor={persistor}>
|
|
||||||
<React.Suspense fallback={<Loading />}>
|
<React.Suspense fallback={<Loading />}>
|
||||||
<ThemeLocaleProvider>
|
<ThemeLocaleProvider>
|
||||||
<App config={config}>{children}</App>
|
<App config={config} headerComponent={headerComponent} />
|
||||||
</ThemeLocaleProvider>
|
</ThemeLocaleProvider>
|
||||||
</React.Suspense>
|
</React.Suspense>
|
||||||
</PersistGate>
|
|
||||||
</Provider>
|
</Provider>
|
||||||
</React.StrictMode>
|
</React.StrictMode>
|
||||||
);
|
);
|
||||||
|
@ -1,4 +1,8 @@
|
|||||||
import { ChakraProvider, extendTheme } from '@chakra-ui/react';
|
import {
|
||||||
|
ChakraProvider,
|
||||||
|
createLocalStorageManager,
|
||||||
|
extendTheme,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
import { ReactNode, useEffect } from 'react';
|
import { ReactNode, useEffect } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { theme as invokeAITheme } from 'theme/theme';
|
import { theme as invokeAITheme } from 'theme/theme';
|
||||||
@ -9,15 +13,8 @@ import { greenTeaThemeColors } from 'theme/colors/greenTea';
|
|||||||
import { invokeAIThemeColors } from 'theme/colors/invokeAI';
|
import { invokeAIThemeColors } from 'theme/colors/invokeAI';
|
||||||
import { lightThemeColors } from 'theme/colors/lightTheme';
|
import { lightThemeColors } from 'theme/colors/lightTheme';
|
||||||
import { oceanBlueColors } from 'theme/colors/oceanBlue';
|
import { oceanBlueColors } from 'theme/colors/oceanBlue';
|
||||||
import '@fontsource/inter/100.css';
|
|
||||||
import '@fontsource/inter/200.css';
|
import '@fontsource/inter/variable.css';
|
||||||
import '@fontsource/inter/300.css';
|
|
||||||
import '@fontsource/inter/400.css';
|
|
||||||
import '@fontsource/inter/500.css';
|
|
||||||
import '@fontsource/inter/600.css';
|
|
||||||
import '@fontsource/inter/700.css';
|
|
||||||
import '@fontsource/inter/800.css';
|
|
||||||
import '@fontsource/inter/900.css';
|
|
||||||
import 'overlayscrollbars/overlayscrollbars.css';
|
import 'overlayscrollbars/overlayscrollbars.css';
|
||||||
import 'theme/css/overlayscrollbars.css';
|
import 'theme/css/overlayscrollbars.css';
|
||||||
|
|
||||||
@ -32,6 +29,8 @@ const THEMES = {
|
|||||||
ocean: oceanBlueColors,
|
ocean: oceanBlueColors,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const manager = createLocalStorageManager('@@invokeai-color-mode');
|
||||||
|
|
||||||
function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
||||||
const { i18n } = useTranslation();
|
const { i18n } = useTranslation();
|
||||||
|
|
||||||
@ -51,7 +50,11 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
|||||||
document.body.dir = direction;
|
document.body.dir = direction;
|
||||||
}, [direction]);
|
}, [direction]);
|
||||||
|
|
||||||
return <ChakraProvider theme={theme}>{children}</ChakraProvider>;
|
return (
|
||||||
|
<ChakraProvider theme={theme} colorModeManager={manager}>
|
||||||
|
{children}
|
||||||
|
</ChakraProvider>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
export default ThemeLocaleProvider;
|
export default ThemeLocaleProvider;
|
||||||
|
@ -2,17 +2,28 @@
|
|||||||
|
|
||||||
export const DIFFUSERS_SCHEDULERS: Array<string> = [
|
export const DIFFUSERS_SCHEDULERS: Array<string> = [
|
||||||
'ddim',
|
'ddim',
|
||||||
'plms',
|
'ddpm',
|
||||||
'k_lms',
|
'deis',
|
||||||
'dpmpp_2',
|
'lms',
|
||||||
'k_dpm_2',
|
'pndm',
|
||||||
'k_dpm_2_a',
|
'heun',
|
||||||
'k_dpmpp_2',
|
'euler',
|
||||||
'k_euler',
|
'euler_k',
|
||||||
'k_euler_a',
|
'euler_a',
|
||||||
'k_heun',
|
'kdpm_2',
|
||||||
|
'kdpm_2_a',
|
||||||
|
'dpmpp_2s',
|
||||||
|
'dpmpp_2m',
|
||||||
|
'dpmpp_2m_k',
|
||||||
|
'unipc',
|
||||||
];
|
];
|
||||||
|
|
||||||
|
export const IMG2IMG_DIFFUSERS_SCHEDULERS = DIFFUSERS_SCHEDULERS.filter(
|
||||||
|
(scheduler) => {
|
||||||
|
return scheduler !== 'dpmpp_2s';
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
// Valid image widths
|
// Valid image widths
|
||||||
export const WIDTHS: Array<number> = Array.from(Array(64)).map(
|
export const WIDTHS: Array<number> = Array.from(Array(64)).map(
|
||||||
(_x, i) => (i + 1) * 64
|
(_x, i) => (i + 1) * 64
|
||||||
|
@ -1,26 +1,20 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
import { validateSeedWeights } from 'common/util/seedWeightPairs';
|
||||||
import { initialCanvasImageSelector } from 'features/canvas/store/canvasSelectors';
|
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
|
||||||
export const readinessSelector = createSelector(
|
export const readinessSelector = createSelector(
|
||||||
[
|
[generationSelector, systemSelector, activeTabNameSelector],
|
||||||
generationSelector,
|
(generation, system, activeTabName) => {
|
||||||
systemSelector,
|
|
||||||
initialCanvasImageSelector,
|
|
||||||
activeTabNameSelector,
|
|
||||||
],
|
|
||||||
(generation, system, initialCanvasImage, activeTabName) => {
|
|
||||||
const {
|
const {
|
||||||
prompt,
|
prompt,
|
||||||
shouldGenerateVariations,
|
shouldGenerateVariations,
|
||||||
seedWeights,
|
seedWeights,
|
||||||
initialImage,
|
initialImage,
|
||||||
seed,
|
seed,
|
||||||
isImageToImageEnabled,
|
|
||||||
} = generation;
|
} = generation;
|
||||||
|
|
||||||
const { isProcessing, isConnected } = system;
|
const { isProcessing, isConnected } = system;
|
||||||
@ -34,7 +28,7 @@ export const readinessSelector = createSelector(
|
|||||||
reasonsWhyNotReady.push('Missing prompt');
|
reasonsWhyNotReady.push('Missing prompt');
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isImageToImageEnabled && !initialImage) {
|
if (activeTabName === 'img2img' && !initialImage) {
|
||||||
isReady = false;
|
isReady = false;
|
||||||
reasonsWhyNotReady.push('No initial image selected');
|
reasonsWhyNotReady.push('No initial image selected');
|
||||||
}
|
}
|
||||||
@ -64,10 +58,5 @@ export const readinessSelector = createSelector(
|
|||||||
// All good
|
// All good
|
||||||
return { isReady, reasonsWhyNotReady };
|
return { isReady, reasonsWhyNotReady };
|
||||||
},
|
},
|
||||||
{
|
defaultSelectorOptions
|
||||||
memoizeOptions: {
|
|
||||||
equalityCheck: isEqual,
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
);
|
||||||
|
@ -1,209 +1,209 @@
|
|||||||
// import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
|
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
|
||||||
// import * as InvokeAI from 'app/types/invokeai';
|
import * as InvokeAI from 'app/types/invokeai';
|
||||||
// import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
// import {
|
import {
|
||||||
// frontendToBackendParameters,
|
frontendToBackendParameters,
|
||||||
// FrontendToBackendParametersConfig,
|
FrontendToBackendParametersConfig,
|
||||||
// } from 'common/util/parameterTranslation';
|
} from 'common/util/parameterTranslation';
|
||||||
// import dateFormat from 'dateformat';
|
import dateFormat from 'dateformat';
|
||||||
// import {
|
import {
|
||||||
// GalleryCategory,
|
GalleryCategory,
|
||||||
// GalleryState,
|
GalleryState,
|
||||||
// removeImage,
|
removeImage,
|
||||||
// } from 'features/gallery/store/gallerySlice';
|
} from 'features/gallery/store/gallerySlice';
|
||||||
// import {
|
import {
|
||||||
// generationRequested,
|
generationRequested,
|
||||||
// modelChangeRequested,
|
modelChangeRequested,
|
||||||
// modelConvertRequested,
|
modelConvertRequested,
|
||||||
// modelMergingRequested,
|
modelMergingRequested,
|
||||||
// setIsProcessing,
|
setIsProcessing,
|
||||||
// } from 'features/system/store/systemSlice';
|
} from 'features/system/store/systemSlice';
|
||||||
// import { InvokeTabName } from 'features/ui/store/tabMap';
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
// import { Socket } from 'socket.io-client';
|
import { Socket } from 'socket.io-client';
|
||||||
|
|
||||||
// /**
|
/**
|
||||||
// * Returns an object containing all functions which use `socketio.emit()`.
|
* Returns an object containing all functions which use `socketio.emit()`.
|
||||||
// * i.e. those which make server requests.
|
* i.e. those which make server requests.
|
||||||
// */
|
*/
|
||||||
// const makeSocketIOEmitters = (
|
const makeSocketIOEmitters = (
|
||||||
// store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
|
store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
|
||||||
// socketio: Socket
|
socketio: Socket
|
||||||
// ) => {
|
) => {
|
||||||
// // We need to dispatch actions to redux and get pieces of state from the store.
|
// We need to dispatch actions to redux and get pieces of state from the store.
|
||||||
// const { dispatch, getState } = store;
|
const { dispatch, getState } = store;
|
||||||
|
|
||||||
// return {
|
return {
|
||||||
// emitGenerateImage: (generationMode: InvokeTabName) => {
|
emitGenerateImage: (generationMode: InvokeTabName) => {
|
||||||
// dispatch(setIsProcessing(true));
|
dispatch(setIsProcessing(true));
|
||||||
|
|
||||||
// const state: RootState = getState();
|
const state: RootState = getState();
|
||||||
|
|
||||||
// const {
|
const {
|
||||||
// generation: generationState,
|
generation: generationState,
|
||||||
// postprocessing: postprocessingState,
|
postprocessing: postprocessingState,
|
||||||
// system: systemState,
|
system: systemState,
|
||||||
// canvas: canvasState,
|
canvas: canvasState,
|
||||||
// } = state;
|
} = state;
|
||||||
|
|
||||||
// const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
|
const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
|
||||||
// {
|
{
|
||||||
// generationMode,
|
generationMode,
|
||||||
// generationState,
|
generationState,
|
||||||
// postprocessingState,
|
postprocessingState,
|
||||||
// canvasState,
|
canvasState,
|
||||||
// systemState,
|
systemState,
|
||||||
// };
|
};
|
||||||
|
|
||||||
// dispatch(generationRequested());
|
dispatch(generationRequested());
|
||||||
|
|
||||||
// const { generationParameters, esrganParameters, facetoolParameters } =
|
const { generationParameters, esrganParameters, facetoolParameters } =
|
||||||
// frontendToBackendParameters(frontendToBackendParametersConfig);
|
frontendToBackendParameters(frontendToBackendParametersConfig);
|
||||||
|
|
||||||
// socketio.emit(
|
socketio.emit(
|
||||||
// 'generateImage',
|
'generateImage',
|
||||||
// generationParameters,
|
generationParameters,
|
||||||
// esrganParameters,
|
esrganParameters,
|
||||||
// facetoolParameters
|
facetoolParameters
|
||||||
// );
|
);
|
||||||
|
|
||||||
// // we need to truncate the init_mask base64 else it takes up the whole log
|
// we need to truncate the init_mask base64 else it takes up the whole log
|
||||||
// // TODO: handle maintaining masks for reproducibility in future
|
// TODO: handle maintaining masks for reproducibility in future
|
||||||
// if (generationParameters.init_mask) {
|
if (generationParameters.init_mask) {
|
||||||
// generationParameters.init_mask = generationParameters.init_mask
|
generationParameters.init_mask = generationParameters.init_mask
|
||||||
// .substr(0, 64)
|
.substr(0, 64)
|
||||||
// .concat('...');
|
.concat('...');
|
||||||
// }
|
}
|
||||||
// if (generationParameters.init_img) {
|
if (generationParameters.init_img) {
|
||||||
// generationParameters.init_img = generationParameters.init_img
|
generationParameters.init_img = generationParameters.init_img
|
||||||
// .substr(0, 64)
|
.substr(0, 64)
|
||||||
// .concat('...');
|
.concat('...');
|
||||||
// }
|
}
|
||||||
|
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Image generation requested: ${JSON.stringify({
|
message: `Image generation requested: ${JSON.stringify({
|
||||||
// ...generationParameters,
|
...generationParameters,
|
||||||
// ...esrganParameters,
|
...esrganParameters,
|
||||||
// ...facetoolParameters,
|
...facetoolParameters,
|
||||||
// })}`,
|
})}`,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// },
|
},
|
||||||
// emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
|
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
|
||||||
// dispatch(setIsProcessing(true));
|
dispatch(setIsProcessing(true));
|
||||||
|
|
||||||
// const {
|
const {
|
||||||
// postprocessing: {
|
postprocessing: {
|
||||||
// upscalingLevel,
|
upscalingLevel,
|
||||||
// upscalingDenoising,
|
upscalingDenoising,
|
||||||
// upscalingStrength,
|
upscalingStrength,
|
||||||
// },
|
},
|
||||||
// } = getState();
|
} = getState();
|
||||||
|
|
||||||
// const esrganParameters = {
|
const esrganParameters = {
|
||||||
// upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
|
upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
|
||||||
// };
|
};
|
||||||
// socketio.emit('runPostprocessing', imageToProcess, {
|
socketio.emit('runPostprocessing', imageToProcess, {
|
||||||
// type: 'esrgan',
|
type: 'esrgan',
|
||||||
// ...esrganParameters,
|
...esrganParameters,
|
||||||
// });
|
});
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `ESRGAN upscale requested: ${JSON.stringify({
|
message: `ESRGAN upscale requested: ${JSON.stringify({
|
||||||
// file: imageToProcess.url,
|
file: imageToProcess.url,
|
||||||
// ...esrganParameters,
|
...esrganParameters,
|
||||||
// })}`,
|
})}`,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// },
|
},
|
||||||
// emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
|
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
|
||||||
// dispatch(setIsProcessing(true));
|
dispatch(setIsProcessing(true));
|
||||||
|
|
||||||
// const {
|
const {
|
||||||
// postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
|
postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
|
||||||
// } = getState();
|
} = getState();
|
||||||
|
|
||||||
// const facetoolParameters: Record<string, unknown> = {
|
const facetoolParameters: Record<string, unknown> = {
|
||||||
// facetool_strength: facetoolStrength,
|
facetool_strength: facetoolStrength,
|
||||||
// };
|
};
|
||||||
|
|
||||||
// if (facetoolType === 'codeformer') {
|
if (facetoolType === 'codeformer') {
|
||||||
// facetoolParameters.codeformer_fidelity = codeformerFidelity;
|
facetoolParameters.codeformer_fidelity = codeformerFidelity;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// socketio.emit('runPostprocessing', imageToProcess, {
|
socketio.emit('runPostprocessing', imageToProcess, {
|
||||||
// type: facetoolType,
|
type: facetoolType,
|
||||||
// ...facetoolParameters,
|
...facetoolParameters,
|
||||||
// });
|
});
|
||||||
// dispatch(
|
dispatch(
|
||||||
// addLogEntry({
|
addLogEntry({
|
||||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||||
// message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
|
message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
|
||||||
// {
|
{
|
||||||
// file: imageToProcess.url,
|
file: imageToProcess.url,
|
||||||
// ...facetoolParameters,
|
...facetoolParameters,
|
||||||
// }
|
}
|
||||||
// )}`,
|
)}`,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// },
|
},
|
||||||
// emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
|
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
|
||||||
// const { url, uuid, category, thumbnail } = imageToDelete;
|
const { url, uuid, category, thumbnail } = imageToDelete;
|
||||||
// dispatch(removeImage(imageToDelete));
|
dispatch(removeImage(imageToDelete));
|
||||||
// socketio.emit('deleteImage', url, thumbnail, uuid, category);
|
socketio.emit('deleteImage', url, thumbnail, uuid, category);
|
||||||
// },
|
},
|
||||||
// emitRequestImages: (category: GalleryCategory) => {
|
emitRequestImages: (category: GalleryCategory) => {
|
||||||
// const gallery: GalleryState = getState().gallery;
|
const gallery: GalleryState = getState().gallery;
|
||||||
// const { earliest_mtime } = gallery.categories[category];
|
const { earliest_mtime } = gallery.categories[category];
|
||||||
// socketio.emit('requestImages', category, earliest_mtime);
|
socketio.emit('requestImages', category, earliest_mtime);
|
||||||
// },
|
},
|
||||||
// emitRequestNewImages: (category: GalleryCategory) => {
|
emitRequestNewImages: (category: GalleryCategory) => {
|
||||||
// const gallery: GalleryState = getState().gallery;
|
const gallery: GalleryState = getState().gallery;
|
||||||
// const { latest_mtime } = gallery.categories[category];
|
const { latest_mtime } = gallery.categories[category];
|
||||||
// socketio.emit('requestLatestImages', category, latest_mtime);
|
socketio.emit('requestLatestImages', category, latest_mtime);
|
||||||
// },
|
},
|
||||||
// emitCancelProcessing: () => {
|
emitCancelProcessing: () => {
|
||||||
// socketio.emit('cancel');
|
socketio.emit('cancel');
|
||||||
// },
|
},
|
||||||
// emitRequestSystemConfig: () => {
|
emitRequestSystemConfig: () => {
|
||||||
// socketio.emit('requestSystemConfig');
|
socketio.emit('requestSystemConfig');
|
||||||
// },
|
},
|
||||||
// emitSearchForModels: (modelFolder: string) => {
|
emitSearchForModels: (modelFolder: string) => {
|
||||||
// socketio.emit('searchForModels', modelFolder);
|
socketio.emit('searchForModels', modelFolder);
|
||||||
// },
|
},
|
||||||
// emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
|
emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
|
||||||
// socketio.emit('addNewModel', modelConfig);
|
socketio.emit('addNewModel', modelConfig);
|
||||||
// },
|
},
|
||||||
// emitDeleteModel: (modelName: string) => {
|
emitDeleteModel: (modelName: string) => {
|
||||||
// socketio.emit('deleteModel', modelName);
|
socketio.emit('deleteModel', modelName);
|
||||||
// },
|
},
|
||||||
// emitConvertToDiffusers: (
|
emitConvertToDiffusers: (
|
||||||
// modelToConvert: InvokeAI.InvokeModelConversionProps
|
modelToConvert: InvokeAI.InvokeModelConversionProps
|
||||||
// ) => {
|
) => {
|
||||||
// dispatch(modelConvertRequested());
|
dispatch(modelConvertRequested());
|
||||||
// socketio.emit('convertToDiffusers', modelToConvert);
|
socketio.emit('convertToDiffusers', modelToConvert);
|
||||||
// },
|
},
|
||||||
// emitMergeDiffusersModels: (
|
emitMergeDiffusersModels: (
|
||||||
// modelMergeInfo: InvokeAI.InvokeModelMergingProps
|
modelMergeInfo: InvokeAI.InvokeModelMergingProps
|
||||||
// ) => {
|
) => {
|
||||||
// dispatch(modelMergingRequested());
|
dispatch(modelMergingRequested());
|
||||||
// socketio.emit('mergeDiffusersModels', modelMergeInfo);
|
socketio.emit('mergeDiffusersModels', modelMergeInfo);
|
||||||
// },
|
},
|
||||||
// emitRequestModelChange: (modelName: string) => {
|
emitRequestModelChange: (modelName: string) => {
|
||||||
// dispatch(modelChangeRequested());
|
dispatch(modelChangeRequested());
|
||||||
// socketio.emit('requestModelChange', modelName);
|
socketio.emit('requestModelChange', modelName);
|
||||||
// },
|
},
|
||||||
// emitSaveStagingAreaImageToGallery: (url: string) => {
|
emitSaveStagingAreaImageToGallery: (url: string) => {
|
||||||
// socketio.emit('requestSaveStagingAreaImageToGallery', url);
|
socketio.emit('requestSaveStagingAreaImageToGallery', url);
|
||||||
// },
|
},
|
||||||
// emitRequestEmptyTempFolder: () => {
|
emitRequestEmptyTempFolder: () => {
|
||||||
// socketio.emit('requestEmptyTempFolder');
|
socketio.emit('requestEmptyTempFolder');
|
||||||
// },
|
},
|
||||||
// };
|
};
|
||||||
// };
|
};
|
||||||
|
|
||||||
// export default makeSocketIOEmitters;
|
export default makeSocketIOEmitters;
|
||||||
|
|
||||||
export default {};
|
export default {};
|
||||||
|
4
invokeai/frontend/web/src/app/store/actions.ts
Normal file
4
invokeai/frontend/web/src/app/store/actions.ts
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
|
|
||||||
|
export const userInvoked = createAction<InvokeTabName>('app/userInvoked');
|
8
invokeai/frontend/web/src/app/store/constants.ts
Normal file
8
invokeai/frontend/web/src/app/store/constants.ts
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
export const LOCALSTORAGE_KEYS = [
|
||||||
|
'chakra-ui-color-mode',
|
||||||
|
'i18nextLng',
|
||||||
|
'ROARR_FILTER',
|
||||||
|
'ROARR_LOG',
|
||||||
|
];
|
||||||
|
|
||||||
|
export const LOCALSTORAGE_PREFIX = '@@invokeai-';
|
@ -0,0 +1,36 @@
|
|||||||
|
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
|
||||||
|
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
|
||||||
|
import { resultsPersistDenylist } from 'features/gallery/store/resultsPersistDenylist';
|
||||||
|
import { uploadsPersistDenylist } from 'features/gallery/store/uploadsPersistDenylist';
|
||||||
|
import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
|
||||||
|
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
||||||
|
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
||||||
|
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
|
||||||
|
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
|
||||||
|
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
|
||||||
|
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
|
||||||
|
import { omit } from 'lodash-es';
|
||||||
|
import { SerializeFunction } from 'redux-remember';
|
||||||
|
|
||||||
|
const serializationDenylist: {
|
||||||
|
[key: string]: string[];
|
||||||
|
} = {
|
||||||
|
canvas: canvasPersistDenylist,
|
||||||
|
gallery: galleryPersistDenylist,
|
||||||
|
generation: generationPersistDenylist,
|
||||||
|
lightbox: lightboxPersistDenylist,
|
||||||
|
models: modelsPersistDenylist,
|
||||||
|
nodes: nodesPersistDenylist,
|
||||||
|
postprocessing: postprocessingPersistDenylist,
|
||||||
|
results: resultsPersistDenylist,
|
||||||
|
system: systemPersistDenylist,
|
||||||
|
// config: configPersistDenyList,
|
||||||
|
ui: uiPersistDenylist,
|
||||||
|
uploads: uploadsPersistDenylist,
|
||||||
|
// hotkeys: hotkeysPersistDenylist,
|
||||||
|
};
|
||||||
|
|
||||||
|
export const serialize: SerializeFunction = (data, key) => {
|
||||||
|
const result = omit(data, serializationDenylist[key]);
|
||||||
|
return JSON.stringify(result);
|
||||||
|
};
|
@ -0,0 +1,38 @@
|
|||||||
|
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
|
||||||
|
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
|
||||||
|
import { initialResultsState } from 'features/gallery/store/resultsSlice';
|
||||||
|
import { initialUploadsState } from 'features/gallery/store/uploadsSlice';
|
||||||
|
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
|
||||||
|
import { initialNodesState } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||||
|
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
|
||||||
|
import { initialConfigState } from 'features/system/store/configSlice';
|
||||||
|
import { initialModelsState } from 'features/system/store/modelSlice';
|
||||||
|
import { initialSystemState } from 'features/system/store/systemSlice';
|
||||||
|
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
|
||||||
|
import { initialUIState } from 'features/ui/store/uiSlice';
|
||||||
|
import { defaultsDeep } from 'lodash-es';
|
||||||
|
import { UnserializeFunction } from 'redux-remember';
|
||||||
|
|
||||||
|
const initialStates: {
|
||||||
|
[key: string]: any;
|
||||||
|
} = {
|
||||||
|
canvas: initialCanvasState,
|
||||||
|
gallery: initialGalleryState,
|
||||||
|
generation: initialGenerationState,
|
||||||
|
lightbox: initialLightboxState,
|
||||||
|
models: initialModelsState,
|
||||||
|
nodes: initialNodesState,
|
||||||
|
postprocessing: initialPostprocessingState,
|
||||||
|
results: initialResultsState,
|
||||||
|
system: initialSystemState,
|
||||||
|
config: initialConfigState,
|
||||||
|
ui: initialUIState,
|
||||||
|
uploads: initialUploadsState,
|
||||||
|
hotkeys: initialHotkeysState,
|
||||||
|
};
|
||||||
|
|
||||||
|
export const unserialize: UnserializeFunction = (data, key) => {
|
||||||
|
const result = defaultsDeep(JSON.parse(data), initialStates[key]);
|
||||||
|
return result;
|
||||||
|
};
|
@ -0,0 +1,30 @@
|
|||||||
|
import { AnyAction } from '@reduxjs/toolkit';
|
||||||
|
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import { Graph } from 'services/api';
|
||||||
|
|
||||||
|
export const actionSanitizer = <A extends AnyAction>(action: A): A => {
|
||||||
|
if (isAnyGraphBuilt(action)) {
|
||||||
|
if (action.payload.nodes) {
|
||||||
|
const sanitizedNodes: Graph['nodes'] = {};
|
||||||
|
|
||||||
|
// Sanitize nodes as needed
|
||||||
|
forEach(action.payload.nodes, (node, key) => {
|
||||||
|
// Don't log the whole freaking dataURL
|
||||||
|
if (node.type === 'dataURL_image') {
|
||||||
|
const { dataURL, ...rest } = node;
|
||||||
|
sanitizedNodes[key] = { ...rest, dataURL: '<dataURL>' };
|
||||||
|
} else {
|
||||||
|
sanitizedNodes[key] = { ...node };
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
...action,
|
||||||
|
payload: { ...action.payload, nodes: sanitizedNodes },
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return action;
|
||||||
|
};
|
@ -0,0 +1,11 @@
|
|||||||
|
export const actionsDenylist = [
|
||||||
|
'canvas/setCursorPosition',
|
||||||
|
'canvas/setStageCoordinates',
|
||||||
|
'canvas/setStageScale',
|
||||||
|
'canvas/setIsDrawing',
|
||||||
|
'canvas/setBoundingBoxCoordinates',
|
||||||
|
'canvas/setBoundingBoxDimensions',
|
||||||
|
'canvas/setIsDrawing',
|
||||||
|
'canvas/addPointToCurrentLine',
|
||||||
|
'socket/generatorProgress',
|
||||||
|
];
|
@ -0,0 +1,3 @@
|
|||||||
|
export const stateSanitizer = <S>(state: S): S => {
|
||||||
|
return state;
|
||||||
|
};
|
@ -0,0 +1,45 @@
|
|||||||
|
import {
|
||||||
|
createListenerMiddleware,
|
||||||
|
addListener,
|
||||||
|
ListenerEffect,
|
||||||
|
AnyAction,
|
||||||
|
} from '@reduxjs/toolkit';
|
||||||
|
import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit';
|
||||||
|
|
||||||
|
import type { RootState, AppDispatch } from '../../store';
|
||||||
|
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
|
||||||
|
import { addImageResultReceivedListener } from './listeners/invocationComplete';
|
||||||
|
import { addImageUploadedListener } from './listeners/imageUploaded';
|
||||||
|
import { addRequestedImageDeletionListener } from './listeners/imageDeleted';
|
||||||
|
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
||||||
|
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||||
|
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||||
|
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||||
|
|
||||||
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
|
|
||||||
|
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||||
|
|
||||||
|
export const startAppListening =
|
||||||
|
listenerMiddleware.startListening as AppStartListening;
|
||||||
|
|
||||||
|
export const addAppListener = addListener as TypedAddListener<
|
||||||
|
RootState,
|
||||||
|
AppDispatch
|
||||||
|
>;
|
||||||
|
|
||||||
|
export type AppListenerEffect = ListenerEffect<
|
||||||
|
AnyAction,
|
||||||
|
RootState,
|
||||||
|
AppDispatch
|
||||||
|
>;
|
||||||
|
|
||||||
|
addImageUploadedListener();
|
||||||
|
addInitialImageSelectedListener();
|
||||||
|
addImageResultReceivedListener();
|
||||||
|
addRequestedImageDeletionListener();
|
||||||
|
|
||||||
|
addUserInvokedCanvasListener();
|
||||||
|
addUserInvokedNodesListener();
|
||||||
|
addUserInvokedTextToImageListener();
|
||||||
|
addUserInvokedImageToImageListener();
|
@ -0,0 +1,31 @@
|
|||||||
|
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import {
|
||||||
|
canvasSessionIdChanged,
|
||||||
|
stagingAreaInitialized,
|
||||||
|
} from 'features/canvas/store/canvasSlice';
|
||||||
|
import { sessionInvoked } from 'services/thunks/session';
|
||||||
|
|
||||||
|
export const addCanvasGraphBuiltListener = () =>
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: canvasGraphBuilt,
|
||||||
|
effect: async (action, { dispatch, getState, take }) => {
|
||||||
|
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
|
||||||
|
const { sessionId } = meta.arg;
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
||||||
|
dispatch(
|
||||||
|
stagingAreaInitialized({
|
||||||
|
sessionId,
|
||||||
|
boundingBox: {
|
||||||
|
...state.canvas.boundingBoxCoordinates,
|
||||||
|
...state.canvas.boundingBoxDimensions,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(canvasSessionIdChanged(sessionId));
|
||||||
|
},
|
||||||
|
});
|
@ -0,0 +1,59 @@
|
|||||||
|
import { requestedImageDeletion } from 'features/gallery/store/actions';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { imageDeleted } from 'services/thunks/image';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { clamp } from 'lodash-es';
|
||||||
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
||||||
|
|
||||||
|
export const addRequestedImageDeletionListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: requestedImageDeletion,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
const image = action.payload;
|
||||||
|
if (!image) {
|
||||||
|
moduleLog.warn('No image provided');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { name, type } = image;
|
||||||
|
|
||||||
|
if (type !== 'uploads' && type !== 'results') {
|
||||||
|
moduleLog.warn({ data: image }, `Invalid image type ${type}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const selectedImageName = getState().gallery.selectedImage?.name;
|
||||||
|
|
||||||
|
if (selectedImageName === name) {
|
||||||
|
const allIds = getState()[type].ids;
|
||||||
|
const allEntities = getState()[type].entities;
|
||||||
|
|
||||||
|
const deletedImageIndex = allIds.findIndex(
|
||||||
|
(result) => result.toString() === name
|
||||||
|
);
|
||||||
|
|
||||||
|
const filteredIds = allIds.filter((id) => id.toString() !== name);
|
||||||
|
|
||||||
|
const newSelectedImageIndex = clamp(
|
||||||
|
deletedImageIndex,
|
||||||
|
0,
|
||||||
|
filteredIds.length - 1
|
||||||
|
);
|
||||||
|
|
||||||
|
const newSelectedImageId = filteredIds[newSelectedImageIndex];
|
||||||
|
|
||||||
|
const newSelectedImage = allEntities[newSelectedImageId];
|
||||||
|
|
||||||
|
if (newSelectedImageId) {
|
||||||
|
dispatch(imageSelected(newSelectedImage));
|
||||||
|
} else {
|
||||||
|
dispatch(imageSelected());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(imageDeleted({ imageName: name, imageType: type }));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,25 @@
|
|||||||
|
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { uploadAdded } from 'features/gallery/store/uploadsSlice';
|
||||||
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
|
|
||||||
|
export const addImageUploadedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
predicate: (action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||||
|
imageUploaded.fulfilled.match(action) &&
|
||||||
|
action.payload.response.image_type !== 'intermediates',
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
const { response } = action.payload;
|
||||||
|
|
||||||
|
const state = getState();
|
||||||
|
const image = deserializeImageResponse(response);
|
||||||
|
|
||||||
|
dispatch(uploadAdded(image));
|
||||||
|
|
||||||
|
if (state.gallery.shouldAutoSwitchToNewImages) {
|
||||||
|
dispatch(imageSelected(image));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,54 @@
|
|||||||
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
|
import { Image, isInvokeAIImage } from 'app/types/invokeai';
|
||||||
|
import { selectResultsById } from 'features/gallery/store/resultsSlice';
|
||||||
|
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
|
||||||
|
import { makeToast } from 'features/system/hooks/useToastWatcher';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
|
|
||||||
|
export const addInitialImageSelectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: initialImageSelected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
if (!action.payload) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({ title: t('toast.imageNotLoadedDesc'), status: 'error' })
|
||||||
|
)
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isInvokeAIImage(action.payload)) {
|
||||||
|
dispatch(initialImageChanged(action.payload));
|
||||||
|
dispatch(addToast(makeToast(t('toast.sentToImageToImage'))));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { name, type } = action.payload;
|
||||||
|
|
||||||
|
let image: Image | undefined;
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
if (type === 'results') {
|
||||||
|
image = selectResultsById(state, name);
|
||||||
|
} else if (type === 'uploads') {
|
||||||
|
image = selectUploadsById(state, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!image) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({ title: t('toast.imageNotLoadedDesc'), status: 'error' })
|
||||||
|
)
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(initialImageChanged(image));
|
||||||
|
dispatch(addToast(makeToast(t('toast.sentToImageToImage'))));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,88 @@
|
|||||||
|
import { invocationComplete } from 'services/events/actions';
|
||||||
|
import { isImageOutput } from 'services/types/guards';
|
||||||
|
import {
|
||||||
|
buildImageUrls,
|
||||||
|
extractTimestampFromImageName,
|
||||||
|
} from 'services/util/deserializeImageField';
|
||||||
|
import { Image } from 'app/types/invokeai';
|
||||||
|
import { resultAdded } from 'features/gallery/store/resultsSlice';
|
||||||
|
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
|
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||||
|
|
||||||
|
const nodeDenylist = ['dataURL_image'];
|
||||||
|
|
||||||
|
export const addImageResultReceivedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
predicate: (action) => {
|
||||||
|
if (
|
||||||
|
invocationComplete.match(action) &&
|
||||||
|
isImageOutput(action.payload.data.result)
|
||||||
|
) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
},
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
if (!invocationComplete.match(action)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { data, shouldFetchImages } = action.payload;
|
||||||
|
const { result, node, graph_execution_state_id } = data;
|
||||||
|
|
||||||
|
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
|
||||||
|
const name = result.image.image_name;
|
||||||
|
const type = result.image.image_type;
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
// if we need to refetch, set URLs to placeholder for now
|
||||||
|
const { url, thumbnail } = shouldFetchImages
|
||||||
|
? { url: '', thumbnail: '' }
|
||||||
|
: buildImageUrls(type, name);
|
||||||
|
|
||||||
|
const timestamp = extractTimestampFromImageName(name);
|
||||||
|
|
||||||
|
const image: Image = {
|
||||||
|
name,
|
||||||
|
type,
|
||||||
|
url,
|
||||||
|
thumbnail,
|
||||||
|
metadata: {
|
||||||
|
created: timestamp,
|
||||||
|
width: result.width,
|
||||||
|
height: result.height,
|
||||||
|
invokeai: {
|
||||||
|
session_id: graph_execution_state_id,
|
||||||
|
...(node ? { node } : {}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
dispatch(resultAdded(image));
|
||||||
|
|
||||||
|
if (state.gallery.shouldAutoSwitchToNewImages) {
|
||||||
|
dispatch(imageSelected(image));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (state.config.shouldFetchImages) {
|
||||||
|
dispatch(imageReceived({ imageName: name, imageType: type }));
|
||||||
|
dispatch(
|
||||||
|
thumbnailReceived({
|
||||||
|
thumbnailName: name,
|
||||||
|
thumbnailType: type,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
graph_execution_state_id ===
|
||||||
|
state.canvas.layerState.stagingArea.sessionId
|
||||||
|
) {
|
||||||
|
dispatch(addImageToStagingArea(image));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,126 @@
|
|||||||
|
import { startAppListening } from '..';
|
||||||
|
import { sessionCreated, sessionInvoked } from 'services/thunks/session';
|
||||||
|
import { buildCanvasGraphAndBlobs } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||||
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
import { Graph } from 'services/api';
|
||||||
|
import {
|
||||||
|
canvasSessionIdChanged,
|
||||||
|
stagingAreaInitialized,
|
||||||
|
} from 'features/canvas/store/canvasSlice';
|
||||||
|
import { userInvoked } from 'app/store/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'invoke' });
|
||||||
|
|
||||||
|
export const addUserInvokedCanvasListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||||
|
userInvoked.match(action) && action.payload === 'unifiedCanvas',
|
||||||
|
effect: async (action, { getState, dispatch, take }) => {
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
const data = await buildCanvasGraphAndBlobs(state);
|
||||||
|
|
||||||
|
if (!data) {
|
||||||
|
moduleLog.error('Problem building graph');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const {
|
||||||
|
rangeNode,
|
||||||
|
iterateNode,
|
||||||
|
baseNode,
|
||||||
|
edges,
|
||||||
|
baseBlob,
|
||||||
|
maskBlob,
|
||||||
|
generationMode,
|
||||||
|
} = data;
|
||||||
|
|
||||||
|
const baseFilename = `${uuidv4()}.png`;
|
||||||
|
const maskFilename = `${uuidv4()}.png`;
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
imageUploaded({
|
||||||
|
imageType: 'intermediates',
|
||||||
|
formData: {
|
||||||
|
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
|
||||||
|
const [{ payload: basePayload }] = await take(
|
||||||
|
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||||
|
imageUploaded.fulfilled.match(action) &&
|
||||||
|
action.meta.arg.formData.file.name === baseFilename
|
||||||
|
);
|
||||||
|
|
||||||
|
const { image_name: baseName, image_type: baseType } =
|
||||||
|
basePayload.response;
|
||||||
|
|
||||||
|
baseNode.image = {
|
||||||
|
image_name: baseName,
|
||||||
|
image_type: baseType,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (baseNode.type === 'inpaint') {
|
||||||
|
dispatch(
|
||||||
|
imageUploaded({
|
||||||
|
imageType: 'intermediates',
|
||||||
|
formData: {
|
||||||
|
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
const [{ payload: maskPayload }] = await take(
|
||||||
|
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||||
|
imageUploaded.fulfilled.match(action) &&
|
||||||
|
action.meta.arg.formData.file.name === maskFilename
|
||||||
|
);
|
||||||
|
|
||||||
|
const { image_name: maskName, image_type: maskType } =
|
||||||
|
maskPayload.response;
|
||||||
|
|
||||||
|
baseNode.mask = {
|
||||||
|
image_name: maskName,
|
||||||
|
image_type: maskType,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assemble!
|
||||||
|
const nodes: Graph['nodes'] = {
|
||||||
|
[rangeNode.id]: rangeNode,
|
||||||
|
[iterateNode.id]: iterateNode,
|
||||||
|
[baseNode.id]: baseNode,
|
||||||
|
};
|
||||||
|
|
||||||
|
const graph = { nodes, edges };
|
||||||
|
|
||||||
|
dispatch(canvasGraphBuilt(graph));
|
||||||
|
moduleLog({ data: graph }, 'Canvas graph built');
|
||||||
|
|
||||||
|
dispatch(sessionCreated({ graph }));
|
||||||
|
|
||||||
|
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
|
||||||
|
const { sessionId } = meta.arg;
|
||||||
|
|
||||||
|
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
||||||
|
dispatch(
|
||||||
|
stagingAreaInitialized({
|
||||||
|
sessionId,
|
||||||
|
boundingBox: {
|
||||||
|
...state.canvas.boundingBoxCoordinates,
|
||||||
|
...state.canvas.boundingBoxDimensions,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(canvasSessionIdChanged(sessionId));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,24 @@
|
|||||||
|
import { startAppListening } from '..';
|
||||||
|
import { buildImageToImageGraph } from 'features/nodes/util/graphBuilders/buildImageToImageGraph';
|
||||||
|
import { sessionCreated } from 'services/thunks/session';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||||
|
import { userInvoked } from 'app/store/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'invoke' });
|
||||||
|
|
||||||
|
export const addUserInvokedImageToImageListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||||
|
userInvoked.match(action) && action.payload === 'img2img',
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
const graph = buildImageToImageGraph(state);
|
||||||
|
dispatch(imageToImageGraphBuilt(graph));
|
||||||
|
moduleLog({ data: graph }, 'Image to Image graph built');
|
||||||
|
|
||||||
|
dispatch(sessionCreated({ graph }));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,24 @@
|
|||||||
|
import { startAppListening } from '..';
|
||||||
|
import { sessionCreated } from 'services/thunks/session';
|
||||||
|
import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { nodesGraphBuilt } from 'features/nodes/store/actions';
|
||||||
|
import { userInvoked } from 'app/store/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'invoke' });
|
||||||
|
|
||||||
|
export const addUserInvokedNodesListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||||
|
userInvoked.match(action) && action.payload === 'nodes',
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
const graph = buildNodesGraph(state);
|
||||||
|
dispatch(nodesGraphBuilt(graph));
|
||||||
|
moduleLog({ data: graph }, 'Nodes graph built');
|
||||||
|
|
||||||
|
dispatch(sessionCreated({ graph }));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,24 @@
|
|||||||
|
import { startAppListening } from '..';
|
||||||
|
import { buildTextToImageGraph } from 'features/nodes/util/graphBuilders/buildTextToImageGraph';
|
||||||
|
import { sessionCreated } from 'services/thunks/session';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||||
|
import { userInvoked } from 'app/store/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'invoke' });
|
||||||
|
|
||||||
|
export const addUserInvokedTextToImageListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||||
|
userInvoked.match(action) && action.payload === 'txt2img',
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
const graph = buildTextToImageGraph(state);
|
||||||
|
dispatch(textToImageGraphBuilt(graph));
|
||||||
|
moduleLog({ data: graph }, 'Text to Image graph built');
|
||||||
|
|
||||||
|
dispatch(sessionCreated({ graph }));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -1,4 +0,0 @@
|
|||||||
import { store } from 'app/store/store';
|
|
||||||
import { persistStore } from 'redux-persist';
|
|
||||||
|
|
||||||
export const persistor = persistStore(store);
|
|
@ -1,9 +1,12 @@
|
|||||||
import { combineReducers, configureStore } from '@reduxjs/toolkit';
|
import {
|
||||||
|
AnyAction,
|
||||||
|
ThunkDispatch,
|
||||||
|
combineReducers,
|
||||||
|
configureStore,
|
||||||
|
} from '@reduxjs/toolkit';
|
||||||
|
|
||||||
import { persistReducer } from 'redux-persist';
|
import { rememberReducer, rememberEnhancer } from 'redux-remember';
|
||||||
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
|
|
||||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||||
import { getPersistConfig } from 'redux-deep-persist';
|
|
||||||
|
|
||||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||||
@ -19,33 +22,17 @@ import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
|||||||
import modelsReducer from 'features/system/store/modelSlice';
|
import modelsReducer from 'features/system/store/modelSlice';
|
||||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||||
|
|
||||||
import { canvasDenylist } from 'features/canvas/store/canvasPersistDenylist';
|
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
||||||
import { galleryDenylist } from 'features/gallery/store/galleryPersistDenylist';
|
|
||||||
import { generationDenylist } from 'features/parameters/store/generationPersistDenylist';
|
|
||||||
import { lightboxDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
|
|
||||||
import { modelsDenylist } from 'features/system/store/modelsPersistDenylist';
|
|
||||||
import { nodesDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
|
||||||
import { postprocessingDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
|
|
||||||
import { systemDenylist } from 'features/system/store/systemPersistDenylist';
|
|
||||||
import { uiDenylist } from 'features/ui/store/uiPersistDenylist';
|
|
||||||
import { resultsDenylist } from 'features/gallery/store/resultsPersistDenylist';
|
|
||||||
import { uploadsDenylist } from 'features/gallery/store/uploadsPersistDenylist';
|
|
||||||
|
|
||||||
/**
|
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||||
* redux-persist provides an easy and reliable way to persist state across reloads.
|
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||||
*
|
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||||
* While we definitely want generation parameters to be persisted, there are a number
|
|
||||||
* of things we do *not* want to be persisted across reloads:
|
|
||||||
* - Gallery/selected image (user may add/delete images from disk between page loads)
|
|
||||||
* - Connection/processing status
|
|
||||||
* - Availability of external libraries like ESRGAN/GFPGAN
|
|
||||||
*
|
|
||||||
* These can be denylisted in redux-persist.
|
|
||||||
*
|
|
||||||
* The necesssary nested persistors with denylists are configured below.
|
|
||||||
*/
|
|
||||||
|
|
||||||
const rootReducer = combineReducers({
|
import { serialize } from './enhancers/reduxRemember/serialize';
|
||||||
|
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
||||||
|
import { LOCALSTORAGE_PREFIX } from './constants';
|
||||||
|
|
||||||
|
const allReducers = {
|
||||||
canvas: canvasReducer,
|
canvas: canvasReducer,
|
||||||
gallery: galleryReducer,
|
gallery: galleryReducer,
|
||||||
generation: generationReducer,
|
generation: generationReducer,
|
||||||
@ -59,65 +46,54 @@ const rootReducer = combineReducers({
|
|||||||
ui: uiReducer,
|
ui: uiReducer,
|
||||||
uploads: uploadsReducer,
|
uploads: uploadsReducer,
|
||||||
hotkeys: hotkeysReducer,
|
hotkeys: hotkeysReducer,
|
||||||
});
|
};
|
||||||
|
|
||||||
const rootPersistConfig = getPersistConfig({
|
const rootReducer = combineReducers(allReducers);
|
||||||
key: 'root',
|
|
||||||
storage,
|
|
||||||
rootReducer,
|
|
||||||
blacklist: [
|
|
||||||
...canvasDenylist,
|
|
||||||
...galleryDenylist,
|
|
||||||
...generationDenylist,
|
|
||||||
...lightboxDenylist,
|
|
||||||
...modelsDenylist,
|
|
||||||
...nodesDenylist,
|
|
||||||
...postprocessingDenylist,
|
|
||||||
// ...resultsDenylist,
|
|
||||||
'results',
|
|
||||||
...systemDenylist,
|
|
||||||
...uiDenylist,
|
|
||||||
// ...uploadsDenylist,
|
|
||||||
'uploads',
|
|
||||||
'hotkeys',
|
|
||||||
'config',
|
|
||||||
],
|
|
||||||
});
|
|
||||||
|
|
||||||
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
|
const rememberedRootReducer = rememberReducer(rootReducer);
|
||||||
|
|
||||||
// TODO: rip the old middleware out when nodes is complete
|
const rememberedKeys: (keyof typeof allReducers)[] = [
|
||||||
// export function buildMiddleware() {
|
'canvas',
|
||||||
// if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
|
'gallery',
|
||||||
// return socketMiddleware();
|
'generation',
|
||||||
// } else {
|
'lightbox',
|
||||||
// return socketioMiddleware();
|
// 'models',
|
||||||
// }
|
'nodes',
|
||||||
// }
|
'postprocessing',
|
||||||
|
'system',
|
||||||
|
'ui',
|
||||||
|
// 'hotkeys',
|
||||||
|
// 'results',
|
||||||
|
// 'uploads',
|
||||||
|
// 'config',
|
||||||
|
];
|
||||||
|
|
||||||
export const store = configureStore({
|
export const store = configureStore({
|
||||||
reducer: persistedReducer,
|
reducer: rememberedRootReducer,
|
||||||
|
enhancers: [
|
||||||
|
rememberEnhancer(window.localStorage, rememberedKeys, {
|
||||||
|
persistDebounce: 300,
|
||||||
|
serialize,
|
||||||
|
unserialize,
|
||||||
|
prefix: LOCALSTORAGE_PREFIX,
|
||||||
|
}),
|
||||||
|
],
|
||||||
middleware: (getDefaultMiddleware) =>
|
middleware: (getDefaultMiddleware) =>
|
||||||
getDefaultMiddleware({
|
getDefaultMiddleware({
|
||||||
immutableCheck: false,
|
immutableCheck: false,
|
||||||
serializableCheck: false,
|
serializableCheck: false,
|
||||||
}).concat(dynamicMiddlewares),
|
})
|
||||||
|
.concat(dynamicMiddlewares)
|
||||||
|
.prepend(listenerMiddleware.middleware),
|
||||||
devTools: {
|
devTools: {
|
||||||
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
|
actionsDenylist,
|
||||||
actionsDenylist: [
|
actionSanitizer,
|
||||||
'canvas/setCursorPosition',
|
stateSanitizer,
|
||||||
'canvas/setStageCoordinates',
|
trace: true,
|
||||||
'canvas/setStageScale',
|
|
||||||
'canvas/setIsDrawing',
|
|
||||||
'canvas/setBoundingBoxCoordinates',
|
|
||||||
'canvas/setBoundingBoxDimensions',
|
|
||||||
'canvas/setIsDrawing',
|
|
||||||
'canvas/addPointToCurrentLine',
|
|
||||||
'socket/generatorProgress',
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
export type AppGetState = typeof store.getState;
|
export type AppGetState = typeof store.getState;
|
||||||
export type RootState = ReturnType<typeof store.getState>;
|
export type RootState = ReturnType<typeof store.getState>;
|
||||||
|
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
|
||||||
export type AppDispatch = typeof store.dispatch;
|
export type AppDispatch = typeof store.dispatch;
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { TypedUseSelectorHook, useDispatch, useSelector } from 'react-redux';
|
import { TypedUseSelectorHook, useDispatch, useSelector } from 'react-redux';
|
||||||
import { AppDispatch, RootState } from 'app/store/store';
|
import { AppThunkDispatch, RootState } from 'app/store/store';
|
||||||
|
|
||||||
// Use throughout your app instead of plain `useDispatch` and `useSelector`
|
// Use throughout your app instead of plain `useDispatch` and `useSelector`
|
||||||
export const useAppDispatch: () => AppDispatch = useDispatch;
|
export const useAppDispatch = () => useDispatch<AppThunkDispatch>();
|
||||||
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
|
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
|
||||||
|
@ -0,0 +1,7 @@
|
|||||||
|
import { isEqual } from 'lodash-es';
|
||||||
|
|
||||||
|
export const defaultSelectorOptions = {
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
};
|
@ -12,12 +12,10 @@
|
|||||||
* 'gfpgan'.
|
* 'gfpgan'.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { GalleryCategory } from 'features/gallery/store/gallerySlice';
|
import { SelectedImage } from 'features/parameters/store/actions';
|
||||||
import { FacetoolType } from 'features/parameters/store/postprocessingSlice';
|
|
||||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
import { IRect } from 'konva/lib/types';
|
import { IRect } from 'konva/lib/types';
|
||||||
import { ImageResponseMetadata, ImageType } from 'services/api';
|
import { ImageResponseMetadata, ImageType } from 'services/api';
|
||||||
import { AnyInvocation } from 'services/events/types';
|
|
||||||
import { O } from 'ts-toolbelt';
|
import { O } from 'ts-toolbelt';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -49,15 +47,20 @@ export type CommonGeneratedImageMetadata = {
|
|||||||
postprocessing: null | Array<ESRGANMetadata | FacetoolMetadata>;
|
postprocessing: null | Array<ESRGANMetadata | FacetoolMetadata>;
|
||||||
sampler:
|
sampler:
|
||||||
| 'ddim'
|
| 'ddim'
|
||||||
| 'k_dpm_2_a'
|
| 'ddpm'
|
||||||
| 'k_dpm_2'
|
| 'deis'
|
||||||
| 'k_dpmpp_2_a'
|
| 'lms'
|
||||||
| 'k_dpmpp_2'
|
| 'pndm'
|
||||||
| 'k_euler_a'
|
| 'heun'
|
||||||
| 'k_euler'
|
| 'euler'
|
||||||
| 'k_heun'
|
| 'euler_k'
|
||||||
| 'k_lms'
|
| 'euler_a'
|
||||||
| 'plms';
|
| 'kdpm_2'
|
||||||
|
| 'kdpm_2_a'
|
||||||
|
| 'dpmpp_2s'
|
||||||
|
| 'dpmpp_2m'
|
||||||
|
| 'dpmpp_2m_k'
|
||||||
|
| 'unipc';
|
||||||
prompt: Prompt;
|
prompt: Prompt;
|
||||||
seed: number;
|
seed: number;
|
||||||
variations: SeedWeights;
|
variations: SeedWeights;
|
||||||
@ -126,6 +129,14 @@ export type Image = {
|
|||||||
metadata: ImageResponseMetadata;
|
metadata: ImageResponseMetadata;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const isInvokeAIImage = (obj: Image | SelectedImage): obj is Image => {
|
||||||
|
if ('url' in obj && 'thumbnail' in obj) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Types related to the system status.
|
* Types related to the system status.
|
||||||
*/
|
*/
|
||||||
@ -270,7 +281,7 @@ export type FoundModelResponse = {
|
|||||||
|
|
||||||
// export type SystemConfigResponse = SystemConfig;
|
// export type SystemConfigResponse = SystemConfig;
|
||||||
|
|
||||||
export type ImageResultResponse = Omit<_Image, 'uuid'> & {
|
export type ImageResultResponse = Omit<Image, 'uuid'> & {
|
||||||
boundingBox?: IRect;
|
boundingBox?: IRect;
|
||||||
generationMode: InvokeTabName;
|
generationMode: InvokeTabName;
|
||||||
};
|
};
|
||||||
@ -315,11 +326,11 @@ export type AppFeature =
|
|||||||
/**
|
/**
|
||||||
* A disable-able Stable Diffusion feature
|
* A disable-able Stable Diffusion feature
|
||||||
*/
|
*/
|
||||||
export type StableDiffusionFeature =
|
export type SDFeature =
|
||||||
| 'noiseConfig'
|
| 'noise'
|
||||||
| 'variations'
|
| 'variation'
|
||||||
| 'symmetry'
|
| 'symmetry'
|
||||||
| 'tiling'
|
| 'seamless'
|
||||||
| 'hires';
|
| 'hires';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -337,6 +348,7 @@ export type AppConfig = {
|
|||||||
shouldFetchImages: boolean;
|
shouldFetchImages: boolean;
|
||||||
disabledTabs: InvokeTabName[];
|
disabledTabs: InvokeTabName[];
|
||||||
disabledFeatures: AppFeature[];
|
disabledFeatures: AppFeature[];
|
||||||
|
disabledSDFeatures: SDFeature[];
|
||||||
canRestoreDeletedImagesFromBin: boolean;
|
canRestoreDeletedImagesFromBin: boolean;
|
||||||
sd: {
|
sd: {
|
||||||
iterations: {
|
iterations: {
|
||||||
|
61
invokeai/frontend/web/src/common/components/IAICollapse.tsx
Normal file
61
invokeai/frontend/web/src/common/components/IAICollapse.tsx
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import { ChevronUpIcon } from '@chakra-ui/icons';
|
||||||
|
import { Box, Collapse, Flex, Spacer, Switch } from '@chakra-ui/react';
|
||||||
|
import { PropsWithChildren, memo } from 'react';
|
||||||
|
|
||||||
|
export type IAIToggleCollapseProps = PropsWithChildren & {
|
||||||
|
label: string;
|
||||||
|
isOpen: boolean;
|
||||||
|
onToggle: () => void;
|
||||||
|
withSwitch?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
const IAICollapse = (props: IAIToggleCollapseProps) => {
|
||||||
|
const { label, isOpen, onToggle, children, withSwitch = false } = props;
|
||||||
|
return (
|
||||||
|
<Box>
|
||||||
|
<Flex
|
||||||
|
onClick={onToggle}
|
||||||
|
sx={{
|
||||||
|
alignItems: 'center',
|
||||||
|
p: 2,
|
||||||
|
px: 4,
|
||||||
|
borderTopRadius: 'base',
|
||||||
|
borderBottomRadius: isOpen ? 0 : 'base',
|
||||||
|
bg: isOpen ? 'base.750' : 'base.800',
|
||||||
|
color: 'base.100',
|
||||||
|
_hover: {
|
||||||
|
bg: isOpen ? 'base.700' : 'base.750',
|
||||||
|
},
|
||||||
|
fontSize: 'sm',
|
||||||
|
fontWeight: 600,
|
||||||
|
cursor: 'pointer',
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: 'normal',
|
||||||
|
userSelect: 'none',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{label}
|
||||||
|
<Spacer />
|
||||||
|
{withSwitch && <Switch isChecked={isOpen} pointerEvents="none" />}
|
||||||
|
{!withSwitch && (
|
||||||
|
<ChevronUpIcon
|
||||||
|
sx={{
|
||||||
|
w: '1rem',
|
||||||
|
h: '1rem',
|
||||||
|
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: 'normal',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
<Collapse in={isOpen} animateOpacity>
|
||||||
|
<Box sx={{ p: 4, borderBottomRadius: 'base', bg: 'base.800' }}>
|
||||||
|
{children}
|
||||||
|
</Box>
|
||||||
|
</Collapse>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(IAICollapse);
|
@ -27,7 +27,7 @@ const IAIPopover = (props: IAIPopoverProps) => {
|
|||||||
return (
|
return (
|
||||||
<Popover isLazy={isLazy} {...rest}>
|
<Popover isLazy={isLazy} {...rest}>
|
||||||
<PopoverTrigger>{triggerComponent}</PopoverTrigger>
|
<PopoverTrigger>{triggerComponent}</PopoverTrigger>
|
||||||
<PopoverContent>
|
<PopoverContent shadow="dark-lg">
|
||||||
{hasArrow && <PopoverArrow />}
|
{hasArrow && <PopoverArrow />}
|
||||||
{children}
|
{children}
|
||||||
</PopoverContent>
|
</PopoverContent>
|
||||||
|
@ -0,0 +1,54 @@
|
|||||||
|
import { Badge, Flex } from '@chakra-ui/react';
|
||||||
|
import { Image } from 'app/types/invokeai';
|
||||||
|
import { isNumber, isString } from 'lodash-es';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
|
type ImageMetadataOverlayProps = {
|
||||||
|
image: Image;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => {
|
||||||
|
const dimensions = useMemo(() => {
|
||||||
|
if (!isNumber(image.metadata?.width) || isNumber(!image.metadata?.height)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return `${image.metadata?.width} × ${image.metadata?.height}`;
|
||||||
|
}, [image.metadata]);
|
||||||
|
|
||||||
|
const model = useMemo(() => {
|
||||||
|
if (!isString(image.metadata?.invokeai?.node?.model)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return image.metadata?.invokeai?.node?.model;
|
||||||
|
}, [image.metadata]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
pointerEvents: 'none',
|
||||||
|
flexDirection: 'column',
|
||||||
|
position: 'absolute',
|
||||||
|
top: 0,
|
||||||
|
right: 0,
|
||||||
|
p: 2,
|
||||||
|
alignItems: 'flex-end',
|
||||||
|
gap: 2,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{dimensions && (
|
||||||
|
<Badge variant="solid" colorScheme="base">
|
||||||
|
{dimensions}
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
{model && (
|
||||||
|
<Badge variant="solid" colorScheme="base">
|
||||||
|
{model}
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ImageMetadataOverlay;
|
@ -7,7 +7,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
|||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||||
|
|
||||||
const ImageToImageSettingsHeader = () => {
|
const InitialImageButtons = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
@ -18,24 +18,19 @@ const ImageToImageSettingsHeader = () => {
|
|||||||
return (
|
return (
|
||||||
<Flex w="full" alignItems="center">
|
<Flex w="full" alignItems="center">
|
||||||
<Text size="sm" fontWeight={500} color="base.300">
|
<Text size="sm" fontWeight={500} color="base.300">
|
||||||
Image to Image
|
{t('parameters.initialImage')}
|
||||||
</Text>
|
</Text>
|
||||||
<Spacer />
|
<Spacer />
|
||||||
<ButtonGroup>
|
<ButtonGroup>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
size="sm"
|
|
||||||
icon={<FaUndo />}
|
icon={<FaUndo />}
|
||||||
aria-label={t('accessibility.reset')}
|
aria-label={t('accessibility.reset')}
|
||||||
onClick={handleResetInitialImage}
|
onClick={handleResetInitialImage}
|
||||||
/>
|
/>
|
||||||
<IAIIconButton
|
<IAIIconButton icon={<FaUpload />} aria-label={t('common.upload')} />
|
||||||
size="sm"
|
|
||||||
icon={<FaUpload />}
|
|
||||||
aria-label={t('common.upload')}
|
|
||||||
/>
|
|
||||||
</ButtonGroup>
|
</ButtonGroup>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default ImageToImageSettingsHeader;
|
export default InitialImageButtons;
|
@ -1,36 +0,0 @@
|
|||||||
import { Badge, Box, Flex } from '@chakra-ui/react';
|
|
||||||
import { Image } from 'app/types/invokeai';
|
|
||||||
|
|
||||||
type ImageToImageOverlayProps = {
|
|
||||||
image: Image;
|
|
||||||
};
|
|
||||||
|
|
||||||
const ImageToImageOverlay = ({ image }: ImageToImageOverlayProps) => {
|
|
||||||
return (
|
|
||||||
<Box
|
|
||||||
sx={{
|
|
||||||
top: 0,
|
|
||||||
left: 0,
|
|
||||||
w: 'full',
|
|
||||||
h: 'full',
|
|
||||||
position: 'absolute',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
position: 'absolute',
|
|
||||||
top: 0,
|
|
||||||
right: 0,
|
|
||||||
p: 2,
|
|
||||||
alignItems: 'flex-start',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Badge variant="solid" colorScheme="base">
|
|
||||||
{image.metadata?.width} × {image.metadata?.height}
|
|
||||||
</Badge>
|
|
||||||
</Flex>
|
|
||||||
</Box>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default ImageToImageOverlay;
|
|
@ -49,7 +49,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
|||||||
|
|
||||||
const fileAcceptedCallback = useCallback(
|
const fileAcceptedCallback = useCallback(
|
||||||
async (file: File) => {
|
async (file: File) => {
|
||||||
dispatch(imageUploaded({ formData: { file } }));
|
dispatch(imageUploaded({ imageType: 'uploads', formData: { file } }));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
@ -124,7 +124,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(imageUploaded({ formData: { file } }));
|
dispatch(imageUploaded({ imageType: 'uploads', formData: { file } }));
|
||||||
};
|
};
|
||||||
document.addEventListener('paste', pasteImageListener);
|
document.addEventListener('paste', pasteImageListener);
|
||||||
return () => {
|
return () => {
|
||||||
|
@ -7,7 +7,7 @@ const SelectImagePlaceholder = () => {
|
|||||||
sx={{
|
sx={{
|
||||||
w: 'full',
|
w: 'full',
|
||||||
h: 'full',
|
h: 'full',
|
||||||
bg: 'base.800',
|
// bg: 'base.800',
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
justifyContent: 'center',
|
justifyContent: 'center',
|
||||||
|
@ -2,6 +2,13 @@ import { createSelector } from '@reduxjs/toolkit';
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
|
||||||
|
import {
|
||||||
|
setActiveTab,
|
||||||
|
toggleGalleryPanel,
|
||||||
|
toggleParametersPanel,
|
||||||
|
togglePinGalleryPanel,
|
||||||
|
togglePinParametersPanel,
|
||||||
|
} from 'features/ui/store/uiSlice';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
|
||||||
|
|
||||||
@ -36,4 +43,36 @@ export const useGlobalHotkeys = () => {
|
|||||||
{ keyup: true, keydown: true },
|
{ keyup: true, keydown: true },
|
||||||
[shift]
|
[shift]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
useHotkeys('o', () => {
|
||||||
|
dispatch(toggleParametersPanel());
|
||||||
|
});
|
||||||
|
|
||||||
|
useHotkeys(['shift+o'], () => {
|
||||||
|
dispatch(togglePinParametersPanel());
|
||||||
|
});
|
||||||
|
|
||||||
|
useHotkeys('g', () => {
|
||||||
|
dispatch(toggleGalleryPanel());
|
||||||
|
});
|
||||||
|
|
||||||
|
useHotkeys(['shift+g'], () => {
|
||||||
|
dispatch(togglePinGalleryPanel());
|
||||||
|
});
|
||||||
|
|
||||||
|
useHotkeys('1', () => {
|
||||||
|
dispatch(setActiveTab('txt2img'));
|
||||||
|
});
|
||||||
|
|
||||||
|
useHotkeys('2', () => {
|
||||||
|
dispatch(setActiveTab('img2img'));
|
||||||
|
});
|
||||||
|
|
||||||
|
useHotkeys('3', () => {
|
||||||
|
dispatch(setActiveTab('unifiedCanvas'));
|
||||||
|
});
|
||||||
|
|
||||||
|
useHotkeys('4', () => {
|
||||||
|
dispatch(setActiveTab('nodes'));
|
||||||
|
});
|
||||||
};
|
};
|
||||||
|
33
invokeai/frontend/web/src/common/util/arrayBuffer.ts
Normal file
33
invokeai/frontend/web/src/common/util/arrayBuffer.ts
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
export const getImageDataTransparency = (pixels: Uint8ClampedArray) => {
|
||||||
|
let isFullyTransparent = true;
|
||||||
|
let isPartiallyTransparent = false;
|
||||||
|
const len = pixels.length;
|
||||||
|
let i = 3;
|
||||||
|
for (i; i < len; i += 4) {
|
||||||
|
if (pixels[i] === 255) {
|
||||||
|
isFullyTransparent = false;
|
||||||
|
} else {
|
||||||
|
isPartiallyTransparent = true;
|
||||||
|
}
|
||||||
|
if (!isFullyTransparent && isPartiallyTransparent) {
|
||||||
|
return { isFullyTransparent, isPartiallyTransparent };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return { isFullyTransparent, isPartiallyTransparent };
|
||||||
|
};
|
||||||
|
|
||||||
|
export const areAnyPixelsBlack = (pixels: Uint8ClampedArray) => {
|
||||||
|
const len = pixels.length;
|
||||||
|
let i = 0;
|
||||||
|
for (i; i < len; ) {
|
||||||
|
if (
|
||||||
|
pixels[i++] === 0 &&
|
||||||
|
pixels[i++] === 0 &&
|
||||||
|
pixels[i++] === 0 &&
|
||||||
|
pixels[i++] === 255
|
||||||
|
) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
@ -19,6 +19,7 @@ import { InvokeTabName } from 'features/ui/store/tabMap';
|
|||||||
import openBase64ImageInTab from './openBase64ImageInTab';
|
import openBase64ImageInTab from './openBase64ImageInTab';
|
||||||
import randomInt from './randomInt';
|
import randomInt from './randomInt';
|
||||||
import { stringToSeedWeightsArray } from './seedWeightPairs';
|
import { stringToSeedWeightsArray } from './seedWeightPairs';
|
||||||
|
import { getIsImageDataTransparent, getIsImageDataWhite } from './arrayBuffer';
|
||||||
|
|
||||||
export type FrontendToBackendParametersConfig = {
|
export type FrontendToBackendParametersConfig = {
|
||||||
generationMode: InvokeTabName;
|
generationMode: InvokeTabName;
|
||||||
@ -256,7 +257,7 @@ export const frontendToBackendParameters = (
|
|||||||
...boundingBoxDimensions,
|
...boundingBoxDimensions,
|
||||||
};
|
};
|
||||||
|
|
||||||
const maskDataURL = generateMask(
|
const { dataURL: maskDataURL, imageData: maskImageData } = generateMask(
|
||||||
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [],
|
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [],
|
||||||
boundingBox
|
boundingBox
|
||||||
);
|
);
|
||||||
@ -287,6 +288,17 @@ export const frontendToBackendParameters = (
|
|||||||
height: boundingBox.height,
|
height: boundingBox.height,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const ctx = canvasBaseLayer.getContext();
|
||||||
|
const imageData = ctx.getImageData(
|
||||||
|
boundingBox.x + absPos.x,
|
||||||
|
boundingBox.y + absPos.y,
|
||||||
|
boundingBox.width,
|
||||||
|
boundingBox.height
|
||||||
|
);
|
||||||
|
|
||||||
|
const doesBaseHaveTransparency = getIsImageDataTransparent(imageData);
|
||||||
|
const doesMaskHaveTransparency = getIsImageDataWhite(maskImageData);
|
||||||
|
|
||||||
if (enableImageDebugging) {
|
if (enableImageDebugging) {
|
||||||
openBase64ImageInTab([
|
openBase64ImageInTab([
|
||||||
{ base64: maskDataURL, caption: 'mask sent as init_mask' },
|
{ base64: maskDataURL, caption: 'mask sent as init_mask' },
|
||||||
|
@ -34,6 +34,7 @@ import IAICanvasStagingAreaToolbar from './IAICanvasStagingAreaToolbar';
|
|||||||
import IAICanvasStatusText from './IAICanvasStatusText';
|
import IAICanvasStatusText from './IAICanvasStatusText';
|
||||||
import IAICanvasBoundingBox from './IAICanvasToolbar/IAICanvasBoundingBox';
|
import IAICanvasBoundingBox from './IAICanvasToolbar/IAICanvasBoundingBox';
|
||||||
import IAICanvasToolPreview from './IAICanvasToolPreview';
|
import IAICanvasToolPreview from './IAICanvasToolPreview';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[canvasSelector, isStagingSelector],
|
[canvasSelector, isStagingSelector],
|
||||||
@ -52,6 +53,7 @@ const selector = createSelector(
|
|||||||
shouldShowIntermediates,
|
shouldShowIntermediates,
|
||||||
shouldShowGrid,
|
shouldShowGrid,
|
||||||
shouldRestrictStrokesToBox,
|
shouldRestrictStrokesToBox,
|
||||||
|
shouldAntialias,
|
||||||
} = canvas;
|
} = canvas;
|
||||||
|
|
||||||
let stageCursor: string | undefined = 'none';
|
let stageCursor: string | undefined = 'none';
|
||||||
@ -80,13 +82,10 @@ const selector = createSelector(
|
|||||||
tool,
|
tool,
|
||||||
isStaging,
|
isStaging,
|
||||||
shouldShowIntermediates,
|
shouldShowIntermediates,
|
||||||
|
shouldAntialias,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
{
|
defaultSelectorOptions
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const ChakraStage = chakra(Stage, {
|
const ChakraStage = chakra(Stage, {
|
||||||
@ -106,6 +105,7 @@ const IAICanvas = () => {
|
|||||||
tool,
|
tool,
|
||||||
isStaging,
|
isStaging,
|
||||||
shouldShowIntermediates,
|
shouldShowIntermediates,
|
||||||
|
shouldAntialias,
|
||||||
} = useAppSelector(selector);
|
} = useAppSelector(selector);
|
||||||
useCanvasHotkeys();
|
useCanvasHotkeys();
|
||||||
|
|
||||||
@ -190,7 +190,7 @@ const IAICanvas = () => {
|
|||||||
id="base"
|
id="base"
|
||||||
ref={canvasBaseLayerRefCallback}
|
ref={canvasBaseLayerRefCallback}
|
||||||
listening={false}
|
listening={false}
|
||||||
imageSmoothingEnabled={false}
|
imageSmoothingEnabled={shouldAntialias}
|
||||||
>
|
>
|
||||||
<IAICanvasObjectRenderer />
|
<IAICanvasObjectRenderer />
|
||||||
</Layer>
|
</Layer>
|
||||||
@ -201,7 +201,7 @@ const IAICanvas = () => {
|
|||||||
<Layer>
|
<Layer>
|
||||||
<IAICanvasBoundingBoxOverlay />
|
<IAICanvasBoundingBoxOverlay />
|
||||||
</Layer>
|
</Layer>
|
||||||
<Layer id="preview" imageSmoothingEnabled={false}>
|
<Layer id="preview" imageSmoothingEnabled={shouldAntialias}>
|
||||||
{!isStaging && (
|
{!isStaging && (
|
||||||
<IAICanvasToolPreview
|
<IAICanvasToolPreview
|
||||||
visible={tool !== 'move'}
|
visible={tool !== 'move'}
|
||||||
|
@ -12,18 +12,20 @@ const selector = createSelector(
|
|||||||
[canvasSelector],
|
[canvasSelector],
|
||||||
(canvas) => {
|
(canvas) => {
|
||||||
const {
|
const {
|
||||||
layerState: {
|
layerState,
|
||||||
stagingArea: { images, selectedImageIndex },
|
|
||||||
},
|
|
||||||
shouldShowStagingImage,
|
shouldShowStagingImage,
|
||||||
shouldShowStagingOutline,
|
shouldShowStagingOutline,
|
||||||
boundingBoxCoordinates: { x, y },
|
boundingBoxCoordinates: { x, y },
|
||||||
boundingBoxDimensions: { width, height },
|
boundingBoxDimensions: { width, height },
|
||||||
} = canvas;
|
} = canvas;
|
||||||
|
|
||||||
|
const { selectedImageIndex, images } = layerState.stagingArea;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
currentStagingAreaImage:
|
currentStagingAreaImage:
|
||||||
images.length > 0 ? images[selectedImageIndex] : undefined,
|
images.length > 0 && selectedImageIndex !== undefined
|
||||||
|
? images[selectedImageIndex]
|
||||||
|
: undefined,
|
||||||
isOnFirstImage: selectedImageIndex === 0,
|
isOnFirstImage: selectedImageIndex === 0,
|
||||||
isOnLastImage: selectedImageIndex === images.length - 1,
|
isOnLastImage: selectedImageIndex === images.length - 1,
|
||||||
shouldShowStagingImage,
|
shouldShowStagingImage,
|
||||||
|
@ -6,6 +6,7 @@ import IAIIconButton from 'common/components/IAIIconButton';
|
|||||||
import IAIPopover from 'common/components/IAIPopover';
|
import IAIPopover from 'common/components/IAIPopover';
|
||||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
import {
|
import {
|
||||||
|
setShouldAntialias,
|
||||||
setShouldAutoSave,
|
setShouldAutoSave,
|
||||||
setShouldCropToBoundingBoxOnSave,
|
setShouldCropToBoundingBoxOnSave,
|
||||||
setShouldDarkenOutsideBoundingBox,
|
setShouldDarkenOutsideBoundingBox,
|
||||||
@ -36,6 +37,7 @@ export const canvasControlsSelector = createSelector(
|
|||||||
shouldShowIntermediates,
|
shouldShowIntermediates,
|
||||||
shouldSnapToGrid,
|
shouldSnapToGrid,
|
||||||
shouldRestrictStrokesToBox,
|
shouldRestrictStrokesToBox,
|
||||||
|
shouldAntialias,
|
||||||
} = canvas;
|
} = canvas;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -47,6 +49,7 @@ export const canvasControlsSelector = createSelector(
|
|||||||
shouldShowIntermediates,
|
shouldShowIntermediates,
|
||||||
shouldSnapToGrid,
|
shouldSnapToGrid,
|
||||||
shouldRestrictStrokesToBox,
|
shouldRestrictStrokesToBox,
|
||||||
|
shouldAntialias,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -69,6 +72,7 @@ const IAICanvasSettingsButtonPopover = () => {
|
|||||||
shouldShowIntermediates,
|
shouldShowIntermediates,
|
||||||
shouldSnapToGrid,
|
shouldSnapToGrid,
|
||||||
shouldRestrictStrokesToBox,
|
shouldRestrictStrokesToBox,
|
||||||
|
shouldAntialias,
|
||||||
} = useAppSelector(canvasControlsSelector);
|
} = useAppSelector(canvasControlsSelector);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
@ -148,6 +152,12 @@ const IAICanvasSettingsButtonPopover = () => {
|
|||||||
dispatch(setShouldShowCanvasDebugInfo(e.target.checked))
|
dispatch(setShouldShowCanvasDebugInfo(e.target.checked))
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<IAICheckbox
|
||||||
|
label={t('unifiedCanvas.antialiasing')}
|
||||||
|
isChecked={shouldAntialias}
|
||||||
|
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
|
||||||
|
/>
|
||||||
<ClearCanvasHistoryButtonModal />
|
<ClearCanvasHistoryButtonModal />
|
||||||
<EmptyTempFolderButtonModal />
|
<EmptyTempFolderButtonModal />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -9,6 +9,12 @@ const itemsToDenylist: (keyof CanvasState)[] = [
|
|||||||
'doesCanvasNeedScaling',
|
'doesCanvasNeedScaling',
|
||||||
];
|
];
|
||||||
|
|
||||||
|
export const canvasPersistDenylist: (keyof CanvasState)[] = [
|
||||||
|
'cursorPosition',
|
||||||
|
'isCanvasInitialized',
|
||||||
|
'doesCanvasNeedScaling',
|
||||||
|
];
|
||||||
|
|
||||||
export const canvasDenylist = itemsToDenylist.map(
|
export const canvasDenylist = itemsToDenylist.map(
|
||||||
(denylistItem) => `canvas.${denylistItem}`
|
(denylistItem) => `canvas.${denylistItem}`
|
||||||
);
|
);
|
||||||
|
@ -38,7 +38,7 @@ export const initialLayerState: CanvasLayerState = {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
const initialCanvasState: CanvasState = {
|
export const initialCanvasState: CanvasState = {
|
||||||
boundingBoxCoordinates: { x: 0, y: 0 },
|
boundingBoxCoordinates: { x: 0, y: 0 },
|
||||||
boundingBoxDimensions: { width: 512, height: 512 },
|
boundingBoxDimensions: { width: 512, height: 512 },
|
||||||
boundingBoxPreviewFill: { r: 0, g: 0, b: 0, a: 0.5 },
|
boundingBoxPreviewFill: { r: 0, g: 0, b: 0, a: 0.5 },
|
||||||
@ -66,6 +66,7 @@ const initialCanvasState: CanvasState = {
|
|||||||
minimumStageScale: 1,
|
minimumStageScale: 1,
|
||||||
pastLayerStates: [],
|
pastLayerStates: [],
|
||||||
scaledBoundingBoxDimensions: { width: 512, height: 512 },
|
scaledBoundingBoxDimensions: { width: 512, height: 512 },
|
||||||
|
shouldAntialias: true,
|
||||||
shouldAutoSave: false,
|
shouldAutoSave: false,
|
||||||
shouldCropToBoundingBoxOnSave: false,
|
shouldCropToBoundingBoxOnSave: false,
|
||||||
shouldDarkenOutsideBoundingBox: false,
|
shouldDarkenOutsideBoundingBox: false,
|
||||||
@ -156,22 +157,20 @@ export const canvasSlice = createSlice({
|
|||||||
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
|
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
|
||||||
state.cursorPosition = action.payload;
|
state.cursorPosition = action.payload;
|
||||||
},
|
},
|
||||||
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI._Image>) => {
|
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI.Image>) => {
|
||||||
const image = action.payload;
|
const image = action.payload;
|
||||||
|
const { width, height } = image.metadata;
|
||||||
const { stageDimensions } = state;
|
const { stageDimensions } = state;
|
||||||
|
|
||||||
const newBoundingBoxDimensions = {
|
const newBoundingBoxDimensions = {
|
||||||
width: roundDownToMultiple(clamp(image.width, 64, 512), 64),
|
width: roundDownToMultiple(clamp(width, 64, 512), 64),
|
||||||
height: roundDownToMultiple(clamp(image.height, 64, 512), 64),
|
height: roundDownToMultiple(clamp(height, 64, 512), 64),
|
||||||
};
|
};
|
||||||
|
|
||||||
const newBoundingBoxCoordinates = {
|
const newBoundingBoxCoordinates = {
|
||||||
x: roundToMultiple(
|
x: roundToMultiple(width / 2 - newBoundingBoxDimensions.width / 2, 64),
|
||||||
image.width / 2 - newBoundingBoxDimensions.width / 2,
|
|
||||||
64
|
|
||||||
),
|
|
||||||
y: roundToMultiple(
|
y: roundToMultiple(
|
||||||
image.height / 2 - newBoundingBoxDimensions.height / 2,
|
height / 2 - newBoundingBoxDimensions.height / 2,
|
||||||
64
|
64
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
@ -196,8 +195,8 @@ export const canvasSlice = createSlice({
|
|||||||
layer: 'base',
|
layer: 'base',
|
||||||
x: 0,
|
x: 0,
|
||||||
y: 0,
|
y: 0,
|
||||||
width: image.width,
|
width: width,
|
||||||
height: image.height,
|
height: height,
|
||||||
image: image,
|
image: image,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@ -208,8 +207,8 @@ export const canvasSlice = createSlice({
|
|||||||
const newScale = calculateScale(
|
const newScale = calculateScale(
|
||||||
stageDimensions.width,
|
stageDimensions.width,
|
||||||
stageDimensions.height,
|
stageDimensions.height,
|
||||||
image.width,
|
width,
|
||||||
image.height,
|
height,
|
||||||
STAGE_PADDING_PERCENTAGE
|
STAGE_PADDING_PERCENTAGE
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -218,8 +217,8 @@ export const canvasSlice = createSlice({
|
|||||||
stageDimensions.height,
|
stageDimensions.height,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
image.width,
|
width,
|
||||||
image.height,
|
height,
|
||||||
newScale
|
newScale
|
||||||
);
|
);
|
||||||
state.stageScale = newScale;
|
state.stageScale = newScale;
|
||||||
@ -287,16 +286,28 @@ export const canvasSlice = createSlice({
|
|||||||
setIsMoveStageKeyHeld: (state, action: PayloadAction<boolean>) => {
|
setIsMoveStageKeyHeld: (state, action: PayloadAction<boolean>) => {
|
||||||
state.isMoveStageKeyHeld = action.payload;
|
state.isMoveStageKeyHeld = action.payload;
|
||||||
},
|
},
|
||||||
addImageToStagingArea: (
|
canvasSessionIdChanged: (state, action: PayloadAction<string>) => {
|
||||||
|
state.layerState.stagingArea.sessionId = action.payload;
|
||||||
|
},
|
||||||
|
stagingAreaInitialized: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{ sessionId: string; boundingBox: IRect }>
|
||||||
boundingBox: IRect;
|
|
||||||
image: InvokeAI._Image;
|
|
||||||
}>
|
|
||||||
) => {
|
) => {
|
||||||
const { boundingBox, image } = action.payload;
|
const { sessionId, boundingBox } = action.payload;
|
||||||
|
|
||||||
if (!boundingBox || !image) return;
|
state.layerState.stagingArea = {
|
||||||
|
boundingBox,
|
||||||
|
sessionId,
|
||||||
|
images: [],
|
||||||
|
selectedImageIndex: -1,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
addImageToStagingArea: (state, action: PayloadAction<InvokeAI.Image>) => {
|
||||||
|
const image = action.payload;
|
||||||
|
|
||||||
|
if (!image || !state.layerState.stagingArea.boundingBox) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||||
|
|
||||||
@ -307,7 +318,7 @@ export const canvasSlice = createSlice({
|
|||||||
state.layerState.stagingArea.images.push({
|
state.layerState.stagingArea.images.push({
|
||||||
kind: 'image',
|
kind: 'image',
|
||||||
layer: 'base',
|
layer: 'base',
|
||||||
...boundingBox,
|
...state.layerState.stagingArea.boundingBox,
|
||||||
image,
|
image,
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -323,9 +334,7 @@ export const canvasSlice = createSlice({
|
|||||||
state.pastLayerStates.shift();
|
state.pastLayerStates.shift();
|
||||||
}
|
}
|
||||||
|
|
||||||
state.layerState.stagingArea = {
|
state.layerState.stagingArea = { ...initialLayerState.stagingArea };
|
||||||
...initialLayerState.stagingArea,
|
|
||||||
};
|
|
||||||
|
|
||||||
state.futureLayerStates = [];
|
state.futureLayerStates = [];
|
||||||
state.shouldShowStagingOutline = true;
|
state.shouldShowStagingOutline = true;
|
||||||
@ -663,6 +672,10 @@ export const canvasSlice = createSlice({
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
nextStagingAreaImage: (state) => {
|
nextStagingAreaImage: (state) => {
|
||||||
|
if (!state.layerState.stagingArea.images.length) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const currentIndex = state.layerState.stagingArea.selectedImageIndex;
|
const currentIndex = state.layerState.stagingArea.selectedImageIndex;
|
||||||
const length = state.layerState.stagingArea.images.length;
|
const length = state.layerState.stagingArea.images.length;
|
||||||
|
|
||||||
@ -672,6 +685,10 @@ export const canvasSlice = createSlice({
|
|||||||
);
|
);
|
||||||
},
|
},
|
||||||
prevStagingAreaImage: (state) => {
|
prevStagingAreaImage: (state) => {
|
||||||
|
if (!state.layerState.stagingArea.images.length) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const currentIndex = state.layerState.stagingArea.selectedImageIndex;
|
const currentIndex = state.layerState.stagingArea.selectedImageIndex;
|
||||||
|
|
||||||
state.layerState.stagingArea.selectedImageIndex = Math.max(
|
state.layerState.stagingArea.selectedImageIndex = Math.max(
|
||||||
@ -680,6 +697,10 @@ export const canvasSlice = createSlice({
|
|||||||
);
|
);
|
||||||
},
|
},
|
||||||
commitStagingAreaImage: (state) => {
|
commitStagingAreaImage: (state) => {
|
||||||
|
if (!state.layerState.stagingArea.images.length) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const { images, selectedImageIndex } = state.layerState.stagingArea;
|
const { images, selectedImageIndex } = state.layerState.stagingArea;
|
||||||
|
|
||||||
state.pastLayerStates.push(cloneDeep(state.layerState));
|
state.pastLayerStates.push(cloneDeep(state.layerState));
|
||||||
@ -776,6 +797,9 @@ export const canvasSlice = createSlice({
|
|||||||
setShouldRestrictStrokesToBox: (state, action: PayloadAction<boolean>) => {
|
setShouldRestrictStrokesToBox: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldRestrictStrokesToBox = action.payload;
|
state.shouldRestrictStrokesToBox = action.payload;
|
||||||
},
|
},
|
||||||
|
setShouldAntialias: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.shouldAntialias = action.payload;
|
||||||
|
},
|
||||||
setShouldCropToBoundingBoxOnSave: (
|
setShouldCropToBoundingBoxOnSave: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<boolean>
|
action: PayloadAction<boolean>
|
||||||
@ -885,6 +909,9 @@ export const {
|
|||||||
undo,
|
undo,
|
||||||
setScaledBoundingBoxDimensions,
|
setScaledBoundingBoxDimensions,
|
||||||
setShouldRestrictStrokesToBox,
|
setShouldRestrictStrokesToBox,
|
||||||
|
stagingAreaInitialized,
|
||||||
|
canvasSessionIdChanged,
|
||||||
|
setShouldAntialias,
|
||||||
} = canvasSlice.actions;
|
} = canvasSlice.actions;
|
||||||
|
|
||||||
export default canvasSlice.reducer;
|
export default canvasSlice.reducer;
|
||||||
|
@ -37,7 +37,7 @@ export type CanvasImage = {
|
|||||||
y: number;
|
y: number;
|
||||||
width: number;
|
width: number;
|
||||||
height: number;
|
height: number;
|
||||||
image: InvokeAI._Image;
|
image: InvokeAI.Image;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type CanvasMaskLine = {
|
export type CanvasMaskLine = {
|
||||||
@ -90,9 +90,16 @@ export type CanvasLayerState = {
|
|||||||
stagingArea: {
|
stagingArea: {
|
||||||
images: CanvasImage[];
|
images: CanvasImage[];
|
||||||
selectedImageIndex: number;
|
selectedImageIndex: number;
|
||||||
|
sessionId?: string;
|
||||||
|
boundingBox?: IRect;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type CanvasSession = {
|
||||||
|
sessionId: string;
|
||||||
|
boundingBox: IRect;
|
||||||
|
};
|
||||||
|
|
||||||
// type guards
|
// type guards
|
||||||
export const isCanvasMaskLine = (obj: CanvasObject): obj is CanvasMaskLine =>
|
export const isCanvasMaskLine = (obj: CanvasObject): obj is CanvasMaskLine =>
|
||||||
obj.kind === 'line' && obj.layer === 'mask';
|
obj.kind === 'line' && obj.layer === 'mask';
|
||||||
@ -125,7 +132,7 @@ export interface CanvasState {
|
|||||||
cursorPosition: Vector2d | null;
|
cursorPosition: Vector2d | null;
|
||||||
doesCanvasNeedScaling: boolean;
|
doesCanvasNeedScaling: boolean;
|
||||||
futureLayerStates: CanvasLayerState[];
|
futureLayerStates: CanvasLayerState[];
|
||||||
intermediateImage?: InvokeAI._Image;
|
intermediateImage?: InvokeAI.Image;
|
||||||
isCanvasInitialized: boolean;
|
isCanvasInitialized: boolean;
|
||||||
isDrawing: boolean;
|
isDrawing: boolean;
|
||||||
isMaskEnabled: boolean;
|
isMaskEnabled: boolean;
|
||||||
@ -142,6 +149,7 @@ export interface CanvasState {
|
|||||||
minimumStageScale: number;
|
minimumStageScale: number;
|
||||||
pastLayerStates: CanvasLayerState[];
|
pastLayerStates: CanvasLayerState[];
|
||||||
scaledBoundingBoxDimensions: Dimensions;
|
scaledBoundingBoxDimensions: Dimensions;
|
||||||
|
shouldAntialias: boolean;
|
||||||
shouldAutoSave: boolean;
|
shouldAutoSave: boolean;
|
||||||
shouldCropToBoundingBoxOnSave: boolean;
|
shouldCropToBoundingBoxOnSave: boolean;
|
||||||
shouldDarkenOutsideBoundingBox: boolean;
|
shouldDarkenOutsideBoundingBox: boolean;
|
||||||
|
@ -0,0 +1,13 @@
|
|||||||
|
/**
|
||||||
|
* Gets a Blob from a canvas.
|
||||||
|
*/
|
||||||
|
export const canvasToBlob = async (canvas: HTMLCanvasElement): Promise<Blob> =>
|
||||||
|
new Promise((resolve, reject) => {
|
||||||
|
canvas.toBlob((blob) => {
|
||||||
|
if (blob) {
|
||||||
|
resolve(blob);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
reject('Unable to create Blob');
|
||||||
|
});
|
||||||
|
});
|
@ -0,0 +1,29 @@
|
|||||||
|
/**
|
||||||
|
* Gets an ImageData object from an image dataURL by drawing it to a canvas.
|
||||||
|
*/
|
||||||
|
export const dataURLToImageData = async (
|
||||||
|
dataURL: string,
|
||||||
|
width: number,
|
||||||
|
height: number
|
||||||
|
): Promise<ImageData> =>
|
||||||
|
new Promise((resolve, reject) => {
|
||||||
|
const canvas = document.createElement('canvas');
|
||||||
|
canvas.width = width;
|
||||||
|
canvas.height = height;
|
||||||
|
const ctx = canvas.getContext('2d');
|
||||||
|
const image = new Image();
|
||||||
|
|
||||||
|
if (!ctx) {
|
||||||
|
canvas.remove();
|
||||||
|
reject('Unable to get context');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
image.onload = function () {
|
||||||
|
ctx.drawImage(image, 0, 0);
|
||||||
|
canvas.remove();
|
||||||
|
resolve(ctx.getImageData(0, 0, width, height));
|
||||||
|
};
|
||||||
|
|
||||||
|
image.src = dataURL;
|
||||||
|
});
|
@ -1,6 +1,110 @@
|
|||||||
|
// import { CanvasMaskLine } from 'features/canvas/store/canvasTypes';
|
||||||
|
// import Konva from 'konva';
|
||||||
|
// import { Stage } from 'konva/lib/Stage';
|
||||||
|
// import { IRect } from 'konva/lib/types';
|
||||||
|
|
||||||
|
// /**
|
||||||
|
// * Generating a mask image from InpaintingCanvas.tsx is not as simple
|
||||||
|
// * as calling toDataURL() on the canvas, because the mask may be represented
|
||||||
|
// * by colored lines or transparency, or the user may have inverted the mask
|
||||||
|
// * display.
|
||||||
|
// *
|
||||||
|
// * So we need to regenerate the mask image by creating an offscreen canvas,
|
||||||
|
// * drawing the mask and compositing everything correctly to output a valid
|
||||||
|
// * mask image.
|
||||||
|
// */
|
||||||
|
// export const getStageDataURL = (stage: Stage, boundingBox: IRect): string => {
|
||||||
|
// // create an offscreen canvas and add the mask to it
|
||||||
|
// // const { stage, offscreenContainer } = buildMaskStage(lines, boundingBox);
|
||||||
|
|
||||||
|
// const dataURL = stage.toDataURL({ ...boundingBox });
|
||||||
|
|
||||||
|
// // const imageData = stage
|
||||||
|
// // .toCanvas()
|
||||||
|
// // .getContext('2d')
|
||||||
|
// // ?.getImageData(
|
||||||
|
// // boundingBox.x,
|
||||||
|
// // boundingBox.y,
|
||||||
|
// // boundingBox.width,
|
||||||
|
// // boundingBox.height
|
||||||
|
// // );
|
||||||
|
|
||||||
|
// // offscreenContainer.remove();
|
||||||
|
|
||||||
|
// // return { dataURL, imageData };
|
||||||
|
|
||||||
|
// return dataURL;
|
||||||
|
// };
|
||||||
|
|
||||||
|
// export const getStageImageData = (
|
||||||
|
// stage: Stage,
|
||||||
|
// boundingBox: IRect
|
||||||
|
// ): ImageData | undefined => {
|
||||||
|
// const imageData = stage
|
||||||
|
// .toCanvas()
|
||||||
|
// .getContext('2d')
|
||||||
|
// ?.getImageData(
|
||||||
|
// boundingBox.x,
|
||||||
|
// boundingBox.y,
|
||||||
|
// boundingBox.width,
|
||||||
|
// boundingBox.height
|
||||||
|
// );
|
||||||
|
|
||||||
|
// return imageData;
|
||||||
|
// };
|
||||||
|
|
||||||
|
// export const buildMaskStage = (
|
||||||
|
// lines: CanvasMaskLine[],
|
||||||
|
// boundingBox: IRect
|
||||||
|
// ): { stage: Stage; offscreenContainer: HTMLDivElement } => {
|
||||||
|
// // create an offscreen canvas and add the mask to it
|
||||||
|
// const { width, height } = boundingBox;
|
||||||
|
|
||||||
|
// const offscreenContainer = document.createElement('div');
|
||||||
|
|
||||||
|
// const stage = new Konva.Stage({
|
||||||
|
// container: offscreenContainer,
|
||||||
|
// width: width,
|
||||||
|
// height: height,
|
||||||
|
// });
|
||||||
|
|
||||||
|
// const baseLayer = new Konva.Layer();
|
||||||
|
// const maskLayer = new Konva.Layer();
|
||||||
|
|
||||||
|
// // composite the image onto the mask layer
|
||||||
|
// baseLayer.add(
|
||||||
|
// new Konva.Rect({
|
||||||
|
// ...boundingBox,
|
||||||
|
// fill: 'white',
|
||||||
|
// })
|
||||||
|
// );
|
||||||
|
|
||||||
|
// lines.forEach((line) =>
|
||||||
|
// maskLayer.add(
|
||||||
|
// new Konva.Line({
|
||||||
|
// points: line.points,
|
||||||
|
// stroke: 'black',
|
||||||
|
// strokeWidth: line.strokeWidth * 2,
|
||||||
|
// tension: 0,
|
||||||
|
// lineCap: 'round',
|
||||||
|
// lineJoin: 'round',
|
||||||
|
// shadowForStrokeEnabled: false,
|
||||||
|
// globalCompositeOperation:
|
||||||
|
// line.tool === 'brush' ? 'source-over' : 'destination-out',
|
||||||
|
// })
|
||||||
|
// )
|
||||||
|
// );
|
||||||
|
|
||||||
|
// stage.add(baseLayer);
|
||||||
|
// stage.add(maskLayer);
|
||||||
|
|
||||||
|
// return { stage, offscreenContainer };
|
||||||
|
// };
|
||||||
|
|
||||||
import { CanvasMaskLine } from 'features/canvas/store/canvasTypes';
|
import { CanvasMaskLine } from 'features/canvas/store/canvasTypes';
|
||||||
import Konva from 'konva';
|
import Konva from 'konva';
|
||||||
import { IRect } from 'konva/lib/types';
|
import { IRect } from 'konva/lib/types';
|
||||||
|
import { canvasToBlob } from './canvasToBlob';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generating a mask image from InpaintingCanvas.tsx is not as simple
|
* Generating a mask image from InpaintingCanvas.tsx is not as simple
|
||||||
@ -12,7 +116,7 @@ import { IRect } from 'konva/lib/types';
|
|||||||
* drawing the mask and compositing everything correctly to output a valid
|
* drawing the mask and compositing everything correctly to output a valid
|
||||||
* mask image.
|
* mask image.
|
||||||
*/
|
*/
|
||||||
const generateMask = (lines: CanvasMaskLine[], boundingBox: IRect): string => {
|
const generateMask = async (lines: CanvasMaskLine[], boundingBox: IRect) => {
|
||||||
// create an offscreen canvas and add the mask to it
|
// create an offscreen canvas and add the mask to it
|
||||||
const { width, height } = boundingBox;
|
const { width, height } = boundingBox;
|
||||||
|
|
||||||
@ -54,11 +158,13 @@ const generateMask = (lines: CanvasMaskLine[], boundingBox: IRect): string => {
|
|||||||
stage.add(baseLayer);
|
stage.add(baseLayer);
|
||||||
stage.add(maskLayer);
|
stage.add(maskLayer);
|
||||||
|
|
||||||
const dataURL = stage.toDataURL({ ...boundingBox });
|
const maskDataURL = stage.toDataURL(boundingBox);
|
||||||
|
|
||||||
|
const maskBlob = await canvasToBlob(stage.toCanvas(boundingBox));
|
||||||
|
|
||||||
offscreenContainer.remove();
|
offscreenContainer.remove();
|
||||||
|
|
||||||
return dataURL;
|
return { maskDataURL, maskBlob };
|
||||||
};
|
};
|
||||||
|
|
||||||
export default generateMask;
|
export default generateMask;
|
||||||
|
128
invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts
Normal file
128
invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { getCanvasBaseLayer, getCanvasStage } from './konvaInstanceProvider';
|
||||||
|
import { isCanvasMaskLine } from '../store/canvasTypes';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import {
|
||||||
|
areAnyPixelsBlack,
|
||||||
|
getImageDataTransparency,
|
||||||
|
} from 'common/util/arrayBuffer';
|
||||||
|
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||||
|
import generateMask from './generateMask';
|
||||||
|
import { dataURLToImageData } from './dataURLToImageData';
|
||||||
|
import { canvasToBlob } from './canvasToBlob';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'getCanvasDataURLs' });
|
||||||
|
|
||||||
|
export const getCanvasData = async (state: RootState) => {
|
||||||
|
const canvasBaseLayer = getCanvasBaseLayer();
|
||||||
|
const canvasStage = getCanvasStage();
|
||||||
|
|
||||||
|
if (!canvasBaseLayer || !canvasStage) {
|
||||||
|
moduleLog.error('Unable to find canvas / stage');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const {
|
||||||
|
layerState: { objects },
|
||||||
|
boundingBoxCoordinates,
|
||||||
|
boundingBoxDimensions,
|
||||||
|
stageScale,
|
||||||
|
isMaskEnabled,
|
||||||
|
shouldPreserveMaskedArea,
|
||||||
|
boundingBoxScaleMethod: boundingBoxScale,
|
||||||
|
scaledBoundingBoxDimensions,
|
||||||
|
} = state.canvas;
|
||||||
|
|
||||||
|
const boundingBox = {
|
||||||
|
...boundingBoxCoordinates,
|
||||||
|
...boundingBoxDimensions,
|
||||||
|
};
|
||||||
|
|
||||||
|
// generationParameters.fit = false;
|
||||||
|
|
||||||
|
// generationParameters.strength = img2imgStrength;
|
||||||
|
|
||||||
|
// generationParameters.invert_mask = shouldPreserveMaskedArea;
|
||||||
|
|
||||||
|
// generationParameters.bounding_box = boundingBox;
|
||||||
|
|
||||||
|
const tempScale = canvasBaseLayer.scale();
|
||||||
|
|
||||||
|
canvasBaseLayer.scale({
|
||||||
|
x: 1 / stageScale,
|
||||||
|
y: 1 / stageScale,
|
||||||
|
});
|
||||||
|
|
||||||
|
const absPos = canvasBaseLayer.getAbsolutePosition();
|
||||||
|
|
||||||
|
const offsetBoundingBox = {
|
||||||
|
x: boundingBox.x + absPos.x,
|
||||||
|
y: boundingBox.y + absPos.y,
|
||||||
|
width: boundingBox.width,
|
||||||
|
height: boundingBox.height,
|
||||||
|
};
|
||||||
|
|
||||||
|
const baseDataURL = canvasBaseLayer.toDataURL(offsetBoundingBox);
|
||||||
|
const baseBlob = await canvasToBlob(
|
||||||
|
canvasBaseLayer.toCanvas(offsetBoundingBox)
|
||||||
|
);
|
||||||
|
|
||||||
|
canvasBaseLayer.scale(tempScale);
|
||||||
|
|
||||||
|
const { maskDataURL, maskBlob } = await generateMask(
|
||||||
|
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [],
|
||||||
|
boundingBox
|
||||||
|
);
|
||||||
|
|
||||||
|
const baseImageData = await dataURLToImageData(
|
||||||
|
baseDataURL,
|
||||||
|
boundingBox.width,
|
||||||
|
boundingBox.height
|
||||||
|
);
|
||||||
|
|
||||||
|
const maskImageData = await dataURLToImageData(
|
||||||
|
maskDataURL,
|
||||||
|
boundingBox.width,
|
||||||
|
boundingBox.height
|
||||||
|
);
|
||||||
|
|
||||||
|
const {
|
||||||
|
isPartiallyTransparent: baseIsPartiallyTransparent,
|
||||||
|
isFullyTransparent: baseIsFullyTransparent,
|
||||||
|
} = getImageDataTransparency(baseImageData.data);
|
||||||
|
|
||||||
|
const doesMaskHaveBlackPixels = areAnyPixelsBlack(maskImageData.data);
|
||||||
|
|
||||||
|
if (state.system.enableImageDebugging) {
|
||||||
|
openBase64ImageInTab([
|
||||||
|
{ base64: maskDataURL, caption: 'mask b64' },
|
||||||
|
{ base64: baseDataURL, caption: 'image b64' },
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// generationParameters.init_img = imageDataURL;
|
||||||
|
// generationParameters.progress_images = false;
|
||||||
|
|
||||||
|
// if (boundingBoxScale !== 'none') {
|
||||||
|
// generationParameters.inpaint_width = scaledBoundingBoxDimensions.width;
|
||||||
|
// generationParameters.inpaint_height = scaledBoundingBoxDimensions.height;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// generationParameters.seam_size = seamSize;
|
||||||
|
// generationParameters.seam_blur = seamBlur;
|
||||||
|
// generationParameters.seam_strength = seamStrength;
|
||||||
|
// generationParameters.seam_steps = seamSteps;
|
||||||
|
// generationParameters.tile_size = tileSize;
|
||||||
|
// generationParameters.infill_method = infillMethod;
|
||||||
|
// generationParameters.force_outpaint = false;
|
||||||
|
|
||||||
|
return {
|
||||||
|
baseDataURL,
|
||||||
|
baseBlob,
|
||||||
|
maskDataURL,
|
||||||
|
maskBlob,
|
||||||
|
baseIsPartiallyTransparent,
|
||||||
|
baseIsFullyTransparent,
|
||||||
|
doesMaskHaveBlackPixels,
|
||||||
|
};
|
||||||
|
};
|
@ -1,12 +1,17 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { get, isEqual, isNumber, isString } from 'lodash-es';
|
import { isEqual, isString } from 'lodash-es';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
ButtonGroup,
|
ButtonGroup,
|
||||||
Flex,
|
Flex,
|
||||||
FlexProps,
|
FlexProps,
|
||||||
FormControl,
|
IconButton,
|
||||||
Link,
|
Link,
|
||||||
|
Menu,
|
||||||
|
MenuButton,
|
||||||
|
MenuItemOption,
|
||||||
|
MenuList,
|
||||||
|
MenuOptionGroup,
|
||||||
useDisclosure,
|
useDisclosure,
|
||||||
useToast,
|
useToast,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
@ -15,21 +20,12 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import IAIPopover from 'common/components/IAIPopover';
|
import IAIPopover from 'common/components/IAIPopover';
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
|
||||||
import { GalleryState } from 'features/gallery/store/gallerySlice';
|
|
||||||
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
|
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
|
||||||
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
||||||
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
|
|
||||||
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
|
|
||||||
import {
|
|
||||||
initialImageSelected,
|
|
||||||
setAllParameters,
|
|
||||||
// setInitialImage,
|
|
||||||
setSeed,
|
|
||||||
} from 'features/parameters/store/generationSlice';
|
|
||||||
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
|
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
import { SystemState } from 'features/system/store/systemSlice';
|
|
||||||
import {
|
import {
|
||||||
activeTabNameSelector,
|
activeTabNameSelector,
|
||||||
uiSelector,
|
uiSelector,
|
||||||
@ -56,6 +52,7 @@ import {
|
|||||||
FaShare,
|
FaShare,
|
||||||
FaShareAlt,
|
FaShareAlt,
|
||||||
FaTrash,
|
FaTrash,
|
||||||
|
FaWrench,
|
||||||
} from 'react-icons/fa';
|
} from 'react-icons/fa';
|
||||||
import {
|
import {
|
||||||
gallerySelector,
|
gallerySelector,
|
||||||
@ -66,8 +63,13 @@ import { useCallback } from 'react';
|
|||||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||||
import { useGetUrl } from 'common/util/getUrl';
|
import { useGetUrl } from 'common/util/getUrl';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
import { imageDeleted } from 'services/thunks/image';
|
|
||||||
import { useParameters } from 'features/parameters/hooks/useParameters';
|
import { useParameters } from 'features/parameters/hooks/useParameters';
|
||||||
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
|
import { requestedImageDeletion } from '../store/actions';
|
||||||
|
import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceRestore/FaceRestoreSettings';
|
||||||
|
import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings';
|
||||||
|
import { allParametersSet } from 'features/parameters/store/generationSlice';
|
||||||
|
import DeleteImageButton from './ImageActionButtons/DeleteImageButton';
|
||||||
|
|
||||||
const currentImageButtonsSelector = createSelector(
|
const currentImageButtonsSelector = createSelector(
|
||||||
[
|
[
|
||||||
@ -150,6 +152,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
} = useAppSelector(currentImageButtonsSelector);
|
} = useAppSelector(currentImageButtonsSelector);
|
||||||
|
|
||||||
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
||||||
|
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||||
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
|
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
|
||||||
const isFaceRestoreEnabled = useFeatureStatus('faceRestore').isFeatureEnabled;
|
const isFaceRestoreEnabled = useFeatureStatus('faceRestore').isFeatureEnabled;
|
||||||
|
|
||||||
@ -164,40 +167,59 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
const toast = useToast();
|
const toast = useToast();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { recallPrompt, recallSeed, sendToImageToImage } = useParameters();
|
const { recallPrompt, recallSeed, recallAllParameters } = useParameters();
|
||||||
|
|
||||||
const handleCopyImage = useCallback(async () => {
|
// const handleCopyImage = useCallback(async () => {
|
||||||
if (!image?.url) {
|
// if (!image?.url) {
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const url = getUrl(image.url);
|
||||||
|
|
||||||
|
// if (!url) {
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const blob = await fetch(url).then((res) => res.blob());
|
||||||
|
// const data = [new ClipboardItem({ [blob.type]: blob })];
|
||||||
|
|
||||||
|
// await navigator.clipboard.write(data);
|
||||||
|
|
||||||
|
// toast({
|
||||||
|
// title: t('toast.imageCopied'),
|
||||||
|
// status: 'success',
|
||||||
|
// duration: 2500,
|
||||||
|
// isClosable: true,
|
||||||
|
// });
|
||||||
|
// }, [getUrl, t, image?.url, toast]);
|
||||||
|
|
||||||
|
const handleCopyImageLink = useCallback(() => {
|
||||||
|
const getImageUrl = () => {
|
||||||
|
if (!image) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const url = getUrl(image.url);
|
if (shouldTransformUrls) {
|
||||||
|
return getUrl(image.url);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (image.url.startsWith('http')) {
|
||||||
|
return image.url;
|
||||||
|
}
|
||||||
|
|
||||||
|
return window.location.toString() + image.url;
|
||||||
|
};
|
||||||
|
|
||||||
|
const url = getImageUrl();
|
||||||
|
|
||||||
if (!url) {
|
if (!url) {
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const blob = await fetch(url).then((res) => res.blob());
|
|
||||||
const data = [new ClipboardItem({ [blob.type]: blob })];
|
|
||||||
|
|
||||||
await navigator.clipboard.write(data);
|
|
||||||
|
|
||||||
toast({
|
toast({
|
||||||
title: t('toast.imageCopied'),
|
title: t('toast.problemCopyingImageLink'),
|
||||||
status: 'success',
|
status: 'error',
|
||||||
duration: 2500,
|
duration: 2500,
|
||||||
isClosable: true,
|
isClosable: true,
|
||||||
});
|
});
|
||||||
}, [getUrl, t, image?.url, toast]);
|
|
||||||
|
|
||||||
const handleCopyImageLink = useCallback(() => {
|
|
||||||
const url = image
|
|
||||||
? shouldTransformUrls
|
|
||||||
? getUrl(image.url)
|
|
||||||
: window.location.toString() + image.url
|
|
||||||
: '';
|
|
||||||
|
|
||||||
if (!url) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -216,39 +238,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
}, [dispatch, shouldHidePreview]);
|
}, [dispatch, shouldHidePreview]);
|
||||||
|
|
||||||
const handleClickUseAllParameters = useCallback(() => {
|
const handleClickUseAllParameters = useCallback(() => {
|
||||||
if (!image) return;
|
recallAllParameters(image);
|
||||||
// selectedImage.metadata &&
|
}, [image, recallAllParameters]);
|
||||||
// dispatch(setAllParameters(selectedImage.metadata));
|
|
||||||
// if (selectedImage.metadata?.image.type === 'img2img') {
|
|
||||||
// dispatch(setActiveTab('img2img'));
|
|
||||||
// } else if (selectedImage.metadata?.image.type === 'txt2img') {
|
|
||||||
// dispatch(setActiveTab('txt2img'));
|
|
||||||
// }
|
|
||||||
}, [image]);
|
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
'a',
|
'a',
|
||||||
() => {
|
() => {
|
||||||
const type = image?.metadata?.invokeai?.node?.types;
|
handleClickUseAllParameters;
|
||||||
if (isString(type) && ['txt2img', 'img2img'].includes(type)) {
|
|
||||||
handleClickUseAllParameters();
|
|
||||||
toast({
|
|
||||||
title: t('toast.parametersSet'),
|
|
||||||
status: 'success',
|
|
||||||
duration: 2500,
|
|
||||||
isClosable: true,
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
toast({
|
|
||||||
title: t('toast.parametersNotSet'),
|
|
||||||
description: t('toast.parametersNotSetDesc'),
|
|
||||||
status: 'error',
|
|
||||||
duration: 2500,
|
|
||||||
isClosable: true,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
[image]
|
[image, recallAllParameters]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleUseSeed = useCallback(() => {
|
const handleUseSeed = useCallback(() => {
|
||||||
@ -264,8 +262,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
useHotkeys('p', handleUsePrompt, [image]);
|
useHotkeys('p', handleUsePrompt, [image]);
|
||||||
|
|
||||||
const handleSendToImageToImage = useCallback(() => {
|
const handleSendToImageToImage = useCallback(() => {
|
||||||
sendToImageToImage(image);
|
dispatch(initialImageSelected(image));
|
||||||
}, [image, sendToImageToImage]);
|
}, [dispatch, image]);
|
||||||
|
|
||||||
useHotkeys('shift+i', handleSendToImageToImage, [image]);
|
useHotkeys('shift+i', handleSendToImageToImage, [image]);
|
||||||
|
|
||||||
@ -375,7 +373,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
|
|
||||||
const handleDelete = useCallback(() => {
|
const handleDelete = useCallback(() => {
|
||||||
if (canDeleteImage && image) {
|
if (canDeleteImage && image) {
|
||||||
dispatch(imageDeleted({ imageType: image.type, imageName: image.name }));
|
dispatch(requestedImageDeletion(image));
|
||||||
}
|
}
|
||||||
}, [image, canDeleteImage, dispatch]);
|
}, [image, canDeleteImage, dispatch]);
|
||||||
|
|
||||||
@ -432,6 +430,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
>
|
>
|
||||||
{t('parameters.sendToImg2Img')}
|
{t('parameters.sendToImg2Img')}
|
||||||
</IAIButton>
|
</IAIButton>
|
||||||
|
{isCanvasEnabled && (
|
||||||
<IAIButton
|
<IAIButton
|
||||||
size="sm"
|
size="sm"
|
||||||
onClick={handleSendToCanvas}
|
onClick={handleSendToCanvas}
|
||||||
@ -439,14 +438,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
>
|
>
|
||||||
{t('parameters.sendToUnifiedCanvas')}
|
{t('parameters.sendToUnifiedCanvas')}
|
||||||
</IAIButton>
|
</IAIButton>
|
||||||
|
)}
|
||||||
|
|
||||||
<IAIButton
|
{/* <IAIButton
|
||||||
size="sm"
|
size="sm"
|
||||||
onClick={handleCopyImage}
|
onClick={handleCopyImage}
|
||||||
leftIcon={<FaCopy />}
|
leftIcon={<FaCopy />}
|
||||||
>
|
>
|
||||||
{t('parameters.copyImage')}
|
{t('parameters.copyImage')}
|
||||||
</IAIButton>
|
</IAIButton> */}
|
||||||
<IAIButton
|
<IAIButton
|
||||||
size="sm"
|
size="sm"
|
||||||
onClick={handleCopyImageLink}
|
onClick={handleCopyImageLink}
|
||||||
@ -462,7 +462,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
</Link>
|
</Link>
|
||||||
</Flex>
|
</Flex>
|
||||||
</IAIPopover>
|
</IAIPopover>
|
||||||
<IAIIconButton
|
{/* <IAIIconButton
|
||||||
icon={shouldHidePreview ? <FaEyeSlash /> : <FaEye />}
|
icon={shouldHidePreview ? <FaEyeSlash /> : <FaEye />}
|
||||||
tooltip={
|
tooltip={
|
||||||
!shouldHidePreview
|
!shouldHidePreview
|
||||||
@ -476,7 +476,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
}
|
}
|
||||||
isChecked={shouldHidePreview}
|
isChecked={shouldHidePreview}
|
||||||
onClick={handlePreviewVisibility}
|
onClick={handlePreviewVisibility}
|
||||||
/>
|
/> */}
|
||||||
{isLightboxEnabled && (
|
{isLightboxEnabled && (
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
icon={<FaExpand />}
|
icon={<FaExpand />}
|
||||||
@ -518,8 +518,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
tooltip={`${t('parameters.useAll')} (A)`}
|
tooltip={`${t('parameters.useAll')} (A)`}
|
||||||
aria-label={`${t('parameters.useAll')} (A)`}
|
aria-label={`${t('parameters.useAll')} (A)`}
|
||||||
isDisabled={
|
isDisabled={
|
||||||
!['txt2img', 'img2img'].includes(
|
!['txt2img', 'img2img', 'inpaint'].includes(
|
||||||
image?.metadata?.sd_metadata?.type
|
String(image?.metadata?.invokeai?.node?.type)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
onClick={handleClickUseAllParameters}
|
onClick={handleClickUseAllParameters}
|
||||||
@ -602,22 +602,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
/>
|
/>
|
||||||
</ButtonGroup>
|
</ButtonGroup>
|
||||||
|
|
||||||
<IAIIconButton
|
<ButtonGroup isAttached={true}>
|
||||||
onClick={handleInitiateDelete}
|
<DeleteImageButton image={image} />
|
||||||
icon={<FaTrash />}
|
</ButtonGroup>
|
||||||
tooltip={`${t('gallery.deleteImage')} (Del)`}
|
|
||||||
aria-label={`${t('gallery.deleteImage')} (Del)`}
|
|
||||||
isDisabled={!image || !isConnected}
|
|
||||||
colorScheme="error"
|
|
||||||
/>
|
|
||||||
</Flex>
|
</Flex>
|
||||||
{image && (
|
|
||||||
<DeleteImageModal
|
|
||||||
isOpen={isDeleteDialogOpen}
|
|
||||||
onClose={onDeleteDialogClose}
|
|
||||||
handleDelete={handleDelete}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -2,26 +2,35 @@ import { Box, Flex, Image } from '@chakra-ui/react';
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useGetUrl } from 'common/util/getUrl';
|
import { useGetUrl } from 'common/util/getUrl';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
|
||||||
import { selectedImageSelector } from '../store/gallerySelectors';
|
import { gallerySelector } from '../store/gallerySelectors';
|
||||||
import CurrentImageFallback from './CurrentImageFallback';
|
|
||||||
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
||||||
import NextPrevImageButtons from './NextPrevImageButtons';
|
import NextPrevImageButtons from './NextPrevImageButtons';
|
||||||
import CurrentImageHidden from './CurrentImageHidden';
|
import CurrentImageHidden from './CurrentImageHidden';
|
||||||
import { memo } from 'react';
|
import { DragEvent, memo, useCallback } from 'react';
|
||||||
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
|
import ImageFallbackSpinner from './ImageFallbackSpinner';
|
||||||
|
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
|
||||||
|
|
||||||
export const imagesSelector = createSelector(
|
export const imagesSelector = createSelector(
|
||||||
[uiSelector, selectedImageSelector, systemSelector],
|
[uiSelector, gallerySelector, systemSelector],
|
||||||
(ui, selectedImage, system) => {
|
(ui, gallery, system) => {
|
||||||
const { shouldShowImageDetails, shouldHidePreview } = ui;
|
const {
|
||||||
|
shouldShowImageDetails,
|
||||||
|
shouldHidePreview,
|
||||||
|
shouldShowProgressInViewer,
|
||||||
|
} = ui;
|
||||||
|
const { selectedImage } = gallery;
|
||||||
|
const { progressImage, shouldAntialiasProgressImage } = system;
|
||||||
return {
|
return {
|
||||||
shouldShowImageDetails,
|
shouldShowImageDetails,
|
||||||
shouldHidePreview,
|
shouldHidePreview,
|
||||||
image: selectedImage,
|
image: selectedImage,
|
||||||
|
progressImage,
|
||||||
|
shouldShowProgressInViewer,
|
||||||
|
shouldAntialiasProgressImage,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -32,26 +41,61 @@ export const imagesSelector = createSelector(
|
|||||||
);
|
);
|
||||||
|
|
||||||
const CurrentImagePreview = () => {
|
const CurrentImagePreview = () => {
|
||||||
const { shouldShowImageDetails, image, shouldHidePreview } =
|
const {
|
||||||
useAppSelector(imagesSelector);
|
shouldShowImageDetails,
|
||||||
|
image,
|
||||||
|
shouldHidePreview,
|
||||||
|
progressImage,
|
||||||
|
shouldShowProgressInViewer,
|
||||||
|
shouldAntialiasProgressImage,
|
||||||
|
} = useAppSelector(imagesSelector);
|
||||||
const { getUrl } = useGetUrl();
|
const { getUrl } = useGetUrl();
|
||||||
|
|
||||||
|
const handleDragStart = useCallback(
|
||||||
|
(e: DragEvent<HTMLDivElement>) => {
|
||||||
|
if (!image) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
e.dataTransfer.setData('invokeai/imageName', image.name);
|
||||||
|
e.dataTransfer.setData('invokeai/imageType', image.type);
|
||||||
|
e.dataTransfer.effectAllowed = 'move';
|
||||||
|
},
|
||||||
|
[image]
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
sx={{
|
sx={{
|
||||||
position: 'relative',
|
|
||||||
justifyContent: 'center',
|
|
||||||
alignItems: 'center',
|
|
||||||
width: '100%',
|
width: '100%',
|
||||||
height: '100%',
|
height: '100%',
|
||||||
|
position: 'relative',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{image && (
|
{progressImage && shouldShowProgressInViewer ? (
|
||||||
<Image
|
<Image
|
||||||
src={shouldHidePreview ? undefined : getUrl(image.url)}
|
src={progressImage.dataURL}
|
||||||
width={image.metadata.width}
|
width={progressImage.width}
|
||||||
height={image.metadata.height}
|
height={progressImage.height}
|
||||||
fallback={shouldHidePreview ? <CurrentImageHidden /> : undefined}
|
sx={{
|
||||||
|
objectFit: 'contain',
|
||||||
|
maxWidth: '100%',
|
||||||
|
maxHeight: '100%',
|
||||||
|
height: 'auto',
|
||||||
|
position: 'absolute',
|
||||||
|
borderRadius: 'base',
|
||||||
|
imageRendering: shouldAntialiasProgressImage ? 'auto' : 'pixelated',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
image && (
|
||||||
|
<>
|
||||||
|
<Image
|
||||||
|
src={getUrl(image.url)}
|
||||||
|
fallbackStrategy="beforeLoadOrError"
|
||||||
|
fallback={<ImageFallbackSpinner />}
|
||||||
|
onDragStart={handleDragStart}
|
||||||
sx={{
|
sx={{
|
||||||
objectFit: 'contain',
|
objectFit: 'contain',
|
||||||
maxWidth: '100%',
|
maxWidth: '100%',
|
||||||
@ -61,6 +105,9 @@ const CurrentImagePreview = () => {
|
|||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
<ImageMetadataOverlay image={image} />
|
||||||
|
</>
|
||||||
|
)
|
||||||
)}
|
)}
|
||||||
{shouldShowImageDetails && image && 'metadata' in image && (
|
{shouldShowImageDetails && image && 'metadata' in image && (
|
||||||
<Box
|
<Box
|
||||||
|
@ -0,0 +1,70 @@
|
|||||||
|
import { Flex, Image, Spinner } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { gallerySelector } from '../store/gallerySelectors';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[systemSelector, gallerySelector],
|
||||||
|
(system, gallery) => {
|
||||||
|
const { shouldUseSingleGalleryColumn, galleryImageObjectFit } = gallery;
|
||||||
|
const { progressImage, shouldAntialiasProgressImage } = system;
|
||||||
|
|
||||||
|
return {
|
||||||
|
progressImage,
|
||||||
|
shouldUseSingleGalleryColumn,
|
||||||
|
galleryImageObjectFit,
|
||||||
|
shouldAntialiasProgressImage,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const GalleryProgressImage = () => {
|
||||||
|
const {
|
||||||
|
progressImage,
|
||||||
|
shouldUseSingleGalleryColumn,
|
||||||
|
galleryImageObjectFit,
|
||||||
|
shouldAntialiasProgressImage,
|
||||||
|
} = useAppSelector(selector);
|
||||||
|
|
||||||
|
if (!progressImage) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
aspectRatio: '1/1',
|
||||||
|
position: 'relative',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Image
|
||||||
|
draggable={false}
|
||||||
|
src={progressImage.dataURL}
|
||||||
|
width={progressImage.width}
|
||||||
|
height={progressImage.height}
|
||||||
|
sx={{
|
||||||
|
objectFit: shouldUseSingleGalleryColumn
|
||||||
|
? 'contain'
|
||||||
|
: galleryImageObjectFit,
|
||||||
|
width: '100%',
|
||||||
|
height: '100%',
|
||||||
|
maxWidth: '100%',
|
||||||
|
maxHeight: '100%',
|
||||||
|
borderRadius: 'base',
|
||||||
|
imageRendering: shouldAntialiasProgressImage ? 'auto' : 'pixelated',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<Spinner sx={{ position: 'absolute', top: 1, right: 1, opacity: 0.7 }} />
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(GalleryProgressImage);
|
@ -5,19 +5,20 @@ import {
|
|||||||
Image,
|
Image,
|
||||||
MenuItem,
|
MenuItem,
|
||||||
MenuList,
|
MenuList,
|
||||||
Skeleton,
|
|
||||||
useDisclosure,
|
useDisclosure,
|
||||||
useTheme,
|
|
||||||
useToast,
|
useToast,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
import { DragEvent, memo, useCallback, useState } from 'react';
|
import { DragEvent, MouseEvent, memo, useCallback, useState } from 'react';
|
||||||
import { FaCheck, FaExpand, FaImage, FaShare, FaTrash } from 'react-icons/fa';
|
import { FaCheck, FaExpand, FaImage, FaShare, FaTrash } from 'react-icons/fa';
|
||||||
import DeleteImageModal from './DeleteImageModal';
|
import DeleteImageModal from './DeleteImageModal';
|
||||||
import { ContextMenu } from 'chakra-ui-contextmenu';
|
import { ContextMenu } from 'chakra-ui-contextmenu';
|
||||||
import * as InvokeAI from 'app/types/invokeai';
|
import * as InvokeAI from 'app/types/invokeai';
|
||||||
import { resizeAndScaleCanvas } from 'features/canvas/store/canvasSlice';
|
import {
|
||||||
|
resizeAndScaleCanvas,
|
||||||
|
setInitialCanvasImage,
|
||||||
|
} from 'features/canvas/store/canvasSlice';
|
||||||
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
|
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
|
||||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -25,7 +26,6 @@ import IAIIconButton from 'common/components/IAIIconButton';
|
|||||||
import { useGetUrl } from 'common/util/getUrl';
|
import { useGetUrl } from 'common/util/getUrl';
|
||||||
import { ExternalLinkIcon } from '@chakra-ui/icons';
|
import { ExternalLinkIcon } from '@chakra-ui/icons';
|
||||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||||
import { imageDeleted } from 'services/thunks/image';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
|
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
|
||||||
@ -33,6 +33,8 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
|||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
import { useParameters } from 'features/parameters/hooks/useParameters';
|
import { useParameters } from 'features/parameters/hooks/useParameters';
|
||||||
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
|
import { requestedImageDeletion } from '../store/actions';
|
||||||
|
|
||||||
export const selector = createSelector(
|
export const selector = createSelector(
|
||||||
[gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector],
|
[gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector],
|
||||||
@ -94,16 +96,19 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
} = useDisclosure();
|
} = useDisclosure();
|
||||||
|
|
||||||
const { image, isSelected } = props;
|
const { image, isSelected } = props;
|
||||||
const { url, thumbnail, name, metadata } = image;
|
const { url, thumbnail, name } = image;
|
||||||
const { getUrl } = useGetUrl();
|
const { getUrl } = useGetUrl();
|
||||||
|
|
||||||
const [isHovered, setIsHovered] = useState<boolean>(false);
|
const [isHovered, setIsHovered] = useState<boolean>(false);
|
||||||
|
|
||||||
const toast = useToast();
|
const toast = useToast();
|
||||||
const { direction } = useTheme();
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { isFeatureEnabled: isLightboxEnabled } = useFeatureStatus('lightbox');
|
|
||||||
const { recallSeed, recallPrompt, sendToImageToImage, recallInitialImage } =
|
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
||||||
|
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||||
|
|
||||||
|
const { recallSeed, recallPrompt, recallInitialImage, recallAllParameters } =
|
||||||
useParameters();
|
useParameters();
|
||||||
|
|
||||||
const handleMouseOver = () => setIsHovered(true);
|
const handleMouseOver = () => setIsHovered(true);
|
||||||
@ -112,18 +117,22 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
// Immediately deletes an image
|
// Immediately deletes an image
|
||||||
const handleDelete = useCallback(() => {
|
const handleDelete = useCallback(() => {
|
||||||
if (canDeleteImage && image) {
|
if (canDeleteImage && image) {
|
||||||
dispatch(imageDeleted({ imageType: image.type, imageName: image.name }));
|
dispatch(requestedImageDeletion(image));
|
||||||
}
|
}
|
||||||
}, [dispatch, image, canDeleteImage]);
|
}, [dispatch, image, canDeleteImage]);
|
||||||
|
|
||||||
// Opens the alert dialog to check if user is sure they want to delete
|
// Opens the alert dialog to check if user is sure they want to delete
|
||||||
const handleInitiateDelete = useCallback(() => {
|
const handleInitiateDelete = useCallback(
|
||||||
|
(e: MouseEvent) => {
|
||||||
|
e.stopPropagation();
|
||||||
if (shouldConfirmOnDelete) {
|
if (shouldConfirmOnDelete) {
|
||||||
onDeleteDialogOpen();
|
onDeleteDialogOpen();
|
||||||
} else {
|
} else {
|
||||||
handleDelete();
|
handleDelete();
|
||||||
}
|
}
|
||||||
}, [handleDelete, onDeleteDialogOpen, shouldConfirmOnDelete]);
|
},
|
||||||
|
[handleDelete, onDeleteDialogOpen, shouldConfirmOnDelete]
|
||||||
|
);
|
||||||
|
|
||||||
const handleSelectImage = useCallback(() => {
|
const handleSelectImage = useCallback(() => {
|
||||||
dispatch(imageSelected(image));
|
dispatch(imageSelected(image));
|
||||||
@ -148,8 +157,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
}, [image, recallSeed]);
|
}, [image, recallSeed]);
|
||||||
|
|
||||||
const handleSendToImageToImage = useCallback(() => {
|
const handleSendToImageToImage = useCallback(() => {
|
||||||
sendToImageToImage(image);
|
dispatch(initialImageSelected(image));
|
||||||
}, [image, sendToImageToImage]);
|
}, [dispatch, image]);
|
||||||
|
|
||||||
const handleRecallInitialImage = useCallback(() => {
|
const handleRecallInitialImage = useCallback(() => {
|
||||||
recallInitialImage(image.metadata.invokeai?.node?.image);
|
recallInitialImage(image.metadata.invokeai?.node?.image);
|
||||||
@ -159,7 +168,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
* TODO: the rest of these
|
* TODO: the rest of these
|
||||||
*/
|
*/
|
||||||
const handleSendToCanvas = () => {
|
const handleSendToCanvas = () => {
|
||||||
// dispatch(setInitialCanvasImage(image));
|
dispatch(setInitialCanvasImage(image));
|
||||||
|
|
||||||
dispatch(resizeAndScaleCanvas());
|
dispatch(resizeAndScaleCanvas());
|
||||||
|
|
||||||
@ -175,16 +184,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleUseAllParameters = () => {
|
const handleUseAllParameters = useCallback(() => {
|
||||||
// metadata.invokeai?.node &&
|
recallAllParameters(image);
|
||||||
// dispatch(setAllParameters(metadata.invokeai?.node));
|
}, [image, recallAllParameters]);
|
||||||
// toast({
|
|
||||||
// title: t('toast.parametersSet'),
|
|
||||||
// status: 'success',
|
|
||||||
// duration: 2500,
|
|
||||||
// isClosable: true,
|
|
||||||
// });
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleLightBox = () => {
|
const handleLightBox = () => {
|
||||||
// dispatch(setCurrentImage(image));
|
// dispatch(setCurrentImage(image));
|
||||||
@ -238,7 +240,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
icon={<IoArrowUndoCircleOutline />}
|
icon={<IoArrowUndoCircleOutline />}
|
||||||
onClickCapture={handleUseAllParameters}
|
onClickCapture={handleUseAllParameters}
|
||||||
isDisabled={
|
isDisabled={
|
||||||
!['txt2img', 'img2img'].includes(
|
!['txt2img', 'img2img', 'inpaint'].includes(
|
||||||
String(image?.metadata?.invokeai?.node?.type)
|
String(image?.metadata?.invokeai?.node?.type)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -251,9 +253,11 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
>
|
>
|
||||||
{t('parameters.sendToImg2Img')}
|
{t('parameters.sendToImg2Img')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
|
{isCanvasEnabled && (
|
||||||
<MenuItem icon={<FaShare />} onClickCapture={handleSendToCanvas}>
|
<MenuItem icon={<FaShare />} onClickCapture={handleSendToCanvas}>
|
||||||
{t('parameters.sendToUnifiedCanvas')}
|
{t('parameters.sendToUnifiedCanvas')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
|
)}
|
||||||
<MenuItem icon={<FaTrash />} onClickCapture={onDeleteDialogOpen}>
|
<MenuItem icon={<FaTrash />} onClickCapture={onDeleteDialogOpen}>
|
||||||
{t('gallery.deleteImage')}
|
{t('gallery.deleteImage')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
@ -279,6 +283,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
h: 'full',
|
h: 'full',
|
||||||
transition: 'transform 0.2s ease-out',
|
transition: 'transform 0.2s ease-out',
|
||||||
aspectRatio: '1/1',
|
aspectRatio: '1/1',
|
||||||
|
cursor: 'pointer',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Image
|
<Image
|
||||||
@ -315,6 +320,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
sx={{
|
sx={{
|
||||||
width: '50%',
|
width: '50%',
|
||||||
height: '50%',
|
height: '50%',
|
||||||
|
maxWidth: '4rem',
|
||||||
|
maxHeight: '4rem',
|
||||||
fill: 'ok.500',
|
fill: 'ok.500',
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
@ -0,0 +1,92 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
|
||||||
|
import { useDisclosure } from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
|
|
||||||
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { FaTrash } from 'react-icons/fa';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import DeleteImageModal from '../DeleteImageModal';
|
||||||
|
import { requestedImageDeletion } from 'features/gallery/store/actions';
|
||||||
|
import { Image } from 'app/types/invokeai';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[systemSelector],
|
||||||
|
(system) => {
|
||||||
|
const { isProcessing, isConnected, shouldConfirmOnDelete } = system;
|
||||||
|
|
||||||
|
return {
|
||||||
|
canDeleteImage: isConnected && !isProcessing,
|
||||||
|
shouldConfirmOnDelete,
|
||||||
|
isProcessing,
|
||||||
|
isConnected,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
type DeleteImageButtonProps = {
|
||||||
|
image: Image | undefined;
|
||||||
|
};
|
||||||
|
|
||||||
|
const DeleteImageButton = (props: DeleteImageButtonProps) => {
|
||||||
|
const { image } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { isProcessing, isConnected, canDeleteImage, shouldConfirmOnDelete } =
|
||||||
|
useAppSelector(selector);
|
||||||
|
|
||||||
|
const {
|
||||||
|
isOpen: isDeleteDialogOpen,
|
||||||
|
onOpen: onDeleteDialogOpen,
|
||||||
|
onClose: onDeleteDialogClose,
|
||||||
|
} = useDisclosure();
|
||||||
|
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleDelete = useCallback(() => {
|
||||||
|
if (canDeleteImage && image) {
|
||||||
|
dispatch(requestedImageDeletion(image));
|
||||||
|
}
|
||||||
|
}, [image, canDeleteImage, dispatch]);
|
||||||
|
|
||||||
|
const handleInitiateDelete = useCallback(() => {
|
||||||
|
if (shouldConfirmOnDelete) {
|
||||||
|
onDeleteDialogOpen();
|
||||||
|
} else {
|
||||||
|
handleDelete();
|
||||||
|
}
|
||||||
|
}, [shouldConfirmOnDelete, onDeleteDialogOpen, handleDelete]);
|
||||||
|
|
||||||
|
useHotkeys('delete', handleInitiateDelete, [
|
||||||
|
image,
|
||||||
|
shouldConfirmOnDelete,
|
||||||
|
isConnected,
|
||||||
|
isProcessing,
|
||||||
|
]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<IAIIconButton
|
||||||
|
onClick={handleInitiateDelete}
|
||||||
|
icon={<FaTrash />}
|
||||||
|
tooltip={`${t('gallery.deleteImage')} (Del)`}
|
||||||
|
aria-label={`${t('gallery.deleteImage')} (Del)`}
|
||||||
|
isDisabled={!image || !isConnected}
|
||||||
|
colorScheme="error"
|
||||||
|
/>
|
||||||
|
{image && (
|
||||||
|
<DeleteImageModal
|
||||||
|
isOpen={isDeleteDialogOpen}
|
||||||
|
onClose={onDeleteDialogClose}
|
||||||
|
handleDelete={handleDelete}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(DeleteImageButton);
|
@ -1,8 +1,8 @@
|
|||||||
import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react';
|
import { Flex, Spinner, SpinnerProps } from '@chakra-ui/react';
|
||||||
|
|
||||||
type CurrentImageFallbackProps = SpinnerProps;
|
type ImageFallbackSpinnerProps = SpinnerProps;
|
||||||
|
|
||||||
const CurrentImageFallback = (props: CurrentImageFallbackProps) => {
|
const ImageFallbackSpinner = (props: ImageFallbackSpinnerProps) => {
|
||||||
const { size = 'xl', ...rest } = props;
|
const { size = 'xl', ...rest } = props;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -21,4 +21,4 @@ const CurrentImageFallback = (props: CurrentImageFallbackProps) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default CurrentImageFallback;
|
export default ImageFallbackSpinner;
|
@ -5,6 +5,7 @@ import {
|
|||||||
FlexProps,
|
FlexProps,
|
||||||
Grid,
|
Grid,
|
||||||
Icon,
|
Icon,
|
||||||
|
Image,
|
||||||
Text,
|
Text,
|
||||||
forwardRef,
|
forwardRef,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
@ -14,7 +15,10 @@ import IAICheckbox from 'common/components/IAICheckbox';
|
|||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import IAIPopover from 'common/components/IAIPopover';
|
import IAIPopover from 'common/components/IAIPopover';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { imageGallerySelector } from 'features/gallery/store/gallerySelectors';
|
import {
|
||||||
|
gallerySelector,
|
||||||
|
imageGallerySelector,
|
||||||
|
} from 'features/gallery/store/gallerySelectors';
|
||||||
import {
|
import {
|
||||||
setCurrentCategory,
|
setCurrentCategory,
|
||||||
setGalleryImageMinimumWidth,
|
setGalleryImageMinimumWidth,
|
||||||
@ -50,30 +54,42 @@ import { uploadsAdapter } from '../store/uploadsSlice';
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso';
|
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso';
|
||||||
|
import { Image as ImageType } from 'app/types/invokeai';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import GalleryProgressImage from './GalleryProgressImage';
|
||||||
|
|
||||||
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
|
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
|
||||||
|
const PROGRESS_IMAGE_PLACEHOLDER = 'PROGRESS_IMAGE_PLACEHOLDER';
|
||||||
|
|
||||||
const gallerySelector = createSelector(
|
const selector = createSelector(
|
||||||
[
|
[(state: RootState) => state],
|
||||||
(state: RootState) => state.uploads,
|
(state) => {
|
||||||
(state: RootState) => state.results,
|
const { results, uploads, system, gallery } = state;
|
||||||
(state: RootState) => state.gallery,
|
|
||||||
],
|
|
||||||
(uploads, results, gallery) => {
|
|
||||||
const { currentCategory } = gallery;
|
const { currentCategory } = gallery;
|
||||||
|
|
||||||
return currentCategory === 'results'
|
if (currentCategory === 'results') {
|
||||||
? {
|
const tempImages: (ImageType | typeof PROGRESS_IMAGE_PLACEHOLDER)[] = [];
|
||||||
images: resultsAdapter.getSelectors().selectAll(results),
|
|
||||||
|
if (system.progressImage) {
|
||||||
|
tempImages.push(PROGRESS_IMAGE_PLACEHOLDER);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
images: tempImages.concat(
|
||||||
|
resultsAdapter.getSelectors().selectAll(results)
|
||||||
|
),
|
||||||
isLoading: results.isLoading,
|
isLoading: results.isLoading,
|
||||||
areMoreImagesAvailable: results.page < results.pages - 1,
|
areMoreImagesAvailable: results.page < results.pages - 1,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
: {
|
|
||||||
|
return {
|
||||||
images: uploadsAdapter.getSelectors().selectAll(uploads),
|
images: uploadsAdapter.getSelectors().selectAll(uploads),
|
||||||
isLoading: uploads.isLoading,
|
isLoading: uploads.isLoading,
|
||||||
areMoreImagesAvailable: uploads.page < uploads.pages - 1,
|
areMoreImagesAvailable: uploads.page < uploads.pages - 1,
|
||||||
};
|
};
|
||||||
}
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
const ImageGalleryContent = () => {
|
const ImageGalleryContent = () => {
|
||||||
@ -108,7 +124,7 @@ const ImageGalleryContent = () => {
|
|||||||
} = useAppSelector(imageGallerySelector);
|
} = useAppSelector(imageGallerySelector);
|
||||||
|
|
||||||
const { images, areMoreImagesAvailable, isLoading } =
|
const { images, areMoreImagesAvailable, isLoading } =
|
||||||
useAppSelector(gallerySelector);
|
useAppSelector(selector);
|
||||||
|
|
||||||
const handleClickLoadMore = () => {
|
const handleClickLoadMore = () => {
|
||||||
if (currentCategory === 'results') {
|
if (currentCategory === 'results') {
|
||||||
@ -170,8 +186,24 @@ const ImageGalleryContent = () => {
|
|||||||
}
|
}
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
const handleEndReached = useCallback(() => {
|
||||||
|
if (currentCategory === 'results') {
|
||||||
|
dispatch(receivedResultImagesPage());
|
||||||
|
} else if (currentCategory === 'uploads') {
|
||||||
|
dispatch(receivedUploadImagesPage());
|
||||||
|
}
|
||||||
|
}, [dispatch, currentCategory]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex flexDirection="column" w="full" h="full" gap={4}>
|
<Flex
|
||||||
|
sx={{
|
||||||
|
gap: 2,
|
||||||
|
flexDirection: 'column',
|
||||||
|
h: 'full',
|
||||||
|
w: 'full',
|
||||||
|
borderRadius: 'base',
|
||||||
|
}}
|
||||||
|
>
|
||||||
<Flex
|
<Flex
|
||||||
ref={resizeObserverRef}
|
ref={resizeObserverRef}
|
||||||
alignItems="center"
|
alignItems="center"
|
||||||
@ -290,18 +322,27 @@ const ImageGalleryContent = () => {
|
|||||||
<Virtuoso
|
<Virtuoso
|
||||||
style={{ height: '100%' }}
|
style={{ height: '100%' }}
|
||||||
data={images}
|
data={images}
|
||||||
|
endReached={handleEndReached}
|
||||||
scrollerRef={(ref) => setScrollerRef(ref)}
|
scrollerRef={(ref) => setScrollerRef(ref)}
|
||||||
itemContent={(index, image) => {
|
itemContent={(index, image) => {
|
||||||
const { name } = image;
|
const isSelected =
|
||||||
const isSelected = selectedImage?.name === name;
|
image === PROGRESS_IMAGE_PLACEHOLDER
|
||||||
|
? false
|
||||||
|
: selectedImage?.name === image?.name;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex sx={{ pb: 2 }}>
|
<Flex sx={{ pb: 2 }}>
|
||||||
|
{image === PROGRESS_IMAGE_PLACEHOLDER ? (
|
||||||
|
<GalleryProgressImage
|
||||||
|
key={PROGRESS_IMAGE_PLACEHOLDER}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
<HoverableImage
|
<HoverableImage
|
||||||
key={`${name}-${image.thumbnail}`}
|
key={`${image.name}-${image.thumbnail}`}
|
||||||
image={image}
|
image={image}
|
||||||
isSelected={isSelected}
|
isSelected={isSelected}
|
||||||
/>
|
/>
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
}}
|
}}
|
||||||
@ -310,18 +351,23 @@ const ImageGalleryContent = () => {
|
|||||||
<VirtuosoGrid
|
<VirtuosoGrid
|
||||||
style={{ height: '100%' }}
|
style={{ height: '100%' }}
|
||||||
data={images}
|
data={images}
|
||||||
|
endReached={handleEndReached}
|
||||||
components={{
|
components={{
|
||||||
Item: ItemContainer,
|
Item: ItemContainer,
|
||||||
List: ListContainer,
|
List: ListContainer,
|
||||||
}}
|
}}
|
||||||
scrollerRef={setScroller}
|
scrollerRef={setScroller}
|
||||||
itemContent={(index, image) => {
|
itemContent={(index, image) => {
|
||||||
const { name } = image;
|
const isSelected =
|
||||||
const isSelected = selectedImage?.name === name;
|
image === PROGRESS_IMAGE_PLACEHOLDER
|
||||||
|
? false
|
||||||
|
: selectedImage?.name === image?.name;
|
||||||
|
|
||||||
return (
|
return image === PROGRESS_IMAGE_PLACEHOLDER ? (
|
||||||
|
<GalleryProgressImage key={PROGRESS_IMAGE_PLACEHOLDER} />
|
||||||
|
) : (
|
||||||
<HoverableImage
|
<HoverableImage
|
||||||
key={`${name}-${image.thumbnail}`}
|
key={`${image.name}-${image.thumbnail}`}
|
||||||
image={image}
|
image={image}
|
||||||
isSelected={isSelected}
|
isSelected={isSelected}
|
||||||
/>
|
/>
|
||||||
@ -334,6 +380,7 @@ const ImageGalleryContent = () => {
|
|||||||
onClick={handleClickLoadMore}
|
onClick={handleClickLoadMore}
|
||||||
isDisabled={!areMoreImagesAvailable}
|
isDisabled={!areMoreImagesAvailable}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
|
loadingText="Loading"
|
||||||
flexShrink={0}
|
flexShrink={0}
|
||||||
>
|
>
|
||||||
{areMoreImagesAvailable
|
{areMoreImagesAvailable
|
||||||
|
@ -5,7 +5,6 @@ import {
|
|||||||
// selectPrevImage,
|
// selectPrevImage,
|
||||||
setGalleryImageMinimumWidth,
|
setGalleryImageMinimumWidth,
|
||||||
} from 'features/gallery/store/gallerySlice';
|
} from 'features/gallery/store/gallerySlice';
|
||||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
|
||||||
|
|
||||||
import { clamp, isEqual } from 'lodash-es';
|
import { clamp, isEqual } from 'lodash-es';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
@ -13,11 +12,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
|||||||
import './ImageGallery.css';
|
import './ImageGallery.css';
|
||||||
import ImageGalleryContent from './ImageGalleryContent';
|
import ImageGalleryContent from './ImageGalleryContent';
|
||||||
import ResizableDrawer from 'features/ui/components/common/ResizableDrawer/ResizableDrawer';
|
import ResizableDrawer from 'features/ui/components/common/ResizableDrawer/ResizableDrawer';
|
||||||
import {
|
import { setShouldShowGallery } from 'features/ui/store/uiSlice';
|
||||||
setShouldShowGallery,
|
|
||||||
toggleGalleryPanel,
|
|
||||||
togglePinGalleryPanel,
|
|
||||||
} from 'features/ui/store/uiSlice';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import {
|
import {
|
||||||
activeTabNameSelector,
|
activeTabNameSelector,
|
||||||
@ -26,22 +21,20 @@ import {
|
|||||||
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
|
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||||
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
|
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
|
||||||
import useResolution from 'common/hooks/useResolution';
|
|
||||||
import { Flex } from '@chakra-ui/react';
|
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
|
||||||
const GALLERY_TAB_WIDTHS: Record<
|
// const GALLERY_TAB_WIDTHS: Record<
|
||||||
InvokeTabName,
|
// InvokeTabName,
|
||||||
{ galleryMinWidth: number; galleryMaxWidth: number }
|
// { galleryMinWidth: number; galleryMaxWidth: number }
|
||||||
> = {
|
// > = {
|
||||||
// txt2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
// txt2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||||
// img2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
// img2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||||
generate: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
// generate: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||||
unifiedCanvas: { galleryMinWidth: 200, galleryMaxWidth: 200 },
|
// unifiedCanvas: { galleryMinWidth: 200, galleryMaxWidth: 200 },
|
||||||
nodes: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
// nodes: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||||
// postprocessing: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
// postprocessing: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||||
// training: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
// training: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||||
};
|
// };
|
||||||
|
|
||||||
const galleryPanelSelector = createSelector(
|
const galleryPanelSelector = createSelector(
|
||||||
[
|
[
|
||||||
@ -73,50 +66,50 @@ const galleryPanelSelector = createSelector(
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
export const ImageGalleryPanel = () => {
|
const GalleryDrawer = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const {
|
const {
|
||||||
shouldPinGallery,
|
shouldPinGallery,
|
||||||
shouldShowGallery,
|
shouldShowGallery,
|
||||||
galleryImageMinimumWidth,
|
galleryImageMinimumWidth,
|
||||||
activeTabName,
|
// activeTabName,
|
||||||
isStaging,
|
// isStaging,
|
||||||
isResizable,
|
// isResizable,
|
||||||
isLightboxOpen,
|
// isLightboxOpen,
|
||||||
} = useAppSelector(galleryPanelSelector);
|
} = useAppSelector(galleryPanelSelector);
|
||||||
|
|
||||||
const handleSetShouldPinGallery = () => {
|
// const handleSetShouldPinGallery = () => {
|
||||||
dispatch(togglePinGalleryPanel());
|
// dispatch(togglePinGalleryPanel());
|
||||||
dispatch(requestCanvasRescale());
|
// dispatch(requestCanvasRescale());
|
||||||
};
|
// };
|
||||||
|
|
||||||
const handleToggleGallery = () => {
|
// const handleToggleGallery = () => {
|
||||||
dispatch(toggleGalleryPanel());
|
// dispatch(toggleGalleryPanel());
|
||||||
shouldPinGallery && dispatch(requestCanvasRescale());
|
// shouldPinGallery && dispatch(requestCanvasRescale());
|
||||||
};
|
// };
|
||||||
|
|
||||||
const handleCloseGallery = () => {
|
const handleCloseGallery = () => {
|
||||||
dispatch(setShouldShowGallery(false));
|
dispatch(setShouldShowGallery(false));
|
||||||
shouldPinGallery && dispatch(requestCanvasRescale());
|
shouldPinGallery && dispatch(requestCanvasRescale());
|
||||||
};
|
};
|
||||||
|
|
||||||
const resolution = useResolution();
|
// const resolution = useResolution();
|
||||||
|
|
||||||
useHotkeys(
|
// useHotkeys(
|
||||||
'g',
|
// 'g',
|
||||||
() => {
|
// () => {
|
||||||
handleToggleGallery();
|
// handleToggleGallery();
|
||||||
},
|
// },
|
||||||
[shouldPinGallery]
|
// [shouldPinGallery]
|
||||||
);
|
// );
|
||||||
|
|
||||||
useHotkeys(
|
// useHotkeys(
|
||||||
'shift+g',
|
// 'shift+g',
|
||||||
() => {
|
// () => {
|
||||||
handleSetShouldPinGallery();
|
// handleSetShouldPinGallery();
|
||||||
},
|
// },
|
||||||
[shouldPinGallery]
|
// [shouldPinGallery]
|
||||||
);
|
// );
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
'esc',
|
'esc',
|
||||||
@ -162,55 +155,71 @@ export const ImageGalleryPanel = () => {
|
|||||||
[galleryImageMinimumWidth]
|
[galleryImageMinimumWidth]
|
||||||
);
|
);
|
||||||
|
|
||||||
const calcGalleryMinHeight = () => {
|
// const calcGalleryMinHeight = () => {
|
||||||
if (resolution === 'desktop') return;
|
// if (resolution === 'desktop') return;
|
||||||
return 300;
|
// return 300;
|
||||||
};
|
// };
|
||||||
|
|
||||||
const imageGalleryContent = () => {
|
// const imageGalleryContent = () => {
|
||||||
return (
|
// return (
|
||||||
<Flex
|
// <Flex
|
||||||
w="100vw"
|
// w="100vw"
|
||||||
h={{ base: 300, xl: '100vh' }}
|
// h={{ base: 300, xl: '100vh' }}
|
||||||
paddingRight={{ base: 8, xl: 0 }}
|
// paddingRight={{ base: 8, xl: 0 }}
|
||||||
paddingBottom={{ base: 4, xl: 0 }}
|
// paddingBottom={{ base: 4, xl: 0 }}
|
||||||
>
|
// >
|
||||||
<ImageGalleryContent />
|
// <ImageGalleryContent />
|
||||||
</Flex>
|
// </Flex>
|
||||||
);
|
// );
|
||||||
};
|
// };
|
||||||
|
|
||||||
|
// const resizableImageGalleryContent = () => {
|
||||||
|
// return (
|
||||||
|
// <ResizableDrawer
|
||||||
|
// direction="right"
|
||||||
|
// isResizable={isResizable || !shouldPinGallery}
|
||||||
|
// isOpen={shouldShowGallery}
|
||||||
|
// onClose={handleCloseGallery}
|
||||||
|
// isPinned={shouldPinGallery && !isLightboxOpen}
|
||||||
|
// minWidth={
|
||||||
|
// shouldPinGallery
|
||||||
|
// ? GALLERY_TAB_WIDTHS[activeTabName].galleryMinWidth
|
||||||
|
// : 200
|
||||||
|
// }
|
||||||
|
// maxWidth={
|
||||||
|
// shouldPinGallery
|
||||||
|
// ? GALLERY_TAB_WIDTHS[activeTabName].galleryMaxWidth
|
||||||
|
// : undefined
|
||||||
|
// }
|
||||||
|
// minHeight={calcGalleryMinHeight()}
|
||||||
|
// >
|
||||||
|
// <ImageGalleryContent />
|
||||||
|
// </ResizableDrawer>
|
||||||
|
// );
|
||||||
|
// };
|
||||||
|
|
||||||
|
// const renderImageGallery = () => {
|
||||||
|
// if (['mobile', 'tablet'].includes(resolution)) return imageGalleryContent();
|
||||||
|
// return resizableImageGalleryContent();
|
||||||
|
// };
|
||||||
|
|
||||||
|
if (shouldPinGallery) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
const resizableImageGalleryContent = () => {
|
|
||||||
return (
|
return (
|
||||||
<ResizableDrawer
|
<ResizableDrawer
|
||||||
direction="right"
|
direction="right"
|
||||||
isResizable={isResizable || !shouldPinGallery}
|
isResizable={true}
|
||||||
isOpen={shouldShowGallery}
|
isOpen={shouldShowGallery}
|
||||||
onClose={handleCloseGallery}
|
onClose={handleCloseGallery}
|
||||||
isPinned={shouldPinGallery && !isLightboxOpen}
|
minWidth={200}
|
||||||
minWidth={
|
|
||||||
shouldPinGallery
|
|
||||||
? GALLERY_TAB_WIDTHS[activeTabName].galleryMinWidth
|
|
||||||
: 200
|
|
||||||
}
|
|
||||||
maxWidth={
|
|
||||||
shouldPinGallery
|
|
||||||
? GALLERY_TAB_WIDTHS[activeTabName].galleryMaxWidth
|
|
||||||
: undefined
|
|
||||||
}
|
|
||||||
minHeight={calcGalleryMinHeight()}
|
|
||||||
>
|
>
|
||||||
<ImageGalleryContent />
|
<ImageGalleryContent />
|
||||||
</ResizableDrawer>
|
</ResizableDrawer>
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// return renderImageGallery();
|
||||||
};
|
};
|
||||||
|
|
||||||
const renderImageGallery = () => {
|
export default memo(GalleryDrawer);
|
||||||
if (['mobile', 'tablet'].includes(resolution)) return imageGalleryContent();
|
|
||||||
return resizableImageGalleryContent();
|
|
||||||
};
|
|
||||||
|
|
||||||
return renderImageGallery();
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(ImageGalleryPanel);
|
|
||||||
|
@ -3,7 +3,6 @@ import {
|
|||||||
Box,
|
Box,
|
||||||
Center,
|
Center,
|
||||||
Flex,
|
Flex,
|
||||||
Heading,
|
|
||||||
IconButton,
|
IconButton,
|
||||||
Link,
|
Link,
|
||||||
Text,
|
Text,
|
||||||
@ -19,8 +18,6 @@ import {
|
|||||||
setCfgScale,
|
setCfgScale,
|
||||||
setHeight,
|
setHeight,
|
||||||
setImg2imgStrength,
|
setImg2imgStrength,
|
||||||
// setInitialImage,
|
|
||||||
setMaskPath,
|
|
||||||
setPerlin,
|
setPerlin,
|
||||||
setSampler,
|
setSampler,
|
||||||
setSeamless,
|
setSeamless,
|
||||||
@ -31,21 +28,14 @@ import {
|
|||||||
setThreshold,
|
setThreshold,
|
||||||
setWidth,
|
setWidth,
|
||||||
} from 'features/parameters/store/generationSlice';
|
} from 'features/parameters/store/generationSlice';
|
||||||
import {
|
import { setHiresFix } from 'features/parameters/store/postprocessingSlice';
|
||||||
setCodeformerFidelity,
|
|
||||||
setFacetoolStrength,
|
|
||||||
setFacetoolType,
|
|
||||||
setHiresFix,
|
|
||||||
setUpscalingDenoising,
|
|
||||||
setUpscalingLevel,
|
|
||||||
setUpscalingStrength,
|
|
||||||
} from 'features/parameters/store/postprocessingSlice';
|
|
||||||
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
|
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { FaCopy } from 'react-icons/fa';
|
import { FaCopy } from 'react-icons/fa';
|
||||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||||
|
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||||
|
|
||||||
type MetadataItemProps = {
|
type MetadataItemProps = {
|
||||||
isLink?: boolean;
|
isLink?: boolean;
|
||||||
@ -300,7 +290,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
</Text>
|
</Text>
|
||||||
</Center>
|
</Center>
|
||||||
)}
|
)}
|
||||||
<Flex gap={2} direction="column">
|
<Flex gap={2} direction="column" overflow="auto">
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
<Tooltip label="Copy metadata JSON">
|
<Tooltip label="Copy metadata JSON">
|
||||||
<IconButton
|
<IconButton
|
||||||
@ -314,22 +304,19 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
</Tooltip>
|
</Tooltip>
|
||||||
<Text fontWeight="semibold">Metadata JSON:</Text>
|
<Text fontWeight="semibold">Metadata JSON:</Text>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
<OverlayScrollbarsComponent defer>
|
||||||
<Box
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
mt: 0,
|
|
||||||
mr: 2,
|
|
||||||
mb: 4,
|
|
||||||
ml: 2,
|
|
||||||
padding: 4,
|
padding: 4,
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
overflowX: 'scroll',
|
|
||||||
wordBreak: 'break-all',
|
|
||||||
bg: 'whiteAlpha.500',
|
bg: 'whiteAlpha.500',
|
||||||
_dark: { bg: 'blackAlpha.500' },
|
_dark: { bg: 'blackAlpha.500' },
|
||||||
|
w: 'max-content',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<pre>{metadataJSON}</pre>
|
<pre>{metadataJSON}</pre>
|
||||||
</Box>
|
</Box>
|
||||||
|
</OverlayScrollbarsComponent>
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
|
@ -0,0 +1,7 @@
|
|||||||
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
|
import { Image } from 'app/types/invokeai';
|
||||||
|
import { SelectedImage } from 'features/parameters/store/actions';
|
||||||
|
|
||||||
|
export const requestedImageDeletion = createAction<
|
||||||
|
Image | SelectedImage | undefined
|
||||||
|
>('gallery/requestedImageDeletion');
|
@ -4,12 +4,13 @@ import { GalleryState } from './gallerySlice';
|
|||||||
* Gallery slice persist denylist
|
* Gallery slice persist denylist
|
||||||
*/
|
*/
|
||||||
const itemsToDenylist: (keyof GalleryState)[] = [
|
const itemsToDenylist: (keyof GalleryState)[] = [
|
||||||
'categories',
|
|
||||||
'currentCategory',
|
'currentCategory',
|
||||||
'currentImage',
|
|
||||||
'currentImageUuid',
|
|
||||||
'shouldAutoSwitchToNewImages',
|
'shouldAutoSwitchToNewImages',
|
||||||
'intermediateImage',
|
];
|
||||||
|
|
||||||
|
export const galleryPersistDenylist: (keyof GalleryState)[] = [
|
||||||
|
'currentCategory',
|
||||||
|
'shouldAutoSwitchToNewImages',
|
||||||
];
|
];
|
||||||
|
|
||||||
export const galleryDenylist = itemsToDenylist.map(
|
export const galleryDenylist = itemsToDenylist.map(
|
||||||
|
@ -1,23 +1,14 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
|
import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
|
||||||
import { configSelector } from 'features/system/store/configSelectors';
|
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
|
||||||
import {
|
import {
|
||||||
activeTabNameSelector,
|
activeTabNameSelector,
|
||||||
uiSelector,
|
uiSelector,
|
||||||
} from 'features/ui/store/uiSelectors';
|
} from 'features/ui/store/uiSelectors';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
import {
|
import { selectResultsById, selectResultsEntities } from './resultsSlice';
|
||||||
selectResultsAll,
|
import { selectUploadsAll, selectUploadsById } from './uploadsSlice';
|
||||||
selectResultsById,
|
|
||||||
selectResultsEntities,
|
|
||||||
} from './resultsSlice';
|
|
||||||
import {
|
|
||||||
selectUploadsAll,
|
|
||||||
selectUploadsById,
|
|
||||||
selectUploadsEntities,
|
|
||||||
} from './uploadsSlice';
|
|
||||||
|
|
||||||
export const gallerySelector = (state: RootState) => state.gallery;
|
export const gallerySelector = (state: RootState) => state.gallery;
|
||||||
|
|
||||||
@ -44,6 +35,11 @@ export const imageGallerySelector = createSelector(
|
|||||||
|
|
||||||
const { isLightboxOpen } = lightbox;
|
const { isLightboxOpen } = lightbox;
|
||||||
|
|
||||||
|
const images =
|
||||||
|
currentCategory === 'results'
|
||||||
|
? selectResultsEntities(state)
|
||||||
|
: selectUploadsAll(state);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
shouldPinGallery,
|
shouldPinGallery,
|
||||||
galleryImageMinimumWidth,
|
galleryImageMinimumWidth,
|
||||||
@ -53,7 +49,7 @@ export const imageGallerySelector = createSelector(
|
|||||||
: `repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, auto))`,
|
: `repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, auto))`,
|
||||||
shouldAutoSwitchToNewImages,
|
shouldAutoSwitchToNewImages,
|
||||||
currentCategory,
|
currentCategory,
|
||||||
images: state[currentCategory].entities,
|
images,
|
||||||
galleryWidth,
|
galleryWidth,
|
||||||
shouldEnableResize:
|
shouldEnableResize:
|
||||||
isLightboxOpen ||
|
isLightboxOpen ||
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import { invocationComplete } from 'services/events/actions';
|
import { Image } from 'app/types/invokeai';
|
||||||
import { isImageOutput } from 'services/types/guards';
|
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
|
||||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
import {
|
||||||
import { imageUploaded } from 'services/thunks/image';
|
receivedResultImagesPage,
|
||||||
import { SelectedImage } from 'features/parameters/store/generationSlice';
|
receivedUploadImagesPage,
|
||||||
|
} from '../../../services/thunks/gallery';
|
||||||
|
|
||||||
type GalleryImageObjectFitType = 'contain' | 'cover';
|
type GalleryImageObjectFitType = 'contain' | 'cover';
|
||||||
|
|
||||||
@ -12,7 +13,7 @@ export interface GalleryState {
|
|||||||
/**
|
/**
|
||||||
* The selected image
|
* The selected image
|
||||||
*/
|
*/
|
||||||
selectedImage?: SelectedImage;
|
selectedImage?: Image;
|
||||||
galleryImageMinimumWidth: number;
|
galleryImageMinimumWidth: number;
|
||||||
galleryImageObjectFit: GalleryImageObjectFitType;
|
galleryImageObjectFit: GalleryImageObjectFitType;
|
||||||
shouldAutoSwitchToNewImages: boolean;
|
shouldAutoSwitchToNewImages: boolean;
|
||||||
@ -21,8 +22,7 @@ export interface GalleryState {
|
|||||||
currentCategory: 'results' | 'uploads';
|
currentCategory: 'results' | 'uploads';
|
||||||
}
|
}
|
||||||
|
|
||||||
const initialState: GalleryState = {
|
export const initialGalleryState: GalleryState = {
|
||||||
selectedImage: undefined,
|
|
||||||
galleryImageMinimumWidth: 64,
|
galleryImageMinimumWidth: 64,
|
||||||
galleryImageObjectFit: 'cover',
|
galleryImageObjectFit: 'cover',
|
||||||
shouldAutoSwitchToNewImages: true,
|
shouldAutoSwitchToNewImages: true,
|
||||||
@ -33,12 +33,9 @@ const initialState: GalleryState = {
|
|||||||
|
|
||||||
export const gallerySlice = createSlice({
|
export const gallerySlice = createSlice({
|
||||||
name: 'gallery',
|
name: 'gallery',
|
||||||
initialState,
|
initialState: initialGalleryState,
|
||||||
reducers: {
|
reducers: {
|
||||||
imageSelected: (
|
imageSelected: (state, action: PayloadAction<Image | undefined>) => {
|
||||||
state,
|
|
||||||
action: PayloadAction<SelectedImage | undefined>
|
|
||||||
) => {
|
|
||||||
state.selectedImage = action.payload;
|
state.selectedImage = action.payload;
|
||||||
// TODO: if the user selects an image, disable the auto switch?
|
// TODO: if the user selects an image, disable the auto switch?
|
||||||
// state.shouldAutoSwitchToNewImages = false;
|
// state.shouldAutoSwitchToNewImages = false;
|
||||||
@ -72,27 +69,50 @@ export const gallerySlice = createSlice({
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
extraReducers(builder) {
|
extraReducers(builder) {
|
||||||
/**
|
builder.addCase(imageReceived.fulfilled, (state, action) => {
|
||||||
* Invocation Complete
|
// When we get an updated URL for an image, we need to update the selectedImage in gallery,
|
||||||
*/
|
// which is currently its own object (instead of a reference to an image in results/uploads)
|
||||||
builder.addCase(invocationComplete, (state, action) => {
|
const { imagePath } = action.payload;
|
||||||
const { data } = action.payload;
|
const { imageName } = action.meta.arg;
|
||||||
if (isImageOutput(data.result) && state.shouldAutoSwitchToNewImages) {
|
|
||||||
state.selectedImage = {
|
if (state.selectedImage?.name === imageName) {
|
||||||
name: data.result.image.image_name,
|
state.selectedImage.url = imagePath;
|
||||||
type: 'results',
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
builder.addCase(thumbnailReceived.fulfilled, (state, action) => {
|
||||||
* Upload Image - FULFILLED
|
// When we get an updated URL for an image, we need to update the selectedImage in gallery,
|
||||||
*/
|
// which is currently its own object (instead of a reference to an image in results/uploads)
|
||||||
builder.addCase(imageUploaded.fulfilled, (state, action) => {
|
const { thumbnailPath } = action.payload;
|
||||||
const { response } = action.payload;
|
const { thumbnailName } = action.meta.arg;
|
||||||
|
|
||||||
const uploadedImage = deserializeImageResponse(response);
|
if (state.selectedImage?.name === thumbnailName) {
|
||||||
state.selectedImage = { name: uploadedImage.name, type: 'uploads' };
|
state.selectedImage.thumbnail = thumbnailPath;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
|
||||||
|
// rehydrate selectedImage URL when results list comes in
|
||||||
|
// solves case when outdated URL is in local storage
|
||||||
|
if (state.selectedImage) {
|
||||||
|
const selectedImageInResults = action.payload.items.find(
|
||||||
|
(image) => image.image_name === state.selectedImage!.name
|
||||||
|
);
|
||||||
|
if (selectedImageInResults) {
|
||||||
|
state.selectedImage.url = selectedImageInResults.image_url;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
|
||||||
|
// rehydrate selectedImage URL when results list comes in
|
||||||
|
// solves case when outdated URL is in local storage
|
||||||
|
if (state.selectedImage) {
|
||||||
|
const selectedImageInResults = action.payload.items.find(
|
||||||
|
(image) => image.image_name === state.selectedImage!.name
|
||||||
|
);
|
||||||
|
if (selectedImageInResults) {
|
||||||
|
state.selectedImage.url = selectedImageInResults.image_url;
|
||||||
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -5,7 +5,9 @@ import { ResultsState } from './resultsSlice';
|
|||||||
*
|
*
|
||||||
* Currently denylisting results slice entirely, see persist config in store.ts
|
* Currently denylisting results slice entirely, see persist config in store.ts
|
||||||
*/
|
*/
|
||||||
const itemsToDenylist: (keyof ResultsState)[] = ['isLoading'];
|
const itemsToDenylist: (keyof ResultsState)[] = [];
|
||||||
|
|
||||||
|
export const resultsPersistDenylist: (keyof ResultsState)[] = [];
|
||||||
|
|
||||||
export const resultsDenylist = itemsToDenylist.map(
|
export const resultsDenylist = itemsToDenylist.map(
|
||||||
(denylistItem) => `results.${denylistItem}`
|
(denylistItem) => `results.${denylistItem}`
|
||||||
|
@ -1,17 +1,11 @@
|
|||||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||||
import { Image } from 'app/types/invokeai';
|
import { Image } from 'app/types/invokeai';
|
||||||
import { invocationComplete } from 'services/events/actions';
|
|
||||||
|
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import {
|
import {
|
||||||
receivedResultImagesPage,
|
receivedResultImagesPage,
|
||||||
IMAGES_PER_PAGE,
|
IMAGES_PER_PAGE,
|
||||||
} from 'services/thunks/gallery';
|
} from 'services/thunks/gallery';
|
||||||
import { isImageOutput } from 'services/types/guards';
|
|
||||||
import {
|
|
||||||
buildImageUrls,
|
|
||||||
extractTimestampFromImageName,
|
|
||||||
} from 'services/util/deserializeImageField';
|
|
||||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||||
import {
|
import {
|
||||||
imageDeleted,
|
imageDeleted,
|
||||||
@ -73,44 +67,6 @@ const resultsSlice = createSlice({
|
|||||||
state.isLoading = false;
|
state.isLoading = false;
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
|
||||||
* Invocation Complete
|
|
||||||
*/
|
|
||||||
builder.addCase(invocationComplete, (state, action) => {
|
|
||||||
const { data, shouldFetchImages } = action.payload;
|
|
||||||
const { result, node, graph_execution_state_id } = data;
|
|
||||||
|
|
||||||
if (isImageOutput(result)) {
|
|
||||||
const name = result.image.image_name;
|
|
||||||
const type = result.image.image_type;
|
|
||||||
|
|
||||||
// if we need to refetch, set URLs to placeholder for now
|
|
||||||
const { url, thumbnail } = shouldFetchImages
|
|
||||||
? { url: '', thumbnail: '' }
|
|
||||||
: buildImageUrls(type, name);
|
|
||||||
|
|
||||||
const timestamp = extractTimestampFromImageName(name);
|
|
||||||
|
|
||||||
const image: Image = {
|
|
||||||
name,
|
|
||||||
type,
|
|
||||||
url,
|
|
||||||
thumbnail,
|
|
||||||
metadata: {
|
|
||||||
created: timestamp,
|
|
||||||
width: result.width,
|
|
||||||
height: result.height,
|
|
||||||
invokeai: {
|
|
||||||
session_id: graph_execution_state_id,
|
|
||||||
...(node ? { node } : {}),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
resultsAdapter.setOne(state, image);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Image Received - FULFILLED
|
* Image Received - FULFILLED
|
||||||
*/
|
*/
|
||||||
@ -142,9 +98,10 @@ const resultsSlice = createSlice({
|
|||||||
});
|
});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Delete Image - FULFILLED
|
* Delete Image - PENDING
|
||||||
|
* Pre-emptively remove the image from the gallery
|
||||||
*/
|
*/
|
||||||
builder.addCase(imageDeleted.fulfilled, (state, action) => {
|
builder.addCase(imageDeleted.pending, (state, action) => {
|
||||||
const { imageType, imageName } = action.meta.arg;
|
const { imageType, imageName } = action.meta.arg;
|
||||||
|
|
||||||
if (imageType === 'results') {
|
if (imageType === 'results') {
|
||||||
|
@ -5,7 +5,8 @@ import { UploadsState } from './uploadsSlice';
|
|||||||
*
|
*
|
||||||
* Currently denylisting uploads slice entirely, see persist config in store.ts
|
* Currently denylisting uploads slice entirely, see persist config in store.ts
|
||||||
*/
|
*/
|
||||||
const itemsToDenylist: (keyof UploadsState)[] = ['isLoading'];
|
const itemsToDenylist: (keyof UploadsState)[] = [];
|
||||||
|
export const uploadsPersistDenylist: (keyof UploadsState)[] = [];
|
||||||
|
|
||||||
export const uploadsDenylist = itemsToDenylist.map(
|
export const uploadsDenylist = itemsToDenylist.map(
|
||||||
(denylistItem) => `uploads.${denylistItem}`
|
(denylistItem) => `uploads.${denylistItem}`
|
||||||
|
@ -6,7 +6,7 @@ import {
|
|||||||
receivedUploadImagesPage,
|
receivedUploadImagesPage,
|
||||||
IMAGES_PER_PAGE,
|
IMAGES_PER_PAGE,
|
||||||
} from 'services/thunks/gallery';
|
} from 'services/thunks/gallery';
|
||||||
import { imageDeleted, imageUploaded } from 'services/thunks/image';
|
import { imageDeleted } from 'services/thunks/image';
|
||||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||||
|
|
||||||
export const uploadsAdapter = createEntityAdapter<Image>({
|
export const uploadsAdapter = createEntityAdapter<Image>({
|
||||||
@ -21,7 +21,7 @@ type AdditionalUploadsState = {
|
|||||||
nextPage: number;
|
nextPage: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
const initialUploadsState =
|
export const initialUploadsState =
|
||||||
uploadsAdapter.getInitialState<AdditionalUploadsState>({
|
uploadsAdapter.getInitialState<AdditionalUploadsState>({
|
||||||
page: 0,
|
page: 0,
|
||||||
pages: 0,
|
pages: 0,
|
||||||
@ -35,7 +35,7 @@ const uploadsSlice = createSlice({
|
|||||||
name: 'uploads',
|
name: 'uploads',
|
||||||
initialState: initialUploadsState,
|
initialState: initialUploadsState,
|
||||||
reducers: {
|
reducers: {
|
||||||
uploadAdded: uploadsAdapter.addOne,
|
uploadAdded: uploadsAdapter.upsertOne,
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
/**
|
/**
|
||||||
@ -62,20 +62,10 @@ const uploadsSlice = createSlice({
|
|||||||
});
|
});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Upload Image - FULFILLED
|
* Delete Image - pending
|
||||||
|
* Pre-emptively remove the image from the gallery
|
||||||
*/
|
*/
|
||||||
builder.addCase(imageUploaded.fulfilled, (state, action) => {
|
builder.addCase(imageDeleted.pending, (state, action) => {
|
||||||
const { location, response } = action.payload;
|
|
||||||
|
|
||||||
const uploadedImage = deserializeImageResponse(response);
|
|
||||||
|
|
||||||
uploadsAdapter.setOne(state, uploadedImage);
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Delete Image - FULFILLED
|
|
||||||
*/
|
|
||||||
builder.addCase(imageDeleted.fulfilled, (state, action) => {
|
|
||||||
const { imageType, imageName } = action.meta.arg;
|
const { imageType, imageName } = action.meta.arg;
|
||||||
|
|
||||||
if (imageType === 'uploads') {
|
if (imageType === 'uploads') {
|
||||||
|
@ -4,7 +4,7 @@ import * as InvokeAI from 'app/types/invokeai';
|
|||||||
import { useGetUrl } from 'common/util/getUrl';
|
import { useGetUrl } from 'common/util/getUrl';
|
||||||
|
|
||||||
type ReactPanZoomProps = {
|
type ReactPanZoomProps = {
|
||||||
image: InvokeAI._Image;
|
image: InvokeAI.Image;
|
||||||
styleClass?: string;
|
styleClass?: string;
|
||||||
alt?: string;
|
alt?: string;
|
||||||
ref?: React.Ref<HTMLImageElement>;
|
ref?: React.Ref<HTMLImageElement>;
|
||||||
|
@ -4,6 +4,9 @@ import { LightboxState } from './lightboxSlice';
|
|||||||
* Lightbox slice persist denylist
|
* Lightbox slice persist denylist
|
||||||
*/
|
*/
|
||||||
const itemsToDenylist: (keyof LightboxState)[] = ['isLightboxOpen'];
|
const itemsToDenylist: (keyof LightboxState)[] = ['isLightboxOpen'];
|
||||||
|
export const lightboxPersistDenylist: (keyof LightboxState)[] = [
|
||||||
|
'isLightboxOpen',
|
||||||
|
];
|
||||||
|
|
||||||
export const lightboxDenylist = itemsToDenylist.map(
|
export const lightboxDenylist = itemsToDenylist.map(
|
||||||
(denylistItem) => `lightbox.${denylistItem}`
|
(denylistItem) => `lightbox.${denylistItem}`
|
||||||
|
@ -5,7 +5,7 @@ export interface LightboxState {
|
|||||||
isLightboxOpen: boolean;
|
isLightboxOpen: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
const initialLightboxState: LightboxState = {
|
export const initialLightboxState: LightboxState = {
|
||||||
isLightboxOpen: false,
|
isLightboxOpen: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
|
|
||||||
import 'reactflow/dist/style.css';
|
import 'reactflow/dist/style.css';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import {
|
import {
|
||||||
@ -8,12 +6,11 @@ import {
|
|||||||
MenuButton,
|
MenuButton,
|
||||||
MenuList,
|
MenuList,
|
||||||
MenuItem,
|
MenuItem,
|
||||||
IconButton,
|
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { FaEllipsisV, FaPlus } from 'react-icons/fa';
|
import { FaEllipsisV } from 'react-icons/fa';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { nodeAdded } from '../store/nodesSlice';
|
import { nodeAdded } from '../store/nodesSlice';
|
||||||
import { cloneDeep, map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { useBuildInvocation } from '../hooks/useBuildInvocation';
|
import { useBuildInvocation } from '../hooks/useBuildInvocation';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user