diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index 046e8f0c57..ed3ab0375c 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -2,7 +2,7 @@
/.github/workflows/ @lstein @blessedcoolant
# documentation
-/docs/ @lstein @tildebyte @blessedcoolant
+/docs/ @lstein @blessedcoolant @hipsterusername
/mkdocs.yml @lstein @blessedcoolant
# nodes
@@ -18,17 +18,17 @@
/invokeai/version @lstein @blessedcoolant
# web ui
-/invokeai/frontend @blessedcoolant @psychedelicious @lstein
-/invokeai/backend @blessedcoolant @psychedelicious @lstein
+/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp
+/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp
# generation, model management, postprocessing
-/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2
+/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 @StAlKeR7779
# front ends
/invokeai/frontend/CLI @lstein
/invokeai/frontend/install @lstein @ebr
-/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
-/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
-/invokeai/frontend/web @psychedelicious @blessedcoolant
+/invokeai/frontend/merge @lstein @blessedcoolant
+/invokeai/frontend/training @lstein @blessedcoolant
+/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp
diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py
index 1ea365eb50..d2fcf0a364 100644
--- a/invokeai/app/api/dependencies.py
+++ b/invokeai/app/api/dependencies.py
@@ -5,6 +5,7 @@ import os
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService
+from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger
@@ -67,7 +68,7 @@ class ApiDependencies:
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
-
+ names = SimpleNameService()
latents = ForwardCacheLatentsStorage(
DiskLatentsStorage(f"{output_folder}/latents")
)
@@ -78,6 +79,7 @@ class ApiDependencies:
metadata=metadata,
url=urls,
logger=logger,
+ names=names,
graph_execution_manager=graph_execution_manager,
)
diff --git a/invokeai/app/api/models/images.py b/invokeai/app/api/models/images.py
deleted file mode 100644
index fa04702326..0000000000
--- a/invokeai/app/api/models/images.py
+++ /dev/null
@@ -1,39 +0,0 @@
-from typing import Optional
-from pydantic import BaseModel, Field
-
-from invokeai.app.models.image import ImageType
-
-
-class ImageResponseMetadata(BaseModel):
- """An image's metadata. Used only in HTTP responses."""
-
- created: int = Field(description="The creation timestamp of the image")
- width: int = Field(description="The width of the image in pixels")
- height: int = Field(description="The height of the image in pixels")
- # invokeai: Optional[InvokeAIMetadata] = Field(
- # description="The image's InvokeAI-specific metadata"
- # )
-
-
-class ImageResponse(BaseModel):
- """The response type for images"""
-
- image_type: ImageType = Field(description="The type of the image")
- image_name: str = Field(description="The name of the image")
- image_url: str = Field(description="The url of the image")
- thumbnail_url: str = Field(description="The url of the image's thumbnail")
- metadata: ImageResponseMetadata = Field(description="The image's metadata")
-
-
-class ProgressImage(BaseModel):
- """The progress image sent intermittently during processing"""
-
- width: int = Field(description="The effective width of the image in pixels")
- height: int = Field(description="The effective height of the image in pixels")
- dataURL: str = Field(description="The image data as a b64 data URL")
-
-
-class SavedImage(BaseModel):
- image_name: str = Field(description="The name of the saved image")
- thumbnail_name: str = Field(description="The name of the saved thumbnail")
- created: int = Field(description="The created timestamp of the saved image")
diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py
index 0615ff187e..ae10cce140 100644
--- a/invokeai/app/api/routers/images.py
+++ b/invokeai/app/api/routers/images.py
@@ -1,13 +1,19 @@
import io
-from fastapi import HTTPException, Path, Query, Request, Response, UploadFile
+from typing import Optional
+from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.routing import APIRouter
from fastapi.responses import FileResponse
from PIL import Image
from invokeai.app.models.image import (
ImageCategory,
- ImageType,
+ ResourceOrigin,
+)
+from invokeai.app.services.image_record_storage import OffsetPaginatedResults
+from invokeai.app.services.models.image_record import (
+ ImageDTO,
+ ImageRecordChanges,
+ ImageUrlsDTO,
)
-from invokeai.app.services.models.image_record import ImageDTO, ImageUrlsDTO
from invokeai.app.services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
@@ -27,10 +33,13 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
)
async def upload_image(
file: UploadFile,
- image_type: ImageType,
request: Request,
response: Response,
- image_category: ImageCategory = ImageCategory.GENERAL,
+ image_category: ImageCategory = Query(description="The category of the image"),
+ is_intermediate: bool = Query(description="Whether this is an intermediate image"),
+ session_id: Optional[str] = Query(
+ default=None, description="The session ID associated with this upload, if any"
+ ),
) -> ImageDTO:
"""Uploads an image"""
if not file.content_type.startswith("image"):
@@ -46,9 +55,11 @@ async def upload_image(
try:
image_dto = ApiDependencies.invoker.services.images.create(
- pil_image,
- image_type,
- image_category,
+ image=pil_image,
+ image_origin=ResourceOrigin.EXTERNAL,
+ image_category=image_category,
+ session_id=session_id,
+ is_intermediate=is_intermediate,
)
response.status_code = 201
@@ -59,41 +70,61 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image")
-@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
+@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
async def delete_image(
- image_type: ImageType = Query(description="The type of image to delete"),
+ image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image"""
try:
- ApiDependencies.invoker.services.images.delete(image_type, image_name)
+ ApiDependencies.invoker.services.images.delete(image_origin, image_name)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass
+@images_router.patch(
+ "/{image_origin}/{image_name}",
+ operation_id="update_image",
+ response_model=ImageDTO,
+)
+async def update_image(
+ image_origin: ResourceOrigin = Path(description="The origin of image to update"),
+ image_name: str = Path(description="The name of the image to update"),
+ image_changes: ImageRecordChanges = Body(
+ description="The changes to apply to the image"
+ ),
+) -> ImageDTO:
+ """Updates an image"""
+
+ try:
+ return ApiDependencies.invoker.services.images.update(
+ image_origin, image_name, image_changes
+ )
+ except Exception as e:
+ raise HTTPException(status_code=400, detail="Failed to update image")
+
+
@images_router.get(
- "/{image_type}/{image_name}/metadata",
+ "/{image_origin}/{image_name}/metadata",
operation_id="get_image_metadata",
response_model=ImageDTO,
)
async def get_image_metadata(
- image_type: ImageType = Path(description="The type of image to get"),
+ image_origin: ResourceOrigin = Path(description="The origin of image to get"),
image_name: str = Path(description="The name of image to get"),
) -> ImageDTO:
"""Gets an image's metadata"""
try:
- return ApiDependencies.invoker.services.images.get_dto(
- image_type, image_name
- )
+ return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get(
- "/{image_type}/{image_name}",
+ "/{image_origin}/{image_name}",
operation_id="get_image_full",
response_class=Response,
responses={
@@ -105,7 +136,7 @@ async def get_image_metadata(
},
)
async def get_image_full(
- image_type: ImageType = Path(
+ image_origin: ResourceOrigin = Path(
description="The type of full-resolution image file to get"
),
image_name: str = Path(description="The name of full-resolution image file to get"),
@@ -113,9 +144,7 @@ async def get_image_full(
"""Gets a full-resolution image file"""
try:
- path = ApiDependencies.invoker.services.images.get_path(
- image_type, image_name
- )
+ path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
@@ -131,7 +160,7 @@ async def get_image_full(
@images_router.get(
- "/{image_type}/{image_name}/thumbnail",
+ "/{image_origin}/{image_name}/thumbnail",
operation_id="get_image_thumbnail",
response_class=Response,
responses={
@@ -143,14 +172,14 @@ async def get_image_full(
},
)
async def get_image_thumbnail(
- image_type: ImageType = Path(description="The type of thumbnail image file to get"),
+ image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
image_name: str = Path(description="The name of thumbnail image file to get"),
) -> FileResponse:
"""Gets a thumbnail image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(
- image_type, image_name, thumbnail=True
+ image_origin, image_name, thumbnail=True
)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
@@ -163,25 +192,25 @@ async def get_image_thumbnail(
@images_router.get(
- "/{image_type}/{image_name}/urls",
+ "/{image_origin}/{image_name}/urls",
operation_id="get_image_urls",
response_model=ImageUrlsDTO,
)
async def get_image_urls(
- image_type: ImageType = Path(description="The type of the image whose URL to get"),
+ image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
image_name: str = Path(description="The name of the image whose URL to get"),
) -> ImageUrlsDTO:
"""Gets an image and thumbnail URL"""
try:
image_url = ApiDependencies.invoker.services.images.get_url(
- image_type, image_name
+ image_origin, image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
- image_type, image_name, thumbnail=True
+ image_origin, image_name, thumbnail=True
)
return ImageUrlsDTO(
- image_type=image_type,
+ image_origin=image_origin,
image_name=image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,
@@ -193,23 +222,29 @@ async def get_image_urls(
@images_router.get(
"/",
operation_id="list_images_with_metadata",
- response_model=PaginatedResults[ImageDTO],
+ response_model=OffsetPaginatedResults[ImageDTO],
)
async def list_images_with_metadata(
- image_type: ImageType = Query(description="The type of images to list"),
- image_category: ImageCategory = Query(description="The kind of images to list"),
- page: int = Query(default=0, description="The page of image metadata to get"),
- per_page: int = Query(
- default=10, description="The number of image metadata per page"
+ image_origin: Optional[ResourceOrigin] = Query(
+ default=None, description="The origin of images to list"
),
-) -> PaginatedResults[ImageDTO]:
- """Gets a list of images with metadata"""
+ categories: Optional[list[ImageCategory]] = Query(
+ default=None, description="The categories of image to include"
+ ),
+ is_intermediate: Optional[bool] = Query(
+ default=None, description="Whether to list intermediate images"
+ ),
+ offset: int = Query(default=0, description="The page offset"),
+ limit: int = Query(default=10, description="The number of images per page"),
+) -> OffsetPaginatedResults[ImageDTO]:
+ """Gets a list of images"""
image_dtos = ApiDependencies.invoker.services.images.get_many(
- image_type,
- image_category,
- page,
- per_page,
+ offset,
+ limit,
+ image_origin,
+ categories,
+ is_intermediate,
)
return image_dtos
diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py
index c6267a6871..15afe65a8a 100644
--- a/invokeai/app/cli_app.py
+++ b/invokeai/app/cli_app.py
@@ -12,11 +12,10 @@ from pydantic import BaseModel, ValidationError
from pydantic.fields import Field
import invokeai.backend.util.logging as logger
-import invokeai.version
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
-from invokeai.app.services.metadata import (CoreMetadataService,
- PngMetadataService)
+from invokeai.app.services.metadata import CoreMetadataService
+from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from .cli.commands import (BaseCommand, CliContext, ExitCli,
@@ -232,6 +231,7 @@ def invoke_cli():
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
+ names = SimpleNameService()
images = ImageService(
image_record_storage=image_record_storage,
@@ -239,6 +239,7 @@ def invoke_cli():
metadata=metadata,
url=urls,
logger=logger,
+ names=names,
graph_execution_manager=graph_execution_manager,
)
diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py
index da61641105..4ce3e839b6 100644
--- a/invokeai/app/invocations/baseinvocation.py
+++ b/invokeai/app/invocations/baseinvocation.py
@@ -78,6 +78,7 @@ class BaseInvocation(ABC, BaseModel):
#fmt: off
id: str = Field(description="The id of this node. Must be unique among all nodes.")
+ is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
#fmt: on
@@ -95,6 +96,7 @@ class UIConfig(TypedDict, total=False):
"image",
"latents",
"model",
+ "control",
],
]
tags: List[str]
diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py
index 475b6028a9..891f217317 100644
--- a/invokeai/app/invocations/collections.py
+++ b/invokeai/app/invocations/collections.py
@@ -22,6 +22,14 @@ class IntCollectionOutput(BaseInvocationOutput):
# Outputs
collection: list[int] = Field(default=[], description="The int collection")
+class FloatCollectionOutput(BaseInvocationOutput):
+ """A collection of floats"""
+
+ type: Literal["float_collection"] = "float_collection"
+
+ # Outputs
+ collection: list[float] = Field(default=[], description="The float collection")
+
class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step"""
diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py
new file mode 100644
index 0000000000..7d5160a491
--- /dev/null
+++ b/invokeai/app/invocations/controlnet_image_processors.py
@@ -0,0 +1,428 @@
+# InvokeAI nodes for ControlNet image preprocessors
+# initial implementation by Gregg Helt, 2023
+# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
+
+import numpy as np
+from typing import Literal, Optional, Union, List
+from PIL import Image, ImageFilter, ImageOps
+from pydantic import BaseModel, Field
+
+from ..models.image import ImageField, ImageCategory, ResourceOrigin
+from .baseinvocation import (
+ BaseInvocation,
+ BaseInvocationOutput,
+ InvocationContext,
+ InvocationConfig,
+)
+from controlnet_aux import (
+ CannyDetector,
+ HEDdetector,
+ LineartDetector,
+ LineartAnimeDetector,
+ MidasDetector,
+ MLSDdetector,
+ NormalBaeDetector,
+ OpenposeDetector,
+ PidiNetDetector,
+ ContentShuffleDetector,
+ ZoeDetector,
+ MediapipeFaceDetector,
+)
+
+from .image import ImageOutput, PILInvocationConfig
+
+CONTROLNET_DEFAULT_MODELS = [
+ ###########################################
+ # lllyasviel sd v1.5, ControlNet v1.0 models
+ ##############################################
+ "lllyasviel/sd-controlnet-canny",
+ "lllyasviel/sd-controlnet-depth",
+ "lllyasviel/sd-controlnet-hed",
+ "lllyasviel/sd-controlnet-seg",
+ "lllyasviel/sd-controlnet-openpose",
+ "lllyasviel/sd-controlnet-scribble",
+ "lllyasviel/sd-controlnet-normal",
+ "lllyasviel/sd-controlnet-mlsd",
+
+ #############################################
+ # lllyasviel sd v1.5, ControlNet v1.1 models
+ #############################################
+ "lllyasviel/control_v11p_sd15_canny",
+ "lllyasviel/control_v11p_sd15_openpose",
+ "lllyasviel/control_v11p_sd15_seg",
+ # "lllyasviel/control_v11p_sd15_depth", # broken
+ "lllyasviel/control_v11f1p_sd15_depth",
+ "lllyasviel/control_v11p_sd15_normalbae",
+ "lllyasviel/control_v11p_sd15_scribble",
+ "lllyasviel/control_v11p_sd15_mlsd",
+ "lllyasviel/control_v11p_sd15_softedge",
+ "lllyasviel/control_v11p_sd15s2_lineart_anime",
+ "lllyasviel/control_v11p_sd15_lineart",
+ "lllyasviel/control_v11p_sd15_inpaint",
+ # "lllyasviel/control_v11u_sd15_tile",
+ # problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
+ # so for now replace "lllyasviel/control_v11f1e_sd15_tile",
+ "lllyasviel/control_v11e_sd15_shuffle",
+ "lllyasviel/control_v11e_sd15_ip2p",
+ "lllyasviel/control_v11f1e_sd15_tile",
+
+ #################################################
+ # thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
+ ##################################################
+ "thibaud/controlnet-sd21-openpose-diffusers",
+ "thibaud/controlnet-sd21-canny-diffusers",
+ "thibaud/controlnet-sd21-depth-diffusers",
+ "thibaud/controlnet-sd21-scribble-diffusers",
+ "thibaud/controlnet-sd21-hed-diffusers",
+ "thibaud/controlnet-sd21-zoedepth-diffusers",
+ "thibaud/controlnet-sd21-color-diffusers",
+ "thibaud/controlnet-sd21-openposev2-diffusers",
+ "thibaud/controlnet-sd21-lineart-diffusers",
+ "thibaud/controlnet-sd21-normalbae-diffusers",
+ "thibaud/controlnet-sd21-ade20k-diffusers",
+
+ ##############################################
+ # ControlNetMediaPipeface, ControlNet v1.1
+ ##############################################
+ # ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
+ # diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
+ # hacked t2l to split to model & subfolder if format is "model,subfolder"
+ "CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
+ "CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
+]
+
+CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
+
+class ControlField(BaseModel):
+ image: ImageField = Field(default=None, description="processed image")
+ control_model: Optional[str] = Field(default=None, description="control model used")
+ control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
+ begin_step_percent: float = Field(default=0, ge=0, le=1,
+ description="% of total steps at which controlnet is first applied")
+ end_step_percent: float = Field(default=1, ge=0, le=1,
+ description="% of total steps at which controlnet is last applied")
+
+ class Config:
+ schema_extra = {
+ "required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"]
+ }
+
+
+class ControlOutput(BaseInvocationOutput):
+ """node output for ControlNet info"""
+ # fmt: off
+ type: Literal["control_output"] = "control_output"
+ control: ControlField = Field(default=None, description="The control info dict")
+ # fmt: on
+
+
+class ControlNetInvocation(BaseInvocation):
+ """Collects ControlNet info to pass to other nodes"""
+ # fmt: off
+ type: Literal["controlnet"] = "controlnet"
+ # Inputs
+ image: ImageField = Field(default=None, description="image to process")
+ control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
+ description="control model used")
+ control_weight: float = Field(default=1.0, ge=0, le=1, description="weight given to controlnet")
+ # TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
+ begin_step_percent: float = Field(default=0, ge=0, le=1,
+ description="% of total steps at which controlnet is first applied")
+ end_step_percent: float = Field(default=1, ge=0, le=1,
+ description="% of total steps at which controlnet is last applied")
+ # fmt: on
+
+
+ def invoke(self, context: InvocationContext) -> ControlOutput:
+
+ return ControlOutput(
+ control=ControlField(
+ image=self.image,
+ control_model=self.control_model,
+ control_weight=self.control_weight,
+ begin_step_percent=self.begin_step_percent,
+ end_step_percent=self.end_step_percent,
+ ),
+ )
+
+# TODO: move image processors to separate file (image_analysis.py
+class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
+ """Base class for invocations that preprocess images for ControlNet"""
+
+ # fmt: off
+ type: Literal["image_processor"] = "image_processor"
+ # Inputs
+ image: ImageField = Field(default=None, description="image to process")
+ # fmt: on
+
+
+ def run_processor(self, image):
+ # superclass just passes through image without processing
+ return image
+
+ def invoke(self, context: InvocationContext) -> ImageOutput:
+
+ raw_image = context.services.images.get_pil_image(
+ self.image.image_origin, self.image.image_name
+ )
+ # image type should be PIL.PngImagePlugin.PngImageFile ?
+ processed_image = self.run_processor(raw_image)
+
+ # FIXME: what happened to image metadata?
+ # metadata = context.services.metadata.build_metadata(
+ # session_id=context.graph_execution_state_id, node=self
+ # )
+
+ # currently can't see processed image in node UI without a showImage node,
+ # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
+ image_dto = context.services.images.create(
+ image=processed_image,
+ image_origin=ResourceOrigin.INTERNAL,
+ image_category=ImageCategory.CONTROL,
+ session_id=context.graph_execution_state_id,
+ node_id=self.id,
+ is_intermediate=self.is_intermediate
+ )
+
+ """Builds an ImageOutput and its ImageField"""
+ processed_image_field = ImageField(
+ image_name=image_dto.image_name,
+ image_origin=image_dto.image_origin,
+ )
+ return ImageOutput(
+ image=processed_image_field,
+ # width=processed_image.width,
+ width = image_dto.width,
+ # height=processed_image.height,
+ height = image_dto.height,
+ # mode=processed_image.mode,
+ )
+
+
+class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+ """Canny edge detection for ControlNet"""
+ # fmt: off
+ type: Literal["canny_image_processor"] = "canny_image_processor"
+ # Input
+ low_threshold: float = Field(default=100, ge=0, description="low threshold of Canny pixel gradient")
+ high_threshold: float = Field(default=200, ge=0, description="high threshold of Canny pixel gradient")
+ # fmt: on
+
+ def run_processor(self, image):
+ canny_processor = CannyDetector()
+ processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
+ return processed_image
+
+
+class HedImageprocessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+ """Applies HED edge detection to image"""
+ # fmt: off
+ type: Literal["hed_image_processor"] = "hed_image_processor"
+ # Inputs
+ detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
+ image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
+ # safe not supported in controlnet_aux v0.0.3
+ # safe: bool = Field(default=False, description="whether to use safe mode")
+ scribble: bool = Field(default=False, description="whether to use scribble mode")
+ # fmt: on
+
+ def run_processor(self, image):
+ hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
+ processed_image = hed_processor(image,
+ detect_resolution=self.detect_resolution,
+ image_resolution=self.image_resolution,
+ # safe not supported in controlnet_aux v0.0.3
+ # safe=self.safe,
+ scribble=self.scribble,
+ )
+ return processed_image
+
+
+class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+ """Applies line art processing to image"""
+ # fmt: off
+ type: Literal["lineart_image_processor"] = "lineart_image_processor"
+ # Inputs
+ detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
+ image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
+ coarse: bool = Field(default=False, description="whether to use coarse mode")
+ # fmt: on
+
+ def run_processor(self, image):
+ lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
+ processed_image = lineart_processor(image,
+ detect_resolution=self.detect_resolution,
+ image_resolution=self.image_resolution,
+ coarse=self.coarse)
+ return processed_image
+
+
+class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+ """Applies line art anime processing to image"""
+ # fmt: off
+ type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
+ # Inputs
+ detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
+ image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
+ # fmt: on
+
+ def run_processor(self, image):
+ processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
+ processed_image = processor(image,
+ detect_resolution=self.detect_resolution,
+ image_resolution=self.image_resolution,
+ )
+ return processed_image
+
+
+class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+ """Applies Openpose processing to image"""
+ # fmt: off
+ type: Literal["openpose_image_processor"] = "openpose_image_processor"
+ # Inputs
+ hand_and_face: bool = Field(default=False, description="whether to use hands and face mode")
+ detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
+ image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
+ # fmt: on
+
+ def run_processor(self, image):
+ openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
+ processed_image = openpose_processor(image,
+ detect_resolution=self.detect_resolution,
+ image_resolution=self.image_resolution,
+ hand_and_face=self.hand_and_face,
+ )
+ return processed_image
+
+
+class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+ """Applies Midas depth processing to image"""
+ # fmt: off
+ type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
+ # Inputs
+ a_mult: float = Field(default=2.0, ge=0, description="Midas parameter a = amult * PI")
+ bg_th: float = Field(default=0.1, ge=0, description="Midas parameter bg_th")
+ # depth_and_normal not supported in controlnet_aux v0.0.3
+ # depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
+ # fmt: on
+
+ def run_processor(self, image):
+ midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
+ processed_image = midas_processor(image,
+ a=np.pi * self.a_mult,
+ bg_th=self.bg_th,
+ # dept_and_normal not supported in controlnet_aux v0.0.3
+ # depth_and_normal=self.depth_and_normal,
+ )
+ return processed_image
+
+
+class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+ """Applies NormalBae processing to image"""
+ # fmt: off
+ type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
+ # Inputs
+ detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
+ image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
+ # fmt: on
+
+ def run_processor(self, image):
+ normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
+ processed_image = normalbae_processor(image,
+ detect_resolution=self.detect_resolution,
+ image_resolution=self.image_resolution)
+ return processed_image
+
+
+class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+ """Applies MLSD processing to image"""
+ # fmt: off
+ type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
+ # Inputs
+ detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
+ image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
+ thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter thr_v")
+ thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter thr_d")
+ # fmt: on
+
+ def run_processor(self, image):
+ mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
+ processed_image = mlsd_processor(image,
+ detect_resolution=self.detect_resolution,
+ image_resolution=self.image_resolution,
+ thr_v=self.thr_v,
+ thr_d=self.thr_d)
+ return processed_image
+
+
+class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+ """Applies PIDI processing to image"""
+ # fmt: off
+ type: Literal["pidi_image_processor"] = "pidi_image_processor"
+ # Inputs
+ detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
+ image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
+ safe: bool = Field(default=False, description="whether to use safe mode")
+ scribble: bool = Field(default=False, description="whether to use scribble mode")
+ # fmt: on
+
+ def run_processor(self, image):
+ pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
+ processed_image = pidi_processor(image,
+ detect_resolution=self.detect_resolution,
+ image_resolution=self.image_resolution,
+ safe=self.safe,
+ scribble=self.scribble)
+ return processed_image
+
+
+class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+ """Applies content shuffle processing to image"""
+ # fmt: off
+ type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
+ # Inputs
+ detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
+ image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
+ h: Union[int | None] = Field(default=512, ge=0, description="content shuffle h parameter")
+ w: Union[int | None] = Field(default=512, ge=0, description="content shuffle w parameter")
+ f: Union[int | None] = Field(default=256, ge=0, description="cont")
+ # fmt: on
+
+ def run_processor(self, image):
+ content_shuffle_processor = ContentShuffleDetector()
+ processed_image = content_shuffle_processor(image,
+ detect_resolution=self.detect_resolution,
+ image_resolution=self.image_resolution,
+ h=self.h,
+ w=self.w,
+ f=self.f
+ )
+ return processed_image
+
+
+# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
+class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+ """Applies Zoe depth processing to image"""
+ # fmt: off
+ type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
+ # fmt: on
+
+ def run_processor(self, image):
+ zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
+ processed_image = zoe_depth_processor(image)
+ return processed_image
+
+
+class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
+ """Applies mediapipe face processing to image"""
+ # fmt: off
+ type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
+ # Inputs
+ max_faces: int = Field(default=1, ge=1, description="maximum number of faces to detect")
+ min_confidence: float = Field(default=0.5, ge=0, le=1, description="minimum confidence for face detection")
+ # fmt: on
+
+ def run_processor(self, image):
+ mediapipe_face_processor = MediapipeFaceDetector()
+ processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
+ return processed_image
diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py
index 26e06a2af8..5275116a2a 100644
--- a/invokeai/app/invocations/cv.py
+++ b/invokeai/app/invocations/cv.py
@@ -7,7 +7,7 @@ import numpy
from PIL import Image, ImageOps
from pydantic import BaseModel, Field
-from invokeai.app.models.image import ImageCategory, ImageField, ImageType
+from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
@@ -37,10 +37,10 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
mask = context.services.images.get_pil_image(
- self.mask.image_type, self.mask.image_name
+ self.mask.image_origin, self.mask.image_name
)
# Convert to cv image/mask
@@ -57,16 +57,17 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
image_dto = context.services.images.create(
image=image_inpainted,
- image_type=ImageType.INTERMEDIATE,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
+ is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
- image_type=image_dto.image_type,
+ image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py
index 1805362416..53d4d16330 100644
--- a/invokeai/app/invocations/generate.py
+++ b/invokeai/app/invocations/generate.py
@@ -3,16 +3,20 @@
from functools import partial
from typing import Literal, Optional, Union, get_args
+import torch
+from diffusers import ControlNetModel
from pydantic import BaseModel, Field
-from invokeai.app.models.image import ImageCategory, ImageType, ColorField, ImageField
+from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
+ ResourceOrigin)
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods
-from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
-from .image import ImageOutput
-from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
+
+from ...backend.generator import Img2Img, Inpaint, InvokeAIGenerator, Txt2Img
from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.step_callback import stable_diffusion_step_callback
+from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
+from .image import ImageOutput
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())]
@@ -53,6 +57,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
+ progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
+ control_model: Optional[str] = Field(default=None, description="The control model to use")
+ control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
# fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching?
@@ -73,17 +80,35 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# Handle invalid model parameter
model = context.services.model_manager.get_model(self.model,node=self,context=context)
+ # loading controlnet image (currently requires pre-processed image)
+ control_image = (
+ None if self.control_image is None
+ else context.services.images.get_pil_image(
+ self.control_image.image_origin, self.control_image.image_name
+ )
+ )
+ # loading controlnet model
+ if (self.control_model is None or self.control_model==''):
+ control_model = None
+ else:
+ # FIXME: change this to dropdown menu?
+ # FIXME: generalize so don't have to hardcode torch_dtype and device
+ control_model = ControlNetModel.from_pretrained(self.control_model,
+ torch_dtype=torch.float16).to("cuda")
+
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
- outputs = Txt2Img(model).generate(
+ txt2img = Txt2Img(model, control_model=control_model)
+ outputs = txt2img.generate(
prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context, source_node_id),
+ control_image=control_image,
**self.dict(
- exclude={"prompt"}
+ exclude={"prompt", "control_image" }
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
@@ -92,16 +117,17 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
image_dto = context.services.images.create(
image=generate_output.image,
- image_type=ImageType.RESULT,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
+ is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
- image_type=image_dto.image_type,
+ image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
@@ -141,7 +167,7 @@ class ImageToImageInvocation(TextToImageInvocation):
None
if self.image is None
else context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
)
@@ -172,16 +198,17 @@ class ImageToImageInvocation(TextToImageInvocation):
image_dto = context.services.images.create(
image=generator_output.image,
- image_type=ImageType.RESULT,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
+ is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
- image_type=image_dto.image_type,
+ image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
@@ -253,13 +280,13 @@ class InpaintInvocation(ImageToImageInvocation):
None
if self.image is None
else context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
)
mask = (
None
if self.mask is None
- else context.services.images.get_pil_image(self.mask.image_type, self.mask.image_name)
+ else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
)
# Handle invalid model parameter
@@ -287,16 +314,17 @@ class InpaintInvocation(ImageToImageInvocation):
image_dto = context.services.images.create(
image=generator_output.image,
- image_type=ImageType.RESULT,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
+ is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
- image_type=image_dto.image_type,
+ image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py
index 21dfb4c1cd..d048410468 100644
--- a/invokeai/app/invocations/image.py
+++ b/invokeai/app/invocations/image.py
@@ -7,7 +7,7 @@ import numpy
from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import BaseModel, Field
-from ..models.image import ImageCategory, ImageField, ImageType
+from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@@ -72,12 +72,12 @@ class LoadImageInvocation(BaseInvocation):
)
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
- image = context.services.images.get_pil_image(self.image.image_type, self.image.image_name)
+ image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
return ImageOutput(
image=ImageField(
image_name=self.image.image_name,
- image_type=self.image.image_type,
+ image_origin=self.image.image_origin,
),
width=image.width,
height=image.height,
@@ -96,7 +96,7 @@ class ShowImageInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
if image:
image.show()
@@ -106,7 +106,7 @@ class ShowImageInvocation(BaseInvocation):
return ImageOutput(
image=ImageField(
image_name=self.image.image_name,
- image_type=self.image.image_type,
+ image_origin=self.image.image_origin,
),
width=image.width,
height=image.height,
@@ -129,7 +129,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
image_crop = Image.new(
@@ -139,16 +139,17 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create(
image=image_crop,
- image_type=ImageType.RESULT,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
+ is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
- image_type=image_dto.image_type,
+ image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
@@ -171,17 +172,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image(
- self.base_image.image_type, self.base_image.image_name
+ self.base_image.image_origin, self.base_image.image_name
)
image = context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
mask = (
None
if self.mask is None
else ImageOps.invert(
context.services.images.get_pil_image(
- self.mask.image_type, self.mask.image_name
+ self.mask.image_origin, self.mask.image_name
)
)
)
@@ -200,16 +201,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create(
image=new_image,
- image_type=ImageType.RESULT,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
+ is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
- image_type=image_dto.image_type,
+ image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
@@ -229,7 +231,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
image_mask = image.split()[-1]
@@ -238,15 +240,16 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create(
image=image_mask,
- image_type=ImageType.RESULT,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.MASK,
node_id=self.id,
session_id=context.graph_execution_state_id,
+ is_intermediate=self.is_intermediate,
)
return MaskOutput(
mask=ImageField(
- image_type=image_dto.image_type, image_name=image_dto.image_name
+ image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
@@ -266,25 +269,26 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image(
- self.image1.image_type, self.image1.image_name
+ self.image1.image_origin, self.image1.image_name
)
image2 = context.services.images.get_pil_image(
- self.image2.image_type, self.image2.image_name
+ self.image2.image_origin, self.image2.image_name
)
multiply_image = ImageChops.multiply(image1, image2)
image_dto = context.services.images.create(
image=multiply_image,
- image_type=ImageType.RESULT,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
+ is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
- image_type=image_dto.image_type, image_name=image_dto.image_name
+ image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
@@ -307,22 +311,23 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
channel_image = image.getchannel(self.channel)
image_dto = context.services.images.create(
image=channel_image,
- image_type=ImageType.RESULT,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
+ is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
- image_type=image_dto.image_type, image_name=image_dto.image_name
+ image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
@@ -345,22 +350,23 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
converted_image = image.convert(self.mode)
image_dto = context.services.images.create(
image=converted_image,
- image_type=ImageType.RESULT,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
+ is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
- image_type=image_dto.image_type, image_name=image_dto.image_name
+ image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
@@ -381,7 +387,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
blur = (
@@ -393,16 +399,126 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create(
image=blur_image,
- image_type=ImageType.RESULT,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
+ is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
- image_type=image_dto.image_type,
+ image_origin=image_dto.image_origin,
+ ),
+ width=image_dto.width,
+ height=image_dto.height,
+ )
+
+
+PIL_RESAMPLING_MODES = Literal[
+ "nearest",
+ "box",
+ "bilinear",
+ "hamming",
+ "bicubic",
+ "lanczos",
+]
+
+
+PIL_RESAMPLING_MAP = {
+ "nearest": Image.Resampling.NEAREST,
+ "box": Image.Resampling.BOX,
+ "bilinear": Image.Resampling.BILINEAR,
+ "hamming": Image.Resampling.HAMMING,
+ "bicubic": Image.Resampling.BICUBIC,
+ "lanczos": Image.Resampling.LANCZOS,
+}
+
+
+class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
+ """Resizes an image to specific dimensions"""
+
+ # fmt: off
+ type: Literal["img_resize"] = "img_resize"
+
+ # Inputs
+ image: Union[ImageField, None] = Field(default=None, description="The image to resize")
+ width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
+ height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
+ resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
+ # fmt: on
+
+ def invoke(self, context: InvocationContext) -> ImageOutput:
+ image = context.services.images.get_pil_image(
+ self.image.image_origin, self.image.image_name
+ )
+
+ resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
+
+ resize_image = image.resize(
+ (self.width, self.height),
+ resample=resample_mode,
+ )
+
+ image_dto = context.services.images.create(
+ image=resize_image,
+ image_origin=ResourceOrigin.INTERNAL,
+ image_category=ImageCategory.GENERAL,
+ node_id=self.id,
+ session_id=context.graph_execution_state_id,
+ is_intermediate=self.is_intermediate,
+ )
+
+ return ImageOutput(
+ image=ImageField(
+ image_name=image_dto.image_name,
+ image_origin=image_dto.image_origin,
+ ),
+ width=image_dto.width,
+ height=image_dto.height,
+ )
+
+
+class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
+ """Scales an image by a factor"""
+
+ # fmt: off
+ type: Literal["img_scale"] = "img_scale"
+
+ # Inputs
+ image: Union[ImageField, None] = Field(default=None, description="The image to scale")
+ scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
+ resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
+ # fmt: on
+
+ def invoke(self, context: InvocationContext) -> ImageOutput:
+ image = context.services.images.get_pil_image(
+ self.image.image_origin, self.image.image_name
+ )
+
+ resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
+ width = int(image.width * self.scale_factor)
+ height = int(image.height * self.scale_factor)
+
+ resize_image = image.resize(
+ (width, height),
+ resample=resample_mode,
+ )
+
+ image_dto = context.services.images.create(
+ image=resize_image,
+ image_origin=ResourceOrigin.INTERNAL,
+ image_category=ImageCategory.GENERAL,
+ node_id=self.id,
+ session_id=context.graph_execution_state_id,
+ is_intermediate=self.is_intermediate,
+ )
+
+ return ImageOutput(
+ image=ImageField(
+ image_name=image_dto.image_name,
+ image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
@@ -423,7 +539,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
@@ -433,16 +549,17 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create(
image=lerp_image,
- image_type=ImageType.RESULT,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
+ is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
- image_type=image_dto.image_type,
+ image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
@@ -463,7 +580,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
image_arr = numpy.asarray(image, dtype=numpy.float32)
@@ -478,16 +595,17 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create(
image=ilerp_image,
- image_type=ImageType.RESULT,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
+ is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
- image_type=image_dto.image_type,
+ image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py
index 17a43dbdac..a06780c1f5 100644
--- a/invokeai/app/invocations/infill.py
+++ b/invokeai/app/invocations/infill.py
@@ -11,7 +11,7 @@ from invokeai.app.invocations.image import ImageOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch
-from ..models.image import ColorField, ImageCategory, ImageField, ImageType
+from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
InvocationContext,
@@ -135,7 +135,7 @@ class InfillColorInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
@@ -145,16 +145,17 @@ class InfillColorInvocation(BaseInvocation):
image_dto = context.services.images.create(
image=infilled,
- image_type=ImageType.RESULT,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
+ is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
- image_type=image_dto.image_type,
+ image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
@@ -179,7 +180,7 @@ class InfillTileInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
infilled = tile_fill_missing(
@@ -189,16 +190,17 @@ class InfillTileInvocation(BaseInvocation):
image_dto = context.services.images.create(
image=infilled,
- image_type=ImageType.RESULT,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
+ is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
- image_type=image_dto.image_type,
+ image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
@@ -216,7 +218,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
if PatchMatch.patchmatch_available():
@@ -226,16 +228,17 @@ class InfillPatchMatchInvocation(BaseInvocation):
image_dto = context.services.images.create(
image=infilled,
- image_type=ImageType.RESULT,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
+ is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
- image_type=image_dto.image_type,
+ image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py
index abfe92f828..6dc491611e 100644
--- a/invokeai/app/invocations/latent.py
+++ b/invokeai/app/invocations/latent.py
@@ -1,37 +1,36 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
-from typing import Literal, Optional, Union
+from contextlib import ExitStack
+from typing import List, Literal, Optional, Union
import einops
import torch
+from diffusers import ControlNetModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator
-from contextlib import ExitStack
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback
+from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.image_util.seamless import configure_model_padding
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
- ConditioningData, StableDiffusionGeneratorPipeline,
+ ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
image_resized_to_grid_as_tensor)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_torch_device, torch_dtype
-from ..services.image_file_storage import ImageType
-from ..services.model_manager_service import ModelManagerService
+from ...backend.model_management.lora import ModelPatcher
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .compel import ConditioningField
-from .image import ImageCategory, ImageField, ImageOutput
+from .controlnet_image_processors import ControlField
+from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
-from ...backend.model_management.lora import LoRAHelper
-
-
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
@@ -93,10 +92,12 @@ def get_scheduler(
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict())
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
+
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)
+
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
@@ -171,12 +172,13 @@ class TextToLatentsInvocation(BaseInvocation):
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
- cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
+ cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
unet: UNetField = Field(default=None, description="UNet submodel")
+ control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
# fmt: on
# Schema customisation
@@ -184,6 +186,10 @@ class TextToLatentsInvocation(BaseInvocation):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
+ "type_hints": {
+ "model": "model",
+ "control": "control",
+ }
},
}
@@ -243,6 +249,82 @@ class TextToLatentsInvocation(BaseInvocation):
precision="float16" if unet.dtype == torch.float16 else "float32",
#precision="float16", # TODO:
)
+
+ def prep_control_data(self,
+ context: InvocationContext,
+ model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
+ control_input: List[ControlField],
+ latents_shape: List[int],
+ do_classifier_free_guidance: bool = True,
+ ) -> List[ControlNetData]:
+ # assuming fixed dimensional scaling of 8:1 for image:latents
+ control_height_resize = latents_shape[2] * 8
+ control_width_resize = latents_shape[3] * 8
+ if control_input is None:
+ # print("control input is None")
+ control_list = None
+ elif isinstance(control_input, list) and len(control_input) == 0:
+ # print("control input is empty list")
+ control_list = None
+ elif isinstance(control_input, ControlField):
+ # print("control input is ControlField")
+ control_list = [control_input]
+ elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
+ # print("control input is list[ControlField]")
+ control_list = control_input
+ else:
+ # print("input control is unrecognized:", type(self.control))
+ control_list = None
+ if (control_list is None):
+ control_data = None
+ # from above handling, any control that is not None should now be of type list[ControlField]
+ else:
+ # FIXME: add checks to skip entry if model or image is None
+ # and if weight is None, populate with default 1.0?
+ control_data = []
+ control_models = []
+ for control_info in control_list:
+ # handle control models
+ if ("," in control_info.control_model):
+ control_model_split = control_info.control_model.split(",")
+ control_name = control_model_split[0]
+ control_subfolder = control_model_split[1]
+ print("Using HF model subfolders")
+ print(" control_name: ", control_name)
+ print(" control_subfolder: ", control_subfolder)
+ control_model = ControlNetModel.from_pretrained(control_name,
+ subfolder=control_subfolder,
+ torch_dtype=model.unet.dtype).to(model.device)
+ else:
+ control_model = ControlNetModel.from_pretrained(control_info.control_model,
+ torch_dtype=model.unet.dtype).to(model.device)
+ control_models.append(control_model)
+ control_image_field = control_info.image
+ input_image = context.services.images.get_pil_image(control_image_field.image_origin,
+ control_image_field.image_name)
+ # self.image.image_type, self.image.image_name
+ # FIXME: still need to test with different widths, heights, devices, dtypes
+ # and add in batch_size, num_images_per_prompt?
+ # and do real check for classifier_free_guidance?
+ # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
+ control_image = model.prepare_control_image(
+ image=input_image,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ width=control_width_resize,
+ height=control_height_resize,
+ # batch_size=batch_size * num_images_per_prompt,
+ # num_images_per_prompt=num_images_per_prompt,
+ device=control_model.device,
+ dtype=control_model.dtype,
+ )
+ control_item = ControlNetData(model=control_model,
+ image_tensor=control_image,
+ weight=control_info.control_weight,
+ begin_step_percent=control_info.begin_step_percent,
+ end_step_percent=control_info.end_step_percent)
+ control_data.append(control_item)
+ # MultiControlNetModel has been refactored out, just need list[ControlNetData]
+ return control_data
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
@@ -269,13 +351,19 @@ class TextToLatentsInvocation(BaseInvocation):
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
- with LoRAHelper.apply_lora_unet(pipeline.unet, loras):
+ print("type of control input: ", type(self.control))
+ control_data = self.prep_control_data(model=pipeline, context=context, control_input=self.control,
+ latents_shape=noise.shape,
+ do_classifier_free_guidance=(self.cfg_scale >= 1.0))
+
+ with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
+ control_data=control_data, # list[ControlNetData]
callback=step_callback
)
@@ -286,7 +374,6 @@ class TextToLatentsInvocation(BaseInvocation):
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
-
class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image."""
@@ -294,13 +381,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
- strength: float = Field(default=0.5, description="The strength of the latents to use")
+ strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
+ "type_hints": {
+ "model": "model",
+ "control": "control",
+ }
},
}
@@ -315,7 +406,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
- #unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
)
@@ -345,7 +435,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
- with LoRAHelper.apply_lora_unet(pipeline.unet, loras):
+ with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
@@ -413,7 +503,7 @@ class LatentsToImageInvocation(BaseInvocation):
image_dto = context.services.images.create(
image=image,
- image_type=ImageType.RESULT,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
@@ -459,6 +549,7 @@ class ResizeLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
+ # context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents)
@@ -489,6 +580,7 @@ class ScaleLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
+ # context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents)
@@ -513,8 +605,11 @@ class ImageToLatentsInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
+ # image = context.services.images.get(
+ # self.image.image_type, self.image.image_name
+ # )
image = context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
@@ -543,6 +638,6 @@ class ImageToLatentsInvocation(BaseInvocation):
latents = 0.18215 * latents
name = f"{context.graph_execution_state_id}__{self.id}"
+ # context.services.latents.set(name, latents)
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)
-
diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py
index 2ce58c016b..113b630200 100644
--- a/invokeai/app/invocations/math.py
+++ b/invokeai/app/invocations/math.py
@@ -34,6 +34,15 @@ class IntOutput(BaseInvocationOutput):
# fmt: on
+class FloatOutput(BaseInvocationOutput):
+ """A float output"""
+
+ # fmt: off
+ type: Literal["float_output"] = "float_output"
+ param: float = Field(default=None, description="The output float")
+ # fmt: on
+
+
class AddInvocation(BaseInvocation, MathInvocationConfig):
"""Adds two numbers"""
diff --git a/invokeai/app/invocations/params.py b/invokeai/app/invocations/params.py
index fcc7f1737a..1c6297665b 100644
--- a/invokeai/app/invocations/params.py
+++ b/invokeai/app/invocations/params.py
@@ -3,7 +3,7 @@
from typing import Literal
from pydantic import Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
-from .math import IntOutput
+from .math import IntOutput, FloatOutput
# Pass-through parameter nodes - used by subgraphs
@@ -16,3 +16,13 @@ class ParamIntInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a)
+
+class ParamFloatInvocation(BaseInvocation):
+ """A float parameter"""
+ #fmt: off
+ type: Literal["param_float"] = "param_float"
+ param: float = Field(default=0.0, description="The float value")
+ #fmt: on
+
+ def invoke(self, context: InvocationContext) -> FloatOutput:
+ return FloatOutput(param=self.param)
diff --git a/invokeai/app/invocations/reconstruct.py b/invokeai/app/invocations/reconstruct.py
index 024134cd46..5313411400 100644
--- a/invokeai/app/invocations/reconstruct.py
+++ b/invokeai/app/invocations/reconstruct.py
@@ -2,7 +2,7 @@ from typing import Literal, Union
from pydantic import Field
-from invokeai.app.models.image import ImageCategory, ImageField, ImageType
+from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
@@ -29,7 +29,7 @@ class RestoreFaceInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
@@ -43,16 +43,17 @@ class RestoreFaceInvocation(BaseInvocation):
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
- image_type=ImageType.INTERMEDIATE,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
+ is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
- image_type=image_dto.image_type,
+ image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py
index 75aeec784f..80e1567047 100644
--- a/invokeai/app/invocations/upscale.py
+++ b/invokeai/app/invocations/upscale.py
@@ -4,7 +4,7 @@ from typing import Literal, Union
from pydantic import Field
-from invokeai.app.models.image import ImageCategory, ImageField, ImageType
+from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
@@ -31,7 +31,7 @@ class UpscaleInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
- self.image.image_type, self.image.image_name
+ self.image.image_origin, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
@@ -45,16 +45,17 @@ class UpscaleInvocation(BaseInvocation):
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
- image_type=ImageType.RESULT,
+ image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
+ is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
- image_type=image_dto.image_type,
+ image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
diff --git a/invokeai/app/models/image.py b/invokeai/app/models/image.py
index 544951ea34..6d48f2dbb1 100644
--- a/invokeai/app/models/image.py
+++ b/invokeai/app/models/image.py
@@ -5,31 +5,52 @@ from pydantic import BaseModel, Field
from invokeai.app.util.metaenum import MetaEnum
-class ImageType(str, Enum, metaclass=MetaEnum):
- """The type of an image."""
+class ResourceOrigin(str, Enum, metaclass=MetaEnum):
+ """The origin of a resource (eg image).
- RESULT = "results"
- UPLOAD = "uploads"
- INTERMEDIATE = "intermediates"
+ - INTERNAL: The resource was created by the application.
+ - EXTERNAL: The resource was not created by the application.
+ This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
+ """
+
+ INTERNAL = "internal"
+ """The resource was created by the application."""
+ EXTERNAL = "external"
+ """The resource was not created by the application.
+ This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
+ """
-class InvalidImageTypeException(ValueError):
- """Raised when a provided value is not a valid ImageType.
+class InvalidOriginException(ValueError):
+ """Raised when a provided value is not a valid ResourceOrigin.
Subclasses `ValueError`.
"""
- def __init__(self, message="Invalid image type."):
+ def __init__(self, message="Invalid resource origin."):
super().__init__(message)
class ImageCategory(str, Enum, metaclass=MetaEnum):
- """The category of an image. Use ImageCategory.OTHER for non-default categories."""
+ """The category of an image.
+
+ - GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
+ - MASK: The image is a mask image.
+ - CONTROL: The image is a ControlNet control image.
+ - USER: The image is a user-provide image.
+ - OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
+ """
GENERAL = "general"
- CONTROL = "control"
+ """GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
MASK = "mask"
+ """MASK: The image is a mask image."""
+ CONTROL = "control"
+ """CONTROL: The image is a ControlNet control image."""
+ USER = "user"
+ """USER: The image is a user-provide image."""
OTHER = "other"
+ """OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
class InvalidImageCategoryException(ValueError):
@@ -45,13 +66,13 @@ class InvalidImageCategoryException(ValueError):
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
- image_type: ImageType = Field(
- default=ImageType.RESULT, description="The type of the image"
+ image_origin: ResourceOrigin = Field(
+ default=ResourceOrigin.INTERNAL, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
- schema_extra = {"required": ["image_type", "image_name"]}
+ schema_extra = {"required": ["image_origin", "image_name"]}
class ColorField(BaseModel):
@@ -62,3 +83,11 @@ class ColorField(BaseModel):
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)
+
+
+class ProgressImage(BaseModel):
+ """The progress image sent intermittently during processing"""
+
+ width: int = Field(description="The effective width of the image in pixels")
+ height: int = Field(description="The effective height of the image in pixels")
+ dataURL: str = Field(description="The image data as a b64 data URL")
diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py
index 7e0b5bdb07..75fc2c74d3 100644
--- a/invokeai/app/services/events.py
+++ b/invokeai/app/services/events.py
@@ -1,7 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
-from typing import Any, Optional
-from invokeai.app.api.models.images import ProgressImage
+from typing import Any
+from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import SDModelType, SDModelInfo
from invokeai.app.models.exceptions import CanceledException
diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py
index 44688ada0a..60e196faa1 100644
--- a/invokeai/app/services/graph.py
+++ b/invokeai/app/services/graph.py
@@ -60,6 +60,35 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
node_input_field = node_inputs.get(field) or None
return node_input_field
+from typing import Optional, Union, List, get_args
+
+def is_union_subtype(t1, t2):
+ t1_args = get_args(t1)
+ t2_args = get_args(t2)
+
+ if not t1_args:
+ # t1 is a single type
+ return t1 in t2_args
+ else:
+ # t1 is a Union, check that all of its types are in t2_args
+ return all(arg in t2_args for arg in t1_args)
+
+def is_list_or_contains_list(t):
+ t_args = get_args(t)
+
+ # If the type is a List
+ if get_origin(t) is list:
+ return True
+
+ # If the type is a Union
+ elif t_args:
+ # Check if any of the types in the Union is a List
+ for arg in t_args:
+ if get_origin(arg) is list:
+ return True
+
+ return False
+
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if not from_type:
@@ -85,7 +114,8 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if to_type in get_args(from_type):
return True
- if not issubclass(from_type, to_type):
+ # if not issubclass(from_type, to_type):
+ if not is_union_subtype(from_type, to_type):
return False
else:
return False
@@ -694,7 +724,11 @@ class Graph(BaseModel):
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
# Verify that all outputs are lists
- if not all((get_origin(f) == list for f in output_fields)):
+ # if not all((get_origin(f) == list for f in output_fields)):
+ # return False
+
+ # Verify that all outputs are lists
+ if not all(is_list_or_contains_list(f) for f in output_fields):
return False
# Verify that all outputs match the input type (are a base class or the same class)
diff --git a/invokeai/app/services/image_file_storage.py b/invokeai/app/services/image_file_storage.py
index 46070b3bf2..68a994ea75 100644
--- a/invokeai/app/services/image_file_storage.py
+++ b/invokeai/app/services/image_file_storage.py
@@ -9,7 +9,7 @@ from PIL.Image import Image as PILImageType
from PIL import Image, PngImagePlugin
from send2trash import send2trash
-from invokeai.app.models.image import ImageType
+from invokeai.app.models.image import ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@@ -40,13 +40,13 @@ class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
- def get(self, image_type: ImageType, image_name: str) -> PILImageType:
+ def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def get_path(
- self, image_type: ImageType, image_name: str, thumbnail: bool = False
+ self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
"""Gets the internal path to an image or thumbnail."""
pass
@@ -62,7 +62,7 @@ class ImageFileStorageBase(ABC):
def save(
self,
image: PILImageType,
- image_type: ImageType,
+ image_origin: ResourceOrigin,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
@@ -71,7 +71,7 @@ class ImageFileStorageBase(ABC):
pass
@abstractmethod
- def delete(self, image_type: ImageType, image_name: str) -> None:
+ def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
@@ -93,17 +93,17 @@ class DiskImageFileStorage(ImageFileStorageBase):
Path(output_folder).mkdir(parents=True, exist_ok=True)
# TODO: don't hard-code. get/save/delete should maybe take subpath?
- for image_type in ImageType:
- Path(os.path.join(output_folder, image_type)).mkdir(
+ for image_origin in ResourceOrigin:
+ Path(os.path.join(output_folder, image_origin)).mkdir(
parents=True, exist_ok=True
)
- Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
+ Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
- def get(self, image_type: ImageType, image_name: str) -> PILImageType:
+ def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
try:
- image_path = self.get_path(image_type, image_name)
+ image_path = self.get_path(image_origin, image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
@@ -117,13 +117,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
def save(
self,
image: PILImageType,
- image_type: ImageType,
+ image_origin: ResourceOrigin,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
) -> None:
try:
- image_path = self.get_path(image_type, image_name)
+ image_path = self.get_path(image_origin, image_name)
if metadata is not None:
pnginfo = PngImagePlugin.PngInfo()
@@ -133,7 +133,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image.save(image_path, "PNG")
thumbnail_name = get_thumbnail_name(image_name)
- thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)
+ thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size)
thumbnail_image.save(thumbnail_path)
@@ -142,10 +142,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
except Exception as e:
raise ImageFileSaveException from e
- def delete(self, image_type: ImageType, image_name: str) -> None:
+ def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
try:
basename = os.path.basename(image_name)
- image_path = self.get_path(image_type, basename)
+ image_path = self.get_path(image_origin, basename)
if os.path.exists(image_path):
send2trash(image_path)
@@ -153,7 +153,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name)
- thumbnail_path = self.get_path(image_type, thumbnail_name, True)
+ thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
if os.path.exists(thumbnail_path):
send2trash(thumbnail_path)
@@ -164,7 +164,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
- self, image_type: ImageType, image_name: str, thumbnail: bool = False
+ self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
@@ -172,10 +172,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
if thumbnail:
thumbnail_name = get_thumbnail_name(basename)
path = os.path.join(
- self.__output_folder, image_type, "thumbnails", thumbnail_name
+ self.__output_folder, image_origin, "thumbnails", thumbnail_name
)
else:
- path = os.path.join(self.__output_folder, image_type, basename)
+ path = os.path.join(self.__output_folder, image_origin, basename)
abspath = os.path.abspath(path)
diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py
index 4e1f73978b..6907ac3952 100644
--- a/invokeai/app/services/image_record_storage.py
+++ b/invokeai/app/services/image_record_storage.py
@@ -1,20 +1,35 @@
from abc import ABC, abstractmethod
from datetime import datetime
-from typing import Optional, cast
+from typing import Generic, Optional, TypeVar, cast
import sqlite3
import threading
from typing import Optional, Union
+from pydantic import BaseModel, Field
+from pydantic.generics import GenericModel
+
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.models.image import (
ImageCategory,
- ImageType,
+ ResourceOrigin,
)
from invokeai.app.services.models.image_record import (
ImageRecord,
+ ImageRecordChanges,
deserialize_image_record,
)
-from invokeai.app.services.item_storage import PaginatedResults
+
+T = TypeVar("T", bound=BaseModel)
+
+class OffsetPaginatedResults(GenericModel, Generic[T]):
+ """Offset-paginated results"""
+
+ # fmt: off
+ items: list[T] = Field(description="Items")
+ offset: int = Field(description="Offset from which to retrieve items")
+ limit: int = Field(description="Limit of items to get")
+ total: int = Field(description="Total number of items in result")
+ # fmt: on
# TODO: Should these excpetions subclass existing python exceptions?
@@ -45,25 +60,36 @@ class ImageRecordStorageBase(ABC):
# TODO: Implement an `update()` method
@abstractmethod
- def get(self, image_type: ImageType, image_name: str) -> ImageRecord:
+ def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
+ @abstractmethod
+ def update(
+ self,
+ image_name: str,
+ image_origin: ResourceOrigin,
+ changes: ImageRecordChanges,
+ ) -> None:
+ """Updates an image record."""
+ pass
+
@abstractmethod
def get_many(
self,
- image_type: ImageType,
- image_category: ImageCategory,
- page: int = 0,
- per_page: int = 10,
- ) -> PaginatedResults[ImageRecord]:
+ offset: int = 0,
+ limit: int = 10,
+ image_origin: Optional[ResourceOrigin] = None,
+ categories: Optional[list[ImageCategory]] = None,
+ is_intermediate: Optional[bool] = None,
+ ) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass
# TODO: The database has a nullable `deleted_at` column, currently unused.
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
@abstractmethod
- def delete(self, image_type: ImageType, image_name: str) -> None:
+ def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
"""Deletes an image record."""
pass
@@ -71,13 +97,14 @@ class ImageRecordStorageBase(ABC):
def save(
self,
image_name: str,
- image_type: ImageType,
+ image_origin: ResourceOrigin,
image_category: ImageCategory,
width: int,
height: int,
session_id: Optional[str],
node_id: Optional[str],
metadata: Optional[ImageMetadata],
+ is_intermediate: bool = False,
) -> datetime:
"""Saves an image record."""
pass
@@ -91,7 +118,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def __init__(self, filename: str) -> None:
super().__init__()
-
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
@@ -117,7 +143,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility
- image_type TEXT NOT NULL,
+ image_origin TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL,
width INTEGER NOT NULL,
@@ -125,9 +151,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id TEXT,
node_id TEXT,
metadata TEXT,
- created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ is_intermediate BOOLEAN DEFAULT FALSE,
+ created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
- updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
@@ -142,7 +169,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
)
self._cursor.execute(
"""--sql
- CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type);
+ CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
"""
)
self._cursor.execute(
@@ -169,7 +196,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
)
- def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]:
+ def get(
+ self, image_origin: ResourceOrigin, image_name: str
+ ) -> Union[ImageRecord, None]:
try:
self._lock.acquire()
@@ -193,38 +222,110 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result))
+ def update(
+ self,
+ image_name: str,
+ image_origin: ResourceOrigin,
+ changes: ImageRecordChanges,
+ ) -> None:
+ try:
+ self._lock.acquire()
+ # Change the category of the image
+ if changes.image_category is not None:
+ self._cursor.execute(
+ f"""--sql
+ UPDATE images
+ SET image_category = ?
+ WHERE image_name = ?;
+ """,
+ (changes.image_category, image_name),
+ )
+
+ # Change the session associated with the image
+ if changes.session_id is not None:
+ self._cursor.execute(
+ f"""--sql
+ UPDATE images
+ SET session_id = ?
+ WHERE image_name = ?;
+ """,
+ (changes.session_id, image_name),
+ )
+
+ # Change the image's `is_intermediate`` flag
+ if changes.is_intermediate is not None:
+ self._cursor.execute(
+ f"""--sql
+ UPDATE images
+ SET is_intermediate = ?
+ WHERE image_name = ?;
+ """,
+ (changes.is_intermediate, image_name),
+ )
+ self._conn.commit()
+ except sqlite3.Error as e:
+ self._conn.rollback()
+ raise ImageRecordSaveException from e
+ finally:
+ self._lock.release()
+
def get_many(
self,
- image_type: ImageType,
- image_category: ImageCategory,
- page: int = 0,
- per_page: int = 10,
- ) -> PaginatedResults[ImageRecord]:
+ offset: int = 0,
+ limit: int = 10,
+ image_origin: Optional[ResourceOrigin] = None,
+ categories: Optional[list[ImageCategory]] = None,
+ is_intermediate: Optional[bool] = None,
+ ) -> OffsetPaginatedResults[ImageRecord]:
try:
self._lock.acquire()
- self._cursor.execute(
- f"""--sql
- SELECT * FROM images
- WHERE image_type = ? AND image_category = ?
- ORDER BY created_at DESC
- LIMIT ? OFFSET ?;
- """,
- (image_type.value, image_category.value, per_page, page * per_page),
- )
+ # Manually build two queries - one for the count, one for the records
+ count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
+ images_query = f"""SELECT * FROM images WHERE 1=1\n"""
+
+ query_conditions = ""
+ query_params = []
+
+ if image_origin is not None:
+ query_conditions += f"""AND image_origin = ?\n"""
+ query_params.append(image_origin.value)
+
+ if categories is not None:
+ ## Convert the enum values to unique list of strings
+ category_strings = list(
+ map(lambda c: c.value, set(categories))
+ )
+ # Create the correct length of placeholders
+ placeholders = ",".join("?" * len(category_strings))
+ query_conditions += f"AND image_category IN ( {placeholders} )\n"
+
+ # Unpack the included categories into the query params
+ for c in category_strings:
+ query_params.append(c)
+
+ if is_intermediate is not None:
+ query_conditions += f"""AND is_intermediate = ?\n"""
+ query_params.append(is_intermediate)
+
+ query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
+
+ # Final images query with pagination
+ images_query += query_conditions + query_pagination + ";"
+ # Add all the parameters
+ images_params = query_params.copy()
+ images_params.append(limit)
+ images_params.append(offset)
+ # Build the list of images, deserializing each row
+ self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
-
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
- self._cursor.execute(
- """--sql
- SELECT count(*) FROM images
- WHERE image_type = ? AND image_category = ?
- """,
- (image_type.value, image_category.value),
- )
-
+ # Set up and execute the count query, without pagination
+ count_query += query_conditions + ";"
+ count_params = query_params.copy()
+ self._cursor.execute(count_query, count_params)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
self._conn.rollback()
@@ -232,13 +333,11 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
finally:
self._lock.release()
- pageCount = int(count / per_page) + 1
-
- return PaginatedResults(
- items=images, page=page, pages=pageCount, per_page=per_page, total=count
+ return OffsetPaginatedResults(
+ items=images, offset=offset, limit=limit, total=count
)
- def delete(self, image_type: ImageType, image_name: str) -> None:
+ def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
@@ -258,13 +357,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def save(
self,
image_name: str,
- image_type: ImageType,
+ image_origin: ResourceOrigin,
image_category: ImageCategory,
session_id: Optional[str],
width: int,
height: int,
node_id: Optional[str],
metadata: Optional[ImageMetadata],
+ is_intermediate: bool = False,
) -> datetime:
try:
metadata_json = (
@@ -275,25 +375,27 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""--sql
INSERT OR IGNORE INTO images (
image_name,
- image_type,
+ image_origin,
image_category,
width,
height,
node_id,
session_id,
- metadata
+ metadata,
+ is_intermediate
)
- VALUES (?, ?, ?, ?, ?, ?, ?, ?);
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
- image_type.value,
+ image_origin.value,
image_category.value,
width,
height,
node_id,
session_id,
metadata_json,
+ is_intermediate,
),
)
self._conn.commit()
diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py
index 914dd3b6d3..2618a9763e 100644
--- a/invokeai/app/services/images.py
+++ b/invokeai/app/services/images.py
@@ -1,14 +1,13 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import Optional, TYPE_CHECKING, Union
-import uuid
from PIL.Image import Image as PILImageType
from invokeai.app.models.image import (
ImageCategory,
- ImageType,
+ ResourceOrigin,
InvalidImageCategoryException,
- InvalidImageTypeException,
+ InvalidOriginException,
)
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.image_record_storage import (
@@ -16,10 +15,12 @@ from invokeai.app.services.image_record_storage import (
ImageRecordNotFoundException,
ImageRecordSaveException,
ImageRecordStorageBase,
+ OffsetPaginatedResults,
)
from invokeai.app.services.models.image_record import (
ImageRecord,
ImageDTO,
+ ImageRecordChanges,
image_record_to_dto,
)
from invokeai.app.services.image_file_storage import (
@@ -30,8 +31,8 @@ from invokeai.app.services.image_file_storage import (
)
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase
+from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.urls import UrlServiceBase
-from invokeai.app.util.misc import get_iso_timestamp
if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState
@@ -44,32 +45,42 @@ class ImageServiceABC(ABC):
def create(
self,
image: PILImageType,
- image_type: ImageType,
+ image_origin: ResourceOrigin,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
- metadata: Optional[ImageMetadata] = None,
+ intermediate: bool = False,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@abstractmethod
- def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
+ def update(
+ self,
+ image_origin: ResourceOrigin,
+ image_name: str,
+ changes: ImageRecordChanges,
+ ) -> ImageDTO:
+ """Updates an image."""
+ pass
+
+ @abstractmethod
+ def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
pass
@abstractmethod
- def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
+ def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
- def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
+ def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod
- def get_path(self, image_type: ImageType, image_name: str) -> str:
+ def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
"""Gets an image's path."""
pass
@@ -80,7 +91,7 @@ class ImageServiceABC(ABC):
@abstractmethod
def get_url(
- self, image_type: ImageType, image_name: str, thumbnail: bool = False
+ self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
"""Gets an image's or thumbnail's URL."""
pass
@@ -88,16 +99,17 @@ class ImageServiceABC(ABC):
@abstractmethod
def get_many(
self,
- image_type: ImageType,
- image_category: ImageCategory,
- page: int = 0,
- per_page: int = 10,
- ) -> PaginatedResults[ImageDTO]:
+ offset: int = 0,
+ limit: int = 10,
+ image_origin: Optional[ResourceOrigin] = None,
+ categories: Optional[list[ImageCategory]] = None,
+ is_intermediate: Optional[bool] = None,
+ ) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass
@abstractmethod
- def delete(self, image_type: ImageType, image_name: str):
+ def delete(self, image_origin: ResourceOrigin, image_name: str):
"""Deletes an image."""
pass
@@ -110,6 +122,7 @@ class ImageServiceDependencies:
metadata: MetadataServiceBase
urls: UrlServiceBase
logger: Logger
+ names: NameServiceBase
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
def __init__(
@@ -119,6 +132,7 @@ class ImageServiceDependencies:
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
+ names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self.records = image_record_storage
@@ -126,6 +140,7 @@ class ImageServiceDependencies:
self.metadata = metadata
self.urls = url
self.logger = logger
+ self.names = names
self.graph_execution_manager = graph_execution_manager
@@ -139,6 +154,7 @@ class ImageService(ImageServiceABC):
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
+ names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self._services = ImageServiceDependencies(
@@ -147,29 +163,26 @@ class ImageService(ImageServiceABC):
metadata=metadata,
url=url,
logger=logger,
+ names=names,
graph_execution_manager=graph_execution_manager,
)
def create(
self,
image: PILImageType,
- image_type: ImageType,
+ image_origin: ResourceOrigin,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
+ is_intermediate: bool = False,
) -> ImageDTO:
- if image_type not in ImageType:
- raise InvalidImageTypeException
+ if image_origin not in ResourceOrigin:
+ raise InvalidOriginException
if image_category not in ImageCategory:
raise InvalidImageCategoryException
- image_name = self._create_image_name(
- image_type=image_type,
- image_category=image_category,
- node_id=node_id,
- session_id=session_id,
- )
+ image_name = self._services.names.create_image_name()
metadata = self._get_metadata(session_id, node_id)
@@ -180,10 +193,12 @@ class ImageService(ImageServiceABC):
created_at = self._services.records.save(
# Non-nullable fields
image_name=image_name,
- image_type=image_type,
+ image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
+ # Meta fields
+ is_intermediate=is_intermediate,
# Nullable fields
node_id=node_id,
session_id=session_id,
@@ -191,21 +206,21 @@ class ImageService(ImageServiceABC):
)
self._services.files.save(
- image_type=image_type,
+ image_origin=image_origin,
image_name=image_name,
image=image,
metadata=metadata,
)
- image_url = self._services.urls.get_image_url(image_type, image_name)
+ image_url = self._services.urls.get_image_url(image_origin, image_name)
thumbnail_url = self._services.urls.get_image_url(
- image_type, image_name, True
+ image_origin, image_name, True
)
return ImageDTO(
# Non-nullable fields
image_name=image_name,
- image_type=image_type,
+ image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
@@ -217,6 +232,7 @@ class ImageService(ImageServiceABC):
created_at=created_at,
updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None,
+ is_intermediate=is_intermediate,
# Extra non-nullable fields for DTO
image_url=image_url,
thumbnail_url=thumbnail_url,
@@ -231,9 +247,25 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem saving image record and file")
raise e
- def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
+ def update(
+ self,
+ image_origin: ResourceOrigin,
+ image_name: str,
+ changes: ImageRecordChanges,
+ ) -> ImageDTO:
try:
- return self._services.files.get(image_type, image_name)
+ self._services.records.update(image_name, image_origin, changes)
+ return self.get_dto(image_origin, image_name)
+ except ImageRecordSaveException:
+ self._services.logger.error("Failed to update image record")
+ raise
+ except Exception as e:
+ self._services.logger.error("Problem updating image record")
+ raise e
+
+ def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
+ try:
+ return self._services.files.get(image_origin, image_name)
except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
raise
@@ -241,9 +273,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image file")
raise e
- def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
+ def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
try:
- return self._services.records.get(image_type, image_name)
+ return self._services.records.get(image_origin, image_name)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
@@ -251,14 +283,14 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image record")
raise e
- def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
+ def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
try:
- image_record = self._services.records.get(image_type, image_name)
+ image_record = self._services.records.get(image_origin, image_name)
image_dto = image_record_to_dto(
image_record,
- self._services.urls.get_image_url(image_type, image_name),
- self._services.urls.get_image_url(image_type, image_name, True),
+ self._services.urls.get_image_url(image_origin, image_name),
+ self._services.urls.get_image_url(image_origin, image_name, True),
)
return image_dto
@@ -270,10 +302,10 @@ class ImageService(ImageServiceABC):
raise e
def get_path(
- self, image_type: ImageType, image_name: str, thumbnail: bool = False
+ self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
try:
- return self._services.files.get_path(image_type, image_name, thumbnail)
+ return self._services.files.get_path(image_origin, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
@@ -286,57 +318,58 @@ class ImageService(ImageServiceABC):
raise e
def get_url(
- self, image_type: ImageType, image_name: str, thumbnail: bool = False
+ self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
try:
- return self._services.urls.get_image_url(image_type, image_name, thumbnail)
+ return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_many(
self,
- image_type: ImageType,
- image_category: ImageCategory,
- page: int = 0,
- per_page: int = 10,
- ) -> PaginatedResults[ImageDTO]:
+ offset: int = 0,
+ limit: int = 10,
+ image_origin: Optional[ResourceOrigin] = None,
+ categories: Optional[list[ImageCategory]] = None,
+ is_intermediate: Optional[bool] = None,
+ ) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self._services.records.get_many(
- image_type,
- image_category,
- page,
- per_page,
+ offset,
+ limit,
+ image_origin,
+ categories,
+ is_intermediate,
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
- self._services.urls.get_image_url(image_type, r.image_name),
+ self._services.urls.get_image_url(r.image_origin, r.image_name),
self._services.urls.get_image_url(
- image_type, r.image_name, True
+ r.image_origin, r.image_name, True
),
),
results.items,
)
)
- return PaginatedResults[ImageDTO](
+ return OffsetPaginatedResults[ImageDTO](
items=image_dtos,
- page=results.page,
- pages=results.pages,
- per_page=results.per_page,
+ offset=results.offset,
+ limit=results.limit,
total=results.total,
)
except Exception as e:
self._services.logger.error("Problem getting paginated image DTOs")
raise e
- def delete(self, image_type: ImageType, image_name: str):
+ def delete(self, image_origin: ResourceOrigin, image_name: str):
try:
- self._services.files.delete(image_type, image_name)
- self._services.records.delete(image_type, image_name)
+ self._services.files.delete(image_origin, image_name)
+ self._services.records.delete(image_origin, image_name)
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise
@@ -347,21 +380,6 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem deleting image record and file")
raise e
- def _create_image_name(
- self,
- image_type: ImageType,
- image_category: ImageCategory,
- node_id: Optional[str] = None,
- session_id: Optional[str] = None,
- ) -> str:
- """Create a unique image name."""
- uuid_str = str(uuid.uuid4())
-
- if node_id is not None and session_id is not None:
- return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
-
- return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
-
def _get_metadata(
self, session_id: Optional[str] = None, node_id: Optional[str] = None
) -> Union[ImageMetadata, None]:
diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py
index c1155ff73e..051236b12b 100644
--- a/invokeai/app/services/models/image_record.py
+++ b/invokeai/app/services/models/image_record.py
@@ -1,7 +1,7 @@
import datetime
from typing import Optional, Union
-from pydantic import BaseModel, Field
-from invokeai.app.models.image import ImageCategory, ImageType
+from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
+from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp
@@ -11,8 +11,8 @@ class ImageRecord(BaseModel):
image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
- image_type: ImageType = Field(description="The type of the image.")
- """The type of the image."""
+ image_origin: ResourceOrigin = Field(description="The type of the image.")
+ """The origin of the image."""
image_category: ImageCategory = Field(description="The category of the image.")
"""The category of the image."""
width: int = Field(description="The width of the image in px.")
@@ -31,6 +31,8 @@ class ImageRecord(BaseModel):
description="The deleted timestamp of the image."
)
"""The deleted timestamp of the image."""
+ is_intermediate: bool = Field(description="Whether this is an intermediate image.")
+ """Whether this is an intermediate image."""
session_id: Optional[str] = Field(
default=None,
description="The session ID that generated this image, if it is a generated image.",
@@ -48,13 +50,37 @@ class ImageRecord(BaseModel):
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
+class ImageRecordChanges(BaseModel, extra=Extra.forbid):
+ """A set of changes to apply to an image record.
+
+ Only limited changes are valid:
+ - `image_category`: change the category of an image
+ - `session_id`: change the session associated with an image
+ - `is_intermediate`: change the image's `is_intermediate` flag
+ """
+
+ image_category: Optional[ImageCategory] = Field(
+ description="The image's new category."
+ )
+ """The image's new category."""
+ session_id: Optional[StrictStr] = Field(
+ default=None,
+ description="The image's new session ID.",
+ )
+ """The image's new session ID."""
+ is_intermediate: Optional[StrictBool] = Field(
+ default=None, description="The image's new `is_intermediate` flag."
+ )
+ """The image's new `is_intermediate` flag."""
+
+
class ImageUrlsDTO(BaseModel):
"""The URLs for an image and its thumbnail."""
image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
- image_type: ImageType = Field(description="The type of the image.")
- """The type of the image."""
+ image_origin: ResourceOrigin = Field(description="The type of the image.")
+ """The origin of the image."""
image_url: str = Field(description="The URL of the image.")
"""The URL of the image."""
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
@@ -84,7 +110,9 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
# Retrieve all the values, setting "reasonable" defaults if they are not present.
image_name = image_dict.get("image_name", "unknown")
- image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
+ image_origin = ResourceOrigin(
+ image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
+ )
image_category = ImageCategory(
image_dict.get("image_category", ImageCategory.GENERAL.value)
)
@@ -95,6 +123,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
created_at = image_dict.get("created_at", get_iso_timestamp())
updated_at = image_dict.get("updated_at", get_iso_timestamp())
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
+ is_intermediate = image_dict.get("is_intermediate", False)
raw_metadata = image_dict.get("metadata")
@@ -105,7 +134,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
return ImageRecord(
image_name=image_name,
- image_type=image_type,
+ image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
@@ -115,4 +144,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
created_at=created_at,
updated_at=updated_at,
deleted_at=deleted_at,
+ is_intermediate=is_intermediate,
)
diff --git a/invokeai/app/services/resource_name.py b/invokeai/app/services/resource_name.py
new file mode 100644
index 0000000000..dd5a76cfc0
--- /dev/null
+++ b/invokeai/app/services/resource_name.py
@@ -0,0 +1,30 @@
+from abc import ABC, abstractmethod
+from enum import Enum, EnumMeta
+import uuid
+
+
+class ResourceType(str, Enum, metaclass=EnumMeta):
+ """Enum for resource types."""
+
+ IMAGE = "image"
+ LATENT = "latent"
+
+
+class NameServiceBase(ABC):
+ """Low-level service responsible for naming resources (images, latents, etc)."""
+
+ # TODO: Add customizable naming schemes
+ @abstractmethod
+ def create_image_name(self) -> str:
+ """Creates a name for an image."""
+ pass
+
+
+class SimpleNameService(NameServiceBase):
+ """Creates image names from UUIDs."""
+
+ # TODO: Add customizable naming schemes
+ def create_image_name(self) -> str:
+ uuid_str = str(uuid.uuid4())
+ filename = f"{uuid_str}.png"
+ return filename
diff --git a/invokeai/app/services/urls.py b/invokeai/app/services/urls.py
index 2716da60ad..4c8354c899 100644
--- a/invokeai/app/services/urls.py
+++ b/invokeai/app/services/urls.py
@@ -1,7 +1,7 @@
import os
from abc import ABC, abstractmethod
-from invokeai.app.models.image import ImageType
+from invokeai.app.models.image import ResourceOrigin
from invokeai.app.util.thumbnails import get_thumbnail_name
@@ -10,7 +10,7 @@ class UrlServiceBase(ABC):
@abstractmethod
def get_image_url(
- self, image_type: ImageType, image_name: str, thumbnail: bool = False
+ self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
"""Gets the URL for an image or thumbnail."""
pass
@@ -21,14 +21,14 @@ class LocalUrlService(UrlServiceBase):
self._base_url = base_url
def get_image_url(
- self, image_type: ImageType, image_name: str, thumbnail: bool = False
+ self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
image_basename = os.path.basename(image_name)
# These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail:
return (
- f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail"
+ f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
)
- return f"{self._base_url}/images/{image_type.value}/{image_basename}"
+ return f"{self._base_url}/images/{image_origin.value}/{image_basename}"
diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py
index 963e770406..b4b9a25909 100644
--- a/invokeai/app/util/step_callback.py
+++ b/invokeai/app/util/step_callback.py
@@ -1,5 +1,5 @@
-from invokeai.app.api.models.images import ProgressImage
from invokeai.app.models.exceptions import CanceledException
+from invokeai.app.models.image import ProgressImage
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator
diff --git a/invokeai/backend/generator/base.py b/invokeai/backend/generator/base.py
index bf976e7290..fb293ab5b2 100644
--- a/invokeai/backend/generator/base.py
+++ b/invokeai/backend/generator/base.py
@@ -75,9 +75,11 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def __init__(self,
model_info: dict,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
+ **kwargs,
):
self.model_info=model_info
self.params=params
+ self.kwargs = kwargs
def generate(self,
prompt: str='',
@@ -120,7 +122,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
)
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
gen_class = self._generator_class()
- generator = gen_class(model, self.params.precision)
+ generator = gen_class(model, self.params.precision, **self.kwargs)
if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'),
@@ -275,7 +277,7 @@ class Generator:
precision: str
model: DiffusionPipeline
- def __init__(self, model: DiffusionPipeline, precision: str):
+ def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
self.model = model
self.precision = precision
self.seed = None
diff --git a/invokeai/backend/generator/txt2img.py b/invokeai/backend/generator/txt2img.py
index e5a96212f0..189fe65710 100644
--- a/invokeai/backend/generator/txt2img.py
+++ b/invokeai/backend/generator/txt2img.py
@@ -4,6 +4,10 @@ invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
import PIL.Image
import torch
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
+
from ..stable_diffusion import (
ConditioningData,
PostprocessingSettings,
@@ -13,8 +17,13 @@ from .base import Generator
class Txt2Img(Generator):
- def __init__(self, model, precision):
- super().__init__(model, precision)
+ def __init__(self, model, precision,
+ control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None,
+ **kwargs):
+ self.control_model = control_model
+ if isinstance(self.control_model, list):
+ self.control_model = MultiControlNetModel(self.control_model)
+ super().__init__(model, precision, **kwargs)
@torch.no_grad()
def get_make_image(
@@ -42,9 +51,12 @@ class Txt2Img(Generator):
kwargs are 'width' and 'height'
"""
self.perlin = perlin
+ control_image = kwargs.get("control_image", None)
+ do_classifier_free_guidance = cfg_scale > 1.0
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
+ pipeline.control_model = self.control_model
pipeline.scheduler = sampler
uc, c, extra_conditioning_info = conditioning
@@ -61,6 +73,37 @@ class Txt2Img(Generator):
),
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
+ # FIXME: still need to test with different widths, heights, devices, dtypes
+ # and add in batch_size, num_images_per_prompt?
+ if control_image is not None:
+ if isinstance(self.control_model, ControlNetModel):
+ control_image = pipeline.prepare_control_image(
+ image=control_image,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ width=width,
+ height=height,
+ # batch_size=batch_size * num_images_per_prompt,
+ # num_images_per_prompt=num_images_per_prompt,
+ device=self.control_model.device,
+ dtype=self.control_model.dtype,
+ )
+ elif isinstance(self.control_model, MultiControlNetModel):
+ images = []
+ for image_ in control_image:
+ image_ = pipeline.prepare_control_image(
+ image=image_,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ width=width,
+ height=height,
+ # batch_size=batch_size * num_images_per_prompt,
+ # num_images_per_prompt=num_images_per_prompt,
+ device=self.control_model.device,
+ dtype=self.control_model.dtype,
+ )
+ images.append(image_)
+ control_image = images
+ kwargs["control_image"] = control_image
+
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
pipeline_output = pipeline.image_from_embeddings(
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
@@ -68,6 +111,7 @@ class Txt2Img(Generator):
num_inference_steps=steps,
conditioning_data=conditioning_data,
callback=step_callback,
+ **kwargs,
)
if (
diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
index 4ca2a5cb30..c07b41e6c1 100644
--- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py
+++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
@@ -2,23 +2,29 @@ from __future__ import annotations
import dataclasses
import inspect
+import math
import secrets
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
+from pydantic import BaseModel, Field
import einops
import PIL.Image
+import numpy as np
from accelerate.utils import set_seed
import psutil
import torch
import torchvision.transforms as T
from compel import EmbeddingsProvider
from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline,
)
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipeline,
)
@@ -27,6 +33,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
)
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
+from diffusers.utils import PIL_INTERPOLATION
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.outputs import BaseOutput
from torchvision.transforms.functional import resize as tv_resize
@@ -68,10 +75,10 @@ class AddsMaskLatents:
initial_image_latents: torch.Tensor
def __call__(
- self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor
+ self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor, **kwargs,
) -> torch.Tensor:
model_input = self.add_mask_channels(latents)
- return self.forward(model_input, t, text_embeddings)
+ return self.forward(model_input, t, text_embeddings, **kwargs)
def add_mask_channels(self, latents):
batch_size = latents.size(0)
@@ -207,6 +214,13 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
raise AssertionError("why was that an empty generator?")
return result
+@dataclass
+class ControlNetData:
+ model: ControlNetModel = Field(default=None)
+ image_tensor: torch.Tensor= Field(default=None)
+ weight: float = Field(default=1.0)
+ begin_step_percent: float = Field(default=0.0)
+ end_step_percent: float = Field(default=1.0)
@dataclass(frozen=True)
class ConditioningData:
@@ -302,6 +316,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
precision: str = "float32",
+ control_model: ControlNetModel = None,
):
super().__init__(
vae,
@@ -322,6 +337,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
+ # FIXME: can't currently register control module
+ # control_model=control_model,
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(
self.unet, self._unet_forward, is_running_diffusers=True
@@ -341,6 +358,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self._model_group = FullyLoadedModelGroup(self.unet.device)
self._model_group.install(*self._submodels)
+ self.control_model = control_model
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
"""
@@ -463,6 +481,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
+ **kwargs,
) -> InvokeAIStableDiffusionPipelineOutput:
r"""
Function invoked when calling the pipeline for generation.
@@ -483,6 +502,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise=noise,
run_id=run_id,
callback=callback,
+ **kwargs,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
@@ -507,6 +527,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None,
run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None,
+ control_data: List[ControlNetData] = None,
+ **kwargs,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu')
@@ -527,6 +549,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance=additional_guidance,
run_id=run_id,
callback=callback,
+ control_data=control_data,
+ **kwargs,
)
return result.latents, result.attention_map_saver
@@ -539,6 +563,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor,
run_id: str = None,
additional_guidance: List[Callable] = None,
+ control_data: List[ControlNetData] = None,
+ **kwargs,
):
self._adjust_memory_efficient_attention(latents)
if run_id is None:
@@ -568,7 +594,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = self.scheduler.add_noise(latents, noise, batched_t)
attention_map_saver: Optional[AttentionMapSaver] = None
-
+ # print("timesteps:", timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
batched_t.fill_(t)
step_output = self.step(
@@ -578,6 +604,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index=i,
total_step_count=len(timesteps),
additional_guidance=additional_guidance,
+ control_data=control_data,
+ **kwargs,
)
latents = step_output.prev_sample
@@ -618,10 +646,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index: int,
total_step_count: int,
additional_guidance: List[Callable] = None,
+ control_data: List[ControlNetData] = None,
+ **kwargs,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
-
if additional_guidance is None:
additional_guidance = []
@@ -629,6 +658,48 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
+ # default is no controlnet, so set controlnet processing output to None
+ down_block_res_samples, mid_block_res_sample = None, None
+
+ if control_data is not None:
+ if conditioning_data.guidance_scale > 1.0:
+ # expand the latents input to control model if doing classifier free guidance
+ # (which I think for now is always true, there is conditional elsewhere that stops execution if
+ # classifier_free_guidance is <= 1.0 ?)
+ latent_control_input = torch.cat([latent_model_input] * 2)
+ else:
+ latent_control_input = latent_model_input
+ # control_data should be type List[ControlNetData]
+ # this loop covers both ControlNet (one ControlNetData in list)
+ # and MultiControlNet (multiple ControlNetData in list)
+ for i, control_datum in enumerate(control_data):
+ # print("controlnet", i, "==>", type(control_datum))
+ first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
+ last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
+ # only apply controlnet if current step is within the controlnet's begin/end step range
+ if step_index >= first_control_step and step_index <= last_control_step:
+ # print("running controlnet", i, "for step", step_index)
+ down_samples, mid_sample = control_datum.model(
+ sample=latent_control_input,
+ timestep=timestep,
+ encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
+ conditioning_data.text_embeddings]),
+ controlnet_cond=control_datum.image_tensor,
+ conditioning_scale=control_datum.weight,
+ # cross_attention_kwargs,
+ guess_mode=False,
+ return_dict=False,
+ )
+ if down_block_res_samples is None and mid_block_res_sample is None:
+ down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
+ else:
+ # add controlnet outputs together if have multiple controlnets
+ down_block_res_samples = [
+ samples_prev + samples_curr
+ for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
+ ]
+ mid_block_res_sample += mid_sample
+
# predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step(
latent_model_input,
@@ -638,6 +709,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data.guidance_scale,
step_index=step_index,
total_step_count=total_step_count,
+ down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
+ mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
)
# compute the previous noisy sample x_t -> x_t-1
@@ -659,6 +732,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
t,
text_embeddings,
cross_attention_kwargs: Optional[dict[str, Any]] = None,
+ **kwargs,
):
"""predict the noise residual"""
if is_inpainting_model(self.unet) and latents.size(1) == 4:
@@ -678,7 +752,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# First three args should be positional, not keywords, so torch hooks can see them.
return self.unet(
- latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs
+ latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs,
+ **kwargs,
).sample
def img2img_from_embeddings(
@@ -728,7 +803,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor,
run_id=None,
callback=None,
- ) -> InvokeAIStableDiffusionPipelineOutput:
+ ) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
result_latents, result_attention_maps = self.latents_from_embeddings(
latents=initial_latents if strength < 1.0 else torch.zeros_like(
@@ -940,3 +1015,51 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
debug_image(
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
)
+
+ # Copied from diffusers pipeline_stable_diffusion_controlnet.py
+ # Returns torch.Tensor of shape (batch_size, 3, height, width)
+ @staticmethod
+ def prepare_control_image(
+ image,
+ # FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
+ # latents,
+ width=512, # should be 8 * latent.shape[3]
+ height=512, # should be 8 * latent height[2]
+ batch_size=1,
+ num_images_per_prompt=1,
+ device="cuda",
+ dtype=torch.float16,
+ do_classifier_free_guidance=True,
+ ):
+
+ if not isinstance(image, torch.Tensor):
+ if isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ images = []
+ for image_ in image:
+ image_ = image_.convert("RGB")
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
+ image_ = np.array(image_)
+ image_ = image_[None, :]
+ images.append(image_)
+ image = images
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+
+ image_batch_size = image.shape[0]
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+ image = image.repeat_interleave(repeat_by, dim=0)
+ image = image.to(device=device, dtype=dtype)
+ if do_classifier_free_guidance:
+ image = torch.cat([image] * 2)
+ return image
diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py
index 4131837b41..d05565c506 100644
--- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py
+++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py
@@ -181,6 +181,7 @@ class InvokeAIDiffuserComponent:
unconditional_guidance_scale: float,
step_index: Optional[int] = None,
total_step_count: Optional[int] = None,
+ **kwargs,
):
"""
:param x: current latents
@@ -209,7 +210,7 @@ class InvokeAIDiffuserComponent:
if wants_hybrid_conditioning:
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
- x, sigma, unconditioning, conditioning
+ x, sigma, unconditioning, conditioning, **kwargs,
)
elif wants_cross_attention_control:
(
@@ -221,13 +222,14 @@ class InvokeAIDiffuserComponent:
unconditioning,
conditioning,
cross_attention_control_types_to_do,
+ **kwargs,
)
elif self.sequential_guidance:
(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
- x, sigma, unconditioning, conditioning
+ x, sigma, unconditioning, conditioning, **kwargs,
)
else:
@@ -235,7 +237,7 @@ class InvokeAIDiffuserComponent:
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
- x, sigma, unconditioning, conditioning
+ x, sigma, unconditioning, conditioning, **kwargs,
)
combined_next_x = self._combine(
@@ -282,13 +284,13 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class.
- def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
+ def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
# fast batched path
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
both_conditionings = torch.cat([unconditioning, conditioning])
both_results = self.model_forward_callback(
- x_twice, sigma_twice, both_conditionings
+ x_twice, sigma_twice, both_conditionings, **kwargs,
)
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
if conditioned_next_x.device.type == "mps":
@@ -302,16 +304,17 @@ class InvokeAIDiffuserComponent:
sigma,
unconditioning: torch.Tensor,
conditioning: torch.Tensor,
+ **kwargs,
):
# low-memory sequential path
- unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
- conditioned_next_x = self.model_forward_callback(x, sigma, conditioning)
+ unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
+ conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
if conditioned_next_x.device.type == "mps":
# prevent a result filled with zeros. seems to be a torch bug.
conditioned_next_x = conditioned_next_x.clone()
return unconditioned_next_x, conditioned_next_x
- def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
+ def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
assert isinstance(conditioning, dict)
assert isinstance(unconditioning, dict)
x_twice = torch.cat([x] * 2)
@@ -326,7 +329,7 @@ class InvokeAIDiffuserComponent:
else:
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
- x_twice, sigma_twice, both_conditionings
+ x_twice, sigma_twice, both_conditionings, **kwargs,
).chunk(2)
return unconditioned_next_x, conditioned_next_x
@@ -337,6 +340,7 @@ class InvokeAIDiffuserComponent:
unconditioning,
conditioning,
cross_attention_control_types_to_do,
+ **kwargs,
):
if self.is_running_diffusers:
return self._apply_cross_attention_controlled_conditioning__diffusers(
@@ -345,6 +349,7 @@ class InvokeAIDiffuserComponent:
unconditioning,
conditioning,
cross_attention_control_types_to_do,
+ **kwargs,
)
else:
return self._apply_cross_attention_controlled_conditioning__compvis(
@@ -353,6 +358,7 @@ class InvokeAIDiffuserComponent:
unconditioning,
conditioning,
cross_attention_control_types_to_do,
+ **kwargs,
)
def _apply_cross_attention_controlled_conditioning__diffusers(
@@ -362,6 +368,7 @@ class InvokeAIDiffuserComponent:
unconditioning,
conditioning,
cross_attention_control_types_to_do,
+ **kwargs,
):
context: Context = self.cross_attention_control_context
@@ -377,6 +384,7 @@ class InvokeAIDiffuserComponent:
sigma,
unconditioning,
{"swap_cross_attn_context": cross_attn_processor_context},
+ **kwargs,
)
# do requested cross attention types for conditioning (positive prompt)
@@ -388,6 +396,7 @@ class InvokeAIDiffuserComponent:
sigma,
conditioning,
{"swap_cross_attn_context": cross_attn_processor_context},
+ **kwargs,
)
return unconditioned_next_x, conditioned_next_x
@@ -398,6 +407,7 @@ class InvokeAIDiffuserComponent:
unconditioning,
conditioning,
cross_attention_control_types_to_do,
+ **kwargs,
):
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS)
@@ -411,13 +421,13 @@ class InvokeAIDiffuserComponent:
context: Context = self.cross_attention_control_context
try:
- unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
+ unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
# process x using the original prompt, saving the attention maps
# print("saving attention maps for", cross_attention_control_types_to_do)
for ca_type in cross_attention_control_types_to_do:
context.request_save_attention_maps(ca_type)
- _ = self.model_forward_callback(x, sigma, conditioning)
+ _ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
context.clear_requests(cleanup=False)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
@@ -428,7 +438,7 @@ class InvokeAIDiffuserComponent:
self.conditioning.cross_attention_control_args.edited_conditioning
)
conditioned_next_x = self.model_forward_callback(
- x, sigma, edited_conditioning
+ x, sigma, edited_conditioning, **kwargs,
)
context.clear_requests(cleanup=True)
diff --git a/invokeai/frontend/web/dist/index.html b/invokeai/frontend/web/dist/index.html
index 63618e60be..8a982a7268 100644
--- a/invokeai/frontend/web/dist/index.html
+++ b/invokeai/frontend/web/dist/index.html
@@ -12,7 +12,7 @@
margin: 0;
}
-
+
diff --git a/invokeai/frontend/web/dist/locales/en.json b/invokeai/frontend/web/dist/locales/en.json
index 94dff3934a..bf14dd5510 100644
--- a/invokeai/frontend/web/dist/locales/en.json
+++ b/invokeai/frontend/web/dist/locales/en.json
@@ -122,7 +122,9 @@
"noImagesInGallery": "No Images In Gallery",
"deleteImage": "Delete Image",
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
- "deleteImagePermanent": "Deleted images cannot be restored."
+ "deleteImagePermanent": "Deleted images cannot be restored.",
+ "images": "Images",
+ "assets": "Assets"
},
"hotkeys": {
"keyboardShortcuts": "Keyboard Shortcuts",
@@ -452,6 +454,8 @@
"height": "Height",
"scheduler": "Scheduler",
"seed": "Seed",
+ "boundingBoxWidth": "Bounding Box Width",
+ "boundingBoxHeight": "Bounding Box Height",
"imageToImage": "Image to Image",
"randomizeSeed": "Randomize Seed",
"shuffle": "Shuffle Seed",
@@ -524,7 +528,7 @@
},
"settings": {
"models": "Models",
- "displayInProgress": "Display In-Progress Images",
+ "displayInProgress": "Display Progress Images",
"saveSteps": "Save images every n steps",
"confirmOnDelete": "Confirm On Delete",
"displayHelpIcons": "Display Help Icons",
@@ -564,6 +568,8 @@
"canvasMerged": "Canvas Merged",
"sentToImageToImage": "Sent To Image To Image",
"sentToUnifiedCanvas": "Sent to Unified Canvas",
+ "parameterSet": "Parameter set",
+ "parameterNotSet": "Parameter not set",
"parametersSet": "Parameters Set",
"parametersNotSet": "Parameters Not Set",
"parametersNotSetDesc": "No metadata found for this image.",
diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json
index 13b8d78bf7..dd1c87effb 100644
--- a/invokeai/frontend/web/package.json
+++ b/invokeai/frontend/web/package.json
@@ -101,7 +101,8 @@
"serialize-error": "^11.0.0",
"socket.io-client": "^4.6.0",
"use-image": "^1.1.0",
- "uuid": "^9.0.0"
+ "uuid": "^9.0.0",
+ "zod": "^3.21.4"
},
"peerDependencies": {
"@chakra-ui/cli": "^2.4.0",
diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index 94dff3934a..bf14dd5510 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -122,7 +122,9 @@
"noImagesInGallery": "No Images In Gallery",
"deleteImage": "Delete Image",
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
- "deleteImagePermanent": "Deleted images cannot be restored."
+ "deleteImagePermanent": "Deleted images cannot be restored.",
+ "images": "Images",
+ "assets": "Assets"
},
"hotkeys": {
"keyboardShortcuts": "Keyboard Shortcuts",
@@ -452,6 +454,8 @@
"height": "Height",
"scheduler": "Scheduler",
"seed": "Seed",
+ "boundingBoxWidth": "Bounding Box Width",
+ "boundingBoxHeight": "Bounding Box Height",
"imageToImage": "Image to Image",
"randomizeSeed": "Randomize Seed",
"shuffle": "Shuffle Seed",
@@ -524,7 +528,7 @@
},
"settings": {
"models": "Models",
- "displayInProgress": "Display In-Progress Images",
+ "displayInProgress": "Display Progress Images",
"saveSteps": "Save images every n steps",
"confirmOnDelete": "Confirm On Delete",
"displayHelpIcons": "Display Help Icons",
@@ -564,6 +568,8 @@
"canvasMerged": "Canvas Merged",
"sentToImageToImage": "Sent To Image To Image",
"sentToUnifiedCanvas": "Sent to Unified Canvas",
+ "parameterSet": "Parameter set",
+ "parameterNotSet": "Parameter not set",
"parametersSet": "Parameters Set",
"parametersNotSet": "Parameters Not Set",
"parametersNotSetDesc": "No metadata found for this image.",
diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts
index d312d725ba..6700a732b3 100644
--- a/invokeai/frontend/web/src/app/constants.ts
+++ b/invokeai/frontend/web/src/app/constants.ts
@@ -21,25 +21,11 @@ export const SCHEDULERS = [
export type Scheduler = (typeof SCHEDULERS)[number];
-export const isScheduler = (x: string): x is Scheduler =>
- SCHEDULERS.includes(x as Scheduler);
-
-// Valid image widths
-export const WIDTHS: Array = Array.from(Array(64)).map(
- (_x, i) => (i + 1) * 64
-);
-
-// Valid image heights
-export const HEIGHTS: Array = Array.from(Array(64)).map(
- (_x, i) => (i + 1) * 64
-);
-
// Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
{ key: '2x', value: 2 },
{ key: '4x', value: 4 },
];
-
export const NUMPY_RAND_MIN = 0;
export const NUMPY_RAND_MAX = 2147483647;
diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts
index 52995e0da3..9fb4ceae32 100644
--- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts
+++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts
@@ -1,7 +1,5 @@
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
-import { resultsPersistDenylist } from 'features/gallery/store/resultsPersistDenylist';
-import { uploadsPersistDenylist } from 'features/gallery/store/uploadsPersistDenylist';
import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
@@ -22,11 +20,9 @@ const serializationDenylist: {
models: modelsPersistDenylist,
nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist,
- results: resultsPersistDenylist,
system: systemPersistDenylist,
// config: configPersistDenyList,
ui: uiPersistDenylist,
- uploads: uploadsPersistDenylist,
// hotkeys: hotkeysPersistDenylist,
};
diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts
index 155a7786b3..c6ae4946f2 100644
--- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts
+++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts
@@ -1,7 +1,6 @@
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
-import { initialResultsState } from 'features/gallery/store/resultsSlice';
-import { initialUploadsState } from 'features/gallery/store/uploadsSlice';
+import { initialImagesState } from 'features/gallery/store/imagesSlice';
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
@@ -24,12 +23,11 @@ const initialStates: {
models: initialModelsState,
nodes: initialNodesState,
postprocessing: initialPostprocessingState,
- results: initialResultsState,
system: initialSystemState,
config: initialConfigState,
ui: initialUIState,
- uploads: initialUploadsState,
hotkeys: initialHotkeysState,
+ images: initialImagesState,
};
export const unserialize: UnserializeFunction = (data, key) => {
diff --git a/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts b/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts
index 743537d7ea..eb54868735 100644
--- a/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts
@@ -7,5 +7,6 @@ export const actionsDenylist = [
'canvas/setBoundingBoxDimensions',
'canvas/setIsDrawing',
'canvas/addPointToCurrentLine',
- 'socket/generatorProgress',
+ 'socket/socketGeneratorProgress',
+ 'socket/appSocketGeneratorProgress',
];
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts
index f23e83a191..ba16e56371 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts
@@ -8,9 +8,16 @@ import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit';
import type { RootState, AppDispatch } from '../../store';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
-import { addImageResultReceivedListener } from './listeners/invocationComplete';
-import { addImageUploadedListener } from './listeners/imageUploaded';
-import { addRequestedImageDeletionListener } from './listeners/imageDeleted';
+import {
+ addImageUploadedFulfilledListener,
+ addImageUploadedRejectedListener,
+} from './listeners/imageUploaded';
+import {
+ addImageDeletedFulfilledListener,
+ addImageDeletedPendingListener,
+ addImageDeletedRejectedListener,
+ addRequestedImageDeletionListener,
+} from './listeners/imageDeleted';
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
@@ -19,6 +26,50 @@ import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGaller
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
import { addCanvasMergedListener } from './listeners/canvasMerged';
+import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
+import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete';
+import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete';
+import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
+import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
+import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
+import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
+import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
+import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
+import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke';
+import {
+ addImageMetadataReceivedFulfilledListener,
+ addImageMetadataReceivedRejectedListener,
+} from './listeners/imageMetadataReceived';
+import {
+ addImageUrlsReceivedFulfilledListener,
+ addImageUrlsReceivedRejectedListener,
+} from './listeners/imageUrlsReceived';
+import {
+ addSessionCreatedFulfilledListener,
+ addSessionCreatedPendingListener,
+ addSessionCreatedRejectedListener,
+} from './listeners/sessionCreated';
+import {
+ addSessionInvokedFulfilledListener,
+ addSessionInvokedPendingListener,
+ addSessionInvokedRejectedListener,
+} from './listeners/sessionInvoked';
+import {
+ addSessionCanceledFulfilledListener,
+ addSessionCanceledPendingListener,
+ addSessionCanceledRejectedListener,
+} from './listeners/sessionCanceled';
+import {
+ addImageUpdatedFulfilledListener,
+ addImageUpdatedRejectedListener,
+} from './listeners/imageUpdated';
+import {
+ addReceivedPageOfImagesFulfilledListener,
+ addReceivedPageOfImagesRejectedListener,
+} from './listeners/receivedPageOfImages';
+import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
+import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
+import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged';
export const listenerMiddleware = createListenerMiddleware();
@@ -38,17 +89,87 @@ export type AppListenerEffect = ListenerEffect<
AppDispatch
>;
-addImageUploadedListener();
-addInitialImageSelectedListener();
-addImageResultReceivedListener();
-addRequestedImageDeletionListener();
+// Image uploaded
+addImageUploadedFulfilledListener();
+addImageUploadedRejectedListener();
+// Image updated
+addImageUpdatedFulfilledListener();
+addImageUpdatedRejectedListener();
+
+// Image selected
+addInitialImageSelectedListener();
+
+// Image deleted
+addRequestedImageDeletionListener();
+addImageDeletedPendingListener();
+addImageDeletedFulfilledListener();
+addImageDeletedRejectedListener();
+
+// Image metadata
+addImageMetadataReceivedFulfilledListener();
+addImageMetadataReceivedRejectedListener();
+
+// Image URLs
+addImageUrlsReceivedFulfilledListener();
+addImageUrlsReceivedRejectedListener();
+
+// User Invoked
addUserInvokedCanvasListener();
addUserInvokedNodesListener();
addUserInvokedTextToImageListener();
addUserInvokedImageToImageListener();
+addSessionReadyToInvokeListener();
+// Canvas actions
addCanvasSavedToGalleryListener();
addCanvasDownloadedAsImageListener();
addCanvasCopiedToClipboardListener();
addCanvasMergedListener();
+addStagingAreaImageSavedListener();
+addCommitStagingAreaImageListener();
+
+/**
+ * Socket.IO Events - these handle SIO events directly and pass on internal application actions.
+ * We don't handle SIO events in slices via `extraReducers` because some of these events shouldn't
+ * actually be handled at all.
+ *
+ * For example, we don't want to respond to progress events for canceled sessions. To avoid
+ * duplicating the logic to determine if an event should be responded to, we handle all of that
+ * "is this session canceled?" logic in these listeners.
+ *
+ * The `socketGeneratorProgress` listener will then only dispatch the `appSocketGeneratorProgress`
+ * action if it should be handled by the rest of the application. It is this `appSocketGeneratorProgress`
+ * action that is handled by reducers in slices.
+ */
+addGeneratorProgressListener();
+addGraphExecutionStateCompleteListener();
+addInvocationCompleteListener();
+addInvocationErrorListener();
+addInvocationStartedListener();
+addSocketConnectedListener();
+addSocketDisconnectedListener();
+addSocketSubscribedListener();
+addSocketUnsubscribedListener();
+
+// Session Created
+addSessionCreatedPendingListener();
+addSessionCreatedFulfilledListener();
+addSessionCreatedRejectedListener();
+
+// Session Invoked
+addSessionInvokedPendingListener();
+addSessionInvokedFulfilledListener();
+addSessionInvokedRejectedListener();
+
+// Session Canceled
+addSessionCanceledPendingListener();
+addSessionCanceledFulfilledListener();
+addSessionCanceledRejectedListener();
+
+// Fetching images
+addReceivedPageOfImagesFulfilledListener();
+addReceivedPageOfImagesRejectedListener();
+
+// Gallery
+addImageCategoriesChangedListener();
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts
new file mode 100644
index 0000000000..90f71879a1
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts
@@ -0,0 +1,42 @@
+import { startAppListening } from '..';
+import { log } from 'app/logging/useLogger';
+import { commitStagingAreaImage } from 'features/canvas/store/canvasSlice';
+import { sessionCanceled } from 'services/thunks/session';
+
+const moduleLog = log.child({ namespace: 'canvas' });
+
+export const addCommitStagingAreaImageListener = () => {
+ startAppListening({
+ actionCreator: commitStagingAreaImage,
+ effect: async (action, { dispatch, getState }) => {
+ const state = getState();
+ const { sessionId, isProcessing } = state.system;
+ const canvasSessionId = action.payload;
+
+ if (!isProcessing) {
+ // Only need to cancel if we are processing
+ return;
+ }
+
+ if (!canvasSessionId) {
+ moduleLog.debug('No canvas session, skipping cancel');
+ return;
+ }
+
+ if (canvasSessionId !== sessionId) {
+ moduleLog.debug(
+ {
+ data: {
+ canvasSessionId,
+ sessionId,
+ },
+ },
+ 'Canvas session does not match global session, skipping cancel'
+ );
+ return;
+ }
+
+ dispatch(sessionCanceled({ sessionId }));
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts
index 1e2d99541c..80865f3126 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts
@@ -52,10 +52,11 @@ export const addCanvasMergedListener = () => {
dispatch(
imageUploaded({
- imageType: 'intermediates',
formData: {
file: new File([blob], filename, { type: 'image/png' }),
},
+ imageCategory: 'general',
+ isIntermediate: true,
})
);
@@ -65,7 +66,7 @@ export const addCanvasMergedListener = () => {
action.meta.arg.formData.file.name === filename
);
- const mergedCanvasImage = payload.response;
+ const mergedCanvasImage = payload;
dispatch(
setMergedCanvas({
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts
index d8237d1d5c..b89620775b 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts
@@ -4,16 +4,18 @@ import { log } from 'app/logging/useLogger';
import { imageUploaded } from 'services/thunks/image';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
+import { v4 as uuidv4 } from 'uuid';
+import { imageUpserted } from 'features/gallery/store/imagesSlice';
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
export const addCanvasSavedToGalleryListener = () => {
startAppListening({
actionCreator: canvasSavedToGallery,
- effect: async (action, { dispatch, getState }) => {
+ effect: async (action, { dispatch, getState, take }) => {
const state = getState();
- const blob = await getBaseLayerBlob(state);
+ const blob = await getBaseLayerBlob(state, true);
if (!blob) {
moduleLog.error('Problem getting base layer blob');
@@ -27,14 +29,25 @@ export const addCanvasSavedToGalleryListener = () => {
return;
}
+ const filename = `mergedCanvas_${uuidv4()}.png`;
+
dispatch(
imageUploaded({
- imageType: 'results',
formData: {
- file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }),
+ file: new File([blob], filename, { type: 'image/png' }),
},
+ imageCategory: 'general',
+ isIntermediate: false,
})
);
+
+ const [{ payload: uploadedImageDTO }] = await take(
+ (action): action is ReturnType =>
+ imageUploaded.fulfilled.match(action) &&
+ action.meta.arg.formData.file.name === filename
+ );
+
+ dispatch(imageUpserted(uploadedImageDTO));
},
});
};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts
new file mode 100644
index 0000000000..85d56d3913
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts
@@ -0,0 +1,24 @@
+import { log } from 'app/logging/useLogger';
+import { startAppListening } from '..';
+import { receivedPageOfImages } from 'services/thunks/image';
+import {
+ imageCategoriesChanged,
+ selectFilteredImagesAsArray,
+} from 'features/gallery/store/imagesSlice';
+
+const moduleLog = log.child({ namespace: 'gallery' });
+
+export const addImageCategoriesChangedListener = () => {
+ startAppListening({
+ actionCreator: imageCategoriesChanged,
+ effect: (action, { getState, dispatch }) => {
+ const filteredImagesCount = selectFilteredImagesAsArray(
+ getState()
+ ).length;
+
+ if (!filteredImagesCount) {
+ dispatch(receivedPageOfImages());
+ }
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts
index 42a62b3d80..bf7ca4020c 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts
@@ -4,9 +4,18 @@ import { imageDeleted } from 'services/thunks/image';
import { log } from 'app/logging/useLogger';
import { clamp } from 'lodash-es';
import { imageSelected } from 'features/gallery/store/gallerySlice';
+import {
+ imageRemoved,
+ imagesAdapter,
+ selectImagesEntities,
+ selectImagesIds,
+} from 'features/gallery/store/imagesSlice';
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
+/**
+ * Called when the user requests an image deletion
+ */
export const addRequestedImageDeletionListener = () => {
startAppListening({
actionCreator: requestedImageDeletion,
@@ -17,24 +26,20 @@ export const addRequestedImageDeletionListener = () => {
return;
}
- const { image_name, image_type } = image;
+ const { image_name, image_origin } = image;
- if (image_type !== 'uploads' && image_type !== 'results') {
- moduleLog.warn({ data: image }, `Invalid image type ${image_type}`);
- return;
- }
+ const state = getState();
+ const selectedImage = state.gallery.selectedImage;
- const selectedImageName = getState().gallery.selectedImage?.image_name;
+ if (selectedImage && selectedImage.image_name === image_name) {
+ const ids = selectImagesIds(state);
+ const entities = selectImagesEntities(state);
- if (selectedImageName === image_name) {
- const allIds = getState()[image_type].ids;
- const allEntities = getState()[image_type].entities;
-
- const deletedImageIndex = allIds.findIndex(
+ const deletedImageIndex = ids.findIndex(
(result) => result.toString() === image_name
);
- const filteredIds = allIds.filter((id) => id.toString() !== image_name);
+ const filteredIds = ids.filter((id) => id.toString() !== image_name);
const newSelectedImageIndex = clamp(
deletedImageIndex,
@@ -44,7 +49,7 @@ export const addRequestedImageDeletionListener = () => {
const newSelectedImageId = filteredIds[newSelectedImageIndex];
- const newSelectedImage = allEntities[newSelectedImageId];
+ const newSelectedImage = entities[newSelectedImageId];
if (newSelectedImageId) {
dispatch(imageSelected(newSelectedImage));
@@ -53,7 +58,52 @@ export const addRequestedImageDeletionListener = () => {
}
}
- dispatch(imageDeleted({ imageName: image_name, imageType: image_type }));
+ dispatch(imageRemoved(image_name));
+
+ dispatch(
+ imageDeleted({ imageName: image_name, imageOrigin: image_origin })
+ );
+ },
+ });
+};
+
+/**
+ * Called when the actual delete request is sent to the server
+ */
+export const addImageDeletedPendingListener = () => {
+ startAppListening({
+ actionCreator: imageDeleted.pending,
+ effect: (action, { dispatch, getState }) => {
+ const { imageName, imageOrigin } = action.meta.arg;
+ // Preemptively remove the image from the gallery
+ imagesAdapter.removeOne(getState().images, imageName);
+ },
+ });
+};
+
+/**
+ * Called on successful delete
+ */
+export const addImageDeletedFulfilledListener = () => {
+ startAppListening({
+ actionCreator: imageDeleted.fulfilled,
+ effect: (action, { dispatch, getState }) => {
+ moduleLog.debug({ data: { image: action.meta.arg } }, 'Image deleted');
+ },
+ });
+};
+
+/**
+ * Called on failed delete
+ */
+export const addImageDeletedRejectedListener = () => {
+ startAppListening({
+ actionCreator: imageDeleted.rejected,
+ effect: (action, { dispatch, getState }) => {
+ moduleLog.debug(
+ { data: { image: action.meta.arg } },
+ 'Unable to delete image'
+ );
},
});
};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts
new file mode 100644
index 0000000000..63aeecb95e
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts
@@ -0,0 +1,33 @@
+import { log } from 'app/logging/useLogger';
+import { startAppListening } from '..';
+import { imageMetadataReceived } from 'services/thunks/image';
+import { imageUpserted } from 'features/gallery/store/imagesSlice';
+
+const moduleLog = log.child({ namespace: 'image' });
+
+export const addImageMetadataReceivedFulfilledListener = () => {
+ startAppListening({
+ actionCreator: imageMetadataReceived.fulfilled,
+ effect: (action, { getState, dispatch }) => {
+ const image = action.payload;
+ if (image.is_intermediate) {
+ // No further actions needed for intermediate images
+ return;
+ }
+ moduleLog.debug({ data: { image } }, 'Image metadata received');
+ dispatch(imageUpserted(image));
+ },
+ });
+};
+
+export const addImageMetadataReceivedRejectedListener = () => {
+ startAppListening({
+ actionCreator: imageMetadataReceived.rejected,
+ effect: (action, { getState, dispatch }) => {
+ moduleLog.debug(
+ { data: { image: action.meta.arg } },
+ 'Problem receiving image metadata'
+ );
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUpdated.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUpdated.ts
new file mode 100644
index 0000000000..6f8b46ec23
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUpdated.ts
@@ -0,0 +1,26 @@
+import { startAppListening } from '..';
+import { imageUpdated } from 'services/thunks/image';
+import { log } from 'app/logging/useLogger';
+
+const moduleLog = log.child({ namespace: 'image' });
+
+export const addImageUpdatedFulfilledListener = () => {
+ startAppListening({
+ actionCreator: imageUpdated.fulfilled,
+ effect: (action, { dispatch, getState }) => {
+ moduleLog.debug(
+ { oldImage: action.meta.arg, updatedImage: action.payload },
+ 'Image updated'
+ );
+ },
+ });
+};
+
+export const addImageUpdatedRejectedListener = () => {
+ startAppListening({
+ actionCreator: imageUpdated.rejected,
+ effect: (action, { dispatch }) => {
+ moduleLog.debug({ oldImage: action.meta.arg }, 'Image update failed');
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts
index 1d66166c12..6d84431f80 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts
@@ -1,44 +1,46 @@
import { startAppListening } from '..';
-import { uploadAdded } from 'features/gallery/store/uploadsSlice';
-import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imageUploaded } from 'services/thunks/image';
import { addToast } from 'features/system/store/systemSlice';
-import { initialImageSelected } from 'features/parameters/store/actions';
-import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
-import { resultAdded } from 'features/gallery/store/resultsSlice';
-import { isResultsImageDTO, isUploadsImageDTO } from 'services/types/guards';
+import { log } from 'app/logging/useLogger';
+import { imageUpserted } from 'features/gallery/store/imagesSlice';
-export const addImageUploadedListener = () => {
+const moduleLog = log.child({ namespace: 'image' });
+
+export const addImageUploadedFulfilledListener = () => {
startAppListening({
- predicate: (action): action is ReturnType =>
- imageUploaded.fulfilled.match(action) &&
- action.payload.response.image_type !== 'intermediates',
+ actionCreator: imageUploaded.fulfilled,
effect: (action, { dispatch, getState }) => {
- const { response: image } = action.payload;
+ const image = action.payload;
+
+ moduleLog.debug({ arg: '', image }, 'Image uploaded');
+
+ if (action.payload.is_intermediate) {
+ // No further actions needed for intermediate images
+ return;
+ }
const state = getState();
- if (isUploadsImageDTO(image)) {
- dispatch(uploadAdded(image));
-
- dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
-
- if (state.gallery.shouldAutoSwitchToNewImages) {
- dispatch(imageSelected(image));
- }
-
- if (action.meta.arg.activeTabName === 'img2img') {
- dispatch(initialImageSelected(image));
- }
-
- if (action.meta.arg.activeTabName === 'unifiedCanvas') {
- dispatch(setInitialCanvasImage(image));
- }
- }
-
- if (isResultsImageDTO(image)) {
- dispatch(resultAdded(image));
- }
+ dispatch(imageUpserted(image));
+ dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
+ },
+ });
+};
+
+export const addImageUploadedRejectedListener = () => {
+ startAppListening({
+ actionCreator: imageUploaded.rejected,
+ effect: (action, { dispatch }) => {
+ const { formData, ...rest } = action.meta.arg;
+ const sanitizedData = { arg: { ...rest, formData: { file: '' } } };
+ moduleLog.error({ data: sanitizedData }, 'Image upload failed');
+ dispatch(
+ addToast({
+ title: 'Image Upload Failed',
+ description: action.error.message,
+ status: 'error',
+ })
+ );
},
});
};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts
new file mode 100644
index 0000000000..fd0461f893
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts
@@ -0,0 +1,38 @@
+import { log } from 'app/logging/useLogger';
+import { startAppListening } from '..';
+import { imageUrlsReceived } from 'services/thunks/image';
+import { imagesAdapter } from 'features/gallery/store/imagesSlice';
+
+const moduleLog = log.child({ namespace: 'image' });
+
+export const addImageUrlsReceivedFulfilledListener = () => {
+ startAppListening({
+ actionCreator: imageUrlsReceived.fulfilled,
+ effect: (action, { getState, dispatch }) => {
+ const image = action.payload;
+ moduleLog.debug({ data: { image } }, 'Image URLs received');
+
+ const { image_name, image_url, thumbnail_url } = image;
+
+ imagesAdapter.updateOne(getState().images, {
+ id: image_name,
+ changes: {
+ image_url,
+ thumbnail_url,
+ },
+ });
+ },
+ });
+};
+
+export const addImageUrlsReceivedRejectedListener = () => {
+ startAppListening({
+ actionCreator: imageUrlsReceived.rejected,
+ effect: (action, { getState, dispatch }) => {
+ moduleLog.debug(
+ { data: { image: action.meta.arg } },
+ 'Problem getting image URLs'
+ );
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts
index d6cfc260f3..940cc84c1e 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts
@@ -1,6 +1,4 @@
import { initialImageChanged } from 'features/parameters/store/generationSlice';
-import { selectResultsById } from 'features/gallery/store/resultsSlice';
-import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
import { t } from 'i18next';
import { addToast } from 'features/system/store/systemSlice';
import { startAppListening } from '..';
@@ -9,7 +7,7 @@ import {
isImageDTO,
} from 'features/parameters/store/actions';
import { makeToast } from 'app/components/Toaster';
-import { ImageDTO } from 'services/api';
+import { selectImagesById } from 'features/gallery/store/imagesSlice';
export const addInitialImageSelectedListener = () => {
startAppListening({
@@ -30,16 +28,8 @@ export const addInitialImageSelectedListener = () => {
return;
}
- const { image_name, image_type } = action.payload;
-
- let image: ImageDTO | undefined;
- const state = getState();
-
- if (image_type === 'results') {
- image = selectResultsById(state, image_name);
- } else if (image_type === 'uploads') {
- image = selectUploadsById(state, image_name);
- }
+ const imageName = action.payload;
+ const image = selectImagesById(getState(), imageName);
if (!image) {
dispatch(
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/invocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/invocationComplete.ts
deleted file mode 100644
index 0222eea93c..0000000000
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/invocationComplete.ts
+++ /dev/null
@@ -1,62 +0,0 @@
-import { invocationComplete } from 'services/events/actions';
-import { isImageOutput } from 'services/types/guards';
-import {
- imageMetadataReceived,
- imageUrlsReceived,
-} from 'services/thunks/image';
-import { startAppListening } from '..';
-import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
-
-const nodeDenylist = ['dataURL_image'];
-
-export const addImageResultReceivedListener = () => {
- startAppListening({
- predicate: (action) => {
- if (
- invocationComplete.match(action) &&
- isImageOutput(action.payload.data.result)
- ) {
- return true;
- }
- return false;
- },
- effect: async (action, { getState, dispatch, take }) => {
- if (!invocationComplete.match(action)) {
- return;
- }
-
- const { data } = action.payload;
- const { result, node, graph_execution_state_id } = data;
-
- if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
- const { image_name, image_type } = result.image;
-
- dispatch(
- imageUrlsReceived({ imageName: image_name, imageType: image_type })
- );
-
- dispatch(
- imageMetadataReceived({
- imageName: image_name,
- imageType: image_type,
- })
- );
-
- // Handle canvas image
- if (
- graph_execution_state_id ===
- getState().canvas.layerState.stagingArea.sessionId
- ) {
- const [{ payload: image }] = await take(
- (
- action
- ): action is ReturnType =>
- imageMetadataReceived.fulfilled.match(action) &&
- action.payload.image_name === image_name
- );
- dispatch(addImageToStagingArea(image));
- }
- }
- },
- });
-};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedPageOfImages.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedPageOfImages.ts
new file mode 100644
index 0000000000..cde7e22e3d
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedPageOfImages.ts
@@ -0,0 +1,33 @@
+import { log } from 'app/logging/useLogger';
+import { startAppListening } from '..';
+import { serializeError } from 'serialize-error';
+import { receivedPageOfImages } from 'services/thunks/image';
+
+const moduleLog = log.child({ namespace: 'gallery' });
+
+export const addReceivedPageOfImagesFulfilledListener = () => {
+ startAppListening({
+ actionCreator: receivedPageOfImages.fulfilled,
+ effect: (action, { getState, dispatch }) => {
+ const page = action.payload;
+ moduleLog.debug(
+ { data: { payload: action.payload } },
+ `Received ${page.items.length} images`
+ );
+ },
+ });
+};
+
+export const addReceivedPageOfImagesRejectedListener = () => {
+ startAppListening({
+ actionCreator: receivedPageOfImages.rejected,
+ effect: (action, { getState, dispatch }) => {
+ if (action.payload) {
+ moduleLog.debug(
+ { data: { error: serializeError(action.payload) } },
+ 'Problem receiving images'
+ );
+ }
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCanceled.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCanceled.ts
new file mode 100644
index 0000000000..6274ad4dc8
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCanceled.ts
@@ -0,0 +1,48 @@
+import { log } from 'app/logging/useLogger';
+import { startAppListening } from '..';
+import { sessionCanceled } from 'services/thunks/session';
+import { serializeError } from 'serialize-error';
+
+const moduleLog = log.child({ namespace: 'session' });
+
+export const addSessionCanceledPendingListener = () => {
+ startAppListening({
+ actionCreator: sessionCanceled.pending,
+ effect: (action, { getState, dispatch }) => {
+ //
+ },
+ });
+};
+
+export const addSessionCanceledFulfilledListener = () => {
+ startAppListening({
+ actionCreator: sessionCanceled.fulfilled,
+ effect: (action, { getState, dispatch }) => {
+ const { sessionId } = action.meta.arg;
+ moduleLog.debug(
+ { data: { sessionId } },
+ `Session canceled (${sessionId})`
+ );
+ },
+ });
+};
+
+export const addSessionCanceledRejectedListener = () => {
+ startAppListening({
+ actionCreator: sessionCanceled.rejected,
+ effect: (action, { getState, dispatch }) => {
+ if (action.payload) {
+ const { arg, error } = action.payload;
+ moduleLog.error(
+ {
+ data: {
+ arg,
+ error: serializeError(error),
+ },
+ },
+ `Problem canceling session`
+ );
+ }
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts
new file mode 100644
index 0000000000..fb8a64d2e3
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts
@@ -0,0 +1,45 @@
+import { log } from 'app/logging/useLogger';
+import { startAppListening } from '..';
+import { sessionCreated } from 'services/thunks/session';
+import { serializeError } from 'serialize-error';
+
+const moduleLog = log.child({ namespace: 'session' });
+
+export const addSessionCreatedPendingListener = () => {
+ startAppListening({
+ actionCreator: sessionCreated.pending,
+ effect: (action, { getState, dispatch }) => {
+ //
+ },
+ });
+};
+
+export const addSessionCreatedFulfilledListener = () => {
+ startAppListening({
+ actionCreator: sessionCreated.fulfilled,
+ effect: (action, { getState, dispatch }) => {
+ const session = action.payload;
+ moduleLog.debug({ data: { session } }, `Session created (${session.id})`);
+ },
+ });
+};
+
+export const addSessionCreatedRejectedListener = () => {
+ startAppListening({
+ actionCreator: sessionCreated.rejected,
+ effect: (action, { getState, dispatch }) => {
+ if (action.payload) {
+ const { arg, error } = action.payload;
+ moduleLog.error(
+ {
+ data: {
+ arg,
+ error: serializeError(error),
+ },
+ },
+ `Problem creating session`
+ );
+ }
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts
new file mode 100644
index 0000000000..272d1d9e1d
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts
@@ -0,0 +1,48 @@
+import { log } from 'app/logging/useLogger';
+import { startAppListening } from '..';
+import { sessionInvoked } from 'services/thunks/session';
+import { serializeError } from 'serialize-error';
+
+const moduleLog = log.child({ namespace: 'session' });
+
+export const addSessionInvokedPendingListener = () => {
+ startAppListening({
+ actionCreator: sessionInvoked.pending,
+ effect: (action, { getState, dispatch }) => {
+ //
+ },
+ });
+};
+
+export const addSessionInvokedFulfilledListener = () => {
+ startAppListening({
+ actionCreator: sessionInvoked.fulfilled,
+ effect: (action, { getState, dispatch }) => {
+ const { sessionId } = action.meta.arg;
+ moduleLog.debug(
+ { data: { sessionId } },
+ `Session invoked (${sessionId})`
+ );
+ },
+ });
+};
+
+export const addSessionInvokedRejectedListener = () => {
+ startAppListening({
+ actionCreator: sessionInvoked.rejected,
+ effect: (action, { getState, dispatch }) => {
+ if (action.payload) {
+ const { arg, error } = action.payload;
+ moduleLog.error(
+ {
+ data: {
+ arg,
+ error: serializeError(error),
+ },
+ },
+ `Problem invoking session`
+ );
+ }
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionReadyToInvoke.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionReadyToInvoke.ts
new file mode 100644
index 0000000000..8d4262e7da
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionReadyToInvoke.ts
@@ -0,0 +1,22 @@
+import { startAppListening } from '..';
+import { sessionInvoked } from 'services/thunks/session';
+import { log } from 'app/logging/useLogger';
+import { sessionReadyToInvoke } from 'features/system/store/actions';
+
+const moduleLog = log.child({ namespace: 'session' });
+
+export const addSessionReadyToInvokeListener = () => {
+ startAppListening({
+ actionCreator: sessionReadyToInvoke,
+ effect: (action, { getState, dispatch }) => {
+ const { sessionId } = getState().system;
+ if (sessionId) {
+ moduleLog.debug(
+ { sessionId },
+ `Session ready to invoke (${sessionId})})`
+ );
+ dispatch(sessionInvoked({ sessionId }));
+ }
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts
new file mode 100644
index 0000000000..3049d2c933
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts
@@ -0,0 +1,38 @@
+import { startAppListening } from '../..';
+import { log } from 'app/logging/useLogger';
+import { appSocketConnected, socketConnected } from 'services/events/actions';
+import { receivedPageOfImages } from 'services/thunks/image';
+import { receivedModels } from 'services/thunks/model';
+import { receivedOpenAPISchema } from 'services/thunks/schema';
+
+const moduleLog = log.child({ namespace: 'socketio' });
+
+export const addSocketConnectedEventListener = () => {
+ startAppListening({
+ actionCreator: socketConnected,
+ effect: (action, { dispatch, getState }) => {
+ const { timestamp } = action.payload;
+
+ moduleLog.debug({ timestamp }, 'Connected');
+
+ const { models, nodes, config, images } = getState();
+
+ const { disabledTabs } = config;
+
+ if (!images.ids.length) {
+ dispatch(receivedPageOfImages());
+ }
+
+ if (!models.ids.length) {
+ dispatch(receivedModels());
+ }
+
+ if (!nodes.schema && !disabledTabs.includes('nodes')) {
+ dispatch(receivedOpenAPISchema());
+ }
+
+ // pass along the socket event as an application action
+ dispatch(appSocketConnected(action.payload));
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected.ts
new file mode 100644
index 0000000000..d5e8914cef
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected.ts
@@ -0,0 +1,19 @@
+import { startAppListening } from '../..';
+import { log } from 'app/logging/useLogger';
+import {
+ socketDisconnected,
+ appSocketDisconnected,
+} from 'services/events/actions';
+
+const moduleLog = log.child({ namespace: 'socketio' });
+
+export const addSocketDisconnectedEventListener = () => {
+ startAppListening({
+ actionCreator: socketDisconnected,
+ effect: (action, { dispatch, getState }) => {
+ moduleLog.debug(action.payload, 'Disconnected');
+ // pass along the socket event as an application action
+ dispatch(appSocketDisconnected(action.payload));
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts
new file mode 100644
index 0000000000..756444d644
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts
@@ -0,0 +1,34 @@
+import { startAppListening } from '../..';
+import { log } from 'app/logging/useLogger';
+import {
+ appSocketGeneratorProgress,
+ socketGeneratorProgress,
+} from 'services/events/actions';
+
+const moduleLog = log.child({ namespace: 'socketio' });
+
+export const addGeneratorProgressEventListener = () => {
+ startAppListening({
+ actionCreator: socketGeneratorProgress,
+ effect: (action, { dispatch, getState }) => {
+ if (
+ getState().system.canceledSession ===
+ action.payload.data.graph_execution_state_id
+ ) {
+ moduleLog.trace(
+ action.payload,
+ 'Ignored generator progress for canceled session'
+ );
+ return;
+ }
+
+ moduleLog.trace(
+ action.payload,
+ `Generator progress (${action.payload.data.node.type})`
+ );
+
+ // pass along the socket event as an application action
+ dispatch(appSocketGeneratorProgress(action.payload));
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete.ts
new file mode 100644
index 0000000000..7297825e32
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete.ts
@@ -0,0 +1,22 @@
+import { log } from 'app/logging/useLogger';
+import {
+ appSocketGraphExecutionStateComplete,
+ socketGraphExecutionStateComplete,
+} from 'services/events/actions';
+import { startAppListening } from '../..';
+
+const moduleLog = log.child({ namespace: 'socketio' });
+
+export const addGraphExecutionStateCompleteEventListener = () => {
+ startAppListening({
+ actionCreator: socketGraphExecutionStateComplete,
+ effect: (action, { dispatch, getState }) => {
+ moduleLog.debug(
+ action.payload,
+ `Session invocation complete (${action.payload.data.graph_execution_state_id})`
+ );
+ // pass along the socket event as an application action
+ dispatch(appSocketGraphExecutionStateComplete(action.payload));
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts
new file mode 100644
index 0000000000..0b47f7a1be
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts
@@ -0,0 +1,67 @@
+import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
+import { startAppListening } from '../..';
+import { log } from 'app/logging/useLogger';
+import {
+ appSocketInvocationComplete,
+ socketInvocationComplete,
+} from 'services/events/actions';
+import { imageMetadataReceived } from 'services/thunks/image';
+import { sessionCanceled } from 'services/thunks/session';
+import { isImageOutput } from 'services/types/guards';
+import { progressImageSet } from 'features/system/store/systemSlice';
+
+const moduleLog = log.child({ namespace: 'socketio' });
+const nodeDenylist = ['dataURL_image'];
+
+export const addInvocationCompleteEventListener = () => {
+ startAppListening({
+ actionCreator: socketInvocationComplete,
+ effect: async (action, { dispatch, getState, take }) => {
+ moduleLog.debug(
+ action.payload,
+ `Invocation complete (${action.payload.data.node.type})`
+ );
+
+ const sessionId = action.payload.data.graph_execution_state_id;
+
+ const { cancelType, isCancelScheduled } = getState().system;
+
+ // Handle scheduled cancelation
+ if (cancelType === 'scheduled' && isCancelScheduled) {
+ dispatch(sessionCanceled({ sessionId }));
+ }
+
+ const { data } = action.payload;
+ const { result, node, graph_execution_state_id } = data;
+
+ // This complete event has an associated image output
+ if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
+ const { image_name, image_origin } = result.image;
+
+ // Get its metadata
+ dispatch(
+ imageMetadataReceived({
+ imageName: image_name,
+ imageOrigin: image_origin,
+ })
+ );
+
+ const [{ payload: imageDTO }] = await take(
+ imageMetadataReceived.fulfilled.match
+ );
+
+ // Handle canvas image
+ if (
+ graph_execution_state_id ===
+ getState().canvas.layerState.stagingArea.sessionId
+ ) {
+ dispatch(addImageToStagingArea(imageDTO));
+ }
+
+ dispatch(progressImageSet(null));
+ }
+ // pass along the socket event as an application action
+ dispatch(appSocketInvocationComplete(action.payload));
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts
new file mode 100644
index 0000000000..51480bbad4
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts
@@ -0,0 +1,21 @@
+import { startAppListening } from '../..';
+import { log } from 'app/logging/useLogger';
+import {
+ appSocketInvocationError,
+ socketInvocationError,
+} from 'services/events/actions';
+
+const moduleLog = log.child({ namespace: 'socketio' });
+
+export const addInvocationErrorEventListener = () => {
+ startAppListening({
+ actionCreator: socketInvocationError,
+ effect: (action, { dispatch, getState }) => {
+ moduleLog.error(
+ action.payload,
+ `Invocation error (${action.payload.data.node.type})`
+ );
+ dispatch(appSocketInvocationError(action.payload));
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts
new file mode 100644
index 0000000000..978be2fef5
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts
@@ -0,0 +1,32 @@
+import { startAppListening } from '../..';
+import { log } from 'app/logging/useLogger';
+import {
+ appSocketInvocationStarted,
+ socketInvocationStarted,
+} from 'services/events/actions';
+
+const moduleLog = log.child({ namespace: 'socketio' });
+
+export const addInvocationStartedEventListener = () => {
+ startAppListening({
+ actionCreator: socketInvocationStarted,
+ effect: (action, { dispatch, getState }) => {
+ if (
+ getState().system.canceledSession ===
+ action.payload.data.graph_execution_state_id
+ ) {
+ moduleLog.trace(
+ action.payload,
+ 'Ignored invocation started for canceled session'
+ );
+ return;
+ }
+
+ moduleLog.debug(
+ action.payload,
+ `Invocation started (${action.payload.data.node.type})`
+ );
+ dispatch(appSocketInvocationStarted(action.payload));
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts
new file mode 100644
index 0000000000..871222981d
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts
@@ -0,0 +1,18 @@
+import { startAppListening } from '../..';
+import { log } from 'app/logging/useLogger';
+import { appSocketSubscribed, socketSubscribed } from 'services/events/actions';
+
+const moduleLog = log.child({ namespace: 'socketio' });
+
+export const addSocketSubscribedEventListener = () => {
+ startAppListening({
+ actionCreator: socketSubscribed,
+ effect: (action, { dispatch, getState }) => {
+ moduleLog.debug(
+ action.payload,
+ `Subscribed (${action.payload.sessionId}))`
+ );
+ dispatch(appSocketSubscribed(action.payload));
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts
new file mode 100644
index 0000000000..ff85379907
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts
@@ -0,0 +1,21 @@
+import { startAppListening } from '../..';
+import { log } from 'app/logging/useLogger';
+import {
+ appSocketUnsubscribed,
+ socketUnsubscribed,
+} from 'services/events/actions';
+
+const moduleLog = log.child({ namespace: 'socketio' });
+
+export const addSocketUnsubscribedEventListener = () => {
+ startAppListening({
+ actionCreator: socketUnsubscribed,
+ effect: (action, { dispatch, getState }) => {
+ moduleLog.debug(
+ action.payload,
+ `Unsubscribed (${action.payload.sessionId})`
+ );
+ dispatch(appSocketUnsubscribed(action.payload));
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts
new file mode 100644
index 0000000000..9bd3cd6dd2
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts
@@ -0,0 +1,54 @@
+import { stagingAreaImageSaved } from 'features/canvas/store/actions';
+import { startAppListening } from '..';
+import { log } from 'app/logging/useLogger';
+import { imageUpdated } from 'services/thunks/image';
+import { imageUpserted } from 'features/gallery/store/imagesSlice';
+import { addToast } from 'features/system/store/systemSlice';
+
+const moduleLog = log.child({ namespace: 'canvas' });
+
+export const addStagingAreaImageSavedListener = () => {
+ startAppListening({
+ actionCreator: stagingAreaImageSaved,
+ effect: async (action, { dispatch, getState, take }) => {
+ const { image_name, image_origin } = action.payload;
+
+ dispatch(
+ imageUpdated({
+ imageName: image_name,
+ imageOrigin: image_origin,
+ requestBody: {
+ is_intermediate: false,
+ },
+ })
+ );
+
+ const [imageUpdatedAction] = await take(
+ (action) =>
+ (imageUpdated.fulfilled.match(action) ||
+ imageUpdated.rejected.match(action)) &&
+ action.meta.arg.imageName === image_name
+ );
+
+ if (imageUpdated.rejected.match(imageUpdatedAction)) {
+ moduleLog.error(
+ { data: { arg: imageUpdatedAction.meta.arg } },
+ 'Image saving failed'
+ );
+ dispatch(
+ addToast({
+ title: 'Image Saving Failed',
+ description: imageUpdatedAction.error.message,
+ status: 'error',
+ })
+ );
+ return;
+ }
+
+ if (imageUpdated.fulfilled.match(imageUpdatedAction)) {
+ dispatch(imageUpserted(imageUpdatedAction.payload));
+ dispatch(addToast({ title: 'Image Saved', status: 'success' }));
+ }
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts
index 2ebd3684e9..0ee3016bdb 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts
@@ -1,9 +1,9 @@
import { startAppListening } from '..';
-import { sessionCreated, sessionInvoked } from 'services/thunks/session';
+import { sessionCreated } from 'services/thunks/session';
import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
import { log } from 'app/logging/useLogger';
import { canvasGraphBuilt } from 'features/nodes/store/actions';
-import { imageUploaded } from 'services/thunks/image';
+import { imageUpdated, imageUploaded } from 'services/thunks/image';
import { v4 as uuidv4 } from 'uuid';
import { Graph } from 'services/api';
import {
@@ -15,12 +15,22 @@ import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
+import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
/**
- * This listener is responsible for building the canvas graph and blobs when the user invokes the canvas.
- * It is also responsible for uploading the base and mask layers to the server.
+ * This listener is responsible invoking the canvas. This involves a number of steps:
+ *
+ * 1. Generate image blobs from the canvas layers
+ * 2. Determine the generation mode from the layers (txt2img, img2img, inpaint)
+ * 3. Build the canvas graph
+ * 4. Create the session with the graph
+ * 5. Upload the init image if necessary
+ * 6. Upload the mask image if necessary
+ * 7. Update the init and mask images with the session ID
+ * 8. Initialize the staging area if not yet initialized
+ * 9. Dispatch the sessionReadyToInvoke action to invoke the session
*/
export const addUserInvokedCanvasListener = () => {
startAppListening({
@@ -70,63 +80,7 @@ export const addUserInvokedCanvasListener = () => {
const { rangeNode, iterateNode, baseNode, edges } = graphComponents;
- // Upload the base layer, to be used as init image
- const baseFilename = `${uuidv4()}.png`;
-
- dispatch(
- imageUploaded({
- imageType: 'intermediates',
- formData: {
- file: new File([baseBlob], baseFilename, { type: 'image/png' }),
- },
- })
- );
-
- if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
- const [{ payload: basePayload }] = await take(
- (action): action is ReturnType =>
- imageUploaded.fulfilled.match(action) &&
- action.meta.arg.formData.file.name === baseFilename
- );
-
- const { image_name: baseName, image_type: baseType } =
- basePayload.response;
-
- baseNode.image = {
- image_name: baseName,
- image_type: baseType,
- };
- }
-
- // Upload the mask layer image
- const maskFilename = `${uuidv4()}.png`;
-
- if (baseNode.type === 'inpaint') {
- dispatch(
- imageUploaded({
- imageType: 'intermediates',
- formData: {
- file: new File([maskBlob], maskFilename, { type: 'image/png' }),
- },
- })
- );
-
- const [{ payload: maskPayload }] = await take(
- (action): action is ReturnType =>
- imageUploaded.fulfilled.match(action) &&
- action.meta.arg.formData.file.name === maskFilename
- );
-
- const { image_name: maskName, image_type: maskType } =
- maskPayload.response;
-
- baseNode.mask = {
- image_name: maskName,
- image_type: maskType,
- };
- }
-
- // Assemble!
+ // Assemble! Note that this graph *does not have the init or mask image set yet!*
const nodes: Graph['nodes'] = {
[rangeNode.id]: rangeNode,
[iterateNode.id]: iterateNode,
@@ -136,15 +90,92 @@ export const addUserInvokedCanvasListener = () => {
const graph = { nodes, edges };
dispatch(canvasGraphBuilt(graph));
- moduleLog({ data: graph }, 'Canvas graph built');
- // Actually create the session
+ moduleLog.debug({ data: graph }, 'Canvas graph built');
+
+ // If we are generating img2img or inpaint, we need to upload the init images
+ if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
+ const baseFilename = `${uuidv4()}.png`;
+ dispatch(
+ imageUploaded({
+ formData: {
+ file: new File([baseBlob], baseFilename, { type: 'image/png' }),
+ },
+ imageCategory: 'general',
+ isIntermediate: true,
+ })
+ );
+
+ // Wait for the image to be uploaded
+ const [{ payload: baseImageDTO }] = await take(
+ (action): action is ReturnType =>
+ imageUploaded.fulfilled.match(action) &&
+ action.meta.arg.formData.file.name === baseFilename
+ );
+
+ // Update the base node with the image name and type
+ baseNode.image = {
+ image_name: baseImageDTO.image_name,
+ image_origin: baseImageDTO.image_origin,
+ };
+ }
+
+ // For inpaint, we also need to upload the mask layer
+ if (baseNode.type === 'inpaint') {
+ const maskFilename = `${uuidv4()}.png`;
+ dispatch(
+ imageUploaded({
+ formData: {
+ file: new File([maskBlob], maskFilename, { type: 'image/png' }),
+ },
+ imageCategory: 'mask',
+ isIntermediate: true,
+ })
+ );
+
+ // Wait for the mask to be uploaded
+ const [{ payload: maskImageDTO }] = await take(
+ (action): action is ReturnType =>
+ imageUploaded.fulfilled.match(action) &&
+ action.meta.arg.formData.file.name === maskFilename
+ );
+
+ // Update the base node with the image name and type
+ baseNode.mask = {
+ image_name: maskImageDTO.image_name,
+ image_origin: maskImageDTO.image_origin,
+ };
+ }
+
+ // Create the session and wait for response
dispatch(sessionCreated({ graph }));
+ const [sessionCreatedAction] = await take(sessionCreated.fulfilled.match);
+ const sessionId = sessionCreatedAction.payload.id;
- // Wait for the session to be invoked (this is just the HTTP request to start processing)
- const [{ meta }] = await take(sessionInvoked.fulfilled.match);
+ // Associate the init image with the session, now that we have the session ID
+ if (
+ (baseNode.type === 'img2img' || baseNode.type === 'inpaint') &&
+ baseNode.image
+ ) {
+ dispatch(
+ imageUpdated({
+ imageName: baseNode.image.image_name,
+ imageOrigin: baseNode.image.image_origin,
+ requestBody: { session_id: sessionId },
+ })
+ );
+ }
- const { sessionId } = meta.arg;
+ // Associate the mask image with the session, now that we have the session ID
+ if (baseNode.type === 'inpaint' && baseNode.mask) {
+ dispatch(
+ imageUpdated({
+ imageName: baseNode.mask.image_name,
+ imageOrigin: baseNode.mask.image_origin,
+ requestBody: { session_id: sessionId },
+ })
+ );
+ }
if (!state.canvas.layerState.stagingArea.boundingBox) {
dispatch(
@@ -158,7 +189,11 @@ export const addUserInvokedCanvasListener = () => {
);
}
+ // Flag the session with the canvas session ID
dispatch(canvasSessionIdChanged(sessionId));
+
+ // We are ready to invoke the session!
+ dispatch(sessionReadyToInvoke());
},
});
};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts
index e747aefa08..7dcbe8a41d 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts
@@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session';
import { log } from 'app/logging/useLogger';
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions';
+import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
@@ -11,14 +12,18 @@ export const addUserInvokedImageToImageListener = () => {
startAppListening({
predicate: (action): action is ReturnType =>
userInvoked.match(action) && action.payload === 'img2img',
- effect: (action, { getState, dispatch }) => {
+ effect: async (action, { getState, dispatch, take }) => {
const state = getState();
const graph = buildImageToImageGraph(state);
dispatch(imageToImageGraphBuilt(graph));
- moduleLog({ data: graph }, 'Image to Image graph built');
+ moduleLog.debug({ data: graph }, 'Image to Image graph built');
dispatch(sessionCreated({ graph }));
+
+ await take(sessionCreated.fulfilled.match);
+
+ dispatch(sessionReadyToInvoke());
},
});
};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts
index 01e532d5ff..6fda3db0d6 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts
@@ -4,6 +4,7 @@ import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGra
import { log } from 'app/logging/useLogger';
import { nodesGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions';
+import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
@@ -11,14 +12,18 @@ export const addUserInvokedNodesListener = () => {
startAppListening({
predicate: (action): action is ReturnType =>
userInvoked.match(action) && action.payload === 'nodes',
- effect: (action, { getState, dispatch }) => {
+ effect: async (action, { getState, dispatch, take }) => {
const state = getState();
const graph = buildNodesGraph(state);
dispatch(nodesGraphBuilt(graph));
- moduleLog({ data: graph }, 'Nodes graph built');
+ moduleLog.debug({ data: graph }, 'Nodes graph built');
dispatch(sessionCreated({ graph }));
+
+ await take(sessionCreated.fulfilled.match);
+
+ dispatch(sessionReadyToInvoke());
},
});
};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts
index e3eb5d0b38..6042d86cb7 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts
@@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session';
import { log } from 'app/logging/useLogger';
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions';
+import { sessionReadyToInvoke } from 'features/system/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
@@ -11,14 +12,20 @@ export const addUserInvokedTextToImageListener = () => {
startAppListening({
predicate: (action): action is ReturnType =>
userInvoked.match(action) && action.payload === 'txt2img',
- effect: (action, { getState, dispatch }) => {
+ effect: async (action, { getState, dispatch, take }) => {
const state = getState();
const graph = buildTextToImageGraph(state);
+
dispatch(textToImageGraphBuilt(graph));
- moduleLog({ data: graph }, 'Text to Image graph built');
+
+ moduleLog.debug({ data: graph }, 'Text to Image graph built');
dispatch(sessionCreated({ graph }));
+
+ await take(sessionCreated.fulfilled.match);
+
+ dispatch(sessionReadyToInvoke());
},
});
};
diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts
index b89615b2c0..521610adcc 100644
--- a/invokeai/frontend/web/src/app/store/store.ts
+++ b/invokeai/frontend/web/src/app/store/store.ts
@@ -10,12 +10,12 @@ import dynamicMiddlewares from 'redux-dynamic-middlewares';
import canvasReducer from 'features/canvas/store/canvasSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
-import resultsReducer from 'features/gallery/store/resultsSlice';
-import uploadsReducer from 'features/gallery/store/uploadsSlice';
+import imagesReducer from 'features/gallery/store/imagesSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice';
+// import sessionReducer from 'features/system/store/sessionSlice';
import configReducer from 'features/system/store/configSlice';
import uiReducer from 'features/ui/store/uiSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
@@ -40,12 +40,12 @@ const allReducers = {
models: modelsReducer,
nodes: nodesReducer,
postprocessing: postprocessingReducer,
- results: resultsReducer,
system: systemReducer,
config: configReducer,
ui: uiReducer,
- uploads: uploadsReducer,
hotkeys: hotkeysReducer,
+ images: imagesReducer,
+ // session: sessionReducer,
};
const rootReducer = combineReducers(allReducers);
@@ -63,8 +63,6 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'system',
'ui',
// 'hotkeys',
- // 'results',
- // 'uploads',
// 'config',
];
diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts
index 68f7568779..8081ffa491 100644
--- a/invokeai/frontend/web/src/app/types/invokeai.ts
+++ b/invokeai/frontend/web/src/app/types/invokeai.ts
@@ -1,316 +1,82 @@
-/**
- * Types for images, the things they are made of, and the things
- * they make up.
- *
- * Generated images are txt2img and img2img images. They may have
- * had additional postprocessing done on them when they were first
- * generated.
- *
- * Postprocessed images are images which were not generated here
- * but only postprocessed by the app. They only get postprocessing
- * metadata and have a different image type, e.g. 'esrgan' or
- * 'gfpgan'.
- */
-
-import { SelectedImage } from 'features/parameters/store/actions';
import { InvokeTabName } from 'features/ui/store/tabMap';
-import { IRect } from 'konva/lib/types';
-import { ImageResponseMetadata, ImageType } from 'services/api';
import { O } from 'ts-toolbelt';
-/**
- * TODO:
- * Once an image has been generated, if it is postprocessed again,
- * additional postprocessing steps are added to its postprocessing
- * array.
- *
- * TODO: Better documentation of types.
- */
+// These are old types from the model management UI
-export type PromptItem = {
- prompt: string;
- weight: number;
-};
+// export type ModelStatus = 'active' | 'cached' | 'not loaded';
-// TECHDEBT: We need to retain compatibility with plain prompt strings and the structure Prompt type
-export type Prompt = Array | string;
-
-export type SeedWeightPair = {
- seed: number;
- weight: number;
-};
-
-export type SeedWeights = Array;
-
-// All generated images contain these metadata.
-export type CommonGeneratedImageMetadata = {
- postprocessing: null | Array;
- sampler:
- | 'ddim'
- | 'ddpm'
- | 'deis'
- | 'lms'
- | 'pndm'
- | 'heun'
- | 'heun_k'
- | 'euler'
- | 'euler_k'
- | 'euler_a'
- | 'kdpm_2'
- | 'kdpm_2_a'
- | 'dpmpp_2s'
- | 'dpmpp_2m'
- | 'dpmpp_2m_k'
- | 'unipc';
- prompt: Prompt;
- seed: number;
- variations: SeedWeights;
- steps: number;
- cfg_scale: number;
- width: number;
- height: number;
- seamless: boolean;
- hires_fix: boolean;
- extra: null | Record; // Pending development of RFC #266
-};
-
-// txt2img and img2img images have some unique attributes.
-export type Txt2ImgMetadata = CommonGeneratedImageMetadata & {
- type: 'txt2img';
-};
-
-export type Img2ImgMetadata = CommonGeneratedImageMetadata & {
- type: 'img2img';
- orig_hash: string;
- strength: number;
- fit: boolean;
- init_image_path: string;
- mask_image_path?: string;
-};
-
-// Superset of generated image metadata types.
-export type GeneratedImageMetadata = Txt2ImgMetadata | Img2ImgMetadata;
-
-// All post processed images contain these metadata.
-export type CommonPostProcessedImageMetadata = {
- orig_path: string;
- orig_hash: string;
-};
-
-// esrgan and gfpgan images have some unique attributes.
-export type ESRGANMetadata = CommonPostProcessedImageMetadata & {
- type: 'esrgan';
- scale: 2 | 4;
- strength: number;
- denoise_str: number;
-};
-
-export type FacetoolMetadata = CommonPostProcessedImageMetadata & {
- type: 'gfpgan' | 'codeformer';
- strength: number;
- fidelity?: number;
-};
-
-// Superset of all postprocessed image metadata types..
-export type PostProcessedImageMetadata = ESRGANMetadata | FacetoolMetadata;
-
-// Metadata includes the system config and image metadata.
-// export type Metadata = SystemGenerationMetadata & {
-// image: GeneratedImageMetadata | PostProcessedImageMetadata;
+// export type Model = {
+// status: ModelStatus;
+// description: string;
+// weights: string;
+// config?: string;
+// vae?: string;
+// width?: number;
+// height?: number;
+// default?: boolean;
+// format?: string;
// };
-/**
- * ResultImage
- */
-// export ty`pe Image = {
+// export type DiffusersModel = {
+// status: ModelStatus;
+// description: string;
+// repo_id?: string;
+// path?: string;
+// vae?: {
+// repo_id?: string;
+// path?: string;
+// };
+// format?: string;
+// default?: boolean;
+// };
+
+// export type ModelList = Record;
+
+// export type FoundModel = {
// name: string;
-// type: ImageType;
-// url: string;
-// thumbnail: string;
-// metadata: ImageResponseMetadata;
+// location: string;
// };
-// export const isInvokeAIImage = (obj: Image | SelectedImage): obj is Image => {
-// if ('url' in obj && 'thumbnail' in obj) {
-// return true;
-// }
-
-// return false;
+// export type InvokeModelConfigProps = {
+// name: string | undefined;
+// description: string | undefined;
+// config: string | undefined;
+// weights: string | undefined;
+// vae: string | undefined;
+// width: number | undefined;
+// height: number | undefined;
+// default: boolean | undefined;
+// format: string | undefined;
// };
-/**
- * Types related to the system status.
- */
-
-// // This represents the processing status of the backend.
-// export type SystemStatus = {
-// isProcessing: boolean;
-// currentStep: number;
-// totalSteps: number;
-// currentIteration: number;
-// totalIterations: number;
-// currentStatus: string;
-// currentStatusHasSteps: boolean;
-// hasError: boolean;
+// export type InvokeDiffusersModelConfigProps = {
+// name: string | undefined;
+// description: string | undefined;
+// repo_id: string | undefined;
+// path: string | undefined;
+// default: boolean | undefined;
+// format: string | undefined;
+// vae: {
+// repo_id: string | undefined;
+// path: string | undefined;
+// };
// };
-// export type SystemGenerationMetadata = {
-// model: string;
-// model_weights?: string;
-// model_id?: string;
-// model_hash: string;
-// app_id: string;
-// app_version: string;
+// export type InvokeModelConversionProps = {
+// model_name: string;
+// save_location: string;
+// custom_location: string | null;
// };
-// export type SystemConfig = SystemGenerationMetadata & {
-// model_list: ModelList;
-// infill_methods: string[];
+// export type InvokeModelMergingProps = {
+// models_to_merge: string[];
+// alpha: number;
+// interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
+// force: boolean;
+// merged_model_name: string;
+// model_merge_save_path: string | null;
// };
-export type ModelStatus = 'active' | 'cached' | 'not loaded';
-
-export type Model = {
- status: ModelStatus;
- description: string;
- weights: string;
- config?: string;
- vae?: string;
- width?: number;
- height?: number;
- default?: boolean;
- format?: string;
-};
-
-export type DiffusersModel = {
- status: ModelStatus;
- description: string;
- repo_id?: string;
- path?: string;
- vae?: {
- repo_id?: string;
- path?: string;
- };
- format?: string;
- default?: boolean;
-};
-
-export type ModelList = Record;
-
-export type FoundModel = {
- name: string;
- location: string;
-};
-
-export type InvokeModelConfigProps = {
- name: string | undefined;
- description: string | undefined;
- config: string | undefined;
- weights: string | undefined;
- vae: string | undefined;
- width: number | undefined;
- height: number | undefined;
- default: boolean | undefined;
- format: string | undefined;
-};
-
-export type InvokeDiffusersModelConfigProps = {
- name: string | undefined;
- description: string | undefined;
- repo_id: string | undefined;
- path: string | undefined;
- default: boolean | undefined;
- format: string | undefined;
- vae: {
- repo_id: string | undefined;
- path: string | undefined;
- };
-};
-
-export type InvokeModelConversionProps = {
- model_name: string;
- save_location: string;
- custom_location: string | null;
-};
-
-export type InvokeModelMergingProps = {
- models_to_merge: string[];
- alpha: number;
- interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
- force: boolean;
- merged_model_name: string;
- model_merge_save_path: string | null;
-};
-
-/**
- * These types type data received from the server via socketio.
- */
-
-export type ModelChangeResponse = {
- model_name: string;
- model_list: ModelList;
-};
-
-export type ModelConvertedResponse = {
- converted_model_name: string;
- model_list: ModelList;
-};
-
-export type ModelsMergedResponse = {
- merged_models: string[];
- merged_model_name: string;
- model_list: ModelList;
-};
-
-export type ModelAddedResponse = {
- new_model_name: string;
- model_list: ModelList;
- update: boolean;
-};
-
-export type ModelDeletedResponse = {
- deleted_model_name: string;
- model_list: ModelList;
-};
-
-export type FoundModelResponse = {
- search_folder: string;
- found_models: FoundModel[];
-};
-
-// export type SystemStatusResponse = SystemStatus;
-
-// export type SystemConfigResponse = SystemConfig;
-
-export type ImageResultResponse = Omit & {
- boundingBox?: IRect;
- generationMode: InvokeTabName;
-};
-
-export type ImageUploadResponse = {
- // image: Omit;
- url: string;
- mtime: number;
- width: number;
- height: number;
- thumbnail: string;
- // bbox: [number, number, number, number];
-};
-
-export type ErrorResponse = {
- message: string;
- additionalData?: string;
-};
-
-export type ImageUrlResponse = {
- url: string;
-};
-
-export type UploadOutpaintingMergeImagePayload = {
- dataURL: string;
- name: string;
-};
-
/**
* A disable-able application feature
*/
@@ -322,7 +88,8 @@ export type AppFeature =
| 'githubLink'
| 'discordLink'
| 'bugLink'
- | 'localization';
+ | 'localization'
+ | 'consoleLogging';
/**
* A disable-able Stable Diffusion feature
@@ -351,6 +118,7 @@ export type AppConfig = {
disabledSDFeatures: SDFeature[];
canRestoreDeletedImagesFromBin: boolean;
sd: {
+ defaultModel?: string;
iterations: {
initial: number;
min: number;
diff --git a/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx b/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx
index d9610346ec..6d6cdbadf5 100644
--- a/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx
+++ b/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx
@@ -21,9 +21,12 @@ import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo } from 'react';
+export type ItemTooltips = { [key: string]: string };
+
type IAICustomSelectProps = {
label?: string;
items: string[];
+ itemTooltips?: ItemTooltips;
selectedItem: string;
setSelectedItem: (v: string | null | undefined) => void;
withCheckIcon?: boolean;
@@ -37,6 +40,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
const {
label,
items,
+ itemTooltips,
setSelectedItem,
selectedItem,
withCheckIcon,
@@ -118,48 +122,56 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
>
{items.map((item, index) => (
-
- {withCheckIcon ? (
-
-
- {selectedItem === item && }
-
-
-
- {item}
-
-
-
- ) : (
-
- {item}
-
- )}
-
+
+ {withCheckIcon ? (
+
+
+ {selectedItem === item && }
+
+
+
+ {item}
+
+
+
+ ) : (
+
+ {item}
+
+ )}
+
+
))}
diff --git a/invokeai/frontend/web/src/common/components/ImageUploadOverlay.tsx b/invokeai/frontend/web/src/common/components/ImageUploadOverlay.tsx
index 28d9d32a71..862d806eb1 100644
--- a/invokeai/frontend/web/src/common/components/ImageUploadOverlay.tsx
+++ b/invokeai/frontend/web/src/common/components/ImageUploadOverlay.tsx
@@ -4,7 +4,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
type ImageUploadOverlayProps = {
isDragAccept: boolean;
isDragReject: boolean;
- overlaySecondaryText: string;
setIsHandlingUpload: (isHandlingUpload: boolean) => void;
};
@@ -12,7 +11,6 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
const {
isDragAccept,
isDragReject: _isDragAccept,
- overlaySecondaryText,
setIsHandlingUpload,
} = props;
@@ -48,7 +46,7 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
}}
>
{isDragAccept ? (
- Upload Image{overlaySecondaryText}
+ Drop to Upload
) : (
<>
Invalid Upload
diff --git a/invokeai/frontend/web/src/common/components/ImageUploader.tsx b/invokeai/frontend/web/src/common/components/ImageUploader.tsx
index db6b9ee517..17f6d68633 100644
--- a/invokeai/frontend/web/src/common/components/ImageUploader.tsx
+++ b/invokeai/frontend/web/src/common/components/ImageUploader.tsx
@@ -68,13 +68,13 @@ const ImageUploader = (props: ImageUploaderProps) => {
async (file: File) => {
dispatch(
imageUploaded({
- imageType: 'uploads',
formData: { file },
- activeTabName,
+ imageCategory: 'user',
+ isIntermediate: false,
})
);
},
- [dispatch, activeTabName]
+ [dispatch]
);
const onDrop = useCallback(
@@ -145,14 +145,6 @@ const ImageUploader = (props: ImageUploaderProps) => {
};
}, [inputRef, open, setOpenUploaderFunction]);
- const overlaySecondaryText = useMemo(() => {
- if (['img2img', 'unifiedCanvas'].includes(activeTabName)) {
- return ` to ${String(t(`common.${activeTabName}` as ResourceKey))}`;
- }
-
- return '';
- }, [t, activeTabName]);
-
return (
{
)}
diff --git a/invokeai/frontend/web/src/common/util/_parseMetadataZod.ts b/invokeai/frontend/web/src/common/util/_parseMetadataZod.ts
deleted file mode 100644
index 584399233f..0000000000
--- a/invokeai/frontend/web/src/common/util/_parseMetadataZod.ts
+++ /dev/null
@@ -1,119 +0,0 @@
-/**
- * PARTIAL ZOD IMPLEMENTATION
- *
- * doesn't work well bc like most validators, zod is not built to skip invalid values.
- * it mostly works but just seems clearer and simpler to manually parse for now.
- *
- * in the future it would be really nice if we could use zod for some things:
- * - zodios (axios + zod): https://github.com/ecyrbe/zodios
- * - openapi to zodios: https://github.com/astahmer/openapi-zod-client
- */
-
-// import { z } from 'zod';
-
-// const zMetadataStringField = z.string();
-// export type MetadataStringField = z.infer;
-
-// const zMetadataIntegerField = z.number().int();
-// export type MetadataIntegerField = z.infer;
-
-// const zMetadataFloatField = z.number();
-// export type MetadataFloatField = z.infer;
-
-// const zMetadataBooleanField = z.boolean();
-// export type MetadataBooleanField = z.infer;
-
-// const zMetadataImageField = z.object({
-// image_type: z.union([
-// z.literal('results'),
-// z.literal('uploads'),
-// z.literal('intermediates'),
-// ]),
-// image_name: z.string().min(1),
-// });
-// export type MetadataImageField = z.infer;
-
-// const zMetadataLatentsField = z.object({
-// latents_name: z.string().min(1),
-// });
-// export type MetadataLatentsField = z.infer;
-
-// /**
-// * zod Schema for any node field. Use a `transform()` to manually parse, skipping invalid values.
-// */
-// const zAnyMetadataField = z.any().transform((val, ctx) => {
-// // Grab the field name from the path
-// const fieldName = String(ctx.path[ctx.path.length - 1]);
-
-// // `id` and `type` must be strings if they exist
-// if (['id', 'type'].includes(fieldName)) {
-// const reservedStringPropertyResult = zMetadataStringField.safeParse(val);
-// if (reservedStringPropertyResult.success) {
-// return reservedStringPropertyResult.data;
-// }
-
-// return;
-// }
-
-// // Parse the rest of the fields, only returning the data if the parsing is successful
-
-// const stringFieldResult = zMetadataStringField.safeParse(val);
-// if (stringFieldResult.success) {
-// return stringFieldResult.data;
-// }
-
-// const integerFieldResult = zMetadataIntegerField.safeParse(val);
-// if (integerFieldResult.success) {
-// return integerFieldResult.data;
-// }
-
-// const floatFieldResult = zMetadataFloatField.safeParse(val);
-// if (floatFieldResult.success) {
-// return floatFieldResult.data;
-// }
-
-// const booleanFieldResult = zMetadataBooleanField.safeParse(val);
-// if (booleanFieldResult.success) {
-// return booleanFieldResult.data;
-// }
-
-// const imageFieldResult = zMetadataImageField.safeParse(val);
-// if (imageFieldResult.success) {
-// return imageFieldResult.data;
-// }
-
-// const latentsFieldResult = zMetadataImageField.safeParse(val);
-// if (latentsFieldResult.success) {
-// return latentsFieldResult.data;
-// }
-// });
-
-// /**
-// * The node metadata schema.
-// */
-// const zNodeMetadata = z.object({
-// session_id: z.string().min(1).optional(),
-// node: z.record(z.string().min(1), zAnyMetadataField).optional(),
-// });
-
-// export type NodeMetadata = z.infer;
-
-// const zMetadata = z.object({
-// invokeai: zNodeMetadata.optional(),
-// 'sd-metadata': z.record(z.string().min(1), z.any()).optional(),
-// });
-// export type Metadata = z.infer;
-
-// export const parseMetadata = (
-// metadata: Record
-// ): Metadata | undefined => {
-// const result = zMetadata.safeParse(metadata);
-// if (!result.success) {
-// console.log(result.error.issues);
-// return;
-// }
-
-// return result.data;
-// };
-
-export default {};
diff --git a/invokeai/frontend/web/src/common/util/parseMetadata.ts b/invokeai/frontend/web/src/common/util/parseMetadata.ts
deleted file mode 100644
index f1828d95e7..0000000000
--- a/invokeai/frontend/web/src/common/util/parseMetadata.ts
+++ /dev/null
@@ -1,334 +0,0 @@
-import { forEach, size } from 'lodash-es';
-import {
- ImageField,
- LatentsField,
- ConditioningField,
- UNetField,
- ClipField,
- VaeField,
-} from 'services/api';
-
-const OBJECT_TYPESTRING = '[object Object]';
-const STRING_TYPESTRING = '[object String]';
-const NUMBER_TYPESTRING = '[object Number]';
-const BOOLEAN_TYPESTRING = '[object Boolean]';
-const ARRAY_TYPESTRING = '[object Array]';
-
-const isObject = (obj: unknown): obj is Record =>
- Object.prototype.toString.call(obj) === OBJECT_TYPESTRING;
-
-const isString = (obj: unknown): obj is string =>
- Object.prototype.toString.call(obj) === STRING_TYPESTRING;
-
-const isNumber = (obj: unknown): obj is number =>
- Object.prototype.toString.call(obj) === NUMBER_TYPESTRING;
-
-const isBoolean = (obj: unknown): obj is boolean =>
- Object.prototype.toString.call(obj) === BOOLEAN_TYPESTRING;
-
-const isArray = (obj: unknown): obj is Array =>
- Object.prototype.toString.call(obj) === ARRAY_TYPESTRING;
-
-const parseImageField = (imageField: unknown): ImageField | undefined => {
- // Must be an object
- if (!isObject(imageField)) {
- return;
- }
-
- // An ImageField must have both `image_name` and `image_type`
- if (!('image_name' in imageField && 'image_type' in imageField)) {
- return;
- }
-
- // An ImageField's `image_type` must be one of the allowed values
- if (
- !['results', 'uploads', 'intermediates'].includes(imageField.image_type)
- ) {
- return;
- }
-
- // An ImageField's `image_name` must be a string
- if (typeof imageField.image_name !== 'string') {
- return;
- }
-
- // Build a valid ImageField
- return {
- image_type: imageField.image_type,
- image_name: imageField.image_name,
- };
-};
-
-const parseLatentsField = (latentsField: unknown): LatentsField | undefined => {
- // Must be an object
- if (!isObject(latentsField)) {
- return;
- }
-
- // A LatentsField must have a `latents_name`
- if (!('latents_name' in latentsField)) {
- return;
- }
-
- // A LatentsField's `latents_name` must be a string
- if (typeof latentsField.latents_name !== 'string') {
- return;
- }
-
- // Build a valid LatentsField
- return {
- latents_name: latentsField.latents_name,
- };
-};
-
-const parseConditioningField = (
- conditioningField: unknown
-): ConditioningField | undefined => {
- // Must be an object
- if (!isObject(conditioningField)) {
- return;
- }
-
- // A ConditioningField must have a `conditioning_name`
- if (!('conditioning_name' in conditioningField)) {
- return;
- }
-
- // A ConditioningField's `conditioning_name` must be a string
- if (typeof conditioningField.conditioning_name !== 'string') {
- return;
- }
-
- // Build a valid ConditioningField
- return {
- conditioning_name: conditioningField.conditioning_name,
- };
-};
-
-const _parseModelInfo = (modelInfo: unknown): ModelInfo | undefined => {
- // Must be an object
- if (!isObject(modelInfo)) {
- return;
- }
-
- if (!('model_name' in modelInfo && typeof modelInfo.model_name == 'string')) {
- return;
- }
-
- if (!('model_type' in modelInfo && typeof modelInfo.model_type == 'string')) {
- return;
- }
-
- if (!('submodel' in modelInfo && typeof modelInfo.submodel == 'string')) {
- return;
- }
-
- return {
- model_name: modelInfo.model_name,
- model_type: modelInfo.model_type,
- submodel: modelInfo.submodel,
- };
-};
-
-const parseUNetField = (unetField: unknown): UNetField | undefined => {
- // Must be an object
- if (!isObject(unetField)) {
- return;
- }
-
- if (!('unet' in unetField && 'scheduler' in unetField)) {
- return;
- }
-
- const unet = _parseModelInfo(unetField.unet);
- const scheduler = _parseModelInfo(unetField.scheduler);
-
- if (!(unet && scheduler)) {
- return;
- }
-
- // Build a valid UNetField
- return {
- unet: unet,
- scheduler: scheduler,
- };
-};
-
-const parseClipField = (clipField: unknown): ClipField | undefined => {
- // Must be an object
- if (!isObject(clipField)) {
- return;
- }
-
- if (!('tokenizer' in clipField && 'text_encoder' in clipField)) {
- return;
- }
-
- const tokenizer = _parseModelInfo(clipField.tokenizer);
- const text_encoder = _parseModelInfo(clipField.text_encoder);
-
- if (!(tokenizer && text_encoder)) {
- return;
- }
-
- // Build a valid ClipField
- return {
- tokenizer: tokenizer,
- text_encoder: text_encoder,
- };
-};
-
-const parseVaeField = (vaeField: unknown): VaeField | undefined => {
- // Must be an object
- if (!isObject(vaeField)) {
- return;
- }
-
- if (!('vae' in vaeField)) {
- return;
- }
-
- const vae = _parseModelInfo(vaeField.vae);
-
- if (!vae) {
- return;
- }
-
- // Build a valid VaeField
- return {
- vae: vae,
- };
-};
-
-type NodeMetadata = {
- [key: string]:
- | string
- | number
- | boolean
- | ImageField
- | LatentsField
- | ConditioningField
- | UNetField
- | ClipField
- | VaeField;
-};
-
-type InvokeAIMetadata = {
- session_id?: string;
- node?: NodeMetadata;
-};
-
-export const parseNodeMetadata = (
- nodeMetadata: Record
-): NodeMetadata | undefined => {
- if (!isObject(nodeMetadata)) {
- return;
- }
-
- const parsed: NodeMetadata = {};
-
- forEach(nodeMetadata, (nodeItem, nodeKey) => {
- // `id` and `type` must be strings if they are present
- if (['id', 'type'].includes(nodeKey)) {
- if (isString(nodeItem)) {
- parsed[nodeKey] = nodeItem;
- }
- return;
- }
-
- // valid object types are:
- // ImageField, LatentsField ConditioningField, UNetField, ClipField, VaeField
- if (isObject(nodeItem)) {
- if ('image_name' in nodeItem || 'image_type' in nodeItem) {
- const imageField = parseImageField(nodeItem);
- if (imageField) {
- parsed[nodeKey] = imageField;
- }
- return;
- }
-
- if ('latents_name' in nodeItem) {
- const latentsField = parseLatentsField(nodeItem);
- if (latentsField) {
- parsed[nodeKey] = latentsField;
- }
- return;
- }
-
- if ('conditioning_name' in nodeItem) {
- const conditioningField = parseConditioningField(nodeItem);
- if (conditioningField) {
- parsed[nodeKey] = conditioningField;
- }
- return;
- }
-
- if ('unet' in nodeItem && 'scheduler' in nodeItem) {
- const unetField = parseUNetField(nodeItem);
- if (unetField) {
- parsed[nodeKey] = unetField;
- }
- }
-
- if ('tokenizer' in nodeItem && 'text_encoder' in nodeItem) {
- const clipField = parseClipField(nodeItem);
- if (clipField) {
- parsed[nodeKey] = clipField;
- }
- }
-
- if ('vae' in nodeItem) {
- const vaeField = parseVaeField(nodeItem);
- if (vaeField) {
- parsed[nodeKey] = vaeField;
- }
- }
- }
-
- // otherwise we accept any string, number or boolean
- if (isString(nodeItem) || isNumber(nodeItem) || isBoolean(nodeItem)) {
- parsed[nodeKey] = nodeItem;
- return;
- }
- });
-
- if (size(parsed) === 0) {
- return;
- }
-
- return parsed;
-};
-
-export const parseInvokeAIMetadata = (
- metadata: Record | undefined
-): InvokeAIMetadata | undefined => {
- if (metadata === undefined) {
- return;
- }
-
- if (!isObject(metadata)) {
- return;
- }
-
- const parsed: InvokeAIMetadata = {};
-
- forEach(metadata, (item, key) => {
- if (key === 'session_id' && isString(item)) {
- parsed['session_id'] = item;
- }
-
- if (key === 'node' && isObject(item)) {
- const nodeMetadata = parseNodeMetadata(item);
-
- if (nodeMetadata) {
- parsed['node'] = nodeMetadata;
- }
- }
- });
-
- if (size(parsed) === 0) {
- return;
- }
-
- return parsed;
-};
diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx
index 745825a975..ea5e9a6486 100644
--- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx
+++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx
@@ -1,18 +1,24 @@
import { createSelector } from '@reduxjs/toolkit';
-import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
-import { useGetUrl } from 'common/util/getUrl';
-import { GalleryState } from 'features/gallery/store/gallerySlice';
+import { systemSelector } from 'features/system/store/systemSelectors';
import { ImageConfig } from 'konva/lib/shapes/Image';
import { isEqual } from 'lodash-es';
import { useEffect, useState } from 'react';
import { Image as KonvaImage } from 'react-konva';
+import { canvasSelector } from '../store/canvasSelectors';
const selector = createSelector(
- [(state: RootState) => state.gallery],
- (gallery: GalleryState) => {
- return gallery.intermediateImage ? gallery.intermediateImage : null;
+ [systemSelector, canvasSelector],
+ (system, canvas) => {
+ const { progressImage, sessionId } = system;
+ const { sessionId: canvasSessionId, boundingBox } =
+ canvas.layerState.stagingArea;
+
+ return {
+ boundingBox,
+ progressImage: sessionId === canvasSessionId ? progressImage : undefined,
+ };
},
{
memoizeOptions: {
@@ -25,33 +31,34 @@ type Props = Omit;
const IAICanvasIntermediateImage = (props: Props) => {
const { ...rest } = props;
- const intermediateImage = useAppSelector(selector);
- const { getUrl } = useGetUrl();
+ const { progressImage, boundingBox } = useAppSelector(selector);
const [loadedImageElement, setLoadedImageElement] =
useState(null);
useEffect(() => {
- if (!intermediateImage) return;
+ if (!progressImage) {
+ return;
+ }
+
const tempImage = new Image();
tempImage.onload = () => {
setLoadedImageElement(tempImage);
};
- tempImage.src = getUrl(intermediateImage.url);
- }, [intermediateImage, getUrl]);
- if (!intermediateImage?.boundingBox) return null;
+ tempImage.src = progressImage.dataURL;
+ }, [progressImage]);
- const {
- boundingBox: { x, y, width, height },
- } = intermediateImage;
+ if (!(progressImage && boundingBox)) {
+ return null;
+ }
return loadedImageElement ? (
{
{shouldShowStagingImage && currentStagingAreaImage && (
diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx
index 64c752fce0..76ffdcf082 100644
--- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx
+++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx
@@ -1,6 +1,5 @@
import { ButtonGroup, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
-// import { saveStagingAreaImageToGallery } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
@@ -26,13 +25,14 @@ import {
FaPlus,
FaSave,
} from 'react-icons/fa';
+import { stagingAreaImageSaved } from '../store/actions';
const selector = createSelector(
[canvasSelector],
(canvas) => {
const {
layerState: {
- stagingArea: { images, selectedImageIndex },
+ stagingArea: { images, selectedImageIndex, sessionId },
},
shouldShowStagingOutline,
shouldShowStagingImage,
@@ -45,6 +45,7 @@ const selector = createSelector(
isOnLastImage: selectedImageIndex === images.length - 1,
shouldShowStagingImage,
shouldShowStagingOutline,
+ sessionId,
};
},
{
@@ -61,6 +62,7 @@ const IAICanvasStagingAreaToolbar = () => {
isOnLastImage,
currentStagingAreaImage,
shouldShowStagingImage,
+ sessionId,
} = useAppSelector(selector);
const { t } = useTranslation();
@@ -106,9 +108,20 @@ const IAICanvasStagingAreaToolbar = () => {
}
);
- const handlePrevImage = () => dispatch(prevStagingAreaImage());
- const handleNextImage = () => dispatch(nextStagingAreaImage());
- const handleAccept = () => dispatch(commitStagingAreaImage());
+ const handlePrevImage = useCallback(
+ () => dispatch(prevStagingAreaImage()),
+ [dispatch]
+ );
+
+ const handleNextImage = useCallback(
+ () => dispatch(nextStagingAreaImage()),
+ [dispatch]
+ );
+
+ const handleAccept = useCallback(
+ () => dispatch(commitStagingAreaImage(sessionId)),
+ [dispatch, sessionId]
+ );
if (!currentStagingAreaImage) return null;
@@ -157,19 +170,15 @@ const IAICanvasStagingAreaToolbar = () => {
}
colorScheme="accent"
/>
- {/* }
onClick={() =>
- dispatch(
- saveStagingAreaImageToGallery(
- currentStagingAreaImage.image.image_url
- )
- )
+ dispatch(stagingAreaImageSaved(currentStagingAreaImage.image))
}
colorScheme="accent"
- /> */}
+ />
(
+ 'canvas/stagingAreaImageSaved'
+);
diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts
index 0ebe5b264c..7f41066ba1 100644
--- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts
+++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts
@@ -29,6 +29,7 @@ import {
isCanvasMaskLine,
} from './canvasTypes';
import { ImageDTO } from 'services/api';
+import { sessionCanceled } from 'services/thunks/session';
export const initialLayerState: CanvasLayerState = {
objects: [],
@@ -696,7 +697,10 @@ export const canvasSlice = createSlice({
0
);
},
- commitStagingAreaImage: (state) => {
+ commitStagingAreaImage: (
+ state,
+ action: PayloadAction
+ ) => {
if (!state.layerState.stagingArea.images.length) {
return;
}
@@ -841,6 +845,13 @@ export const canvasSlice = createSlice({
state.isTransformingBoundingBox = false;
},
},
+ extraReducers: (builder) => {
+ builder.addCase(sessionCanceled.pending, (state) => {
+ if (!state.layerState.stagingArea.images.length) {
+ state.layerState.stagingArea = initialLayerState.stagingArea;
+ }
+ });
+ },
});
export const {
diff --git a/invokeai/frontend/web/src/features/canvas/util/createMaskStage.ts b/invokeai/frontend/web/src/features/canvas/util/createMaskStage.ts
index 96ac592711..b417b3a786 100644
--- a/invokeai/frontend/web/src/features/canvas/util/createMaskStage.ts
+++ b/invokeai/frontend/web/src/features/canvas/util/createMaskStage.ts
@@ -9,7 +9,8 @@ import { IRect } from 'konva/lib/types';
*/
const createMaskStage = async (
lines: CanvasMaskLine[],
- boundingBox: IRect
+ boundingBox: IRect,
+ shouldInvertMask: boolean
): Promise => {
// create an offscreen canvas and add the mask to it
const { width, height } = boundingBox;
@@ -29,7 +30,7 @@ const createMaskStage = async (
baseLayer.add(
new Konva.Rect({
...boundingBox,
- fill: 'white',
+ fill: shouldInvertMask ? 'black' : 'white',
})
);
@@ -37,7 +38,7 @@ const createMaskStage = async (
maskLayer.add(
new Konva.Line({
points: line.points,
- stroke: 'black',
+ stroke: shouldInvertMask ? 'white' : 'black',
strokeWidth: line.strokeWidth * 2,
tension: 0,
lineCap: 'round',
diff --git a/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts b/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts
index 21a33aa349..d0190878e2 100644
--- a/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts
+++ b/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts
@@ -25,6 +25,7 @@ export const getCanvasData = async (state: RootState) => {
boundingBoxCoordinates,
boundingBoxDimensions,
isMaskEnabled,
+ shouldPreserveMaskedArea,
} = state.canvas;
const boundingBox = {
@@ -58,7 +59,8 @@ export const getCanvasData = async (state: RootState) => {
// For the mask layer, use the normal boundingBox
const maskStage = await createMaskStage(
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled
- boundingBox
+ boundingBox,
+ shouldPreserveMaskedArea
);
const maskBlob = await konvaNodeToBlob(maskStage, boundingBox);
const maskImageData = await konvaNodeToImageData(maskStage, boundingBox);
diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx
index c19a404a37..91bd1a0425 100644
--- a/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx
@@ -1,5 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
-import { isEqual, isString } from 'lodash-es';
+import { isEqual } from 'lodash-es';
import {
ButtonGroup,
@@ -25,8 +25,8 @@ import {
} from 'features/ui/store/uiSelectors';
import {
setActiveTab,
- setShouldHidePreview,
setShouldShowImageDetails,
+ setShouldShowProgressInViewer,
} from 'features/ui/store/uiSlice';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
@@ -37,23 +37,19 @@ import {
FaDownload,
FaExpand,
FaExpandArrowsAlt,
- FaEye,
- FaEyeSlash,
FaGrinStars,
+ FaHourglassHalf,
FaQuoteRight,
FaSeedling,
FaShare,
FaShareAlt,
- FaTrash,
- FaWrench,
} from 'react-icons/fa';
import { gallerySelector } from '../store/gallerySelectors';
-import DeleteImageModal from './DeleteImageModal';
import { useCallback } from 'react';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { useGetUrl } from 'common/util/getUrl';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
-import { useParameters } from 'features/parameters/hooks/useParameters';
+import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import {
requestedImageDeletion,
@@ -62,7 +58,6 @@ import {
} from '../store/actions';
import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceRestore/FaceRestoreSettings';
import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings';
-import { allParametersSet } from 'features/parameters/store/generationSlice';
import DeleteImageButton from './ImageActionButtons/DeleteImageButton';
import { useAppToaster } from 'app/components/Toaster';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
@@ -90,7 +85,11 @@ const currentImageButtonsSelector = createSelector(
const { isLightboxOpen } = lightbox;
- const { shouldShowImageDetails, shouldHidePreview } = ui;
+ const {
+ shouldShowImageDetails,
+ shouldHidePreview,
+ shouldShowProgressInViewer,
+ } = ui;
const { selectedImage } = gallery;
@@ -112,6 +111,7 @@ const currentImageButtonsSelector = createSelector(
seed: selectedImage?.metadata?.seed,
prompt: selectedImage?.metadata?.positive_conditioning,
negativePrompt: selectedImage?.metadata?.negative_conditioning,
+ shouldShowProgressInViewer,
};
},
{
@@ -145,6 +145,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
image,
canDeleteImage,
shouldConfirmOnDelete,
+ shouldShowProgressInViewer,
} = useAppSelector(currentImageButtonsSelector);
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
@@ -163,7 +164,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const toaster = useAppToaster();
const { t } = useTranslation();
- const { recallPrompt, recallSeed, recallAllParameters } = useParameters();
+ const { recallBothPrompts, recallSeed, recallAllParameters } =
+ useRecallParameters();
// const handleCopyImage = useCallback(async () => {
// if (!image?.url) {
@@ -229,10 +231,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}, [toaster, shouldTransformUrls, getUrl, t, image]);
- const handlePreviewVisibility = useCallback(() => {
- dispatch(setShouldHidePreview(!shouldHidePreview));
- }, [dispatch, shouldHidePreview]);
-
const handleClickUseAllParameters = useCallback(() => {
recallAllParameters(image);
}, [image, recallAllParameters]);
@@ -252,11 +250,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys('s', handleUseSeed, [image]);
const handleUsePrompt = useCallback(() => {
- recallPrompt(
+ recallBothPrompts(
image?.metadata?.positive_conditioning,
image?.metadata?.negative_conditioning
);
- }, [image, recallPrompt]);
+ }, [image, recallBothPrompts]);
useHotkeys('p', handleUsePrompt, [image]);
@@ -386,6 +384,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}
}, [shouldConfirmOnDelete, onDeleteDialogOpen, handleDelete]);
+ const handleClickProgressImagesToggle = useCallback(() => {
+ dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer));
+ }, [dispatch, shouldShowProgressInViewer]);
+
useHotkeys('delete', handleInitiateDelete, [
image,
shouldConfirmOnDelete,
@@ -412,8 +414,9 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}
/>
}
@@ -458,28 +461,17 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
{t('parameters.copyImageToLink')}
-
+
} size="sm" w="100%">
{t('parameters.downloadImage')}
- {/* : }
- tooltip={
- !shouldHidePreview
- ? t('parameters.hidePreview')
- : t('parameters.showPreview')
- }
- aria-label={
- !shouldHidePreview
- ? t('parameters.hidePreview')
- : t('parameters.showPreview')
- }
- isChecked={shouldHidePreview}
- onClick={handlePreviewVisibility}
- /> */}
{isLightboxEnabled && (
}
@@ -604,6 +596,16 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
/>
+
+ }
+ isChecked={shouldShowProgressInViewer}
+ onClick={handleClickProgressImagesToggle}
+ />
+
+
diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx
index 4562e3458d..280d859b87 100644
--- a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx
@@ -62,7 +62,6 @@ const CurrentImagePreview = () => {
return;
}
e.dataTransfer.setData('invokeai/imageName', image.image_name);
- e.dataTransfer.setData('invokeai/imageType', image.image_type);
e.dataTransfer.effectAllowed = 'move';
},
[image]
diff --git a/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx b/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx
index ed427f4984..f652cebda2 100644
--- a/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx
@@ -30,7 +30,7 @@ import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
-import { useParameters } from 'features/parameters/hooks/useParameters';
+import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import {
requestedImageDeletion,
@@ -114,8 +114,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
- const { recallSeed, recallPrompt, recallInitialImage, recallAllParameters } =
- useParameters();
+ const { recallBothPrompts, recallSeed, recallAllParameters } =
+ useRecallParameters();
const handleMouseOver = () => setIsHovered(true);
const handleMouseOut = () => setIsHovered(false);
@@ -147,7 +147,6 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const handleDragStart = useCallback(
(e: DragEvent) => {
e.dataTransfer.setData('invokeai/imageName', image.image_name);
- e.dataTransfer.setData('invokeai/imageType', image.image_type);
e.dataTransfer.effectAllowed = 'move';
},
[image]
@@ -155,11 +154,15 @@ const HoverableImage = memo((props: HoverableImageProps) => {
// Recall parameters handlers
const handleRecallPrompt = useCallback(() => {
- recallPrompt(
+ recallBothPrompts(
image.metadata?.positive_conditioning,
image.metadata?.negative_conditioning
);
- }, [image, recallPrompt]);
+ }, [
+ image.metadata?.negative_conditioning,
+ image.metadata?.positive_conditioning,
+ recallBothPrompts,
+ ]);
const handleRecallSeed = useCallback(() => {
recallSeed(image.metadata?.seed);
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx
index 468dfd694f..77f42a11a6 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx
@@ -16,7 +16,6 @@ import IAIPopover from 'common/components/IAIPopover';
import IAISlider from 'common/components/IAISlider';
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
import {
- setCurrentCategory,
setGalleryImageMinimumWidth,
setGalleryImageObjectFit,
setShouldAutoSwitchToNewImages,
@@ -31,59 +30,46 @@ import {
memo,
useCallback,
useEffect,
+ useMemo,
useRef,
useState,
} from 'react';
import { useTranslation } from 'react-i18next';
import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs';
-import { FaImage, FaUser, FaWrench } from 'react-icons/fa';
+import { FaImage, FaServer, FaWrench } from 'react-icons/fa';
import { MdPhotoLibrary } from 'react-icons/md';
import HoverableImage from './HoverableImage';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
-import { resultsAdapter } from '../store/resultsSlice';
-import {
- receivedResultImagesPage,
- receivedUploadImagesPage,
-} from 'services/thunks/gallery';
-import { uploadsAdapter } from '../store/uploadsSlice';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
-import GalleryProgressImage from './GalleryProgressImage';
import { uiSelector } from 'features/ui/store/uiSelectors';
-import { ImageDTO } from 'services/api';
-
-const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
-const PROGRESS_IMAGE_PLACEHOLDER = 'PROGRESS_IMAGE_PLACEHOLDER';
+import {
+ ASSETS_CATEGORIES,
+ IMAGE_CATEGORIES,
+ imageCategoriesChanged,
+ selectImagesAll,
+} from '../store/imagesSlice';
+import { receivedPageOfImages } from 'services/thunks/image';
const categorySelector = createSelector(
[(state: RootState) => state],
(state) => {
- const { results, uploads, system, gallery } = state;
- const { currentCategory } = gallery;
+ const { images } = state;
+ const { categories } = images;
- if (currentCategory === 'results') {
- const tempImages: (ImageDTO | typeof PROGRESS_IMAGE_PLACEHOLDER)[] = [];
-
- if (system.progressImage) {
- tempImages.push(PROGRESS_IMAGE_PLACEHOLDER);
- }
-
- return {
- images: tempImages.concat(
- resultsAdapter.getSelectors().selectAll(results)
- ),
- isLoading: results.isLoading,
- areMoreImagesAvailable: results.page < results.pages - 1,
- };
- }
+ const allImages = selectImagesAll(state);
+ const filteredImages = allImages.filter((i) =>
+ categories.includes(i.image_category)
+ );
return {
- images: uploadsAdapter.getSelectors().selectAll(uploads),
- isLoading: uploads.isLoading,
- areMoreImagesAvailable: uploads.page < uploads.pages - 1,
+ images: filteredImages,
+ isLoading: images.isLoading,
+ areMoreImagesAvailable: filteredImages.length < images.total,
+ categories: images.categories,
};
},
defaultSelectorOptions
@@ -93,7 +79,6 @@ const mainSelector = createSelector(
[gallerySelector, uiSelector],
(gallery, ui) => {
const {
- currentCategory,
galleryImageMinimumWidth,
galleryImageObjectFit,
shouldAutoSwitchToNewImages,
@@ -104,7 +89,6 @@ const mainSelector = createSelector(
const { shouldPinGallery } = ui;
return {
- currentCategory,
shouldPinGallery,
galleryImageMinimumWidth,
galleryImageObjectFit,
@@ -120,7 +104,6 @@ const ImageGalleryContent = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const resizeObserverRef = useRef(null);
- const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true);
const rootRef = useRef(null);
const [scroller, setScroller] = useState(null);
const [initialize, osInstance] = useOverlayScrollbars({
@@ -137,7 +120,6 @@ const ImageGalleryContent = () => {
});
const {
- currentCategory,
shouldPinGallery,
galleryImageMinimumWidth,
galleryImageObjectFit,
@@ -146,18 +128,19 @@ const ImageGalleryContent = () => {
selectedImage,
} = useAppSelector(mainSelector);
- const { images, areMoreImagesAvailable, isLoading } =
+ const { images, areMoreImagesAvailable, isLoading, categories } =
useAppSelector(categorySelector);
- const handleClickLoadMore = () => {
- if (currentCategory === 'results') {
- dispatch(receivedResultImagesPage());
- }
+ const handleLoadMoreImages = useCallback(() => {
+ dispatch(receivedPageOfImages());
+ }, [dispatch]);
- if (currentCategory === 'uploads') {
- dispatch(receivedUploadImagesPage());
+ const handleEndReached = useMemo(() => {
+ if (areMoreImagesAvailable && !isLoading) {
+ return handleLoadMoreImages;
}
- };
+ return undefined;
+ }, [areMoreImagesAvailable, handleLoadMoreImages, isLoading]);
const handleChangeGalleryImageMinimumWidth = (v: number) => {
dispatch(setGalleryImageMinimumWidth(v));
@@ -168,28 +151,6 @@ const ImageGalleryContent = () => {
dispatch(requestCanvasRescale());
};
- useEffect(() => {
- if (!resizeObserverRef.current) {
- return;
- }
- const resizeObserver = new ResizeObserver(() => {
- if (!resizeObserverRef.current) {
- return;
- }
-
- if (
- resizeObserverRef.current.clientWidth < GALLERY_SHOW_BUTTONS_MIN_WIDTH
- ) {
- setShouldShouldIconButtons(true);
- return;
- }
-
- setShouldShouldIconButtons(false);
- });
- resizeObserver.observe(resizeObserverRef.current);
- return () => resizeObserver.disconnect(); // clean up
- }, []);
-
useEffect(() => {
const { current: root } = rootRef;
if (scroller && root) {
@@ -209,13 +170,13 @@ const ImageGalleryContent = () => {
}
}, []);
- const handleEndReached = useCallback(() => {
- if (currentCategory === 'results') {
- dispatch(receivedResultImagesPage());
- } else if (currentCategory === 'uploads') {
- dispatch(receivedUploadImagesPage());
- }
- }, [dispatch, currentCategory]);
+ const handleClickImagesCategory = useCallback(() => {
+ dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
+ }, [dispatch]);
+
+ const handleClickAssetsCategory = useCallback(() => {
+ dispatch(imageCategoriesChanged(ASSETS_CATEGORIES));
+ }, [dispatch]);
return (
{
alignItems="center"
justifyContent="space-between"
>
-
- {shouldShouldIconButtons ? (
- <>
- }
- onClick={() => dispatch(setCurrentCategory('results'))}
- />
- }
- onClick={() => dispatch(setCurrentCategory('uploads'))}
- />
- >
- ) : (
- <>
- dispatch(setCurrentCategory('results'))}
- flexGrow={1}
- >
- {t('gallery.generations')}
-
- dispatch(setCurrentCategory('uploads'))}
- flexGrow={1}
- >
- {t('gallery.uploads')}
-
- >
- )}
+
+ }
+ />
+ }
+ />
-
}
/>
}
@@ -347,28 +280,17 @@ const ImageGalleryContent = () => {
data={images}
endReached={handleEndReached}
scrollerRef={(ref) => setScrollerRef(ref)}
- itemContent={(index, image) => {
- const isSelected =
- image === PROGRESS_IMAGE_PLACEHOLDER
- ? false
- : selectedImage?.image_name === image?.image_name;
-
- return (
-
- {image === PROGRESS_IMAGE_PLACEHOLDER ? (
-
- ) : (
-
- )}
-
- );
- }}
+ itemContent={(index, image) => (
+
+
+
+ )}
/>
) : (
{
List: ListContainer,
}}
scrollerRef={setScroller}
- itemContent={(index, image) => {
- const isSelected =
- image === PROGRESS_IMAGE_PLACEHOLDER
- ? false
- : selectedImage?.image_name === image?.image_name;
-
- return image === PROGRESS_IMAGE_PLACEHOLDER ? (
-
- ) : (
-
- );
- }}
+ itemContent={(index, image) => (
+
+ )}
/>
)}
{
const { t } = useTranslation();
+
+ if (!value) {
+ return null;
+ }
+
return (
{onClick && (
@@ -115,6 +121,21 @@ const memoEqualityCheck = (
*/
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch();
+ const {
+ recallBothPrompts,
+ recallPositivePrompt,
+ recallNegativePrompt,
+ recallSeed,
+ recallInitialImage,
+ recallCfgScale,
+ recallModel,
+ recallScheduler,
+ recallSteps,
+ recallWidth,
+ recallHeight,
+ recallStrength,
+ recallAllParameters,
+ } = useRecallParameters();
useHotkeys('esc', () => {
dispatch(setShouldShowImageDetails(false));
@@ -161,52 +182,53 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
{metadata.type && (
)}
- {metadata.width && (
- dispatch(setWidth(Number(metadata.width)))}
- />
- )}
- {metadata.height && (
- dispatch(setHeight(Number(metadata.height)))}
- />
- )}
- {metadata.model && (
-
- )}
+ {sessionId && }
{metadata.positive_conditioning && (
+ recallPositivePrompt(metadata.positive_conditioning)
}
- onClick={() => setPositivePrompt(metadata.positive_conditioning!)}
/>
)}
{metadata.negative_conditioning && (
+ recallNegativePrompt(metadata.negative_conditioning)
}
- onClick={() => setNegativePrompt(metadata.negative_conditioning!)}
/>
)}
{metadata.seed !== undefined && (
dispatch(setSeed(Number(metadata.seed)))}
+ onClick={() => recallSeed(metadata.seed)}
+ />
+ )}
+ {metadata.model !== undefined && (
+ recallModel(metadata.model)}
+ />
+ )}
+ {metadata.width && (
+ recallWidth(metadata.width)}
+ />
+ )}
+ {metadata.height && (
+ recallHeight(metadata.height)}
/>
)}
{/* {metadata.threshold !== undefined && (
@@ -227,23 +249,21 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
- dispatch(setScheduler(metadata.scheduler as Scheduler))
- }
+ onClick={() => recallScheduler(metadata.scheduler)}
/>
)}
{metadata.steps && (
dispatch(setSteps(Number(metadata.steps)))}
+ onClick={() => recallSteps(metadata.steps)}
/>
)}
{metadata.cfg_scale !== undefined && (
dispatch(setCfgScale(Number(metadata.cfg_scale)))}
+ onClick={() => recallCfgScale(metadata.cfg_scale)}
/>
)}
{/* {metadata.variations && metadata.variations.length > 0 && (
@@ -284,9 +304,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
- dispatch(setImg2imgStrength(Number(metadata.strength)))
- }
+ onClick={() => recallStrength(metadata.strength)}
/>
)}
{/* {metadata.fit && (
diff --git a/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx
index fcf8359187..82e7a0d623 100644
--- a/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx
@@ -9,6 +9,10 @@ import { gallerySelector } from '../store/gallerySelectors';
import { RootState } from 'app/store/store';
import { imageSelected } from '../store/gallerySlice';
import { useHotkeys } from 'react-hotkeys-hook';
+import {
+ selectFilteredImagesAsObject,
+ selectFilteredImagesIds,
+} from '../store/imagesSlice';
const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = {
height: '100%',
@@ -21,9 +25,14 @@ const nextPrevButtonStyles: ChakraProps['sx'] = {
};
export const nextPrevImageButtonsSelector = createSelector(
- [(state: RootState) => state, gallerySelector],
- (state, gallery) => {
- const { selectedImage, currentCategory } = gallery;
+ [
+ (state: RootState) => state,
+ gallerySelector,
+ selectFilteredImagesAsObject,
+ selectFilteredImagesIds,
+ ],
+ (state, gallery, filteredImagesAsObject, filteredImageIds) => {
+ const { selectedImage } = gallery;
if (!selectedImage) {
return {
@@ -32,29 +41,29 @@ export const nextPrevImageButtonsSelector = createSelector(
};
}
- const currentImageIndex = state[currentCategory].ids.findIndex(
+ const currentImageIndex = filteredImageIds.findIndex(
(i) => i === selectedImage.image_name
);
const nextImageIndex = clamp(
currentImageIndex + 1,
0,
- state[currentCategory].ids.length - 1
+ filteredImageIds.length - 1
);
const prevImageIndex = clamp(
currentImageIndex - 1,
0,
- state[currentCategory].ids.length - 1
+ filteredImageIds.length - 1
);
- const nextImageId = state[currentCategory].ids[nextImageIndex];
- const prevImageId = state[currentCategory].ids[prevImageIndex];
+ const nextImageId = filteredImageIds[nextImageIndex];
+ const prevImageId = filteredImageIds[prevImageIndex];
- const nextImage = state[currentCategory].entities[nextImageId];
- const prevImage = state[currentCategory].entities[prevImageId];
+ const nextImage = filteredImagesAsObject[nextImageId];
+ const prevImage = filteredImagesAsObject[prevImageId];
- const imagesLength = state[currentCategory].ids.length;
+ const imagesLength = filteredImageIds.length;
return {
isOnFirstImage: currentImageIndex === 0,
diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts b/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts
index ad0870e7a4..89709b322a 100644
--- a/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts
+++ b/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts
@@ -1,33 +1,18 @@
-import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
-import { ImageType } from 'services/api';
-import { selectResultsEntities } from '../store/resultsSlice';
-import { selectUploadsEntities } from '../store/uploadsSlice';
+import { selectImagesEntities } from '../store/imagesSlice';
+import { useCallback } from 'react';
-const useGetImageByNameSelector = createSelector(
- [selectResultsEntities, selectUploadsEntities],
- (allResults, allUploads) => {
- return { allResults, allUploads };
- }
-);
-
-const useGetImageByNameAndType = () => {
- const { allResults, allUploads } = useAppSelector(useGetImageByNameSelector);
- return (name: string, type: ImageType) => {
- if (type === 'results') {
- const resultImagesResult = allResults[name];
- if (resultImagesResult) {
- return resultImagesResult;
+const useGetImageByName = () => {
+ const images = useAppSelector(selectImagesEntities);
+ return useCallback(
+ (name: string | undefined) => {
+ if (!name) {
+ return;
}
- }
-
- if (type === 'uploads') {
- const userImagesResult = allUploads[name];
- if (userImagesResult) {
- return userImagesResult;
- }
- }
- };
+ return images[name];
+ },
+ [images]
+ );
};
-export default useGetImageByNameAndType;
+export default useGetImageByName;
diff --git a/invokeai/frontend/web/src/features/gallery/store/actions.ts b/invokeai/frontend/web/src/features/gallery/store/actions.ts
index 7e071f279d..7c00201da9 100644
--- a/invokeai/frontend/web/src/features/gallery/store/actions.ts
+++ b/invokeai/frontend/web/src/features/gallery/store/actions.ts
@@ -1,9 +1,9 @@
import { createAction } from '@reduxjs/toolkit';
-import { ImageNameAndType } from 'features/parameters/store/actions';
+import { ImageNameAndOrigin } from 'features/parameters/store/actions';
import { ImageDTO } from 'services/api';
export const requestedImageDeletion = createAction<
- ImageDTO | ImageNameAndType | undefined
+ ImageDTO | ImageNameAndOrigin | undefined
>('gallery/requestedImageDeletion');
export const sentImageToCanvas = createAction('gallery/sentImageToCanvas');
diff --git a/invokeai/frontend/web/src/features/gallery/store/galleryPersistDenylist.ts b/invokeai/frontend/web/src/features/gallery/store/galleryPersistDenylist.ts
index 49f51d5a80..44e03f9f71 100644
--- a/invokeai/frontend/web/src/features/gallery/store/galleryPersistDenylist.ts
+++ b/invokeai/frontend/web/src/features/gallery/store/galleryPersistDenylist.ts
@@ -4,6 +4,5 @@ import { GalleryState } from './gallerySlice';
* Gallery slice persist denylist
*/
export const galleryPersistDenylist: (keyof GalleryState)[] = [
- 'currentCategory',
'shouldAutoSwitchToNewImages',
];
diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts
index 9d6f5ece60..ab62646c0f 100644
--- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts
+++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts
@@ -1,10 +1,7 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
-import {
- receivedResultImagesPage,
- receivedUploadImagesPage,
-} from '../../../services/thunks/gallery';
import { ImageDTO } from 'services/api';
+import { imageUpserted } from './imagesSlice';
type GalleryImageObjectFitType = 'contain' | 'cover';
@@ -14,7 +11,6 @@ export interface GalleryState {
galleryImageObjectFit: GalleryImageObjectFitType;
shouldAutoSwitchToNewImages: boolean;
shouldUseSingleGalleryColumn: boolean;
- currentCategory: 'results' | 'uploads';
}
export const initialGalleryState: GalleryState = {
@@ -22,7 +18,6 @@ export const initialGalleryState: GalleryState = {
galleryImageObjectFit: 'cover',
shouldAutoSwitchToNewImages: true,
shouldUseSingleGalleryColumn: false,
- currentCategory: 'results',
};
export const gallerySlice = createSlice({
@@ -46,12 +41,6 @@ export const gallerySlice = createSlice({
setShouldAutoSwitchToNewImages: (state, action: PayloadAction) => {
state.shouldAutoSwitchToNewImages = action.payload;
},
- setCurrentCategory: (
- state,
- action: PayloadAction<'results' | 'uploads'>
- ) => {
- state.currentCategory = action.payload;
- },
setShouldUseSingleGalleryColumn: (
state,
action: PayloadAction
@@ -59,37 +48,10 @@ export const gallerySlice = createSlice({
state.shouldUseSingleGalleryColumn = action.payload;
},
},
- extraReducers(builder) {
- builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
- // rehydrate selectedImage URL when results list comes in
- // solves case when outdated URL is in local storage
- const selectedImage = state.selectedImage;
- if (selectedImage) {
- const selectedImageInResults = action.payload.items.find(
- (image) => image.image_name === selectedImage.image_name
- );
-
- if (selectedImageInResults) {
- selectedImage.image_url = selectedImageInResults.image_url;
- selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url;
- state.selectedImage = selectedImage;
- }
- }
- });
- builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
- // rehydrate selectedImage URL when results list comes in
- // solves case when outdated URL is in local storage
- const selectedImage = state.selectedImage;
- if (selectedImage) {
- const selectedImageInResults = action.payload.items.find(
- (image) => image.image_name === selectedImage.image_name
- );
-
- if (selectedImageInResults) {
- selectedImage.image_url = selectedImageInResults.image_url;
- selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url;
- state.selectedImage = selectedImage;
- }
+ extraReducers: (builder) => {
+ builder.addCase(imageUpserted, (state, action) => {
+ if (state.shouldAutoSwitchToNewImages) {
+ state.selectedImage = action.payload;
}
});
},
@@ -101,7 +63,6 @@ export const {
setGalleryImageObjectFit,
setShouldAutoSwitchToNewImages,
setShouldUseSingleGalleryColumn,
- setCurrentCategory,
} = gallerySlice.actions;
export default gallerySlice.reducer;
diff --git a/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts b/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts
new file mode 100644
index 0000000000..cb6469aeb4
--- /dev/null
+++ b/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts
@@ -0,0 +1,135 @@
+import {
+ PayloadAction,
+ createEntityAdapter,
+ createSelector,
+ createSlice,
+} from '@reduxjs/toolkit';
+import { RootState } from 'app/store/store';
+import { ImageCategory, ImageDTO } from 'services/api';
+import { dateComparator } from 'common/util/dateComparator';
+import { isString, keyBy } from 'lodash-es';
+import { receivedPageOfImages } from 'services/thunks/image';
+
+export const imagesAdapter = createEntityAdapter({
+ selectId: (image) => image.image_name,
+ sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
+});
+
+export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
+export const ASSETS_CATEGORIES: ImageCategory[] = [
+ 'control',
+ 'mask',
+ 'user',
+ 'other',
+];
+
+type AdditionaImagesState = {
+ offset: number;
+ limit: number;
+ total: number;
+ isLoading: boolean;
+ categories: ImageCategory[];
+};
+
+export const initialImagesState =
+ imagesAdapter.getInitialState({
+ offset: 0,
+ limit: 0,
+ total: 0,
+ isLoading: false,
+ categories: IMAGE_CATEGORIES,
+ });
+
+export type ImagesState = typeof initialImagesState;
+
+const imagesSlice = createSlice({
+ name: 'images',
+ initialState: initialImagesState,
+ reducers: {
+ imageUpserted: (state, action: PayloadAction) => {
+ imagesAdapter.upsertOne(state, action.payload);
+ },
+ imageRemoved: (state, action: PayloadAction) => {
+ if (isString(action.payload)) {
+ imagesAdapter.removeOne(state, action.payload);
+ return;
+ }
+
+ imagesAdapter.removeOne(state, action.payload.image_name);
+ },
+ imageCategoriesChanged: (state, action: PayloadAction) => {
+ state.categories = action.payload;
+ },
+ },
+ extraReducers: (builder) => {
+ builder.addCase(receivedPageOfImages.pending, (state) => {
+ state.isLoading = true;
+ });
+ builder.addCase(receivedPageOfImages.rejected, (state) => {
+ state.isLoading = false;
+ });
+ builder.addCase(receivedPageOfImages.fulfilled, (state, action) => {
+ state.isLoading = false;
+ const { items, offset, limit, total } = action.payload;
+ state.offset = offset;
+ state.limit = limit;
+ state.total = total;
+ imagesAdapter.upsertMany(state, items);
+ });
+ },
+});
+
+export const {
+ selectAll: selectImagesAll,
+ selectById: selectImagesById,
+ selectEntities: selectImagesEntities,
+ selectIds: selectImagesIds,
+ selectTotal: selectImagesTotal,
+} = imagesAdapter.getSelectors((state) => state.images);
+
+export const { imageUpserted, imageRemoved, imageCategoriesChanged } =
+ imagesSlice.actions;
+
+export default imagesSlice.reducer;
+
+export const selectFilteredImagesAsArray = createSelector(
+ (state: RootState) => state,
+ (state) => {
+ const {
+ images: { categories },
+ } = state;
+
+ return selectImagesAll(state).filter((i) =>
+ categories.includes(i.image_category)
+ );
+ }
+);
+
+export const selectFilteredImagesAsObject = createSelector(
+ (state: RootState) => state,
+ (state) => {
+ const {
+ images: { categories },
+ } = state;
+
+ return keyBy(
+ selectImagesAll(state).filter((i) =>
+ categories.includes(i.image_category)
+ ),
+ 'image_name'
+ );
+ }
+);
+
+export const selectFilteredImagesIds = createSelector(
+ (state: RootState) => state,
+ (state) => {
+ const {
+ images: { categories },
+ } = state;
+
+ return selectImagesAll(state)
+ .filter((i) => categories.includes(i.image_category))
+ .map((i) => i.image_name);
+ }
+);
diff --git a/invokeai/frontend/web/src/features/gallery/store/resultsPersistDenylist.ts b/invokeai/frontend/web/src/features/gallery/store/resultsPersistDenylist.ts
deleted file mode 100644
index 1c3d8aaaec..0000000000
--- a/invokeai/frontend/web/src/features/gallery/store/resultsPersistDenylist.ts
+++ /dev/null
@@ -1,8 +0,0 @@
-import { ResultsState } from './resultsSlice';
-
-/**
- * Results slice persist denylist
- *
- * Currently denylisting results slice entirely, see `serialize.ts`
- */
-export const resultsPersistDenylist: (keyof ResultsState)[] = [];
diff --git a/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts b/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts
deleted file mode 100644
index 125f4ff5d5..0000000000
--- a/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts
+++ /dev/null
@@ -1,125 +0,0 @@
-import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
-import { RootState } from 'app/store/store';
-import {
- receivedResultImagesPage,
- IMAGES_PER_PAGE,
-} from 'services/thunks/gallery';
-import {
- imageDeleted,
- imageMetadataReceived,
- imageUrlsReceived,
-} from 'services/thunks/image';
-import { ImageDTO } from 'services/api';
-import { dateComparator } from 'common/util/dateComparator';
-
-export type ResultsImageDTO = Omit & {
- image_type: 'results';
-};
-
-export const resultsAdapter = createEntityAdapter({
- selectId: (image) => image.image_name,
- sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
-});
-
-type AdditionalResultsState = {
- page: number;
- pages: number;
- isLoading: boolean;
- nextPage: number;
-};
-
-export const initialResultsState =
- resultsAdapter.getInitialState({
- page: 0,
- pages: 0,
- isLoading: false,
- nextPage: 0,
- });
-
-export type ResultsState = typeof initialResultsState;
-
-const resultsSlice = createSlice({
- name: 'results',
- initialState: initialResultsState,
- reducers: {
- resultAdded: resultsAdapter.upsertOne,
- },
- extraReducers: (builder) => {
- /**
- * Received Result Images Page - PENDING
- */
- builder.addCase(receivedResultImagesPage.pending, (state) => {
- state.isLoading = true;
- });
-
- /**
- * Received Result Images Page - FULFILLED
- */
- builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
- const { page, pages } = action.payload;
-
- // We know these will all be of the results type, but it's not represented in the API types
- const items = action.payload.items as ResultsImageDTO[];
-
- resultsAdapter.setMany(state, items);
-
- state.page = page;
- state.pages = pages;
- state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
- state.isLoading = false;
- });
-
- /**
- * Image Metadata Received - FULFILLED
- */
- builder.addCase(imageMetadataReceived.fulfilled, (state, action) => {
- const { image_type } = action.payload;
-
- if (image_type === 'results') {
- resultsAdapter.upsertOne(state, action.payload as ResultsImageDTO);
- }
- });
-
- /**
- * Image URLs Received - FULFILLED
- */
- builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
- const { image_name, image_type, image_url, thumbnail_url } =
- action.payload;
-
- if (image_type === 'results') {
- resultsAdapter.updateOne(state, {
- id: image_name,
- changes: {
- image_url: image_url,
- thumbnail_url: thumbnail_url,
- },
- });
- }
- });
-
- /**
- * Delete Image - PENDING
- * Pre-emptively remove the image from the gallery
- */
- builder.addCase(imageDeleted.pending, (state, action) => {
- const { imageType, imageName } = action.meta.arg;
-
- if (imageType === 'results') {
- resultsAdapter.removeOne(state, imageName);
- }
- });
- },
-});
-
-export const {
- selectAll: selectResultsAll,
- selectById: selectResultsById,
- selectEntities: selectResultsEntities,
- selectIds: selectResultsIds,
- selectTotal: selectResultsTotal,
-} = resultsAdapter.getSelectors((state) => state.results);
-
-export const { resultAdded } = resultsSlice.actions;
-
-export default resultsSlice.reducer;
diff --git a/invokeai/frontend/web/src/features/gallery/store/uploadsPersistDenylist.ts b/invokeai/frontend/web/src/features/gallery/store/uploadsPersistDenylist.ts
deleted file mode 100644
index 296e8b2057..0000000000
--- a/invokeai/frontend/web/src/features/gallery/store/uploadsPersistDenylist.ts
+++ /dev/null
@@ -1,8 +0,0 @@
-import { UploadsState } from './uploadsSlice';
-
-/**
- * Uploads slice persist denylist
- *
- * Currently denylisting uploads slice entirely, see `serialize.ts`
- */
-export const uploadsPersistDenylist: (keyof UploadsState)[] = [];
diff --git a/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts b/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts
deleted file mode 100644
index 5e458503ec..0000000000
--- a/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts
+++ /dev/null
@@ -1,111 +0,0 @@
-import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
-
-import { RootState } from 'app/store/store';
-import {
- receivedUploadImagesPage,
- IMAGES_PER_PAGE,
-} from 'services/thunks/gallery';
-import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
-import { ImageDTO } from 'services/api';
-import { dateComparator } from 'common/util/dateComparator';
-
-export type UploadsImageDTO = Omit & {
- image_type: 'uploads';
-};
-
-export const uploadsAdapter = createEntityAdapter({
- selectId: (image) => image.image_name,
- sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
-});
-
-type AdditionalUploadsState = {
- page: number;
- pages: number;
- isLoading: boolean;
- nextPage: number;
-};
-
-export const initialUploadsState =
- uploadsAdapter.getInitialState({
- page: 0,
- pages: 0,
- nextPage: 0,
- isLoading: false,
- });
-
-export type UploadsState = typeof initialUploadsState;
-
-const uploadsSlice = createSlice({
- name: 'uploads',
- initialState: initialUploadsState,
- reducers: {
- uploadAdded: uploadsAdapter.upsertOne,
- },
- extraReducers: (builder) => {
- /**
- * Received Upload Images Page - PENDING
- */
- builder.addCase(receivedUploadImagesPage.pending, (state) => {
- state.isLoading = true;
- });
-
- /**
- * Received Upload Images Page - FULFILLED
- */
- builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
- const { page, pages } = action.payload;
-
- // We know these will all be of the uploads type, but it's not represented in the API types
- const items = action.payload.items as UploadsImageDTO[];
-
- uploadsAdapter.setMany(state, items);
-
- state.page = page;
- state.pages = pages;
- state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
- state.isLoading = false;
- });
-
- /**
- * Image URLs Received - FULFILLED
- */
- builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
- const { image_name, image_type, image_url, thumbnail_url } =
- action.payload;
-
- if (image_type === 'uploads') {
- uploadsAdapter.updateOne(state, {
- id: image_name,
- changes: {
- image_url: image_url,
- thumbnail_url: thumbnail_url,
- },
- });
- }
- });
-
- /**
- * Delete Image - pending
- * Pre-emptively remove the image from the gallery
- */
- builder.addCase(imageDeleted.pending, (state, action) => {
- const { imageType, imageName } = action.meta.arg;
-
- if (imageType === 'uploads') {
- uploadsAdapter.removeOne(state, imageName);
- }
- });
- },
-});
-
-export const {
- selectAll: selectUploadsAll,
- selectById: selectUploadsById,
- selectEntities: selectUploadsEntities,
- selectIds: selectUploadsIds,
- selectTotal: selectUploadsTotal,
-} = uploadsAdapter.getSelectors((state) => state.uploads);
-
-export const { uploadAdded } = uploadsSlice.actions;
-
-export default uploadsSlice.reducer;
diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx
index 341ca19fa9..65b7cfa560 100644
--- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx
@@ -10,6 +10,7 @@ import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComp
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
+import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
import StringInputFieldComponent from './fields/StringInputFieldComponent';
@@ -130,6 +131,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
);
}
+ if (type === 'control' && template.type === 'control') {
+ return (
+
+ );
+ }
+
if (type === 'model' && template.type === 'model') {
return (
+) => {
+ const { nodeId, field } = props;
+
+ return null;
+};
+
+export default memo(ControlInputFieldComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx
index 18be021625..57cefb0a9c 100644
--- a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx
@@ -2,7 +2,7 @@ import { Box, Image } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import SelectImagePlaceholder from 'common/components/SelectImagePlaceholder';
import { useGetUrl } from 'common/util/getUrl';
-import useGetImageByNameAndType from 'features/gallery/hooks/useGetImageByName';
+import useGetImageByName from 'features/gallery/hooks/useGetImageByName';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
@@ -11,7 +11,6 @@ import {
} from 'features/nodes/types/types';
import { DragEvent, memo, useCallback, useState } from 'react';
-import { ImageType } from 'services/api';
import { FieldComponentProps } from './types';
const ImageInputFieldComponent = (
@@ -19,7 +18,7 @@ const ImageInputFieldComponent = (
) => {
const { nodeId, field } = props;
- const getImageByNameAndType = useGetImageByNameAndType();
+ const getImageByName = useGetImageByName();
const dispatch = useAppDispatch();
const [url, setUrl] = useState(field.value?.image_url);
const { getUrl } = useGetUrl();
@@ -27,13 +26,7 @@ const ImageInputFieldComponent = (
const handleDrop = useCallback(
(e: DragEvent) => {
const name = e.dataTransfer.getData('invokeai/imageName');
- const type = e.dataTransfer.getData('invokeai/imageType') as ImageType;
-
- if (!name || !type) {
- return;
- }
-
- const image = getImageByNameAndType(name, type);
+ const image = getImageByName(name);
if (!image) {
return;
@@ -49,7 +42,7 @@ const ImageInputFieldComponent = (
})
);
},
- [getImageByNameAndType, dispatch, field.name, nodeId]
+ [getImageByName, dispatch, field.name, nodeId]
);
return (
diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts
index 36c3514eeb..b8cd7efaa4 100644
--- a/invokeai/frontend/web/src/features/nodes/types/constants.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts
@@ -4,6 +4,7 @@ export const HANDLE_TOOLTIP_OPEN_DELAY = 500;
export const FIELD_TYPE_MAP: Record = {
integer: 'integer',
+ float: 'float',
number: 'float',
string: 'string',
boolean: 'boolean',
@@ -18,6 +19,8 @@ export const FIELD_TYPE_MAP: Record = {
array: 'array',
item: 'item',
ColorField: 'color',
+ ControlField: 'control',
+ control: 'control',
};
const COLOR_TOKEN_VALUE = 500;
@@ -25,6 +28,9 @@ const COLOR_TOKEN_VALUE = 500;
const getColorTokenCssVariable = (color: string) =>
`var(--invokeai-colors-${color}-${COLOR_TOKEN_VALUE})`;
+// @ts-ignore
+// @ts-ignore
+// @ts-ignore
export const FIELDS: Record = {
integer: {
color: 'red',
@@ -92,6 +98,12 @@ export const FIELDS: Record = {
title: 'Vae',
description: 'Vae submodel.',
},
+ control: {
+ color: 'cyan',
+ colorCssVar: getColorTokenCssVariable('cyan'), // TODO: no free color left
+ title: 'Control',
+ description: 'Control info passed between nodes.',
+ },
model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts
index 930b7dad57..5e140b6eef 100644
--- a/invokeai/frontend/web/src/features/nodes/types/types.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/types.ts
@@ -64,6 +64,7 @@ export type FieldType =
| 'unet'
| 'clip'
| 'vae'
+ | 'control'
| 'model'
| 'array'
| 'item'
@@ -88,6 +89,7 @@ export type InputFieldValue =
| UNetInputFieldValue
| ClipInputFieldValue
| VaeInputFieldValue
+ | ControlInputFieldValue
| EnumInputFieldValue
| ModelInputFieldValue
| ArrayInputFieldValue
@@ -111,6 +113,7 @@ export type InputFieldTemplate =
| UNetInputFieldTemplate
| ClipInputFieldTemplate
| VaeInputFieldTemplate
+ | ControlInputFieldTemplate
| EnumInputFieldTemplate
| ModelInputFieldTemplate
| ArrayInputFieldTemplate
@@ -186,6 +189,11 @@ export type LatentsInputFieldValue = FieldValueBase & {
export type ConditioningInputFieldValue = FieldValueBase & {
type: 'conditioning';
+ value?: string;
+};
+
+export type ControlInputFieldValue = FieldValueBase & {
+ type: 'control';
value?: undefined;
};
@@ -286,6 +294,11 @@ export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
type: 'conditioning';
};
+export type ControlInputFieldTemplate = InputFieldTemplateBase & {
+ default: undefined;
+ type: 'control';
+};
+
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
default: string | number;
type: 'enum';
diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts
index b275c84248..f1ad731d32 100644
--- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts
@@ -13,6 +13,7 @@ import {
UNetInputFieldTemplate,
ClipInputFieldTemplate,
VaeInputFieldTemplate,
+ ControlInputFieldTemplate,
StringInputFieldTemplate,
ModelInputFieldTemplate,
ArrayInputFieldTemplate,
@@ -263,6 +264,21 @@ const buildVaeInputFieldTemplate = ({
return template;
};
+const buildControlInputFieldTemplate = ({
+ schemaObject,
+ baseField,
+}: BuildInputFieldArg): ControlInputFieldTemplate => {
+ const template: ControlInputFieldTemplate = {
+ ...baseField,
+ type: 'control',
+ inputRequirement: 'always',
+ inputKind: 'connection',
+ default: schemaObject.default ?? undefined,
+ };
+
+ return template;
+};
+
const buildEnumInputFieldTemplate = ({
schemaObject,
baseField,
@@ -334,9 +350,20 @@ export const getFieldType = (
if (typeHints && name in typeHints) {
rawFieldType = typeHints[name];
} else if (!schemaObject.type) {
- rawFieldType = refObjectToFieldType(
- schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
- );
+ // if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
+ if (schemaObject.allOf) {
+ rawFieldType = refObjectToFieldType(
+ schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
+ );
+ } else if (schemaObject.anyOf) {
+ rawFieldType = refObjectToFieldType(
+ schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
+ );
+ } else if (schemaObject.oneOf) {
+ rawFieldType = refObjectToFieldType(
+ schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
+ );
+ }
} else if (schemaObject.enum) {
rawFieldType = 'enum';
} else if (schemaObject.type) {
@@ -388,6 +415,9 @@ export const buildInputFieldTemplate = (
if (['vae'].includes(fieldType)) {
return buildVaeInputFieldTemplate({ schemaObject, baseField });
}
+ if (['control'].includes(fieldType)) {
+ return buildControlInputFieldTemplate({ schemaObject, baseField });
+ }
if (['model'].includes(fieldType)) {
return buildModelInputFieldTemplate({ schemaObject, baseField });
}
diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts
index c0c19708c7..1703c45331 100644
--- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts
@@ -64,6 +64,10 @@ export const buildInputFieldValue = (
fieldValue.value = undefined;
}
+ if (template.type === 'control') {
+ fieldValue.value = undefined;
+ }
+
if (template.type === 'model') {
fieldValue.value = undefined;
}
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts
index 3615f7d298..2d23b882ea 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts
@@ -16,7 +16,7 @@ import { buildEdges } from '../edgeBuilders/buildEdges';
import { log } from 'app/logging/useLogger';
import { buildInpaintNode } from '../nodeBuilders/buildInpaintNode';
-const moduleLog = log.child({ namespace: 'buildCanvasGraph' });
+const moduleLog = log.child({ namespace: 'nodes' });
const buildBaseNode = (
nodeType: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
@@ -26,18 +26,21 @@ const buildBaseNode = (
| ImageToImageInvocation
| InpaintInvocation
| undefined => {
- const dimensionsOverride = state.canvas.boundingBoxDimensions;
+ const overrides = {
+ ...state.canvas.boundingBoxDimensions,
+ is_intermediate: true,
+ };
if (nodeType === 'txt2img') {
- return buildTxt2ImgNode(state, dimensionsOverride);
+ return buildTxt2ImgNode(state, overrides);
}
if (nodeType === 'img2img') {
- return buildImg2ImgNode(state, dimensionsOverride);
+ return buildImg2ImgNode(state, overrides);
}
if (nodeType === 'inpaint' || nodeType === 'outpaint') {
- return buildInpaintNode(state, dimensionsOverride);
+ return buildInpaintNode(state, overrides);
}
};
@@ -77,18 +80,23 @@ export const buildCanvasGraphComponents = async (
infillMethod,
} = state.generation;
- // generationParameters.invert_mask = shouldPreserveMaskedArea;
- // if (boundingBoxScale !== 'none') {
- // generationParameters.inpaint_width = scaledBoundingBoxDimensions.width;
- // generationParameters.inpaint_height = scaledBoundingBoxDimensions.height;
- // }
+ const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } =
+ state.canvas;
+
+ if (boundingBoxScaleMethod !== 'none') {
+ baseNode.inpaint_width = scaledBoundingBoxDimensions.width;
+ baseNode.inpaint_height = scaledBoundingBoxDimensions.height;
+ }
+
baseNode.seam_size = seamSize;
baseNode.seam_blur = seamBlur;
baseNode.seam_strength = seamStrength;
baseNode.seam_steps = seamSteps;
- baseNode.tile_size = tileSize;
baseNode.infill_method = infillMethod as InpaintInvocation['infill_method'];
- // baseNode.force_outpaint = false;
+
+ if (infillMethod === 'tile') {
+ baseNode.tile_size = tileSize;
+ }
}
// We always range and iterate nodes, no matter the iteration count
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts
index d9eb80d654..fe4f6c63b5 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts
@@ -2,21 +2,31 @@ import { RootState } from 'app/store/store';
import {
CompelInvocation,
Graph,
+ ImageResizeInvocation,
ImageToLatentsInvocation,
+ IterateInvocation,
LatentsToImageInvocation,
LatentsToLatentsInvocation,
+ NoiseInvocation,
+ RandomIntInvocation,
+ RangeOfSizeInvocation,
} from 'services/api';
import { NonNullableGraph } from 'features/nodes/types/types';
-import { addNoiseNodes } from '../nodeBuilders/addNoiseNodes';
import { log } from 'app/logging/useLogger';
+import { set } from 'lodash-es';
-const moduleLog = log.child({ namespace: 'buildImageToImageGraph' });
+const moduleLog = log.child({ namespace: 'nodes' });
const POSITIVE_CONDITIONING = 'positive_conditioning';
const NEGATIVE_CONDITIONING = 'negative_conditioning';
const IMAGE_TO_LATENTS = 'image_to_latents';
const LATENTS_TO_LATENTS = 'latents_to_latents';
const LATENTS_TO_IMAGE = 'latents_to_image';
+const RESIZE = 'resize_image';
+const NOISE = 'noise';
+const RANDOM_INT = 'rand_int';
+const RANGE_OF_SIZE = 'range_of_size';
+const ITERATE = 'iterate';
/**
* Builds the Image to Image tab graph.
@@ -31,6 +41,12 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
steps,
initialImage,
img2imgStrength: strength,
+ shouldFitToWidthHeight,
+ width,
+ height,
+ iterations,
+ seed,
+ shouldRandomizeSeed,
} = state.generation;
if (!initialImage) {
@@ -38,12 +54,12 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
throw new Error('No initial image found in state');
}
- let graph: NonNullableGraph = {
+ const graph: NonNullableGraph = {
nodes: {},
edges: [],
};
- // Create the conditioning, t2l and l2i nodes
+ // Create the positive conditioning (prompt) node
const positiveConditioningNode: CompelInvocation = {
id: POSITIVE_CONDITIONING,
type: 'compel',
@@ -51,6 +67,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
model,
};
+ // Negative conditioning
const negativeConditioningNode: CompelInvocation = {
id: NEGATIVE_CONDITIONING,
type: 'compel',
@@ -58,16 +75,15 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
model,
};
+ // This will encode the raster image to latents - but it may get its `image` from a resize node,
+ // so we do not set its `image` property yet
const imageToLatentsNode: ImageToLatentsInvocation = {
id: IMAGE_TO_LATENTS,
type: 'i2l',
model,
- image: {
- image_name: initialImage?.image_name,
- image_type: initialImage?.image_type,
- },
};
+ // This does the actual img2img inference
const latentsToLatentsNode: LatentsToLatentsInvocation = {
id: LATENTS_TO_LATENTS,
type: 'l2l',
@@ -78,20 +94,21 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
strength,
};
+ // Finally we decode the latents back to an image
const latentsToImageNode: LatentsToImageInvocation = {
id: LATENTS_TO_IMAGE,
type: 'l2i',
model,
};
- // Add to the graph
+ // Add all those nodes to the graph
graph.nodes[POSITIVE_CONDITIONING] = positiveConditioningNode;
graph.nodes[NEGATIVE_CONDITIONING] = negativeConditioningNode;
graph.nodes[IMAGE_TO_LATENTS] = imageToLatentsNode;
graph.nodes[LATENTS_TO_LATENTS] = latentsToLatentsNode;
graph.nodes[LATENTS_TO_IMAGE] = latentsToImageNode;
- // Connect them
+ // Connect the prompt nodes to the imageToLatents node
graph.edges.push({
source: { node_id: POSITIVE_CONDITIONING, field: 'conditioning' },
destination: {
@@ -99,7 +116,6 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
field: 'positive_conditioning',
},
});
-
graph.edges.push({
source: { node_id: NEGATIVE_CONDITIONING, field: 'conditioning' },
destination: {
@@ -108,6 +124,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
},
});
+ // Connect the image-encoding node
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'latents' },
destination: {
@@ -116,6 +133,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
},
});
+ // Connect the image-decoding node
graph.edges.push({
source: { node_id: LATENTS_TO_LATENTS, field: 'latents' },
destination: {
@@ -124,8 +142,271 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
},
});
- // Create and add the noise nodes
- graph = addNoiseNodes(graph, latentsToLatentsNode.id, state);
+ /**
+ * Now we need to handle iterations and random seeds. There are four possible scenarios:
+ * - Single iteration, explicit seed
+ * - Single iteration, random seed
+ * - Multiple iterations, explicit seed
+ * - Multiple iterations, random seed
+ *
+ * They all have different graphs and connections.
+ */
+
+ // Single iteration, explicit seed
+ if (!shouldRandomizeSeed && iterations === 1) {
+ // Noise node using the explicit seed
+ const noiseNode: NoiseInvocation = {
+ id: NOISE,
+ type: 'noise',
+ seed: seed,
+ };
+
+ graph.nodes[NOISE] = noiseNode;
+
+ // Connect noise to l2l
+ graph.edges.push({
+ source: { node_id: NOISE, field: 'noise' },
+ destination: {
+ node_id: LATENTS_TO_LATENTS,
+ field: 'noise',
+ },
+ });
+ }
+
+ // Single iteration, random seed
+ if (shouldRandomizeSeed && iterations === 1) {
+ // Random int node to generate the seed
+ const randomIntNode: RandomIntInvocation = {
+ id: RANDOM_INT,
+ type: 'rand_int',
+ };
+
+ // Noise node without any seed
+ const noiseNode: NoiseInvocation = {
+ id: NOISE,
+ type: 'noise',
+ };
+
+ graph.nodes[RANDOM_INT] = randomIntNode;
+ graph.nodes[NOISE] = noiseNode;
+
+ // Connect random int to the seed of the noise node
+ graph.edges.push({
+ source: { node_id: RANDOM_INT, field: 'a' },
+ destination: {
+ node_id: NOISE,
+ field: 'seed',
+ },
+ });
+
+ // Connect noise to l2l
+ graph.edges.push({
+ source: { node_id: NOISE, field: 'noise' },
+ destination: {
+ node_id: LATENTS_TO_LATENTS,
+ field: 'noise',
+ },
+ });
+ }
+
+ // Multiple iterations, explicit seed
+ if (!shouldRandomizeSeed && iterations > 1) {
+ // Range of size node to generate `iterations` count of seeds - range of size generates a collection
+ // of ints from `start` to `start + size`. The `start` is the seed, and the `size` is the number of
+ // iterations.
+ const rangeOfSizeNode: RangeOfSizeInvocation = {
+ id: RANGE_OF_SIZE,
+ type: 'range_of_size',
+ start: seed,
+ size: iterations,
+ };
+
+ // Iterate node to iterate over the seeds generated by the range of size node
+ const iterateNode: IterateInvocation = {
+ id: ITERATE,
+ type: 'iterate',
+ };
+
+ // Noise node without any seed
+ const noiseNode: NoiseInvocation = {
+ id: NOISE,
+ type: 'noise',
+ };
+
+ // Adding to the graph
+ graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
+ graph.nodes[ITERATE] = iterateNode;
+ graph.nodes[NOISE] = noiseNode;
+
+ // Connect range of size to iterate
+ graph.edges.push({
+ source: { node_id: RANGE_OF_SIZE, field: 'collection' },
+ destination: {
+ node_id: ITERATE,
+ field: 'collection',
+ },
+ });
+
+ // Connect iterate to noise
+ graph.edges.push({
+ source: {
+ node_id: ITERATE,
+ field: 'item',
+ },
+ destination: {
+ node_id: NOISE,
+ field: 'seed',
+ },
+ });
+
+ // Connect noise to l2l
+ graph.edges.push({
+ source: { node_id: NOISE, field: 'noise' },
+ destination: {
+ node_id: LATENTS_TO_LATENTS,
+ field: 'noise',
+ },
+ });
+ }
+
+ // Multiple iterations, random seed
+ if (shouldRandomizeSeed && iterations > 1) {
+ // Random int node to generate the seed
+ const randomIntNode: RandomIntInvocation = {
+ id: RANDOM_INT,
+ type: 'rand_int',
+ };
+
+ // Range of size node to generate `iterations` count of seeds - range of size generates a collection
+ const rangeOfSizeNode: RangeOfSizeInvocation = {
+ id: RANGE_OF_SIZE,
+ type: 'range_of_size',
+ size: iterations,
+ };
+
+ // Iterate node to iterate over the seeds generated by the range of size node
+ const iterateNode: IterateInvocation = {
+ id: ITERATE,
+ type: 'iterate',
+ };
+
+ // Noise node without any seed
+ const noiseNode: NoiseInvocation = {
+ id: NOISE,
+ type: 'noise',
+ width,
+ height,
+ };
+
+ // Adding to the graph
+ graph.nodes[RANDOM_INT] = randomIntNode;
+ graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
+ graph.nodes[ITERATE] = iterateNode;
+ graph.nodes[NOISE] = noiseNode;
+
+ // Connect random int to the start of the range of size so the range starts on the random first seed
+ graph.edges.push({
+ source: { node_id: RANDOM_INT, field: 'a' },
+ destination: { node_id: RANGE_OF_SIZE, field: 'start' },
+ });
+
+ // Connect range of size to iterate
+ graph.edges.push({
+ source: { node_id: RANGE_OF_SIZE, field: 'collection' },
+ destination: {
+ node_id: ITERATE,
+ field: 'collection',
+ },
+ });
+
+ // Connect iterate to noise
+ graph.edges.push({
+ source: {
+ node_id: ITERATE,
+ field: 'item',
+ },
+ destination: {
+ node_id: NOISE,
+ field: 'seed',
+ },
+ });
+
+ // Connect noise to l2l
+ graph.edges.push({
+ source: { node_id: NOISE, field: 'noise' },
+ destination: {
+ node_id: LATENTS_TO_LATENTS,
+ field: 'noise',
+ },
+ });
+ }
+
+ if (shouldFitToWidthHeight) {
+ // The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
+
+ // Create a resize node, explicitly setting its image
+ const resizeNode: ImageResizeInvocation = {
+ id: RESIZE,
+ type: 'img_resize',
+ image: {
+ image_name: initialImage.image_name,
+ image_origin: initialImage.image_origin,
+ },
+ is_intermediate: true,
+ height,
+ width,
+ };
+
+ graph.nodes[RESIZE] = resizeNode;
+
+ // The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
+ graph.edges.push({
+ source: { node_id: RESIZE, field: 'image' },
+ destination: {
+ node_id: IMAGE_TO_LATENTS,
+ field: 'image',
+ },
+ });
+
+ // The `RESIZE` node also passes its width and height to `NOISE`
+ graph.edges.push({
+ source: { node_id: RESIZE, field: 'width' },
+ destination: {
+ node_id: NOISE,
+ field: 'width',
+ },
+ });
+
+ graph.edges.push({
+ source: { node_id: RESIZE, field: 'height' },
+ destination: {
+ node_id: NOISE,
+ field: 'height',
+ },
+ });
+ } else {
+ // We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
+ set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
+ image_name: initialImage.image_name,
+ image_origin: initialImage.image_origin,
+ });
+
+ // Pass the image's dimensions to the `NOISE` node
+ graph.edges.push({
+ source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
+ destination: {
+ node_id: NOISE,
+ field: 'width',
+ },
+ });
+ graph.edges.push({
+ source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
+ destination: {
+ node_id: NOISE,
+ field: 'height',
+ },
+ });
+ }
return graph;
};
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts
index eef7379624..6a700d4813 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts
@@ -1,8 +1,9 @@
import { Graph } from 'services/api';
import { v4 as uuidv4 } from 'uuid';
-import { cloneDeep, reduce } from 'lodash-es';
+import { cloneDeep, forEach, omit, reduce, values } from 'lodash-es';
import { RootState } from 'app/store/store';
import { InputFieldValue } from 'features/nodes/types/types';
+import { AnyInvocation } from 'services/events/types';
/**
* We need to do special handling for some fields
@@ -89,6 +90,24 @@ export const buildNodesGraph = (state: RootState): Graph => {
[]
);
+ /**
+ * Omit all inputs that have edges connected.
+ *
+ * Fixes edge case where the user has connected an input, but also provided an invalid explicit,
+ * value.
+ *
+ * In this edge case, pydantic will invalidate the node based on the invalid explicit value,
+ * even though the actual value that will be used comes from the connection.
+ */
+ parsedEdges.forEach((edge) => {
+ const destination_node = parsedNodes[edge.destination.node_id];
+ const field = edge.destination.field;
+ parsedNodes[edge.destination.node_id] = omit(
+ destination_node,
+ field
+ ) as AnyInvocation;
+ });
+
// Assemble!
const graph = {
id: uuidv4(),
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts
index 51f89e8f74..753ccccff8 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts
@@ -2,16 +2,23 @@ import { RootState } from 'app/store/store';
import {
CompelInvocation,
Graph,
+ IterateInvocation,
LatentsToImageInvocation,
+ NoiseInvocation,
+ RandomIntInvocation,
+ RangeOfSizeInvocation,
TextToLatentsInvocation,
} from 'services/api';
import { NonNullableGraph } from 'features/nodes/types/types';
-import { addNoiseNodes } from '../nodeBuilders/addNoiseNodes';
const POSITIVE_CONDITIONING = 'positive_conditioning';
const NEGATIVE_CONDITIONING = 'negative_conditioning';
const TEXT_TO_LATENTS = 'text_to_latents';
const LATENTS_TO_IMAGE = 'latents_to_image';
+const NOISE = 'noise';
+const RANDOM_INT = 'rand_int';
+const RANGE_OF_SIZE = 'range_of_size';
+const ITERATE = 'iterate';
/**
* Builds the Text to Image tab graph.
@@ -24,9 +31,14 @@ export const buildTextToImageGraph = (state: RootState): Graph => {
cfgScale: cfg_scale,
scheduler,
steps,
+ width,
+ height,
+ iterations,
+ seed,
+ shouldRandomizeSeed,
} = state.generation;
- let graph: NonNullableGraph = {
+ const graph: NonNullableGraph = {
nodes: {},
edges: [],
};
@@ -92,8 +104,209 @@ export const buildTextToImageGraph = (state: RootState): Graph => {
},
});
- // Create and add the noise nodes
- graph = addNoiseNodes(graph, TEXT_TO_LATENTS, state);
+ /**
+ * Now we need to handle iterations and random seeds. There are four possible scenarios:
+ * - Single iteration, explicit seed
+ * - Single iteration, random seed
+ * - Multiple iterations, explicit seed
+ * - Multiple iterations, random seed
+ *
+ * They all have different graphs and connections.
+ */
+ // Single iteration, explicit seed
+ if (!shouldRandomizeSeed && iterations === 1) {
+ // Noise node using the explicit seed
+ const noiseNode: NoiseInvocation = {
+ id: NOISE,
+ type: 'noise',
+ seed: seed,
+ width,
+ height,
+ };
+
+ graph.nodes[NOISE] = noiseNode;
+
+ // Connect noise to l2l
+ graph.edges.push({
+ source: { node_id: NOISE, field: 'noise' },
+ destination: {
+ node_id: TEXT_TO_LATENTS,
+ field: 'noise',
+ },
+ });
+ }
+
+ // Single iteration, random seed
+ if (shouldRandomizeSeed && iterations === 1) {
+ // Random int node to generate the seed
+ const randomIntNode: RandomIntInvocation = {
+ id: RANDOM_INT,
+ type: 'rand_int',
+ };
+
+ // Noise node without any seed
+ const noiseNode: NoiseInvocation = {
+ id: NOISE,
+ type: 'noise',
+ width,
+ height,
+ };
+
+ graph.nodes[RANDOM_INT] = randomIntNode;
+ graph.nodes[NOISE] = noiseNode;
+
+ // Connect random int to the seed of the noise node
+ graph.edges.push({
+ source: { node_id: RANDOM_INT, field: 'a' },
+ destination: {
+ node_id: NOISE,
+ field: 'seed',
+ },
+ });
+
+ // Connect noise to t2l
+ graph.edges.push({
+ source: { node_id: NOISE, field: 'noise' },
+ destination: {
+ node_id: TEXT_TO_LATENTS,
+ field: 'noise',
+ },
+ });
+ }
+
+ // Multiple iterations, explicit seed
+ if (!shouldRandomizeSeed && iterations > 1) {
+ // Range of size node to generate `iterations` count of seeds - range of size generates a collection
+ // of ints from `start` to `start + size`. The `start` is the seed, and the `size` is the number of
+ // iterations.
+ const rangeOfSizeNode: RangeOfSizeInvocation = {
+ id: RANGE_OF_SIZE,
+ type: 'range_of_size',
+ start: seed,
+ size: iterations,
+ };
+
+ // Iterate node to iterate over the seeds generated by the range of size node
+ const iterateNode: IterateInvocation = {
+ id: ITERATE,
+ type: 'iterate',
+ };
+
+ // Noise node without any seed
+ const noiseNode: NoiseInvocation = {
+ id: NOISE,
+ type: 'noise',
+ width,
+ height,
+ };
+
+ // Adding to the graph
+ graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
+ graph.nodes[ITERATE] = iterateNode;
+ graph.nodes[NOISE] = noiseNode;
+
+ // Connect range of size to iterate
+ graph.edges.push({
+ source: { node_id: RANGE_OF_SIZE, field: 'collection' },
+ destination: {
+ node_id: ITERATE,
+ field: 'collection',
+ },
+ });
+
+ // Connect iterate to noise
+ graph.edges.push({
+ source: {
+ node_id: ITERATE,
+ field: 'item',
+ },
+ destination: {
+ node_id: NOISE,
+ field: 'seed',
+ },
+ });
+
+ // Connect noise to t2l
+ graph.edges.push({
+ source: { node_id: NOISE, field: 'noise' },
+ destination: {
+ node_id: TEXT_TO_LATENTS,
+ field: 'noise',
+ },
+ });
+ }
+
+ // Multiple iterations, random seed
+ if (shouldRandomizeSeed && iterations > 1) {
+ // Random int node to generate the seed
+ const randomIntNode: RandomIntInvocation = {
+ id: RANDOM_INT,
+ type: 'rand_int',
+ };
+
+ // Range of size node to generate `iterations` count of seeds - range of size generates a collection
+ const rangeOfSizeNode: RangeOfSizeInvocation = {
+ id: RANGE_OF_SIZE,
+ type: 'range_of_size',
+ size: iterations,
+ };
+
+ // Iterate node to iterate over the seeds generated by the range of size node
+ const iterateNode: IterateInvocation = {
+ id: ITERATE,
+ type: 'iterate',
+ };
+
+ // Noise node without any seed
+ const noiseNode: NoiseInvocation = {
+ id: NOISE,
+ type: 'noise',
+ width,
+ height,
+ };
+
+ // Adding to the graph
+ graph.nodes[RANDOM_INT] = randomIntNode;
+ graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
+ graph.nodes[ITERATE] = iterateNode;
+ graph.nodes[NOISE] = noiseNode;
+
+ // Connect random int to the start of the range of size so the range starts on the random first seed
+ graph.edges.push({
+ source: { node_id: RANDOM_INT, field: 'a' },
+ destination: { node_id: RANGE_OF_SIZE, field: 'start' },
+ });
+
+ // Connect range of size to iterate
+ graph.edges.push({
+ source: { node_id: RANGE_OF_SIZE, field: 'collection' },
+ destination: {
+ node_id: ITERATE,
+ field: 'collection',
+ },
+ });
+
+ // Connect iterate to noise
+ graph.edges.push({
+ source: {
+ node_id: ITERATE,
+ field: 'item',
+ },
+ destination: {
+ node_id: NOISE,
+ field: 'seed',
+ },
+ });
+
+ // Connect noise to t2l
+ graph.edges.push({
+ source: { node_id: NOISE, field: 'noise' },
+ destination: {
+ node_id: TEXT_TO_LATENTS,
+ field: 'noise',
+ },
+ });
+ }
return graph;
};
diff --git a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/addNoiseNodes.ts b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/addNoiseNodes.ts
deleted file mode 100644
index ba3d4d8168..0000000000
--- a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/addNoiseNodes.ts
+++ /dev/null
@@ -1,208 +0,0 @@
-import { RootState } from 'app/store/store';
-import {
- IterateInvocation,
- NoiseInvocation,
- RandomIntInvocation,
- RangeOfSizeInvocation,
-} from 'services/api';
-import { NonNullableGraph } from 'features/nodes/types/types';
-import { cloneDeep } from 'lodash-es';
-
-const NOISE = 'noise';
-const RANDOM_INT = 'rand_int';
-const RANGE_OF_SIZE = 'range_of_size';
-const ITERATE = 'iterate';
-/**
- * Adds the appropriate noise nodes to a linear UI t2l or l2l graph.
- *
- * @param graph The graph to add the noise nodes to.
- * @param baseNodeId The id of the base node to connect the noise nodes to.
- * @param state The app state..
- */
-export const addNoiseNodes = (
- graph: NonNullableGraph,
- baseNodeId: string,
- state: RootState
-): NonNullableGraph => {
- const graphClone = cloneDeep(graph);
-
- // Create and add the noise nodes
- const { width, height, seed, iterations, shouldRandomizeSeed } =
- state.generation;
-
- // Single iteration, explicit seed
- if (!shouldRandomizeSeed && iterations === 1) {
- const noiseNode: NoiseInvocation = {
- id: NOISE,
- type: 'noise',
- seed: seed,
- width,
- height,
- };
-
- graphClone.nodes[NOISE] = noiseNode;
-
- // Connect them
- graphClone.edges.push({
- source: { node_id: NOISE, field: 'noise' },
- destination: {
- node_id: baseNodeId,
- field: 'noise',
- },
- });
- }
-
- // Single iteration, random seed
- if (shouldRandomizeSeed && iterations === 1) {
- // TODO: This assumes the `high` value is the max seed value
- const randomIntNode: RandomIntInvocation = {
- id: RANDOM_INT,
- type: 'rand_int',
- };
-
- const noiseNode: NoiseInvocation = {
- id: NOISE,
- type: 'noise',
- width,
- height,
- };
-
- graphClone.nodes[RANDOM_INT] = randomIntNode;
- graphClone.nodes[NOISE] = noiseNode;
-
- graphClone.edges.push({
- source: { node_id: RANDOM_INT, field: 'a' },
- destination: {
- node_id: NOISE,
- field: 'seed',
- },
- });
-
- graphClone.edges.push({
- source: { node_id: NOISE, field: 'noise' },
- destination: {
- node_id: baseNodeId,
- field: 'noise',
- },
- });
- }
-
- // Multiple iterations, explicit seed
- if (!shouldRandomizeSeed && iterations > 1) {
- const rangeOfSizeNode: RangeOfSizeInvocation = {
- id: RANGE_OF_SIZE,
- type: 'range_of_size',
- start: seed,
- size: iterations,
- };
-
- const iterateNode: IterateInvocation = {
- id: ITERATE,
- type: 'iterate',
- };
-
- const noiseNode: NoiseInvocation = {
- id: NOISE,
- type: 'noise',
- width,
- height,
- };
-
- graphClone.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
- graphClone.nodes[ITERATE] = iterateNode;
- graphClone.nodes[NOISE] = noiseNode;
-
- graphClone.edges.push({
- source: { node_id: RANGE_OF_SIZE, field: 'collection' },
- destination: {
- node_id: ITERATE,
- field: 'collection',
- },
- });
-
- graphClone.edges.push({
- source: {
- node_id: ITERATE,
- field: 'item',
- },
- destination: {
- node_id: NOISE,
- field: 'seed',
- },
- });
-
- graphClone.edges.push({
- source: { node_id: NOISE, field: 'noise' },
- destination: {
- node_id: baseNodeId,
- field: 'noise',
- },
- });
- }
-
- // Multiple iterations, random seed
- if (shouldRandomizeSeed && iterations > 1) {
- // TODO: This assumes the `high` value is the max seed value
- const randomIntNode: RandomIntInvocation = {
- id: RANDOM_INT,
- type: 'rand_int',
- };
-
- const rangeOfSizeNode: RangeOfSizeInvocation = {
- id: RANGE_OF_SIZE,
- type: 'range_of_size',
- size: iterations,
- };
-
- const iterateNode: IterateInvocation = {
- id: ITERATE,
- type: 'iterate',
- };
-
- const noiseNode: NoiseInvocation = {
- id: NOISE,
- type: 'noise',
- width,
- height,
- };
-
- graphClone.nodes[RANDOM_INT] = randomIntNode;
- graphClone.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
- graphClone.nodes[ITERATE] = iterateNode;
- graphClone.nodes[NOISE] = noiseNode;
-
- graphClone.edges.push({
- source: { node_id: RANDOM_INT, field: 'a' },
- destination: { node_id: RANGE_OF_SIZE, field: 'start' },
- });
-
- graphClone.edges.push({
- source: { node_id: RANGE_OF_SIZE, field: 'collection' },
- destination: {
- node_id: ITERATE,
- field: 'collection',
- },
- });
-
- graphClone.edges.push({
- source: {
- node_id: ITERATE,
- field: 'item',
- },
- destination: {
- node_id: NOISE,
- field: 'seed',
- },
- });
-
- graphClone.edges.push({
- source: { node_id: NOISE, field: 'noise' },
- destination: {
- node_id: baseNodeId,
- field: 'noise',
- },
- });
- }
-
- return graphClone;
-};
diff --git a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts
index 5f00d12a23..558f937837 100644
--- a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts
@@ -58,7 +58,7 @@ export const buildImg2ImgNode = (
imageToImageNode.image = {
image_name: initialImage.name,
- image_type: initialImage.type,
+ image_origin: initialImage.type,
};
}
diff --git a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildInpaintNode.ts b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildInpaintNode.ts
index b3f6cca933..593658e536 100644
--- a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildInpaintNode.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildInpaintNode.ts
@@ -2,15 +2,12 @@ import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { InpaintInvocation } from 'services/api';
import { O } from 'ts-toolbelt';
-import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
export const buildInpaintNode = (
state: RootState,
overrides: O.Partial = {}
): InpaintInvocation => {
const nodeId = uuidv4();
- const { generation } = state;
- const activeTabName = activeTabNameSelector(state);
const {
positivePrompt: prompt,
@@ -25,8 +22,7 @@ export const buildInpaintNode = (
img2imgStrength: strength,
shouldFitToWidthHeight: fit,
shouldRandomizeSeed,
- initialImage,
- } = generation;
+ } = state.generation;
const inpaintNode: InpaintInvocation = {
id: nodeId,
@@ -42,19 +38,6 @@ export const buildInpaintNode = (
fit,
};
- // on Canvas tab, we do not manually specific init image
- if (activeTabName !== 'unifiedCanvas') {
- if (!initialImage) {
- // TODO: handle this more better
- throw 'no initial image';
- }
-
- inpaintNode.image = {
- image_name: initialImage.name,
- image_type: initialImage.type,
- };
- }
-
if (!shouldRandomizeSeed) {
inpaintNode.seed = seed;
}
diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts
index ddd19b8749..c77fdeca5e 100644
--- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts
@@ -13,7 +13,9 @@ import {
buildOutputFieldTemplates,
} from './fieldTemplateBuilders';
-const invocationDenylist = ['Graph'];
+const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate'];
+
+const invocationDenylist = ['Graph', 'InvocationMeta'];
export const parseSchema = (openAPI: OpenAPIV3.Document) => {
// filter out non-invocation schemas, plus some tricky invocations for now
@@ -73,7 +75,7 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
(inputsAccumulator, property, propertyName) => {
if (
// `type` and `id` are not valid inputs/outputs
- !['type', 'id'].includes(propertyName) &&
+ !RESERVED_FIELD_NAMES.includes(propertyName) &&
isSchemaObject(property)
) {
const field: InputFieldTemplate | undefined =
diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight.tsx
index 75ec70f257..dc83ba8907 100644
--- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight.tsx
@@ -2,18 +2,22 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
-import { canvasSelector } from 'features/canvas/store/canvasSelectors';
+import {
+ canvasSelector,
+ isStagingSelector,
+} from 'features/canvas/store/canvasSelectors';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
- canvasSelector,
- (canvas) => {
+ [canvasSelector, isStagingSelector],
+ (canvas, isStaging) => {
const { boundingBoxDimensions } = canvas;
return {
boundingBoxDimensions,
+ isStaging,
};
},
defaultSelectorOptions
@@ -21,7 +25,7 @@ const selector = createSelector(
const ParamBoundingBoxWidth = () => {
const dispatch = useAppDispatch();
- const { boundingBoxDimensions } = useAppSelector(selector);
+ const { boundingBoxDimensions, isStaging } = useAppSelector(selector);
const { t } = useTranslation();
@@ -45,12 +49,13 @@ const ParamBoundingBoxWidth = () => {
return (
{
+ [canvasSelector, isStagingSelector],
+ (canvas, isStaging) => {
const { boundingBoxDimensions } = canvas;
return {
boundingBoxDimensions,
+ isStaging,
};
},
defaultSelectorOptions
@@ -21,7 +25,7 @@ const selector = createSelector(
const ParamBoundingBoxWidth = () => {
const dispatch = useAppDispatch();
- const { boundingBoxDimensions } = useAppSelector(selector);
+ const { boundingBoxDimensions, isStaging } = useAppSelector(selector);
const { t } = useTranslation();
@@ -45,12 +49,13 @@ const ParamBoundingBoxWidth = () => {
return (
{
return (
diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx
index 2b5db18d93..f4413c4cf6 100644
--- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx
@@ -1,24 +1,33 @@
+import { createSelector } from '@reduxjs/toolkit';
import { Scheduler } from 'app/constants';
-import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICustomSelect from 'common/components/IAICustomSelect';
+import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setScheduler } from 'features/parameters/store/generationSlice';
-import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
+import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
+const selector = createSelector(
+ [uiSelector, generationSelector],
+ (ui, generation) => {
+ // TODO: DPMSolverSinglestepScheduler is fixed in https://github.com/huggingface/diffusers/pull/3413
+ // but we need to wait for the next release before removing this special handling.
+ const allSchedulers = ui.schedulers.filter((scheduler) => {
+ return !['dpmpp_2s'].includes(scheduler);
+ });
+
+ return {
+ scheduler: generation.scheduler,
+ allSchedulers,
+ };
+ },
+ defaultSelectorOptions
+);
+
const ParamScheduler = () => {
- const scheduler = useAppSelector(
- (state: RootState) => state.generation.scheduler
- );
-
- const activeTabName = useAppSelector(activeTabNameSelector);
-
- const schedulers = useAppSelector((state: RootState) => state.ui.schedulers);
-
- const img2imgSchedulers = schedulers.filter((scheduler) => {
- return !['dpmpp_2s'].includes(scheduler);
- });
+ const { allSchedulers, scheduler } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
@@ -38,11 +47,7 @@ const ParamScheduler = () => {
label={t('parameters.scheduler')}
selectedItem={scheduler}
setSelectedItem={handleChange}
- items={
- ['img2img', 'unifiedCanvas'].includes(activeTabName)
- ? img2imgSchedulers
- : schedulers
- }
+ items={allSchedulers}
withCheckIcon
/>
);
diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx
index be40f548e6..cfe1513420 100644
--- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx
@@ -5,7 +5,6 @@ import { useGetUrl } from 'common/util/getUrl';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { DragEvent, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
-import { ImageType } from 'services/api';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { initialImageSelected } from 'features/parameters/store/actions';
@@ -55,9 +54,7 @@ const InitialImagePreview = () => {
const handleDrop = useCallback(
(e: DragEvent) => {
const name = e.dataTransfer.getData('invokeai/imageName');
- const type = e.dataTransfer.getData('invokeai/imageType') as ImageType;
-
- dispatch(initialImageSelected({ image_name: name, image_type: type }));
+ dispatch(initialImageSelected(name));
},
[dispatch]
);
diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useParameters.ts
deleted file mode 100644
index 27ae63e5dd..0000000000
--- a/invokeai/frontend/web/src/features/parameters/hooks/useParameters.ts
+++ /dev/null
@@ -1,151 +0,0 @@
-import { useAppDispatch } from 'app/store/storeHooks';
-import { isFinite, isString } from 'lodash-es';
-import { useCallback } from 'react';
-import { useTranslation } from 'react-i18next';
-import useSetBothPrompts from './usePrompt';
-import { allParametersSet, setSeed } from '../store/generationSlice';
-import { isImageField } from 'services/types/guards';
-import { NUMPY_RAND_MAX } from 'app/constants';
-import { initialImageSelected } from '../store/actions';
-import { setActiveTab } from 'features/ui/store/uiSlice';
-import { useAppToaster } from 'app/components/Toaster';
-import { ImageDTO } from 'services/api';
-
-export const useParameters = () => {
- const dispatch = useAppDispatch();
- const toaster = useAppToaster();
- const { t } = useTranslation();
- const setBothPrompts = useSetBothPrompts();
-
- /**
- * Sets prompt with toast
- */
- const recallPrompt = useCallback(
- (prompt: unknown, negativePrompt?: unknown) => {
- if (!isString(prompt) || !isString(negativePrompt)) {
- toaster({
- title: t('toast.promptNotSet'),
- description: t('toast.promptNotSetDesc'),
- status: 'warning',
- duration: 2500,
- isClosable: true,
- });
- return;
- }
-
- setBothPrompts(prompt, negativePrompt);
- toaster({
- title: t('toast.promptSet'),
- status: 'info',
- duration: 2500,
- isClosable: true,
- });
- },
- [t, toaster, setBothPrompts]
- );
-
- /**
- * Sets seed with toast
- */
- const recallSeed = useCallback(
- (seed: unknown) => {
- const s = Number(seed);
- if (!isFinite(s) || (isFinite(s) && !(s >= 0 && s <= NUMPY_RAND_MAX))) {
- toaster({
- title: t('toast.seedNotSet'),
- description: t('toast.seedNotSetDesc'),
- status: 'warning',
- duration: 2500,
- isClosable: true,
- });
- return;
- }
-
- dispatch(setSeed(s));
- toaster({
- title: t('toast.seedSet'),
- status: 'info',
- duration: 2500,
- isClosable: true,
- });
- },
- [t, toaster, dispatch]
- );
-
- /**
- * Sets initial image with toast
- */
- const recallInitialImage = useCallback(
- async (image: unknown) => {
- if (!isImageField(image)) {
- toaster({
- title: t('toast.initialImageNotSet'),
- description: t('toast.initialImageNotSetDesc'),
- status: 'warning',
- duration: 2500,
- isClosable: true,
- });
- return;
- }
-
- dispatch(initialImageSelected(image));
- toaster({
- title: t('toast.initialImageSet'),
- status: 'info',
- duration: 2500,
- isClosable: true,
- });
- },
- [t, toaster, dispatch]
- );
-
- /**
- * Sets image as initial image with toast
- */
- const sendToImageToImage = useCallback(
- (image: ImageDTO) => {
- dispatch(initialImageSelected(image));
- },
- [dispatch]
- );
-
- const recallAllParameters = useCallback(
- (image: ImageDTO | undefined) => {
- const type = image?.metadata?.type;
- // not sure what this list should be
- if (['t2l', 'l2l', 'inpaint'].includes(String(type))) {
- dispatch(allParametersSet(image));
-
- if (image?.metadata?.type === 'l2l') {
- dispatch(setActiveTab('img2img'));
- } else if (image?.metadata?.type === 't2l') {
- dispatch(setActiveTab('txt2img'));
- }
-
- toaster({
- title: t('toast.parametersSet'),
- status: 'success',
- duration: 2500,
- isClosable: true,
- });
- } else {
- toaster({
- title: t('toast.parametersNotSet'),
- description: t('toast.parametersNotSetDesc'),
- status: 'error',
- duration: 2500,
- isClosable: true,
- });
- }
- },
- [t, toaster, dispatch]
- );
-
- return {
- recallPrompt,
- recallSeed,
- recallInitialImage,
- sendToImageToImage,
- recallAllParameters,
- };
-};
diff --git a/invokeai/frontend/web/src/features/parameters/hooks/usePrompt.ts b/invokeai/frontend/web/src/features/parameters/hooks/usePrompt.ts
deleted file mode 100644
index 3fee0bcdd8..0000000000
--- a/invokeai/frontend/web/src/features/parameters/hooks/usePrompt.ts
+++ /dev/null
@@ -1,23 +0,0 @@
-import { getPromptAndNegative } from 'common/util/getPromptAndNegative';
-
-import * as InvokeAI from 'app/types/invokeai';
-import promptToString from 'common/util/promptToString';
-import { useAppDispatch } from 'app/store/storeHooks';
-import { setNegativePrompt, setPositivePrompt } from '../store/generationSlice';
-import { useCallback } from 'react';
-
-// TECHDEBT: We have two metadata prompt formats and need to handle recalling either of them.
-// This hook provides a function to do that.
-const useSetBothPrompts = () => {
- const dispatch = useAppDispatch();
-
- return useCallback(
- (inputPrompt: InvokeAI.Prompt, negativePrompt: InvokeAI.Prompt) => {
- dispatch(setPositivePrompt(inputPrompt));
- dispatch(setNegativePrompt(negativePrompt));
- },
- [dispatch]
- );
-};
-
-export default useSetBothPrompts;
diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts
new file mode 100644
index 0000000000..7b7a405867
--- /dev/null
+++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts
@@ -0,0 +1,348 @@
+import { useAppDispatch } from 'app/store/storeHooks';
+import { useCallback } from 'react';
+import { useTranslation } from 'react-i18next';
+import {
+ modelSelected,
+ setCfgScale,
+ setHeight,
+ setImg2imgStrength,
+ setNegativePrompt,
+ setPositivePrompt,
+ setScheduler,
+ setSeed,
+ setSteps,
+ setWidth,
+} from '../store/generationSlice';
+import { isImageField } from 'services/types/guards';
+import { initialImageSelected } from '../store/actions';
+import { useAppToaster } from 'app/components/Toaster';
+import { ImageDTO } from 'services/api';
+import {
+ isValidCfgScale,
+ isValidHeight,
+ isValidModel,
+ isValidNegativePrompt,
+ isValidPositivePrompt,
+ isValidScheduler,
+ isValidSeed,
+ isValidSteps,
+ isValidStrength,
+ isValidWidth,
+} from '../store/parameterZodSchemas';
+
+export const useRecallParameters = () => {
+ const dispatch = useAppDispatch();
+ const toaster = useAppToaster();
+ const { t } = useTranslation();
+
+ const parameterSetToast = useCallback(() => {
+ toaster({
+ title: t('toast.parameterSet'),
+ status: 'info',
+ duration: 2500,
+ isClosable: true,
+ });
+ }, [t, toaster]);
+
+ const parameterNotSetToast = useCallback(() => {
+ toaster({
+ title: t('toast.parameterNotSet'),
+ status: 'warning',
+ duration: 2500,
+ isClosable: true,
+ });
+ }, [t, toaster]);
+
+ const allParameterSetToast = useCallback(() => {
+ toaster({
+ title: t('toast.parametersSet'),
+ status: 'info',
+ duration: 2500,
+ isClosable: true,
+ });
+ }, [t, toaster]);
+
+ const allParameterNotSetToast = useCallback(() => {
+ toaster({
+ title: t('toast.parametersNotSet'),
+ status: 'warning',
+ duration: 2500,
+ isClosable: true,
+ });
+ }, [t, toaster]);
+
+ /**
+ * Recall both prompts with toast
+ */
+ const recallBothPrompts = useCallback(
+ (positivePrompt: unknown, negativePrompt: unknown) => {
+ if (
+ isValidPositivePrompt(positivePrompt) ||
+ isValidNegativePrompt(negativePrompt)
+ ) {
+ if (isValidPositivePrompt(positivePrompt)) {
+ dispatch(setPositivePrompt(positivePrompt));
+ }
+ if (isValidNegativePrompt(negativePrompt)) {
+ dispatch(setNegativePrompt(negativePrompt));
+ }
+ parameterSetToast();
+ return;
+ }
+ parameterNotSetToast();
+ },
+ [dispatch, parameterSetToast, parameterNotSetToast]
+ );
+
+ /**
+ * Recall positive prompt with toast
+ */
+ const recallPositivePrompt = useCallback(
+ (positivePrompt: unknown) => {
+ if (!isValidPositivePrompt(positivePrompt)) {
+ parameterNotSetToast();
+ return;
+ }
+ dispatch(setPositivePrompt(positivePrompt));
+ parameterSetToast();
+ },
+ [dispatch, parameterSetToast, parameterNotSetToast]
+ );
+
+ /**
+ * Recall negative prompt with toast
+ */
+ const recallNegativePrompt = useCallback(
+ (negativePrompt: unknown) => {
+ if (!isValidNegativePrompt(negativePrompt)) {
+ parameterNotSetToast();
+ return;
+ }
+ dispatch(setNegativePrompt(negativePrompt));
+ parameterSetToast();
+ },
+ [dispatch, parameterSetToast, parameterNotSetToast]
+ );
+
+ /**
+ * Recall seed with toast
+ */
+ const recallSeed = useCallback(
+ (seed: unknown) => {
+ if (!isValidSeed(seed)) {
+ parameterNotSetToast();
+ return;
+ }
+ dispatch(setSeed(seed));
+ parameterSetToast();
+ },
+ [dispatch, parameterSetToast, parameterNotSetToast]
+ );
+
+ /**
+ * Recall CFG scale with toast
+ */
+ const recallCfgScale = useCallback(
+ (cfgScale: unknown) => {
+ if (!isValidCfgScale(cfgScale)) {
+ parameterNotSetToast();
+ return;
+ }
+ dispatch(setCfgScale(cfgScale));
+ parameterSetToast();
+ },
+ [dispatch, parameterSetToast, parameterNotSetToast]
+ );
+
+ /**
+ * Recall model with toast
+ */
+ const recallModel = useCallback(
+ (model: unknown) => {
+ if (!isValidModel(model)) {
+ parameterNotSetToast();
+ return;
+ }
+ dispatch(modelSelected(model));
+ parameterSetToast();
+ },
+ [dispatch, parameterSetToast, parameterNotSetToast]
+ );
+
+ /**
+ * Recall scheduler with toast
+ */
+ const recallScheduler = useCallback(
+ (scheduler: unknown) => {
+ if (!isValidScheduler(scheduler)) {
+ parameterNotSetToast();
+ return;
+ }
+ dispatch(setScheduler(scheduler));
+ parameterSetToast();
+ },
+ [dispatch, parameterSetToast, parameterNotSetToast]
+ );
+
+ /**
+ * Recall steps with toast
+ */
+ const recallSteps = useCallback(
+ (steps: unknown) => {
+ if (!isValidSteps(steps)) {
+ parameterNotSetToast();
+ return;
+ }
+ dispatch(setSteps(steps));
+ parameterSetToast();
+ },
+ [dispatch, parameterSetToast, parameterNotSetToast]
+ );
+
+ /**
+ * Recall width with toast
+ */
+ const recallWidth = useCallback(
+ (width: unknown) => {
+ if (!isValidWidth(width)) {
+ parameterNotSetToast();
+ return;
+ }
+ dispatch(setWidth(width));
+ parameterSetToast();
+ },
+ [dispatch, parameterSetToast, parameterNotSetToast]
+ );
+
+ /**
+ * Recall height with toast
+ */
+ const recallHeight = useCallback(
+ (height: unknown) => {
+ if (!isValidHeight(height)) {
+ parameterNotSetToast();
+ return;
+ }
+ dispatch(setHeight(height));
+ parameterSetToast();
+ },
+ [dispatch, parameterSetToast, parameterNotSetToast]
+ );
+
+ /**
+ * Recall strength with toast
+ */
+ const recallStrength = useCallback(
+ (strength: unknown) => {
+ if (!isValidStrength(strength)) {
+ parameterNotSetToast();
+ return;
+ }
+ dispatch(setImg2imgStrength(strength));
+ parameterSetToast();
+ },
+ [dispatch, parameterSetToast, parameterNotSetToast]
+ );
+
+ /**
+ * Sets initial image with toast
+ */
+ const recallInitialImage = useCallback(
+ async (image: unknown) => {
+ if (!isImageField(image)) {
+ parameterNotSetToast();
+ return;
+ }
+ dispatch(initialImageSelected(image.image_name));
+ parameterSetToast();
+ },
+ [dispatch, parameterSetToast, parameterNotSetToast]
+ );
+
+ /**
+ * Sets image as initial image with toast
+ */
+ const sendToImageToImage = useCallback(
+ (image: ImageDTO) => {
+ dispatch(initialImageSelected(image));
+ },
+ [dispatch]
+ );
+
+ const recallAllParameters = useCallback(
+ (image: ImageDTO | undefined) => {
+ if (!image || !image.metadata) {
+ allParameterNotSetToast();
+ return;
+ }
+ const {
+ cfg_scale,
+ height,
+ model,
+ positive_conditioning,
+ negative_conditioning,
+ scheduler,
+ seed,
+ steps,
+ width,
+ strength,
+ clip,
+ extra,
+ latents,
+ unet,
+ vae,
+ } = image.metadata;
+
+ if (isValidCfgScale(cfg_scale)) {
+ dispatch(setCfgScale(cfg_scale));
+ }
+ if (isValidModel(model)) {
+ dispatch(modelSelected(model));
+ }
+ if (isValidPositivePrompt(positive_conditioning)) {
+ dispatch(setPositivePrompt(positive_conditioning));
+ }
+ if (isValidNegativePrompt(negative_conditioning)) {
+ dispatch(setNegativePrompt(negative_conditioning));
+ }
+ if (isValidScheduler(scheduler)) {
+ dispatch(setScheduler(scheduler));
+ }
+ if (isValidSeed(seed)) {
+ dispatch(setSeed(seed));
+ }
+ if (isValidSteps(steps)) {
+ dispatch(setSteps(steps));
+ }
+ if (isValidWidth(width)) {
+ dispatch(setWidth(width));
+ }
+ if (isValidHeight(height)) {
+ dispatch(setHeight(height));
+ }
+ if (isValidStrength(strength)) {
+ dispatch(setImg2imgStrength(strength));
+ }
+
+ allParameterSetToast();
+ },
+ [allParameterNotSetToast, allParameterSetToast, dispatch]
+ );
+
+ return {
+ recallBothPrompts,
+ recallPositivePrompt,
+ recallNegativePrompt,
+ recallSeed,
+ recallInitialImage,
+ recallCfgScale,
+ recallModel,
+ recallScheduler,
+ recallSteps,
+ recallWidth,
+ recallHeight,
+ recallStrength,
+ recallAllParameters,
+ sendToImageToImage,
+ };
+};
diff --git a/invokeai/frontend/web/src/features/parameters/store/actions.ts b/invokeai/frontend/web/src/features/parameters/store/actions.ts
index 853597c809..e9b90134e1 100644
--- a/invokeai/frontend/web/src/features/parameters/store/actions.ts
+++ b/invokeai/frontend/web/src/features/parameters/store/actions.ts
@@ -1,10 +1,10 @@
import { createAction } from '@reduxjs/toolkit';
import { isObject } from 'lodash-es';
-import { ImageDTO, ImageType } from 'services/api';
+import { ImageDTO, ResourceOrigin } from 'services/api';
-export type ImageNameAndType = {
+export type ImageNameAndOrigin = {
image_name: string;
- image_type: ImageType;
+ image_origin: ResourceOrigin;
};
export const isImageDTO = (image: any): image is ImageDTO => {
@@ -13,8 +13,8 @@ export const isImageDTO = (image: any): image is ImageDTO => {
isObject(image) &&
'image_name' in image &&
image?.image_name !== undefined &&
- 'image_type' in image &&
- image?.image_type !== undefined &&
+ 'image_origin' in image &&
+ image?.image_origin !== undefined &&
'image_url' in image &&
image?.image_url !== undefined &&
'thumbnail_url' in image &&
@@ -26,6 +26,6 @@ export const isImageDTO = (image: any): image is ImageDTO => {
);
};
-export const initialImageSelected = createAction<
- ImageDTO | ImageNameAndType | undefined
->('generation/initialImageSelected');
+export const initialImageSelected = createAction(
+ 'generation/initialImageSelected'
+);
diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSelectors.ts b/invokeai/frontend/web/src/features/parameters/store/generationSelectors.ts
index dbf5eec791..b7322740ef 100644
--- a/invokeai/frontend/web/src/features/parameters/store/generationSelectors.ts
+++ b/invokeai/frontend/web/src/features/parameters/store/generationSelectors.ts
@@ -1,34 +1,3 @@
-import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
-import { selectResultsById } from 'features/gallery/store/resultsSlice';
-import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
-import { isEqual } from 'lodash-es';
export const generationSelector = (state: RootState) => state.generation;
-
-export const mayGenerateMultipleImagesSelector = createSelector(
- generationSelector,
- ({ shouldRandomizeSeed, shouldGenerateVariations }) => {
- return shouldRandomizeSeed || shouldGenerateVariations;
- },
- {
- memoizeOptions: {
- resultEqualityCheck: isEqual,
- },
- }
-);
-
-export const initialImageSelector = createSelector(
- [(state: RootState) => state, generationSelector],
- (state, generation) => {
- const { initialImage } = generation;
-
- if (initialImage?.type === 'results') {
- return selectResultsById(state, initialImage.name);
- }
-
- if (initialImage?.type === 'uploads') {
- return selectUploadsById(state, initialImage.name);
- }
- }
-);
diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts
index 849f848ff3..6420950e4a 100644
--- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts
+++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts
@@ -1,43 +1,53 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
-import * as InvokeAI from 'app/types/invokeai';
-import promptToString from 'common/util/promptToString';
-import { clamp, sample } from 'lodash-es';
-import { setAllParametersReducer } from './setAllParametersReducer';
+import { clamp, sortBy } from 'lodash-es';
import { receivedModels } from 'services/thunks/model';
import { Scheduler } from 'app/constants';
import { ImageDTO } from 'services/api';
+import { configChanged } from 'features/system/store/configSlice';
+import {
+ CfgScaleParam,
+ HeightParam,
+ ModelParam,
+ NegativePromptParam,
+ PositivePromptParam,
+ SchedulerParam,
+ SeedParam,
+ StepsParam,
+ StrengthParam,
+ WidthParam,
+} from './parameterZodSchemas';
export interface GenerationState {
- cfgScale: number;
- height: number;
- img2imgStrength: number;
+ cfgScale: CfgScaleParam;
+ height: HeightParam;
+ img2imgStrength: StrengthParam;
infillMethod: string;
initialImage?: ImageDTO;
iterations: number;
perlin: number;
- positivePrompt: string;
- negativePrompt: string;
- scheduler: Scheduler;
+ positivePrompt: PositivePromptParam;
+ negativePrompt: NegativePromptParam;
+ scheduler: SchedulerParam;
seamBlur: number;
seamSize: number;
seamSteps: number;
seamStrength: number;
- seed: number;
+ seed: SeedParam;
seedWeights: string;
shouldFitToWidthHeight: boolean;
shouldGenerateVariations: boolean;
shouldRandomizeSeed: boolean;
shouldUseNoiseSettings: boolean;
- steps: number;
+ steps: StepsParam;
threshold: number;
tileSize: number;
variationAmount: number;
- width: number;
+ width: WidthParam;
shouldUseSymmetry: boolean;
horizontalSymmetrySteps: number;
verticalSymmetrySteps: number;
- model: string;
+ model: ModelParam;
shouldUseSeamless: boolean;
seamlessXAxis: boolean;
seamlessYAxis: boolean;
@@ -83,27 +93,11 @@ export const generationSlice = createSlice({
name: 'generation',
initialState,
reducers: {
- setPositivePrompt: (
- state,
- action: PayloadAction
- ) => {
- const newPrompt = action.payload;
- if (typeof newPrompt === 'string') {
- state.positivePrompt = newPrompt;
- } else {
- state.positivePrompt = promptToString(newPrompt);
- }
+ setPositivePrompt: (state, action: PayloadAction) => {
+ state.positivePrompt = action.payload;
},
- setNegativePrompt: (
- state,
- action: PayloadAction
- ) => {
- const newPrompt = action.payload;
- if (typeof newPrompt === 'string') {
- state.negativePrompt = newPrompt;
- } else {
- state.negativePrompt = promptToString(newPrompt);
- }
+ setNegativePrompt: (state, action: PayloadAction) => {
+ state.negativePrompt = action.payload;
},
setIterations: (state, action: PayloadAction) => {
state.iterations = action.payload;
@@ -174,7 +168,6 @@ export const generationSlice = createSlice({
state.shouldGenerateVariations = true;
state.variationAmount = 0;
},
- allParametersSet: setAllParametersReducer,
resetParametersState: (state) => {
return {
...state,
@@ -227,10 +220,15 @@ export const generationSlice = createSlice({
extraReducers: (builder) => {
builder.addCase(receivedModels.fulfilled, (state, action) => {
if (!state.model) {
- const randomModel = sample(action.payload);
- if (randomModel) {
- state.model = randomModel.name;
- }
+ const firstModel = sortBy(action.payload, 'name')[0];
+ state.model = firstModel.name;
+ }
+ });
+
+ builder.addCase(configChanged, (state, action) => {
+ const defaultModel = action.payload.sd?.defaultModel;
+ if (defaultModel && !state.model) {
+ state.model = defaultModel;
}
});
},
@@ -273,7 +271,6 @@ export const {
setSeamless,
setSeamlessXAxis,
setSeamlessYAxis,
- allParametersSet,
} = generationSlice.actions;
export default generationSlice.reducer;
diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts
new file mode 100644
index 0000000000..b99e57bfbb
--- /dev/null
+++ b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts
@@ -0,0 +1,156 @@
+import { NUMPY_RAND_MAX, SCHEDULERS } from 'app/constants';
+import { z } from 'zod';
+
+/**
+ * These zod schemas should match the pydantic node schemas.
+ *
+ * Parameters only need schemas if we want to recall them from metadata.
+ *
+ * Each parameter needs:
+ * - a zod schema
+ * - a type alias, inferred from the zod schema
+ * - a combo validation/type guard function, which returns true if the value is valid
+ */
+
+/**
+ * Zod schema for positive prompt parameter
+ */
+export const zPositivePrompt = z.string();
+/**
+ * Type alias for positive prompt parameter, inferred from its zod schema
+ */
+export type PositivePromptParam = z.infer;
+/**
+ * Validates/type-guards a value as a positive prompt parameter
+ */
+export const isValidPositivePrompt = (
+ val: unknown
+): val is PositivePromptParam => zPositivePrompt.safeParse(val).success;
+
+/**
+ * Zod schema for negative prompt parameter
+ */
+export const zNegativePrompt = z.string();
+/**
+ * Type alias for negative prompt parameter, inferred from its zod schema
+ */
+export type NegativePromptParam = z.infer;
+/**
+ * Validates/type-guards a value as a negative prompt parameter
+ */
+export const isValidNegativePrompt = (
+ val: unknown
+): val is NegativePromptParam => zNegativePrompt.safeParse(val).success;
+
+/**
+ * Zod schema for steps parameter
+ */
+export const zSteps = z.number().int().min(1);
+/**
+ * Type alias for steps parameter, inferred from its zod schema
+ */
+export type StepsParam = z.infer;
+/**
+ * Validates/type-guards a value as a steps parameter
+ */
+export const isValidSteps = (val: unknown): val is StepsParam =>
+ zSteps.safeParse(val).success;
+
+/**
+ * Zod schema for CFG scale parameter
+ */
+export const zCfgScale = z.number().min(1);
+/**
+ * Type alias for CFG scale parameter, inferred from its zod schema
+ */
+export type CfgScaleParam = z.infer;
+/**
+ * Validates/type-guards a value as a CFG scale parameter
+ */
+export const isValidCfgScale = (val: unknown): val is CfgScaleParam =>
+ zCfgScale.safeParse(val).success;
+
+/**
+ * Zod schema for scheduler parameter
+ */
+export const zScheduler = z.enum(SCHEDULERS);
+/**
+ * Type alias for scheduler parameter, inferred from its zod schema
+ */
+export type SchedulerParam = z.infer;
+/**
+ * Validates/type-guards a value as a scheduler parameter
+ */
+export const isValidScheduler = (val: unknown): val is SchedulerParam =>
+ zScheduler.safeParse(val).success;
+
+/**
+ * Zod schema for seed parameter
+ */
+export const zSeed = z.number().int().min(0).max(NUMPY_RAND_MAX);
+/**
+ * Type alias for seed parameter, inferred from its zod schema
+ */
+export type SeedParam = z.infer;
+/**
+ * Validates/type-guards a value as a seed parameter
+ */
+export const isValidSeed = (val: unknown): val is SeedParam =>
+ zSeed.safeParse(val).success;
+
+/**
+ * Zod schema for width parameter
+ */
+export const zWidth = z.number().multipleOf(8).min(64);
+/**
+ * Type alias for width parameter, inferred from its zod schema
+ */
+export type WidthParam = z.infer;
+/**
+ * Validates/type-guards a value as a width parameter
+ */
+export const isValidWidth = (val: unknown): val is WidthParam =>
+ zWidth.safeParse(val).success;
+
+/**
+ * Zod schema for height parameter
+ */
+export const zHeight = z.number().multipleOf(8).min(64);
+/**
+ * Type alias for height parameter, inferred from its zod schema
+ */
+export type HeightParam = z.infer;
+/**
+ * Validates/type-guards a value as a height parameter
+ */
+export const isValidHeight = (val: unknown): val is HeightParam =>
+ zHeight.safeParse(val).success;
+
+/**
+ * Zod schema for model parameter
+ * TODO: Make this a dynamically generated enum?
+ */
+export const zModel = z.string();
+/**
+ * Type alias for model parameter, inferred from its zod schema
+ */
+export type ModelParam = z.infer;
+/**
+ * Validates/type-guards a value as a model parameter
+ */
+export const isValidModel = (val: unknown): val is ModelParam =>
+ zModel.safeParse(val).success;
+
+/**
+ * Zod schema for l2l strength parameter
+ */
+export const zStrength = z.number().min(0).max(1);
+/**
+ * Type alias for l2l strength parameter, inferred from its zod schema
+ */
+export type StrengthParam = z.infer;
+/**
+ * Validates/type-guards a value as a l2l strength parameter
+ */
+export const isValidStrength = (val: unknown): val is StrengthParam =>
+ zStrength.safeParse(val).success;
diff --git a/invokeai/frontend/web/src/features/parameters/store/setAllParametersReducer.ts b/invokeai/frontend/web/src/features/parameters/store/setAllParametersReducer.ts
deleted file mode 100644
index 8f06c7d0ef..0000000000
--- a/invokeai/frontend/web/src/features/parameters/store/setAllParametersReducer.ts
+++ /dev/null
@@ -1,77 +0,0 @@
-import { Draft, PayloadAction } from '@reduxjs/toolkit';
-import { GenerationState } from './generationSlice';
-import { ImageDTO, ImageToImageInvocation } from 'services/api';
-import { isScheduler } from 'app/constants';
-
-export const setAllParametersReducer = (
- state: Draft,
- action: PayloadAction
-) => {
- const metadata = action.payload?.metadata;
-
- if (!metadata) {
- return;
- }
-
- // not sure what this list should be
- if (
- metadata.type === 't2l' ||
- metadata.type === 'l2l' ||
- metadata.type === 'inpaint'
- ) {
- const {
- cfg_scale,
- height,
- model,
- positive_conditioning,
- negative_conditioning,
- scheduler,
- seed,
- steps,
- width,
- } = metadata;
-
- if (cfg_scale !== undefined) {
- state.cfgScale = Number(cfg_scale);
- }
- if (height !== undefined) {
- state.height = Number(height);
- }
- if (model !== undefined) {
- state.model = String(model);
- }
- if (positive_conditioning !== undefined) {
- state.positivePrompt = String(positive_conditioning);
- }
- if (negative_conditioning !== undefined) {
- state.negativePrompt = String(negative_conditioning);
- }
- if (scheduler !== undefined) {
- const schedulerString = String(scheduler);
- if (isScheduler(schedulerString)) {
- state.scheduler = schedulerString;
- }
- }
- if (seed !== undefined) {
- state.seed = Number(seed);
- state.shouldRandomizeSeed = false;
- }
- if (steps !== undefined) {
- state.steps = Number(steps);
- }
- if (width !== undefined) {
- state.width = Number(width);
- }
- }
-
- if (metadata.type === 'l2l') {
- const { fit, image } = metadata as ImageToImageInvocation;
-
- if (fit !== undefined) {
- state.shouldFitToWidthHeight = Boolean(fit);
- }
- // if (image !== undefined) {
- // state.initialImage = image;
- // }
- }
-};
diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx
index 520e30b60a..be4be8ceaa 100644
--- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx
+++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx
@@ -4,19 +4,33 @@ import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
-import { selectModelsById, selectModelsIds } from '../store/modelSlice';
+import {
+ selectModelsAll,
+ selectModelsById,
+ selectModelsIds,
+} from '../store/modelSlice';
import { RootState } from 'app/store/store';
import { modelSelected } from 'features/parameters/store/generationSlice';
import { generationSelector } from 'features/parameters/store/generationSelectors';
-import IAICustomSelect from 'common/components/IAICustomSelect';
+import IAICustomSelect, {
+ ItemTooltips,
+} from 'common/components/IAICustomSelect';
const selector = createSelector(
[(state: RootState) => state, generationSelector],
(state, generation) => {
const selectedModel = selectModelsById(state, generation.model);
const allModelNames = selectModelsIds(state).map((id) => String(id));
+ const allModelTooltips = selectModelsAll(state).reduce(
+ (allModelTooltips, model) => {
+ allModelTooltips[model.name] = model.description ?? '';
+ return allModelTooltips;
+ },
+ {} as ItemTooltips
+ );
return {
allModelNames,
+ allModelTooltips,
selectedModel,
};
},
@@ -30,7 +44,8 @@ const selector = createSelector(
const ModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
- const { allModelNames, selectedModel } = useAppSelector(selector);
+ const { allModelNames, allModelTooltips, selectedModel } =
+ useAppSelector(selector);
const handleChangeModel = useCallback(
(v: string | null | undefined) => {
if (!v) {
@@ -46,6 +61,7 @@ const ModelSelect = () => {
label={t('modelManager.model')}
tooltip={selectedModel?.description}
items={allModelNames}
+ itemTooltips={allModelTooltips}
selectedItem={selectedModel?.name ?? ''}
setSelectedItem={handleChangeModel}
withCheckIcon={true}
diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx
index 54556124c9..4cfe35081b 100644
--- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx
+++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx
@@ -35,7 +35,13 @@ import {
} from 'features/ui/store/uiSlice';
import { UIState } from 'features/ui/store/uiTypes';
import { isEqual } from 'lodash-es';
-import { ChangeEvent, cloneElement, ReactElement, useCallback } from 'react';
+import {
+ ChangeEvent,
+ cloneElement,
+ ReactElement,
+ useCallback,
+ useEffect,
+} from 'react';
import { useTranslation } from 'react-i18next';
import { VALID_LOG_LEVELS } from 'app/logging/useLogger';
import { LogLevelName } from 'roarr';
@@ -85,15 +91,33 @@ const modalSectionStyles: ChakraProps['sx'] = {
borderRadius: 'base',
};
+type ConfigOptions = {
+ shouldShowDeveloperSettings: boolean;
+ shouldShowResetWebUiText: boolean;
+ shouldShowBetaLayout: boolean;
+};
+
type SettingsModalProps = {
/* The button to open the Settings Modal */
children: ReactElement;
+ config?: ConfigOptions;
};
-const SettingsModal = ({ children }: SettingsModalProps) => {
+const SettingsModal = ({ children, config }: SettingsModalProps) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
+ const shouldShowBetaLayout = config?.shouldShowBetaLayout ?? true;
+ const shouldShowDeveloperSettings =
+ config?.shouldShowDeveloperSettings ?? true;
+ const shouldShowResetWebUiText = config?.shouldShowResetWebUiText ?? true;
+
+ useEffect(() => {
+ if (!shouldShowDeveloperSettings) {
+ dispatch(shouldLogToConsoleChanged(false));
+ }
+ }, [shouldShowDeveloperSettings, dispatch]);
+
const {
isOpen: isSettingsModalOpen,
onOpen: onSettingsModalOpen,
@@ -189,13 +213,15 @@ const SettingsModal = ({ children }: SettingsModalProps) => {
dispatch(setShouldDisplayGuides(e.target.checked))
}
/>
- ) =>
- dispatch(setShouldUseCanvasBetaLayout(e.target.checked))
- }
- />
+ {shouldShowBetaLayout && (
+ ) =>
+ dispatch(setShouldUseCanvasBetaLayout(e.target.checked))
+ }
+ />
+ )}
{
/>
-
- {t('settings.developer')}
-
-
- ) =>
- dispatch(setEnableImageDebugging(e.target.checked))
- }
- />
-
+ {shouldShowDeveloperSettings && (
+
+ {t('settings.developer')}
+
+
+ ) =>
+ dispatch(setEnableImageDebugging(e.target.checked))
+ }
+ />
+
+ )}
{t('settings.resetWebUI')}
{t('settings.resetWebUI')}
- {t('settings.resetWebUIDesc1')}
- {t('settings.resetWebUIDesc2')}
+ {shouldShowResetWebUiText && (
+ <>
+ {t('settings.resetWebUIDesc1')}
+ {t('settings.resetWebUIDesc2')}
+ >
+ )}
diff --git a/invokeai/frontend/web/src/features/system/store/actions.ts b/invokeai/frontend/web/src/features/system/store/actions.ts
new file mode 100644
index 0000000000..66181bc803
--- /dev/null
+++ b/invokeai/frontend/web/src/features/system/store/actions.ts
@@ -0,0 +1,3 @@
+import { createAction } from '@reduxjs/toolkit';
+
+export const sessionReadyToInvoke = createAction('system/sessionReadyToInvoke');
diff --git a/invokeai/frontend/web/src/features/system/store/sessionSlice.ts b/invokeai/frontend/web/src/features/system/store/sessionSlice.ts
new file mode 100644
index 0000000000..40d59c7baa
--- /dev/null
+++ b/invokeai/frontend/web/src/features/system/store/sessionSlice.ts
@@ -0,0 +1,62 @@
+// TODO: split system slice inot this
+
+// import type { PayloadAction } from '@reduxjs/toolkit';
+// import { createSlice } from '@reduxjs/toolkit';
+// import { socketSubscribed, socketUnsubscribed } from 'services/events/actions';
+
+// export type SessionState = {
+// /**
+// * The current socket session id
+// */
+// sessionId: string;
+// /**
+// * Whether the current session is a canvas session. Needed to manage the staging area.
+// */
+// isCanvasSession: boolean;
+// /**
+// * When a session is canceled, its ID is stored here until a new session is created.
+// */
+// canceledSessionId: string;
+// };
+
+// export const initialSessionState: SessionState = {
+// sessionId: '',
+// isCanvasSession: false,
+// canceledSessionId: '',
+// };
+
+// export const sessionSlice = createSlice({
+// name: 'session',
+// initialState: initialSessionState,
+// reducers: {
+// sessionIdChanged: (state, action: PayloadAction) => {
+// state.sessionId = action.payload;
+// },
+// isCanvasSessionChanged: (state, action: PayloadAction) => {
+// state.isCanvasSession = action.payload;
+// },
+// },
+// extraReducers: (builder) => {
+// /**
+// * Socket Subscribed
+// */
+// builder.addCase(socketSubscribed, (state, action) => {
+// state.sessionId = action.payload.sessionId;
+// state.canceledSessionId = '';
+// });
+
+// /**
+// * Socket Unsubscribed
+// */
+// builder.addCase(socketUnsubscribed, (state) => {
+// state.sessionId = '';
+// });
+// },
+// });
+
+// export const { sessionIdChanged, isCanvasSessionChanged } =
+// sessionSlice.actions;
+
+// export default sessionSlice.reducer;
+
+export default {};
diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts
index 7331fcdba9..6bc8d7106a 100644
--- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts
+++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts
@@ -1,22 +1,15 @@
import { UseToastOptions } from '@chakra-ui/react';
-import type { PayloadAction } from '@reduxjs/toolkit';
+import { PayloadAction, isAnyOf } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai';
-import {
- generatorProgress,
- graphExecutionStateComplete,
- invocationComplete,
- invocationError,
- invocationStarted,
- socketConnected,
- socketDisconnected,
- socketSubscribed,
- socketUnsubscribed,
-} from 'services/events/actions';
import { ProgressImage } from 'services/events/types';
import { makeToast } from '../../../app/components/Toaster';
-import { sessionCanceled, sessionInvoked } from 'services/thunks/session';
+import {
+ sessionCanceled,
+ sessionCreated,
+ sessionInvoked,
+} from 'services/thunks/session';
import { receivedModels } from 'services/thunks/model';
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { LogLevelName } from 'roarr';
@@ -26,6 +19,17 @@ import { t } from 'i18next';
import { userInvoked } from 'app/store/actions';
import { LANGUAGES } from '../components/LanguagePicker';
import { imageUploaded } from 'services/thunks/image';
+import {
+ appSocketConnected,
+ appSocketDisconnected,
+ appSocketGeneratorProgress,
+ appSocketGraphExecutionStateComplete,
+ appSocketInvocationComplete,
+ appSocketInvocationError,
+ appSocketInvocationStarted,
+ appSocketSubscribed,
+ appSocketUnsubscribed,
+} from 'services/events/actions';
export type CancelStrategy = 'immediate' | 'scheduled';
@@ -215,12 +219,15 @@ export const systemSlice = createSlice({
languageChanged: (state, action: PayloadAction) => {
state.language = action.payload;
},
+ progressImageSet(state, action: PayloadAction) {
+ state.progressImage = action.payload;
+ },
},
extraReducers(builder) {
/**
* Socket Subscribed
*/
- builder.addCase(socketSubscribed, (state, action) => {
+ builder.addCase(appSocketSubscribed, (state, action) => {
state.sessionId = action.payload.sessionId;
state.canceledSession = '';
});
@@ -228,14 +235,14 @@ export const systemSlice = createSlice({
/**
* Socket Unsubscribed
*/
- builder.addCase(socketUnsubscribed, (state) => {
+ builder.addCase(appSocketUnsubscribed, (state) => {
state.sessionId = null;
});
/**
* Socket Connected
*/
- builder.addCase(socketConnected, (state) => {
+ builder.addCase(appSocketConnected, (state) => {
state.isConnected = true;
state.isCancelable = true;
state.isProcessing = false;
@@ -250,7 +257,7 @@ export const systemSlice = createSlice({
/**
* Socket Disconnected
*/
- builder.addCase(socketDisconnected, (state) => {
+ builder.addCase(appSocketDisconnected, (state) => {
state.isConnected = false;
state.isProcessing = false;
state.isCancelable = true;
@@ -265,7 +272,7 @@ export const systemSlice = createSlice({
/**
* Invocation Started
*/
- builder.addCase(invocationStarted, (state) => {
+ builder.addCase(appSocketInvocationStarted, (state) => {
state.isCancelable = true;
state.isProcessing = true;
state.currentStatusHasSteps = false;
@@ -279,7 +286,7 @@ export const systemSlice = createSlice({
/**
* Generator Progress
*/
- builder.addCase(generatorProgress, (state, action) => {
+ builder.addCase(appSocketGeneratorProgress, (state, action) => {
const { step, total_steps, progress_image } = action.payload.data;
state.isProcessing = true;
@@ -296,7 +303,7 @@ export const systemSlice = createSlice({
/**
* Invocation Complete
*/
- builder.addCase(invocationComplete, (state, action) => {
+ builder.addCase(appSocketInvocationComplete, (state, action) => {
const { data } = action.payload;
// state.currentIteration = 0;
@@ -305,7 +312,6 @@ export const systemSlice = createSlice({
state.currentStep = 0;
state.totalSteps = 0;
state.statusTranslationKey = 'common.statusProcessingComplete';
- state.progressImage = null;
if (state.canceledSession === data.graph_execution_state_id) {
state.isProcessing = false;
@@ -316,7 +322,7 @@ export const systemSlice = createSlice({
/**
* Invocation Error
*/
- builder.addCase(invocationError, (state) => {
+ builder.addCase(appSocketInvocationError, (state) => {
state.isProcessing = false;
state.isCancelable = true;
// state.currentIteration = 0;
@@ -333,7 +339,20 @@ export const systemSlice = createSlice({
});
/**
- * Session Invoked - PENDING
+ * Graph Execution State Complete
+ */
+ builder.addCase(appSocketGraphExecutionStateComplete, (state) => {
+ state.isProcessing = false;
+ state.isCancelable = false;
+ state.isCancelScheduled = false;
+ state.currentStep = 0;
+ state.totalSteps = 0;
+ state.statusTranslationKey = 'common.statusConnected';
+ state.progressImage = null;
+ });
+
+ /**
+ * User Invoked
*/
builder.addCase(userInvoked, (state) => {
@@ -343,15 +362,8 @@ export const systemSlice = createSlice({
state.statusTranslationKey = 'common.statusPreparing';
});
- builder.addCase(sessionInvoked.rejected, (state, action) => {
- const error = action.payload as string | undefined;
- state.toastQueue.push(
- makeToast({ title: error || t('toast.serverError'), status: 'error' })
- );
- });
-
/**
- * Session Canceled
+ * Session Canceled - FULFILLED
*/
builder.addCase(sessionCanceled.fulfilled, (state, action) => {
state.canceledSession = action.meta.arg.sessionId;
@@ -368,18 +380,6 @@ export const systemSlice = createSlice({
);
});
- /**
- * Session Canceled
- */
- builder.addCase(graphExecutionStateComplete, (state) => {
- state.isProcessing = false;
- state.isCancelable = false;
- state.isCancelScheduled = false;
- state.currentStep = 0;
- state.totalSteps = 0;
- state.statusTranslationKey = 'common.statusConnected';
- });
-
/**
* Received available models from the backend
*/
@@ -414,6 +414,26 @@ export const systemSlice = createSlice({
builder.addCase(imageUploaded.fulfilled, (state) => {
state.isUploading = false;
});
+
+ // *** Matchers - must be after all cases ***
+
+ /**
+ * Session Invoked - REJECTED
+ * Session Created - REJECTED
+ */
+ builder.addMatcher(isAnySessionRejected, (state, action) => {
+ state.isProcessing = false;
+ state.isCancelable = false;
+ state.isCancelScheduled = false;
+ state.currentStep = 0;
+ state.totalSteps = 0;
+ state.statusTranslationKey = 'common.statusConnected';
+ state.progressImage = null;
+
+ state.toastQueue.push(
+ makeToast({ title: t('toast.serverError'), status: 'error' })
+ );
+ });
},
});
@@ -438,6 +458,12 @@ export const {
isPersistedChanged,
shouldAntialiasProgressImageChanged,
languageChanged,
+ progressImageSet,
} = systemSlice.actions;
export default systemSlice.reducer;
+
+const isAnySessionRejected = isAnyOf(
+ sessionCreated.rejected,
+ sessionInvoked.rejected
+);
diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx
index f2529e5529..1b6b61f018 100644
--- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx
+++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx
@@ -7,11 +7,11 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
-import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth';
-import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight';
import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength';
import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit';
import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel';
+import ParamBoundingBoxWidth from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxWidth';
+import ParamBoundingBoxHeight from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight';
const selector = createSelector(
uiSelector,
@@ -41,8 +41,8 @@ const UnifiedCanvasCoreParameters = () => {
-
-
+
+
@@ -55,8 +55,8 @@ const UnifiedCanvasCoreParameters = () => {
-
-
+
+
)}
diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx
index 4aa68ad56a..c4501ffc44 100644
--- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx
+++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx
@@ -2,7 +2,6 @@ import ProcessButtons from 'features/parameters/components/ProcessButtons/Proces
import ParamSeedCollapse from 'features/parameters/components/Parameters/Seed/ParamSeedCollapse';
import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse';
-import ParamBoundingBoxCollapse from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse';
import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse';
import ParamSeamCorrectionCollapse from 'features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse';
import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters';
@@ -20,7 +19,6 @@ const UnifiedCanvasParameters = () => {
-
>
diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts
index 4893bb3bf6..65a48bc92c 100644
--- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts
+++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts
@@ -19,7 +19,7 @@ export const initialUIState: UIState = {
shouldPinGallery: true,
shouldShowGallery: true,
shouldHidePreview: false,
- shouldShowProgressInViewer: false,
+ shouldShowProgressInViewer: true,
schedulers: SCHEDULERS,
};
diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts
index ecf8621ed6..ff083079f9 100644
--- a/invokeai/frontend/web/src/services/api/index.ts
+++ b/invokeai/frontend/web/src/services/api/index.ts
@@ -7,8 +7,8 @@ export { OpenAPI } from './core/OpenAPI';
export type { OpenAPIConfig } from './core/OpenAPI';
export type { AddInvocation } from './models/AddInvocation';
-export type { BlurInvocation } from './models/BlurInvocation';
export type { Body_upload_image } from './models/Body_upload_image';
+export type { CannyImageProcessorInvocation } from './models/CannyImageProcessorInvocation';
export type { CkptModelInfo } from './models/CkptModelInfo';
export type { CollectInvocation } from './models/CollectInvocation';
export type { CollectInvocationOutput } from './models/CollectInvocationOutput';
@@ -16,26 +16,43 @@ export type { ColorField } from './models/ColorField';
export type { CompelInvocation } from './models/CompelInvocation';
export type { CompelOutput } from './models/CompelOutput';
export type { ConditioningField } from './models/ConditioningField';
+export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation';
+export type { ControlField } from './models/ControlField';
+export type { ControlNetInvocation } from './models/ControlNetInvocation';
+export type { ControlOutput } from './models/ControlOutput';
export type { CreateModelRequest } from './models/CreateModelRequest';
-export type { CropImageInvocation } from './models/CropImageInvocation';
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
export type { DiffusersModelInfo } from './models/DiffusersModelInfo';
export type { DivideInvocation } from './models/DivideInvocation';
export type { Edge } from './models/Edge';
export type { EdgeConnection } from './models/EdgeConnection';
+export type { FloatCollectionOutput } from './models/FloatCollectionOutput';
+export type { FloatOutput } from './models/FloatOutput';
export type { Graph } from './models/Graph';
export type { GraphExecutionState } from './models/GraphExecutionState';
export type { GraphInvocation } from './models/GraphInvocation';
export type { GraphInvocationOutput } from './models/GraphInvocationOutput';
+export type { HedImageprocessorInvocation } from './models/HedImageprocessorInvocation';
export type { HTTPValidationError } from './models/HTTPValidationError';
+export type { ImageBlurInvocation } from './models/ImageBlurInvocation';
export type { ImageCategory } from './models/ImageCategory';
+export type { ImageChannelInvocation } from './models/ImageChannelInvocation';
+export type { ImageConvertInvocation } from './models/ImageConvertInvocation';
+export type { ImageCropInvocation } from './models/ImageCropInvocation';
export type { ImageDTO } from './models/ImageDTO';
export type { ImageField } from './models/ImageField';
+export type { ImageInverseLerpInvocation } from './models/ImageInverseLerpInvocation';
+export type { ImageLerpInvocation } from './models/ImageLerpInvocation';
export type { ImageMetadata } from './models/ImageMetadata';
+export type { ImageMultiplyInvocation } from './models/ImageMultiplyInvocation';
export type { ImageOutput } from './models/ImageOutput';
+export type { ImagePasteInvocation } from './models/ImagePasteInvocation';
+export type { ImageProcessorInvocation } from './models/ImageProcessorInvocation';
+export type { ImageRecordChanges } from './models/ImageRecordChanges';
+export type { ImageResizeInvocation } from './models/ImageResizeInvocation';
+export type { ImageScaleInvocation } from './models/ImageScaleInvocation';
export type { ImageToImageInvocation } from './models/ImageToImageInvocation';
export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation';
-export type { ImageType } from './models/ImageType';
export type { ImageUrlsDTO } from './models/ImageUrlsDTO';
export type { InfillColorInvocation } from './models/InfillColorInvocation';
export type { InfillPatchMatchInvocation } from './models/InfillPatchMatchInvocation';
@@ -43,31 +60,38 @@ export type { InfillTileInvocation } from './models/InfillTileInvocation';
export type { InpaintInvocation } from './models/InpaintInvocation';
export type { IntCollectionOutput } from './models/IntCollectionOutput';
export type { IntOutput } from './models/IntOutput';
-export type { InverseLerpInvocation } from './models/InverseLerpInvocation';
export type { IterateInvocation } from './models/IterateInvocation';
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
export type { LatentsField } from './models/LatentsField';
export type { LatentsOutput } from './models/LatentsOutput';
export type { LatentsToImageInvocation } from './models/LatentsToImageInvocation';
export type { LatentsToLatentsInvocation } from './models/LatentsToLatentsInvocation';
-export type { LerpInvocation } from './models/LerpInvocation';
+export type { LineartAnimeImageProcessorInvocation } from './models/LineartAnimeImageProcessorInvocation';
+export type { LineartImageProcessorInvocation } from './models/LineartImageProcessorInvocation';
export type { LoadImageInvocation } from './models/LoadImageInvocation';
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
export type { MaskOutput } from './models/MaskOutput';
+export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation';
+export type { MidasDepthImageProcessorInvocation } from './models/MidasDepthImageProcessorInvocation';
+export type { MlsdImageProcessorInvocation } from './models/MlsdImageProcessorInvocation';
export type { ModelsList } from './models/ModelsList';
export type { MultiplyInvocation } from './models/MultiplyInvocation';
export type { NoiseInvocation } from './models/NoiseInvocation';
export type { NoiseOutput } from './models/NoiseOutput';
+export type { NormalbaeImageProcessorInvocation } from './models/NormalbaeImageProcessorInvocation';
+export type { OffsetPaginatedResults_ImageDTO_ } from './models/OffsetPaginatedResults_ImageDTO_';
+export type { OpenposeImageProcessorInvocation } from './models/OpenposeImageProcessorInvocation';
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
-export type { PaginatedResults_ImageDTO_ } from './models/PaginatedResults_ImageDTO_';
+export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
export type { ParamIntInvocation } from './models/ParamIntInvocation';
-export type { PasteImageInvocation } from './models/PasteImageInvocation';
+export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
export type { PromptOutput } from './models/PromptOutput';
export type { RandomIntInvocation } from './models/RandomIntInvocation';
export type { RandomRangeInvocation } from './models/RandomRangeInvocation';
export type { RangeInvocation } from './models/RangeInvocation';
export type { RangeOfSizeInvocation } from './models/RangeOfSizeInvocation';
export type { ResizeLatentsInvocation } from './models/ResizeLatentsInvocation';
+export type { ResourceOrigin } from './models/ResourceOrigin';
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
export type { ShowImageInvocation } from './models/ShowImageInvocation';
@@ -77,6 +101,7 @@ export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
export type { UpscaleInvocation } from './models/UpscaleInvocation';
export type { VaeRepo } from './models/VaeRepo';
export type { ValidationError } from './models/ValidationError';
+export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation';
export { ImagesService } from './services/ImagesService';
export { ModelsService } from './services/ModelsService';
diff --git a/invokeai/frontend/web/src/services/api/models/AddInvocation.ts b/invokeai/frontend/web/src/services/api/models/AddInvocation.ts
index 1ff7b010c2..e9671a918f 100644
--- a/invokeai/frontend/web/src/services/api/models/AddInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/AddInvocation.ts
@@ -10,6 +10,10 @@ export type AddInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'add';
/**
* The first number
diff --git a/invokeai/frontend/web/src/services/api/models/CannyImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/CannyImageProcessorInvocation.ts
new file mode 100644
index 0000000000..3a8b0b21e7
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/CannyImageProcessorInvocation.ts
@@ -0,0 +1,33 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Canny edge detection for ControlNet
+ */
+export type CannyImageProcessorInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'canny_image_processor';
+ /**
+ * image to process
+ */
+ image?: ImageField;
+ /**
+ * low threshold of Canny pixel gradient
+ */
+ low_threshold?: number;
+ /**
+ * high threshold of Canny pixel gradient
+ */
+ high_threshold?: number;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/CollectInvocation.ts b/invokeai/frontend/web/src/services/api/models/CollectInvocation.ts
index d250ae4450..f190ab7073 100644
--- a/invokeai/frontend/web/src/services/api/models/CollectInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/CollectInvocation.ts
@@ -10,6 +10,10 @@ export type CollectInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'collect';
/**
* The item to collect (all inputs must be of the same type)
diff --git a/invokeai/frontend/web/src/services/api/models/CompelInvocation.ts b/invokeai/frontend/web/src/services/api/models/CompelInvocation.ts
index f03d53a841..1dc390c1be 100644
--- a/invokeai/frontend/web/src/services/api/models/CompelInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/CompelInvocation.ts
@@ -10,6 +10,10 @@ export type CompelInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'compel';
/**
* Prompt
diff --git a/invokeai/frontend/web/src/services/api/models/ContentShuffleImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/ContentShuffleImageProcessorInvocation.ts
new file mode 100644
index 0000000000..d8bc3fe58e
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/ContentShuffleImageProcessorInvocation.ts
@@ -0,0 +1,45 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Applies content shuffle processing to image
+ */
+export type ContentShuffleImageProcessorInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'content_shuffle_image_processor';
+ /**
+ * image to process
+ */
+ image?: ImageField;
+ /**
+ * pixel resolution for edge detection
+ */
+ detect_resolution?: number;
+ /**
+ * pixel resolution for output image
+ */
+ image_resolution?: number;
+ /**
+ * content shuffle h parameter
+ */
+ 'h'?: number;
+ /**
+ * content shuffle w parameter
+ */
+ 'w'?: number;
+ /**
+ * cont
+ */
+ 'f'?: number;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/ControlField.ts b/invokeai/frontend/web/src/services/api/models/ControlField.ts
new file mode 100644
index 0000000000..4f493d4410
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/ControlField.ts
@@ -0,0 +1,29 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+export type ControlField = {
+ /**
+ * processed image
+ */
+ image: ImageField;
+ /**
+ * control model used
+ */
+ control_model: string;
+ /**
+ * weight given to controlnet
+ */
+ control_weight: number;
+ /**
+ * % of total steps at which controlnet is first applied
+ */
+ begin_step_percent: number;
+ /**
+ * % of total steps at which controlnet is last applied
+ */
+ end_step_percent: number;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/ControlNetInvocation.ts b/invokeai/frontend/web/src/services/api/models/ControlNetInvocation.ts
new file mode 100644
index 0000000000..fad3af911b
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/ControlNetInvocation.ts
@@ -0,0 +1,41 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Collects ControlNet info to pass to other nodes
+ */
+export type ControlNetInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'controlnet';
+ /**
+ * image to process
+ */
+ image?: ImageField;
+ /**
+ * control model used
+ */
+ control_model?: 'lllyasviel/sd-controlnet-canny' | 'lllyasviel/sd-controlnet-depth' | 'lllyasviel/sd-controlnet-hed' | 'lllyasviel/sd-controlnet-seg' | 'lllyasviel/sd-controlnet-openpose' | 'lllyasviel/sd-controlnet-scribble' | 'lllyasviel/sd-controlnet-normal' | 'lllyasviel/sd-controlnet-mlsd' | 'lllyasviel/control_v11p_sd15_canny' | 'lllyasviel/control_v11p_sd15_openpose' | 'lllyasviel/control_v11p_sd15_seg' | 'lllyasviel/control_v11f1p_sd15_depth' | 'lllyasviel/control_v11p_sd15_normalbae' | 'lllyasviel/control_v11p_sd15_scribble' | 'lllyasviel/control_v11p_sd15_mlsd' | 'lllyasviel/control_v11p_sd15_softedge' | 'lllyasviel/control_v11p_sd15s2_lineart_anime' | 'lllyasviel/control_v11p_sd15_lineart' | 'lllyasviel/control_v11p_sd15_inpaint' | 'lllyasviel/control_v11e_sd15_shuffle' | 'lllyasviel/control_v11e_sd15_ip2p' | 'lllyasviel/control_v11f1e_sd15_tile' | 'thibaud/controlnet-sd21-openpose-diffusers' | 'thibaud/controlnet-sd21-canny-diffusers' | 'thibaud/controlnet-sd21-depth-diffusers' | 'thibaud/controlnet-sd21-scribble-diffusers' | 'thibaud/controlnet-sd21-hed-diffusers' | 'thibaud/controlnet-sd21-zoedepth-diffusers' | 'thibaud/controlnet-sd21-color-diffusers' | 'thibaud/controlnet-sd21-openposev2-diffusers' | 'thibaud/controlnet-sd21-lineart-diffusers' | 'thibaud/controlnet-sd21-normalbae-diffusers' | 'thibaud/controlnet-sd21-ade20k-diffusers' | 'CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15' | 'CrucibleAI/ControlNetMediaPipeFace';
+ /**
+ * weight given to controlnet
+ */
+ control_weight?: number;
+ /**
+ * % of total steps at which controlnet is first applied
+ */
+ begin_step_percent?: number;
+ /**
+ * % of total steps at which controlnet is last applied
+ */
+ end_step_percent?: number;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/ControlOutput.ts b/invokeai/frontend/web/src/services/api/models/ControlOutput.ts
new file mode 100644
index 0000000000..43f1b3341c
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/ControlOutput.ts
@@ -0,0 +1,17 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ControlField } from './ControlField';
+
+/**
+ * node output for ControlNet info
+ */
+export type ControlOutput = {
+ type?: 'control_output';
+ /**
+ * The control info dict
+ */
+ control?: ControlField;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/CvInpaintInvocation.ts b/invokeai/frontend/web/src/services/api/models/CvInpaintInvocation.ts
index 19342acf8f..874df93c30 100644
--- a/invokeai/frontend/web/src/services/api/models/CvInpaintInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/CvInpaintInvocation.ts
@@ -12,6 +12,10 @@ export type CvInpaintInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'cv_inpaint';
/**
* The image to inpaint
diff --git a/invokeai/frontend/web/src/services/api/models/DivideInvocation.ts b/invokeai/frontend/web/src/services/api/models/DivideInvocation.ts
index 3cb262e9af..fd5b3475ae 100644
--- a/invokeai/frontend/web/src/services/api/models/DivideInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/DivideInvocation.ts
@@ -10,6 +10,10 @@ export type DivideInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'div';
/**
* The first number
diff --git a/invokeai/frontend/web/src/services/api/models/FloatCollectionOutput.ts b/invokeai/frontend/web/src/services/api/models/FloatCollectionOutput.ts
new file mode 100644
index 0000000000..a3f08247a4
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/FloatCollectionOutput.ts
@@ -0,0 +1,15 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+/**
+ * A collection of floats
+ */
+export type FloatCollectionOutput = {
+ type?: 'float_collection';
+ /**
+ * The float collection
+ */
+ collection?: Array;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/FloatOutput.ts b/invokeai/frontend/web/src/services/api/models/FloatOutput.ts
new file mode 100644
index 0000000000..2331936b30
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/FloatOutput.ts
@@ -0,0 +1,15 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+/**
+ * A float output
+ */
+export type FloatOutput = {
+ type?: 'float_output';
+ /**
+ * The output float
+ */
+ param?: number;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/Graph.ts b/invokeai/frontend/web/src/services/api/models/Graph.ts
index 039923e585..e89e815ab2 100644
--- a/invokeai/frontend/web/src/services/api/models/Graph.ts
+++ b/invokeai/frontend/web/src/services/api/models/Graph.ts
@@ -3,31 +3,50 @@
/* eslint-disable */
import type { AddInvocation } from './AddInvocation';
-import type { BlurInvocation } from './BlurInvocation';
+import type { CannyImageProcessorInvocation } from './CannyImageProcessorInvocation';
import type { CollectInvocation } from './CollectInvocation';
import type { CompelInvocation } from './CompelInvocation';
-import type { CropImageInvocation } from './CropImageInvocation';
+import type { ContentShuffleImageProcessorInvocation } from './ContentShuffleImageProcessorInvocation';
+import type { ControlNetInvocation } from './ControlNetInvocation';
import type { CvInpaintInvocation } from './CvInpaintInvocation';
import type { DivideInvocation } from './DivideInvocation';
import type { Edge } from './Edge';
import type { GraphInvocation } from './GraphInvocation';
+import type { HedImageprocessorInvocation } from './HedImageprocessorInvocation';
+import type { ImageBlurInvocation } from './ImageBlurInvocation';
+import type { ImageChannelInvocation } from './ImageChannelInvocation';
+import type { ImageConvertInvocation } from './ImageConvertInvocation';
+import type { ImageCropInvocation } from './ImageCropInvocation';
+import type { ImageInverseLerpInvocation } from './ImageInverseLerpInvocation';
+import type { ImageLerpInvocation } from './ImageLerpInvocation';
+import type { ImageMultiplyInvocation } from './ImageMultiplyInvocation';
+import type { ImagePasteInvocation } from './ImagePasteInvocation';
+import type { ImageProcessorInvocation } from './ImageProcessorInvocation';
+import type { ImageResizeInvocation } from './ImageResizeInvocation';
+import type { ImageScaleInvocation } from './ImageScaleInvocation';
import type { ImageToImageInvocation } from './ImageToImageInvocation';
import type { ImageToLatentsInvocation } from './ImageToLatentsInvocation';
import type { InfillColorInvocation } from './InfillColorInvocation';
import type { InfillPatchMatchInvocation } from './InfillPatchMatchInvocation';
import type { InfillTileInvocation } from './InfillTileInvocation';
import type { InpaintInvocation } from './InpaintInvocation';
-import type { InverseLerpInvocation } from './InverseLerpInvocation';
import type { IterateInvocation } from './IterateInvocation';
import type { LatentsToImageInvocation } from './LatentsToImageInvocation';
import type { LatentsToLatentsInvocation } from './LatentsToLatentsInvocation';
-import type { LerpInvocation } from './LerpInvocation';
+import type { LineartAnimeImageProcessorInvocation } from './LineartAnimeImageProcessorInvocation';
+import type { LineartImageProcessorInvocation } from './LineartImageProcessorInvocation';
import type { LoadImageInvocation } from './LoadImageInvocation';
import type { MaskFromAlphaInvocation } from './MaskFromAlphaInvocation';
+import type { MediapipeFaceProcessorInvocation } from './MediapipeFaceProcessorInvocation';
+import type { MidasDepthImageProcessorInvocation } from './MidasDepthImageProcessorInvocation';
+import type { MlsdImageProcessorInvocation } from './MlsdImageProcessorInvocation';
import type { MultiplyInvocation } from './MultiplyInvocation';
import type { NoiseInvocation } from './NoiseInvocation';
+import type { NormalbaeImageProcessorInvocation } from './NormalbaeImageProcessorInvocation';
+import type { OpenposeImageProcessorInvocation } from './OpenposeImageProcessorInvocation';
+import type { ParamFloatInvocation } from './ParamFloatInvocation';
import type { ParamIntInvocation } from './ParamIntInvocation';
-import type { PasteImageInvocation } from './PasteImageInvocation';
+import type { PidiImageProcessorInvocation } from './PidiImageProcessorInvocation';
import type { RandomIntInvocation } from './RandomIntInvocation';
import type { RandomRangeInvocation } from './RandomRangeInvocation';
import type { RangeInvocation } from './RangeInvocation';
@@ -40,6 +59,7 @@ import type { SubtractInvocation } from './SubtractInvocation';
import type { TextToImageInvocation } from './TextToImageInvocation';
import type { TextToLatentsInvocation } from './TextToLatentsInvocation';
import type { UpscaleInvocation } from './UpscaleInvocation';
+import type { ZoeDepthImageProcessorInvocation } from './ZoeDepthImageProcessorInvocation';
export type Graph = {
/**
@@ -49,7 +69,7 @@ export type Graph = {
/**
* The nodes in this graph
*/
- nodes?: Record;
+ nodes?: Record;
/**
* The connections between nodes and their fields in this graph
*/
diff --git a/invokeai/frontend/web/src/services/api/models/GraphExecutionState.ts b/invokeai/frontend/web/src/services/api/models/GraphExecutionState.ts
index 8c2eb05657..ea41ce055b 100644
--- a/invokeai/frontend/web/src/services/api/models/GraphExecutionState.ts
+++ b/invokeai/frontend/web/src/services/api/models/GraphExecutionState.ts
@@ -4,6 +4,9 @@
import type { CollectInvocationOutput } from './CollectInvocationOutput';
import type { CompelOutput } from './CompelOutput';
+import type { ControlOutput } from './ControlOutput';
+import type { FloatCollectionOutput } from './FloatCollectionOutput';
+import type { FloatOutput } from './FloatOutput';
import type { Graph } from './Graph';
import type { GraphInvocationOutput } from './GraphInvocationOutput';
import type { ImageOutput } from './ImageOutput';
@@ -42,7 +45,7 @@ export type GraphExecutionState = {
/**
* The results of node executions
*/
- results: Record;
+ results: Record;
/**
* Errors raised when executing nodes
*/
diff --git a/invokeai/frontend/web/src/services/api/models/GraphInvocation.ts b/invokeai/frontend/web/src/services/api/models/GraphInvocation.ts
index 5109a49a68..8512faae74 100644
--- a/invokeai/frontend/web/src/services/api/models/GraphInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/GraphInvocation.ts
@@ -5,14 +5,17 @@
import type { Graph } from './Graph';
/**
- * A node to process inputs and produce outputs.
- * May use dependency injection in __init__ to receive providers.
+ * Execute a graph
*/
export type GraphInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'graph';
/**
* The graph to run
diff --git a/invokeai/frontend/web/src/services/api/models/HedImageprocessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/HedImageprocessorInvocation.ts
new file mode 100644
index 0000000000..f975f18968
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/HedImageprocessorInvocation.ts
@@ -0,0 +1,37 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Applies HED edge detection to image
+ */
+export type HedImageprocessorInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'hed_image_processor';
+ /**
+ * image to process
+ */
+ image?: ImageField;
+ /**
+ * pixel resolution for edge detection
+ */
+ detect_resolution?: number;
+ /**
+ * pixel resolution for output image
+ */
+ image_resolution?: number;
+ /**
+ * whether to use scribble mode
+ */
+ scribble?: boolean;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/BlurInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageBlurInvocation.ts
similarity index 72%
rename from invokeai/frontend/web/src/services/api/models/BlurInvocation.ts
rename to invokeai/frontend/web/src/services/api/models/ImageBlurInvocation.ts
index 0643e4b309..3ba86d8fab 100644
--- a/invokeai/frontend/web/src/services/api/models/BlurInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/ImageBlurInvocation.ts
@@ -7,12 +7,16 @@ import type { ImageField } from './ImageField';
/**
* Blurs an image
*/
-export type BlurInvocation = {
+export type ImageBlurInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
- type?: 'blur';
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'img_blur';
/**
* The image to blur
*/
diff --git a/invokeai/frontend/web/src/services/api/models/ImageCategory.ts b/invokeai/frontend/web/src/services/api/models/ImageCategory.ts
index c4edf90fd3..84551d3cd6 100644
--- a/invokeai/frontend/web/src/services/api/models/ImageCategory.ts
+++ b/invokeai/frontend/web/src/services/api/models/ImageCategory.ts
@@ -3,6 +3,12 @@
/* eslint-disable */
/**
- * The category of an image. Use ImageCategory.OTHER for non-default categories.
+ * The category of an image.
+ *
+ * - GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
+ * - MASK: The image is a mask image.
+ * - CONTROL: The image is a ControlNet control image.
+ * - USER: The image is a user-provide image.
+ * - OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
*/
-export type ImageCategory = 'general' | 'control' | 'other';
+export type ImageCategory = 'general' | 'mask' | 'control' | 'user' | 'other';
diff --git a/invokeai/frontend/web/src/services/api/models/ImageChannelInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageChannelInvocation.ts
new file mode 100644
index 0000000000..47bfd4110f
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/ImageChannelInvocation.ts
@@ -0,0 +1,29 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Gets a channel from an image.
+ */
+export type ImageChannelInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'img_chan';
+ /**
+ * The image to get the channel from
+ */
+ image?: ImageField;
+ /**
+ * The channel to get
+ */
+ channel?: 'A' | 'R' | 'G' | 'B';
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/ImageConvertInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageConvertInvocation.ts
new file mode 100644
index 0000000000..4bd59d03b0
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/ImageConvertInvocation.ts
@@ -0,0 +1,29 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Converts an image to a different mode.
+ */
+export type ImageConvertInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'img_conv';
+ /**
+ * The image to convert
+ */
+ image?: ImageField;
+ /**
+ * The mode to convert to
+ */
+ mode?: 'L' | 'RGB' | 'RGBA' | 'CMYK' | 'YCbCr' | 'LAB' | 'HSV' | 'I' | 'F';
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/CropImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageCropInvocation.ts
similarity index 80%
rename from invokeai/frontend/web/src/services/api/models/CropImageInvocation.ts
rename to invokeai/frontend/web/src/services/api/models/ImageCropInvocation.ts
index 2676f5cb87..5207ebbf6d 100644
--- a/invokeai/frontend/web/src/services/api/models/CropImageInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/ImageCropInvocation.ts
@@ -7,12 +7,16 @@ import type { ImageField } from './ImageField';
/**
* Crops an image to a specified box. The box can be outside of the image.
*/
-export type CropImageInvocation = {
+export type ImageCropInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
- type?: 'crop';
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'img_crop';
/**
* The image to crop
*/
diff --git a/invokeai/frontend/web/src/services/api/models/ImageDTO.ts b/invokeai/frontend/web/src/services/api/models/ImageDTO.ts
index c5377b4c76..f5f2603b03 100644
--- a/invokeai/frontend/web/src/services/api/models/ImageDTO.ts
+++ b/invokeai/frontend/web/src/services/api/models/ImageDTO.ts
@@ -4,7 +4,7 @@
import type { ImageCategory } from './ImageCategory';
import type { ImageMetadata } from './ImageMetadata';
-import type { ImageType } from './ImageType';
+import type { ResourceOrigin } from './ResourceOrigin';
/**
* Deserialized image record, enriched for the frontend with URLs.
@@ -17,7 +17,7 @@ export type ImageDTO = {
/**
* The type of the image.
*/
- image_type: ImageType;
+ image_origin: ResourceOrigin;
/**
* The URL of the image.
*/
@@ -50,6 +50,10 @@ export type ImageDTO = {
* The deleted timestamp of the image.
*/
deleted_at?: string;
+ /**
+ * Whether this is an intermediate image.
+ */
+ is_intermediate: boolean;
/**
* The session ID that generated this image, if it is a generated image.
*/
diff --git a/invokeai/frontend/web/src/services/api/models/ImageField.ts b/invokeai/frontend/web/src/services/api/models/ImageField.ts
index fa22ae8007..63a12f4730 100644
--- a/invokeai/frontend/web/src/services/api/models/ImageField.ts
+++ b/invokeai/frontend/web/src/services/api/models/ImageField.ts
@@ -2,7 +2,7 @@
/* tslint:disable */
/* eslint-disable */
-import type { ImageType } from './ImageType';
+import type { ResourceOrigin } from './ResourceOrigin';
/**
* An image field used for passing image objects between invocations
@@ -11,7 +11,7 @@ export type ImageField = {
/**
* The type of the image
*/
- image_type: ImageType;
+ image_origin: ResourceOrigin;
/**
* The name of the image
*/
diff --git a/invokeai/frontend/web/src/services/api/models/InverseLerpInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageInverseLerpInvocation.ts
similarity index 73%
rename from invokeai/frontend/web/src/services/api/models/InverseLerpInvocation.ts
rename to invokeai/frontend/web/src/services/api/models/ImageInverseLerpInvocation.ts
index 33c59b7bac..0347d4dc38 100644
--- a/invokeai/frontend/web/src/services/api/models/InverseLerpInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/ImageInverseLerpInvocation.ts
@@ -7,12 +7,16 @@ import type { ImageField } from './ImageField';
/**
* Inverse linear interpolation of all pixels of an image
*/
-export type InverseLerpInvocation = {
+export type ImageInverseLerpInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
- type?: 'ilerp';
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'img_ilerp';
/**
* The image to lerp
*/
diff --git a/invokeai/frontend/web/src/services/api/models/LerpInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageLerpInvocation.ts
similarity index 74%
rename from invokeai/frontend/web/src/services/api/models/LerpInvocation.ts
rename to invokeai/frontend/web/src/services/api/models/ImageLerpInvocation.ts
index f2406c2246..388c86061c 100644
--- a/invokeai/frontend/web/src/services/api/models/LerpInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/ImageLerpInvocation.ts
@@ -7,12 +7,16 @@ import type { ImageField } from './ImageField';
/**
* Linear interpolation of all pixels of an image
*/
-export type LerpInvocation = {
+export type ImageLerpInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
- type?: 'lerp';
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'img_lerp';
/**
* The image to lerp
*/
diff --git a/invokeai/frontend/web/src/services/api/models/ImageMultiplyInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageMultiplyInvocation.ts
new file mode 100644
index 0000000000..751ee49158
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/ImageMultiplyInvocation.ts
@@ -0,0 +1,29 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Multiplies two images together using `PIL.ImageChops.multiply()`.
+ */
+export type ImageMultiplyInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'img_mul';
+ /**
+ * The first image to multiply
+ */
+ image1?: ImageField;
+ /**
+ * The second image to multiply
+ */
+ image2?: ImageField;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/PasteImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImagePasteInvocation.ts
similarity index 79%
rename from invokeai/frontend/web/src/services/api/models/PasteImageInvocation.ts
rename to invokeai/frontend/web/src/services/api/models/ImagePasteInvocation.ts
index 8a181ccf07..c883b9a5d8 100644
--- a/invokeai/frontend/web/src/services/api/models/PasteImageInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/ImagePasteInvocation.ts
@@ -7,12 +7,16 @@ import type { ImageField } from './ImageField';
/**
* Pastes an image into another image.
*/
-export type PasteImageInvocation = {
+export type ImagePasteInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
- type?: 'paste';
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'img_paste';
/**
* The base image
*/
diff --git a/invokeai/frontend/web/src/services/api/models/ImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageProcessorInvocation.ts
new file mode 100644
index 0000000000..f972582e2f
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/ImageProcessorInvocation.ts
@@ -0,0 +1,25 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Base class for invocations that preprocess images for ControlNet
+ */
+export type ImageProcessorInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'image_processor';
+ /**
+ * image to process
+ */
+ image?: ImageField;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/ImageRecordChanges.ts b/invokeai/frontend/web/src/services/api/models/ImageRecordChanges.ts
new file mode 100644
index 0000000000..e597cd907d
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/ImageRecordChanges.ts
@@ -0,0 +1,29 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageCategory } from './ImageCategory';
+
+/**
+ * A set of changes to apply to an image record.
+ *
+ * Only limited changes are valid:
+ * - `image_category`: change the category of an image
+ * - `session_id`: change the session associated with an image
+ * - `is_intermediate`: change the image's `is_intermediate` flag
+ */
+export type ImageRecordChanges = {
+ /**
+ * The image's new category.
+ */
+ image_category?: ImageCategory;
+ /**
+ * The image's new session ID.
+ */
+ session_id?: string;
+ /**
+ * The image's new `is_intermediate` flag.
+ */
+ is_intermediate?: boolean;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/ImageResizeInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageResizeInvocation.ts
new file mode 100644
index 0000000000..3b096c83b7
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/ImageResizeInvocation.ts
@@ -0,0 +1,37 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Resizes an image to specific dimensions
+ */
+export type ImageResizeInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'img_resize';
+ /**
+ * The image to resize
+ */
+ image?: ImageField;
+ /**
+ * The width to resize to (px)
+ */
+ width: number;
+ /**
+ * The height to resize to (px)
+ */
+ height: number;
+ /**
+ * The resampling mode
+ */
+ resample_mode?: 'nearest' | 'box' | 'bilinear' | 'hamming' | 'bicubic' | 'lanczos';
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/ImageScaleInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageScaleInvocation.ts
new file mode 100644
index 0000000000..bf4da28a4a
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/ImageScaleInvocation.ts
@@ -0,0 +1,33 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Scales an image by a factor
+ */
+export type ImageScaleInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'img_scale';
+ /**
+ * The image to scale
+ */
+ image?: ImageField;
+ /**
+ * The factor by which to scale the image
+ */
+ scale_factor: number;
+ /**
+ * The resampling mode
+ */
+ resample_mode?: 'nearest' | 'box' | 'bilinear' | 'hamming' | 'bicubic' | 'lanczos';
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/ImageToImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageToImageInvocation.ts
index fb43c76921..e63ec93ada 100644
--- a/invokeai/frontend/web/src/services/api/models/ImageToImageInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/ImageToImageInvocation.ts
@@ -12,6 +12,10 @@ export type ImageToImageInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'img2img';
/**
* The prompt to generate an image from
@@ -45,6 +49,18 @@ export type ImageToImageInvocation = {
* The model to use (currently ignored)
*/
model?: string;
+ /**
+ * Whether or not to produce progress images during generation
+ */
+ progress_images?: boolean;
+ /**
+ * The control model to use
+ */
+ control_model?: string;
+ /**
+ * The processed control image
+ */
+ control_image?: ImageField;
/**
* The input image
*/
diff --git a/invokeai/frontend/web/src/services/api/models/ImageToLatentsInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageToLatentsInvocation.ts
index f72d446615..5569c2fa86 100644
--- a/invokeai/frontend/web/src/services/api/models/ImageToLatentsInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/ImageToLatentsInvocation.ts
@@ -12,6 +12,10 @@ export type ImageToLatentsInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'i2l';
/**
* The image to encode
diff --git a/invokeai/frontend/web/src/services/api/models/ImageType.ts b/invokeai/frontend/web/src/services/api/models/ImageType.ts
deleted file mode 100644
index bba9134e63..0000000000
--- a/invokeai/frontend/web/src/services/api/models/ImageType.ts
+++ /dev/null
@@ -1,8 +0,0 @@
-/* istanbul ignore file */
-/* tslint:disable */
-/* eslint-disable */
-
-/**
- * The type of an image.
- */
-export type ImageType = 'results' | 'uploads' | 'intermediates';
diff --git a/invokeai/frontend/web/src/services/api/models/ImageUrlsDTO.ts b/invokeai/frontend/web/src/services/api/models/ImageUrlsDTO.ts
index af80519ef2..81639be9b3 100644
--- a/invokeai/frontend/web/src/services/api/models/ImageUrlsDTO.ts
+++ b/invokeai/frontend/web/src/services/api/models/ImageUrlsDTO.ts
@@ -2,7 +2,7 @@
/* tslint:disable */
/* eslint-disable */
-import type { ImageType } from './ImageType';
+import type { ResourceOrigin } from './ResourceOrigin';
/**
* The URLs for an image and its thumbnail.
@@ -15,7 +15,7 @@ export type ImageUrlsDTO = {
/**
* The type of the image.
*/
- image_type: ImageType;
+ image_origin: ResourceOrigin;
/**
* The URL of the image.
*/
diff --git a/invokeai/frontend/web/src/services/api/models/InfillColorInvocation.ts b/invokeai/frontend/web/src/services/api/models/InfillColorInvocation.ts
index 157c976e11..3e637b299c 100644
--- a/invokeai/frontend/web/src/services/api/models/InfillColorInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/InfillColorInvocation.ts
@@ -13,6 +13,10 @@ export type InfillColorInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'infill_rgba';
/**
* The image to infill
diff --git a/invokeai/frontend/web/src/services/api/models/InfillPatchMatchInvocation.ts b/invokeai/frontend/web/src/services/api/models/InfillPatchMatchInvocation.ts
index a4c18ade5d..325bfe2080 100644
--- a/invokeai/frontend/web/src/services/api/models/InfillPatchMatchInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/InfillPatchMatchInvocation.ts
@@ -12,6 +12,10 @@ export type InfillPatchMatchInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'infill_patchmatch';
/**
* The image to infill
diff --git a/invokeai/frontend/web/src/services/api/models/InfillTileInvocation.ts b/invokeai/frontend/web/src/services/api/models/InfillTileInvocation.ts
index 12113f57f5..dfb1cbc61d 100644
--- a/invokeai/frontend/web/src/services/api/models/InfillTileInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/InfillTileInvocation.ts
@@ -12,6 +12,10 @@ export type InfillTileInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'infill_tile';
/**
* The image to infill
diff --git a/invokeai/frontend/web/src/services/api/models/InpaintInvocation.ts b/invokeai/frontend/web/src/services/api/models/InpaintInvocation.ts
index 88ead9907c..b8ed268ef9 100644
--- a/invokeai/frontend/web/src/services/api/models/InpaintInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/InpaintInvocation.ts
@@ -13,6 +13,10 @@ export type InpaintInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'inpaint';
/**
* The prompt to generate an image from
@@ -46,6 +50,18 @@ export type InpaintInvocation = {
* The model to use (currently ignored)
*/
model?: string;
+ /**
+ * Whether or not to produce progress images during generation
+ */
+ progress_images?: boolean;
+ /**
+ * The control model to use
+ */
+ control_model?: string;
+ /**
+ * The processed control image
+ */
+ control_image?: ImageField;
/**
* The input image
*/
diff --git a/invokeai/frontend/web/src/services/api/models/IterateInvocation.ts b/invokeai/frontend/web/src/services/api/models/IterateInvocation.ts
index 0ff7a1258d..15bf92dfea 100644
--- a/invokeai/frontend/web/src/services/api/models/IterateInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/IterateInvocation.ts
@@ -3,14 +3,17 @@
/* eslint-disable */
/**
- * A node to process inputs and produce outputs.
- * May use dependency injection in __init__ to receive providers.
+ * Iterates over a list of items
*/
export type IterateInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'iterate';
/**
* The list of items to iterate over
diff --git a/invokeai/frontend/web/src/services/api/models/LatentsToImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/LatentsToImageInvocation.ts
index 8acd872e28..fcaa37d7e8 100644
--- a/invokeai/frontend/web/src/services/api/models/LatentsToImageInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/LatentsToImageInvocation.ts
@@ -12,6 +12,10 @@ export type LatentsToImageInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'l2i';
/**
* The latents to generate an image from
diff --git a/invokeai/frontend/web/src/services/api/models/LatentsToLatentsInvocation.ts b/invokeai/frontend/web/src/services/api/models/LatentsToLatentsInvocation.ts
index 29995c6ad9..f5b4912141 100644
--- a/invokeai/frontend/web/src/services/api/models/LatentsToLatentsInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/LatentsToLatentsInvocation.ts
@@ -3,6 +3,7 @@
/* eslint-disable */
import type { ConditioningField } from './ConditioningField';
+import type { ControlField } from './ControlField';
import type { LatentsField } from './LatentsField';
/**
@@ -13,6 +14,10 @@ export type LatentsToLatentsInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'l2l';
/**
* Positive conditioning for generation
@@ -43,13 +48,9 @@ export type LatentsToLatentsInvocation = {
*/
model?: string;
/**
- * Whether or not to generate an image that can tile without seams
+ * The control to use
*/
- seamless?: boolean;
- /**
- * The axes to tile the image on, 'x' and/or 'y'
- */
- seamless_axes?: string;
+ control?: (ControlField | Array);
/**
* The latents to use as a base image
*/
diff --git a/invokeai/frontend/web/src/services/api/models/LineartAnimeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/LineartAnimeImageProcessorInvocation.ts
new file mode 100644
index 0000000000..4796d2a049
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/LineartAnimeImageProcessorInvocation.ts
@@ -0,0 +1,33 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Applies line art anime processing to image
+ */
+export type LineartAnimeImageProcessorInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'lineart_anime_image_processor';
+ /**
+ * image to process
+ */
+ image?: ImageField;
+ /**
+ * pixel resolution for edge detection
+ */
+ detect_resolution?: number;
+ /**
+ * pixel resolution for output image
+ */
+ image_resolution?: number;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/LineartImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/LineartImageProcessorInvocation.ts
new file mode 100644
index 0000000000..8328849b50
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/LineartImageProcessorInvocation.ts
@@ -0,0 +1,37 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Applies line art processing to image
+ */
+export type LineartImageProcessorInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'lineart_image_processor';
+ /**
+ * image to process
+ */
+ image?: ImageField;
+ /**
+ * pixel resolution for edge detection
+ */
+ detect_resolution?: number;
+ /**
+ * pixel resolution for output image
+ */
+ image_resolution?: number;
+ /**
+ * whether to use coarse mode
+ */
+ coarse?: boolean;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/LoadImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/LoadImageInvocation.ts
index 745a9b44e4..f20d983f9b 100644
--- a/invokeai/frontend/web/src/services/api/models/LoadImageInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/LoadImageInvocation.ts
@@ -2,7 +2,7 @@
/* tslint:disable */
/* eslint-disable */
-import type { ImageType } from './ImageType';
+import type { ImageField } from './ImageField';
/**
* Load an image and provide it as output.
@@ -12,14 +12,14 @@ export type LoadImageInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'load_image';
/**
- * The type of the image
+ * The image to load
*/
- image_type: ImageType;
- /**
- * The name of the image
- */
- image_name: string;
+ image?: ImageField;
};
diff --git a/invokeai/frontend/web/src/services/api/models/MaskFromAlphaInvocation.ts b/invokeai/frontend/web/src/services/api/models/MaskFromAlphaInvocation.ts
index e71b1f464b..e3693f6d98 100644
--- a/invokeai/frontend/web/src/services/api/models/MaskFromAlphaInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/MaskFromAlphaInvocation.ts
@@ -12,6 +12,10 @@ export type MaskFromAlphaInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'tomask';
/**
* The image to create the mask from
diff --git a/invokeai/frontend/web/src/services/api/models/MediapipeFaceProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/MediapipeFaceProcessorInvocation.ts
new file mode 100644
index 0000000000..bd223eed7d
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/MediapipeFaceProcessorInvocation.ts
@@ -0,0 +1,33 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Applies mediapipe face processing to image
+ */
+export type MediapipeFaceProcessorInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'mediapipe_face_processor';
+ /**
+ * image to process
+ */
+ image?: ImageField;
+ /**
+ * maximum number of faces to detect
+ */
+ max_faces?: number;
+ /**
+ * minimum confidence for face detection
+ */
+ min_confidence?: number;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/MidasDepthImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/MidasDepthImageProcessorInvocation.ts
new file mode 100644
index 0000000000..11023086a2
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/MidasDepthImageProcessorInvocation.ts
@@ -0,0 +1,33 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Applies Midas depth processing to image
+ */
+export type MidasDepthImageProcessorInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'midas_depth_image_processor';
+ /**
+ * image to process
+ */
+ image?: ImageField;
+ /**
+ * Midas parameter a = amult * PI
+ */
+ a_mult?: number;
+ /**
+ * Midas parameter bg_th
+ */
+ bg_th?: number;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/MlsdImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/MlsdImageProcessorInvocation.ts
new file mode 100644
index 0000000000..c2d4a61b9a
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/MlsdImageProcessorInvocation.ts
@@ -0,0 +1,41 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Applies MLSD processing to image
+ */
+export type MlsdImageProcessorInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'mlsd_image_processor';
+ /**
+ * image to process
+ */
+ image?: ImageField;
+ /**
+ * pixel resolution for edge detection
+ */
+ detect_resolution?: number;
+ /**
+ * pixel resolution for output image
+ */
+ image_resolution?: number;
+ /**
+ * MLSD parameter thr_v
+ */
+ thr_v?: number;
+ /**
+ * MLSD parameter thr_d
+ */
+ thr_d?: number;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/MultiplyInvocation.ts b/invokeai/frontend/web/src/services/api/models/MultiplyInvocation.ts
index eede8f18d7..9fd716f33d 100644
--- a/invokeai/frontend/web/src/services/api/models/MultiplyInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/MultiplyInvocation.ts
@@ -10,6 +10,10 @@ export type MultiplyInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'mul';
/**
* The first number
diff --git a/invokeai/frontend/web/src/services/api/models/NoiseInvocation.ts b/invokeai/frontend/web/src/services/api/models/NoiseInvocation.ts
index 59e50b76f3..239a24bfe5 100644
--- a/invokeai/frontend/web/src/services/api/models/NoiseInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/NoiseInvocation.ts
@@ -10,6 +10,10 @@ export type NoiseInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'noise';
/**
* The seed to use
diff --git a/invokeai/frontend/web/src/services/api/models/NormalbaeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/NormalbaeImageProcessorInvocation.ts
new file mode 100644
index 0000000000..ecfb50a09f
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/NormalbaeImageProcessorInvocation.ts
@@ -0,0 +1,33 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Applies NormalBae processing to image
+ */
+export type NormalbaeImageProcessorInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'normalbae_image_processor';
+ /**
+ * image to process
+ */
+ image?: ImageField;
+ /**
+ * pixel resolution for edge detection
+ */
+ detect_resolution?: number;
+ /**
+ * pixel resolution for output image
+ */
+ image_resolution?: number;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/PaginatedResults_ImageDTO_.ts b/invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_ImageDTO_.ts
similarity index 56%
rename from invokeai/frontend/web/src/services/api/models/PaginatedResults_ImageDTO_.ts
rename to invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_ImageDTO_.ts
index 5d2bdae5ab..3408bea6db 100644
--- a/invokeai/frontend/web/src/services/api/models/PaginatedResults_ImageDTO_.ts
+++ b/invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_ImageDTO_.ts
@@ -5,25 +5,21 @@
import type { ImageDTO } from './ImageDTO';
/**
- * Paginated results
+ * Offset-paginated results
*/
-export type PaginatedResults_ImageDTO_ = {
+export type OffsetPaginatedResults_ImageDTO_ = {
/**
* Items
*/
items: Array;
/**
- * Current Page
+ * Offset from which to retrieve items
*/
- page: number;
+ offset: number;
/**
- * Total number of pages
+ * Limit of items to get
*/
- pages: number;
- /**
- * Number of items per page
- */
- per_page: number;
+ limit: number;
/**
* Total number of items in result
*/
diff --git a/invokeai/frontend/web/src/services/api/models/OpenposeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/OpenposeImageProcessorInvocation.ts
new file mode 100644
index 0000000000..5af21d542e
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/OpenposeImageProcessorInvocation.ts
@@ -0,0 +1,37 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Applies Openpose processing to image
+ */
+export type OpenposeImageProcessorInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'openpose_image_processor';
+ /**
+ * image to process
+ */
+ image?: ImageField;
+ /**
+ * whether to use hands and face mode
+ */
+ hand_and_face?: boolean;
+ /**
+ * pixel resolution for edge detection
+ */
+ detect_resolution?: number;
+ /**
+ * pixel resolution for output image
+ */
+ image_resolution?: number;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/ParamFloatInvocation.ts b/invokeai/frontend/web/src/services/api/models/ParamFloatInvocation.ts
new file mode 100644
index 0000000000..87c01f847f
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/ParamFloatInvocation.ts
@@ -0,0 +1,23 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+/**
+ * A float parameter
+ */
+export type ParamFloatInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'param_float';
+ /**
+ * The float value
+ */
+ param?: number;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/ParamIntInvocation.ts b/invokeai/frontend/web/src/services/api/models/ParamIntInvocation.ts
index 7047310a87..7a45d0a0ac 100644
--- a/invokeai/frontend/web/src/services/api/models/ParamIntInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/ParamIntInvocation.ts
@@ -10,6 +10,10 @@ export type ParamIntInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'param_int';
/**
* The integer value
diff --git a/invokeai/frontend/web/src/services/api/models/PidiImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/PidiImageProcessorInvocation.ts
new file mode 100644
index 0000000000..a08bf6a920
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/PidiImageProcessorInvocation.ts
@@ -0,0 +1,41 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Applies PIDI processing to image
+ */
+export type PidiImageProcessorInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'pidi_image_processor';
+ /**
+ * image to process
+ */
+ image?: ImageField;
+ /**
+ * pixel resolution for edge detection
+ */
+ detect_resolution?: number;
+ /**
+ * pixel resolution for output image
+ */
+ image_resolution?: number;
+ /**
+ * whether to use safe mode
+ */
+ safe?: boolean;
+ /**
+ * whether to use scribble mode
+ */
+ scribble?: boolean;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/RandomIntInvocation.ts b/invokeai/frontend/web/src/services/api/models/RandomIntInvocation.ts
index 0a5220c31d..a2f7c2f02a 100644
--- a/invokeai/frontend/web/src/services/api/models/RandomIntInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/RandomIntInvocation.ts
@@ -10,6 +10,10 @@ export type RandomIntInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'rand_int';
/**
* The inclusive low value
diff --git a/invokeai/frontend/web/src/services/api/models/RandomRangeInvocation.ts b/invokeai/frontend/web/src/services/api/models/RandomRangeInvocation.ts
index c1f80042a6..925511578d 100644
--- a/invokeai/frontend/web/src/services/api/models/RandomRangeInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/RandomRangeInvocation.ts
@@ -10,6 +10,10 @@ export type RandomRangeInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'random_range';
/**
* The inclusive low value
diff --git a/invokeai/frontend/web/src/services/api/models/RangeInvocation.ts b/invokeai/frontend/web/src/services/api/models/RangeInvocation.ts
index 1c37ca7fe3..3681602a95 100644
--- a/invokeai/frontend/web/src/services/api/models/RangeInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/RangeInvocation.ts
@@ -10,6 +10,10 @@ export type RangeInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'range';
/**
* The start of the range
diff --git a/invokeai/frontend/web/src/services/api/models/RangeOfSizeInvocation.ts b/invokeai/frontend/web/src/services/api/models/RangeOfSizeInvocation.ts
index b918f17130..7dfac68d39 100644
--- a/invokeai/frontend/web/src/services/api/models/RangeOfSizeInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/RangeOfSizeInvocation.ts
@@ -10,6 +10,10 @@ export type RangeOfSizeInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'range_of_size';
/**
* The start of the range
diff --git a/invokeai/frontend/web/src/services/api/models/ResizeLatentsInvocation.ts b/invokeai/frontend/web/src/services/api/models/ResizeLatentsInvocation.ts
index c0fabb4984..9a7b6c61e4 100644
--- a/invokeai/frontend/web/src/services/api/models/ResizeLatentsInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/ResizeLatentsInvocation.ts
@@ -12,6 +12,10 @@ export type ResizeLatentsInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'lresize';
/**
* The latents to resize
diff --git a/invokeai/frontend/web/src/services/api/models/ResourceOrigin.ts b/invokeai/frontend/web/src/services/api/models/ResourceOrigin.ts
new file mode 100644
index 0000000000..a82edda0c1
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/ResourceOrigin.ts
@@ -0,0 +1,12 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+/**
+ * The origin of a resource (eg image).
+ *
+ * - INTERNAL: The resource was created by the application.
+ * - EXTERNAL: The resource was not created by the application.
+ * This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
+ */
+export type ResourceOrigin = 'internal' | 'external';
diff --git a/invokeai/frontend/web/src/services/api/models/RestoreFaceInvocation.ts b/invokeai/frontend/web/src/services/api/models/RestoreFaceInvocation.ts
index e03ed01c81..0bacb5d805 100644
--- a/invokeai/frontend/web/src/services/api/models/RestoreFaceInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/RestoreFaceInvocation.ts
@@ -12,6 +12,10 @@ export type RestoreFaceInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'restore_face';
/**
* The input image
diff --git a/invokeai/frontend/web/src/services/api/models/ScaleLatentsInvocation.ts b/invokeai/frontend/web/src/services/api/models/ScaleLatentsInvocation.ts
index f398eaf408..506b21e540 100644
--- a/invokeai/frontend/web/src/services/api/models/ScaleLatentsInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/ScaleLatentsInvocation.ts
@@ -12,6 +12,10 @@ export type ScaleLatentsInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'lscale';
/**
* The latents to scale
diff --git a/invokeai/frontend/web/src/services/api/models/ShowImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/ShowImageInvocation.ts
index 145895ad75..1b73055584 100644
--- a/invokeai/frontend/web/src/services/api/models/ShowImageInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/ShowImageInvocation.ts
@@ -12,6 +12,10 @@ export type ShowImageInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'show_image';
/**
* The image to show
diff --git a/invokeai/frontend/web/src/services/api/models/SubtractInvocation.ts b/invokeai/frontend/web/src/services/api/models/SubtractInvocation.ts
index 6f2da116a2..23334bd891 100644
--- a/invokeai/frontend/web/src/services/api/models/SubtractInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/SubtractInvocation.ts
@@ -10,6 +10,10 @@ export type SubtractInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'sub';
/**
* The first number
diff --git a/invokeai/frontend/web/src/services/api/models/TextToImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/TextToImageInvocation.ts
index 184e35693b..7128ea8440 100644
--- a/invokeai/frontend/web/src/services/api/models/TextToImageInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/TextToImageInvocation.ts
@@ -2,6 +2,8 @@
/* tslint:disable */
/* eslint-disable */
+import type { ImageField } from './ImageField';
+
/**
* Generates an image using text2img.
*/
@@ -10,6 +12,10 @@ export type TextToImageInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'txt2img';
/**
* The prompt to generate an image from
@@ -43,5 +49,17 @@ export type TextToImageInvocation = {
* The model to use (currently ignored)
*/
model?: string;
+ /**
+ * Whether or not to produce progress images during generation
+ */
+ progress_images?: boolean;
+ /**
+ * The control model to use
+ */
+ control_model?: string;
+ /**
+ * The processed control image
+ */
+ control_image?: ImageField;
};
diff --git a/invokeai/frontend/web/src/services/api/models/TextToLatentsInvocation.ts b/invokeai/frontend/web/src/services/api/models/TextToLatentsInvocation.ts
index d1ec5ed08c..f1831b2b59 100644
--- a/invokeai/frontend/web/src/services/api/models/TextToLatentsInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/TextToLatentsInvocation.ts
@@ -3,6 +3,7 @@
/* eslint-disable */
import type { ConditioningField } from './ConditioningField';
+import type { ControlField } from './ControlField';
import type { LatentsField } from './LatentsField';
/**
@@ -13,6 +14,10 @@ export type TextToLatentsInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 't2l';
/**
* Positive conditioning for generation
@@ -43,12 +48,8 @@ export type TextToLatentsInvocation = {
*/
model?: string;
/**
- * Whether or not to generate an image that can tile without seams
+ * The control to use
*/
- seamless?: boolean;
- /**
- * The axes to tile the image on, 'x' and/or 'y'
- */
- seamless_axes?: string;
+ control?: (ControlField | Array);
};
diff --git a/invokeai/frontend/web/src/services/api/models/UpscaleInvocation.ts b/invokeai/frontend/web/src/services/api/models/UpscaleInvocation.ts
index 8416c2454d..d0aca63964 100644
--- a/invokeai/frontend/web/src/services/api/models/UpscaleInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/UpscaleInvocation.ts
@@ -12,6 +12,10 @@ export type UpscaleInvocation = {
* The id of this node. Must be unique among all nodes.
*/
id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
type?: 'upscale';
/**
* The input image
diff --git a/invokeai/frontend/web/src/services/api/models/ZoeDepthImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/ZoeDepthImageProcessorInvocation.ts
new file mode 100644
index 0000000000..55d05f3167
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/ZoeDepthImageProcessorInvocation.ts
@@ -0,0 +1,25 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { ImageField } from './ImageField';
+
+/**
+ * Applies Zoe depth processing to image
+ */
+export type ZoeDepthImageProcessorInvocation = {
+ /**
+ * The id of this node. Must be unique among all nodes.
+ */
+ id: string;
+ /**
+ * Whether or not this node is an intermediate node.
+ */
+ is_intermediate?: boolean;
+ type?: 'zoe_depth_image_processor';
+ /**
+ * image to process
+ */
+ image?: ImageField;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/schemas/$CannyImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$CannyImageProcessorInvocation.ts
new file mode 100644
index 0000000000..e2f1bc2111
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/schemas/$CannyImageProcessorInvocation.ts
@@ -0,0 +1,31 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+export const $CannyImageProcessorInvocation = {
+ description: `Canny edge detection for ControlNet`,
+ properties: {
+ id: {
+ type: 'string',
+ description: `The id of this node. Must be unique among all nodes.`,
+ isRequired: true,
+ },
+ type: {
+ type: 'Enum',
+ },
+ image: {
+ type: 'all-of',
+ description: `image to process`,
+ contains: [{
+ type: 'ImageField',
+ }],
+ },
+ low_threshold: {
+ type: 'number',
+ description: `low threshold of Canny pixel gradient`,
+ },
+ high_threshold: {
+ type: 'number',
+ description: `high threshold of Canny pixel gradient`,
+ },
+ },
+} as const;
diff --git a/invokeai/frontend/web/src/services/api/schemas/$ContentShuffleImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$ContentShuffleImageProcessorInvocation.ts
new file mode 100644
index 0000000000..9c51fdecc0
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/schemas/$ContentShuffleImageProcessorInvocation.ts
@@ -0,0 +1,43 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+export const $ContentShuffleImageProcessorInvocation = {
+ description: `Applies content shuffle processing to image`,
+ properties: {
+ id: {
+ type: 'string',
+ description: `The id of this node. Must be unique among all nodes.`,
+ isRequired: true,
+ },
+ type: {
+ type: 'Enum',
+ },
+ image: {
+ type: 'all-of',
+ description: `image to process`,
+ contains: [{
+ type: 'ImageField',
+ }],
+ },
+ detect_resolution: {
+ type: 'number',
+ description: `pixel resolution for edge detection`,
+ },
+ image_resolution: {
+ type: 'number',
+ description: `pixel resolution for output image`,
+ },
+ 'h': {
+ type: 'number',
+ description: `content shuffle h parameter`,
+ },
+ 'w': {
+ type: 'number',
+ description: `content shuffle w parameter`,
+ },
+ 'f': {
+ type: 'number',
+ description: `cont`,
+ },
+ },
+} as const;
diff --git a/invokeai/frontend/web/src/services/api/schemas/$ControlField.ts b/invokeai/frontend/web/src/services/api/schemas/$ControlField.ts
new file mode 100644
index 0000000000..81292b8638
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/schemas/$ControlField.ts
@@ -0,0 +1,37 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+export const $ControlField = {
+ properties: {
+ image: {
+ type: 'all-of',
+ description: `processed image`,
+ contains: [{
+ type: 'ImageField',
+ }],
+ isRequired: true,
+ },
+ control_model: {
+ type: 'string',
+ description: `control model used`,
+ isRequired: true,
+ },
+ control_weight: {
+ type: 'number',
+ description: `weight given to controlnet`,
+ isRequired: true,
+ },
+ begin_step_percent: {
+ type: 'number',
+ description: `% of total steps at which controlnet is first applied`,
+ isRequired: true,
+ maximum: 1,
+ },
+ end_step_percent: {
+ type: 'number',
+ description: `% of total steps at which controlnet is last applied`,
+ isRequired: true,
+ maximum: 1,
+ },
+ },
+} as const;
diff --git a/invokeai/frontend/web/src/services/api/schemas/$ControlNetInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$ControlNetInvocation.ts
new file mode 100644
index 0000000000..29ff507e66
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/schemas/$ControlNetInvocation.ts
@@ -0,0 +1,41 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+export const $ControlNetInvocation = {
+ description: `Collects ControlNet info to pass to other nodes`,
+ properties: {
+ id: {
+ type: 'string',
+ description: `The id of this node. Must be unique among all nodes.`,
+ isRequired: true,
+ },
+ type: {
+ type: 'Enum',
+ },
+ image: {
+ type: 'all-of',
+ description: `image to process`,
+ contains: [{
+ type: 'ImageField',
+ }],
+ },
+ control_model: {
+ type: 'Enum',
+ },
+ control_weight: {
+ type: 'number',
+ description: `weight given to controlnet`,
+ maximum: 1,
+ },
+ begin_step_percent: {
+ type: 'number',
+ description: `% of total steps at which controlnet is first applied`,
+ maximum: 1,
+ },
+ end_step_percent: {
+ type: 'number',
+ description: `% of total steps at which controlnet is last applied`,
+ maximum: 1,
+ },
+ },
+} as const;
diff --git a/invokeai/frontend/web/src/services/api/schemas/$ControlOutput.ts b/invokeai/frontend/web/src/services/api/schemas/$ControlOutput.ts
new file mode 100644
index 0000000000..d94d633fca
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/schemas/$ControlOutput.ts
@@ -0,0 +1,28 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+export const $ControlOutput = {
+ description: `node output for ControlNet info`,
+ properties: {
+ type: {
+ type: 'Enum',
+ },
+ control: {
+ type: 'all-of',
+ description: `The control info dict`,
+ contains: [{
+ type: 'ControlField',
+ }],
+ },
+ width: {
+ type: 'number',
+ description: `The width of the noise in pixels`,
+ isRequired: true,
+ },
+ height: {
+ type: 'number',
+ description: `The height of the noise in pixels`,
+ isRequired: true,
+ },
+ },
+} as const;
diff --git a/invokeai/frontend/web/src/services/api/schemas/$HedImageprocessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$HedImageprocessorInvocation.ts
new file mode 100644
index 0000000000..3cffa008f5
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/schemas/$HedImageprocessorInvocation.ts
@@ -0,0 +1,35 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+export const $HedImageprocessorInvocation = {
+ description: `Applies HED edge detection to image`,
+ properties: {
+ id: {
+ type: 'string',
+ description: `The id of this node. Must be unique among all nodes.`,
+ isRequired: true,
+ },
+ type: {
+ type: 'Enum',
+ },
+ image: {
+ type: 'all-of',
+ description: `image to process`,
+ contains: [{
+ type: 'ImageField',
+ }],
+ },
+ detect_resolution: {
+ type: 'number',
+ description: `pixel resolution for edge detection`,
+ },
+ image_resolution: {
+ type: 'number',
+ description: `pixel resolution for output image`,
+ },
+ scribble: {
+ type: 'boolean',
+ description: `whether to use scribble mode`,
+ },
+ },
+} as const;
diff --git a/invokeai/frontend/web/src/services/api/schemas/$ImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$ImageProcessorInvocation.ts
new file mode 100644
index 0000000000..36748982c5
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/schemas/$ImageProcessorInvocation.ts
@@ -0,0 +1,23 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+export const $ImageProcessorInvocation = {
+ description: `Base class for invocations that preprocess images for ControlNet`,
+ properties: {
+ id: {
+ type: 'string',
+ description: `The id of this node. Must be unique among all nodes.`,
+ isRequired: true,
+ },
+ type: {
+ type: 'Enum',
+ },
+ image: {
+ type: 'all-of',
+ description: `image to process`,
+ contains: [{
+ type: 'ImageField',
+ }],
+ },
+ },
+} as const;
diff --git a/invokeai/frontend/web/src/services/api/schemas/$LineartAnimeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$LineartAnimeImageProcessorInvocation.ts
new file mode 100644
index 0000000000..63a9c8158c
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/schemas/$LineartAnimeImageProcessorInvocation.ts
@@ -0,0 +1,31 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+export const $LineartAnimeImageProcessorInvocation = {
+ description: `Applies line art anime processing to image`,
+ properties: {
+ id: {
+ type: 'string',
+ description: `The id of this node. Must be unique among all nodes.`,
+ isRequired: true,
+ },
+ type: {
+ type: 'Enum',
+ },
+ image: {
+ type: 'all-of',
+ description: `image to process`,
+ contains: [{
+ type: 'ImageField',
+ }],
+ },
+ detect_resolution: {
+ type: 'number',
+ description: `pixel resolution for edge detection`,
+ },
+ image_resolution: {
+ type: 'number',
+ description: `pixel resolution for output image`,
+ },
+ },
+} as const;
diff --git a/invokeai/frontend/web/src/services/api/schemas/$LineartImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$LineartImageProcessorInvocation.ts
new file mode 100644
index 0000000000..6ba4064823
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/schemas/$LineartImageProcessorInvocation.ts
@@ -0,0 +1,35 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+export const $LineartImageProcessorInvocation = {
+ description: `Applies line art processing to image`,
+ properties: {
+ id: {
+ type: 'string',
+ description: `The id of this node. Must be unique among all nodes.`,
+ isRequired: true,
+ },
+ type: {
+ type: 'Enum',
+ },
+ image: {
+ type: 'all-of',
+ description: `image to process`,
+ contains: [{
+ type: 'ImageField',
+ }],
+ },
+ detect_resolution: {
+ type: 'number',
+ description: `pixel resolution for edge detection`,
+ },
+ image_resolution: {
+ type: 'number',
+ description: `pixel resolution for output image`,
+ },
+ coarse: {
+ type: 'boolean',
+ description: `whether to use coarse mode`,
+ },
+ },
+} as const;
diff --git a/invokeai/frontend/web/src/services/api/schemas/$MidasDepthImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$MidasDepthImageProcessorInvocation.ts
new file mode 100644
index 0000000000..ea0b2b0099
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/schemas/$MidasDepthImageProcessorInvocation.ts
@@ -0,0 +1,31 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+export const $MidasDepthImageProcessorInvocation = {
+ description: `Applies Midas depth processing to image`,
+ properties: {
+ id: {
+ type: 'string',
+ description: `The id of this node. Must be unique among all nodes.`,
+ isRequired: true,
+ },
+ type: {
+ type: 'Enum',
+ },
+ image: {
+ type: 'all-of',
+ description: `image to process`,
+ contains: [{
+ type: 'ImageField',
+ }],
+ },
+ a_mult: {
+ type: 'number',
+ description: `Midas parameter a = amult * PI`,
+ },
+ bg_th: {
+ type: 'number',
+ description: `Midas parameter bg_th`,
+ },
+ },
+} as const;
diff --git a/invokeai/frontend/web/src/services/api/schemas/$MlsdImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$MlsdImageProcessorInvocation.ts
new file mode 100644
index 0000000000..1bff7579cc
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/schemas/$MlsdImageProcessorInvocation.ts
@@ -0,0 +1,39 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+export const $MlsdImageProcessorInvocation = {
+ description: `Applies MLSD processing to image`,
+ properties: {
+ id: {
+ type: 'string',
+ description: `The id of this node. Must be unique among all nodes.`,
+ isRequired: true,
+ },
+ type: {
+ type: 'Enum',
+ },
+ image: {
+ type: 'all-of',
+ description: `image to process`,
+ contains: [{
+ type: 'ImageField',
+ }],
+ },
+ detect_resolution: {
+ type: 'number',
+ description: `pixel resolution for edge detection`,
+ },
+ image_resolution: {
+ type: 'number',
+ description: `pixel resolution for output image`,
+ },
+ thr_v: {
+ type: 'number',
+ description: `MLSD parameter thr_v`,
+ },
+ thr_d: {
+ type: 'number',
+ description: `MLSD parameter thr_d`,
+ },
+ },
+} as const;
diff --git a/invokeai/frontend/web/src/services/api/schemas/$NormalbaeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$NormalbaeImageProcessorInvocation.ts
new file mode 100644
index 0000000000..7cdfe6f3ae
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/schemas/$NormalbaeImageProcessorInvocation.ts
@@ -0,0 +1,31 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+export const $NormalbaeImageProcessorInvocation = {
+ description: `Applies NormalBae processing to image`,
+ properties: {
+ id: {
+ type: 'string',
+ description: `The id of this node. Must be unique among all nodes.`,
+ isRequired: true,
+ },
+ type: {
+ type: 'Enum',
+ },
+ image: {
+ type: 'all-of',
+ description: `image to process`,
+ contains: [{
+ type: 'ImageField',
+ }],
+ },
+ detect_resolution: {
+ type: 'number',
+ description: `pixel resolution for edge detection`,
+ },
+ image_resolution: {
+ type: 'number',
+ description: `pixel resolution for output image`,
+ },
+ },
+} as const;
diff --git a/invokeai/frontend/web/src/services/api/schemas/$OpenposeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$OpenposeImageProcessorInvocation.ts
new file mode 100644
index 0000000000..2a187e9cf2
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/schemas/$OpenposeImageProcessorInvocation.ts
@@ -0,0 +1,35 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+export const $OpenposeImageProcessorInvocation = {
+ description: `Applies Openpose processing to image`,
+ properties: {
+ id: {
+ type: 'string',
+ description: `The id of this node. Must be unique among all nodes.`,
+ isRequired: true,
+ },
+ type: {
+ type: 'Enum',
+ },
+ image: {
+ type: 'all-of',
+ description: `image to process`,
+ contains: [{
+ type: 'ImageField',
+ }],
+ },
+ hand_and_face: {
+ type: 'boolean',
+ description: `whether to use hands and face mode`,
+ },
+ detect_resolution: {
+ type: 'number',
+ description: `pixel resolution for edge detection`,
+ },
+ image_resolution: {
+ type: 'number',
+ description: `pixel resolution for output image`,
+ },
+ },
+} as const;
diff --git a/invokeai/frontend/web/src/services/api/schemas/$PidiImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$PidiImageProcessorInvocation.ts
new file mode 100644
index 0000000000..0fd53967c2
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/schemas/$PidiImageProcessorInvocation.ts
@@ -0,0 +1,39 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+export const $PidiImageProcessorInvocation = {
+ description: `Applies PIDI processing to image`,
+ properties: {
+ id: {
+ type: 'string',
+ description: `The id of this node. Must be unique among all nodes.`,
+ isRequired: true,
+ },
+ type: {
+ type: 'Enum',
+ },
+ image: {
+ type: 'all-of',
+ description: `image to process`,
+ contains: [{
+ type: 'ImageField',
+ }],
+ },
+ detect_resolution: {
+ type: 'number',
+ description: `pixel resolution for edge detection`,
+ },
+ image_resolution: {
+ type: 'number',
+ description: `pixel resolution for output image`,
+ },
+ safe: {
+ type: 'boolean',
+ description: `whether to use safe mode`,
+ },
+ scribble: {
+ type: 'boolean',
+ description: `whether to use scribble mode`,
+ },
+ },
+} as const;
diff --git a/invokeai/frontend/web/src/services/api/schemas/$RandomIntInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$RandomIntInvocation.ts
new file mode 100644
index 0000000000..e5b0387d5a
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/schemas/$RandomIntInvocation.ts
@@ -0,0 +1,16 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+export const $RandomIntInvocation = {
+ description: `Outputs a single random integer.`,
+ properties: {
+ id: {
+ type: 'string',
+ description: `The id of this node. Must be unique among all nodes.`,
+ isRequired: true,
+ },
+ type: {
+ type: 'Enum',
+ },
+ },
+} as const;
diff --git a/invokeai/frontend/web/src/services/api/services/ImagesService.ts b/invokeai/frontend/web/src/services/api/services/ImagesService.ts
index 13b2ef836a..51fe6c820f 100644
--- a/invokeai/frontend/web/src/services/api/services/ImagesService.ts
+++ b/invokeai/frontend/web/src/services/api/services/ImagesService.ts
@@ -4,9 +4,10 @@
import type { Body_upload_image } from '../models/Body_upload_image';
import type { ImageCategory } from '../models/ImageCategory';
import type { ImageDTO } from '../models/ImageDTO';
-import type { ImageType } from '../models/ImageType';
+import type { ImageRecordChanges } from '../models/ImageRecordChanges';
import type { ImageUrlsDTO } from '../models/ImageUrlsDTO';
-import type { PaginatedResults_ImageDTO_ } from '../models/PaginatedResults_ImageDTO_';
+import type { OffsetPaginatedResults_ImageDTO_ } from '../models/OffsetPaginatedResults_ImageDTO_';
+import type { ResourceOrigin } from '../models/ResourceOrigin';
import type { CancelablePromise } from '../core/CancelablePromise';
import { OpenAPI } from '../core/OpenAPI';
@@ -16,41 +17,47 @@ export class ImagesService {
/**
* List Images With Metadata
- * Gets a list of images with metadata
- * @returns PaginatedResults_ImageDTO_ Successful Response
+ * Gets a list of images
+ * @returns OffsetPaginatedResults_ImageDTO_ Successful Response
* @throws ApiError
*/
public static listImagesWithMetadata({
- imageType,
- imageCategory,
- page,
- perPage = 10,
+ imageOrigin,
+ categories,
+ isIntermediate,
+ offset,
+ limit = 10,
}: {
/**
- * The type of images to list
+ * The origin of images to list
*/
- imageType: ImageType,
+ imageOrigin?: ResourceOrigin,
/**
- * The kind of images to list
+ * The categories of image to include
*/
- imageCategory: ImageCategory,
+ categories?: Array,
/**
- * The page of image metadata to get
+ * Whether to list intermediate images
*/
- page?: number,
+ isIntermediate?: boolean,
/**
- * The number of image metadata per page
+ * The page offset
*/
- perPage?: number,
- }): CancelablePromise {
+ offset?: number,
+ /**
+ * The number of images per page
+ */
+ limit?: number,
+ }): CancelablePromise {
return __request(OpenAPI, {
method: 'GET',
url: '/api/v1/images/',
query: {
- 'image_type': imageType,
- 'image_category': imageCategory,
- 'page': page,
- 'per_page': perPage,
+ 'image_origin': imageOrigin,
+ 'categories': categories,
+ 'is_intermediate': isIntermediate,
+ 'offset': offset,
+ 'limit': limit,
},
errors: {
422: `Validation Error`,
@@ -65,20 +72,32 @@ export class ImagesService {
* @throws ApiError
*/
public static uploadImage({
- imageType,
- formData,
imageCategory,
+ isIntermediate,
+ formData,
+ sessionId,
}: {
- imageType: ImageType,
+ /**
+ * The category of the image
+ */
+ imageCategory: ImageCategory,
+ /**
+ * Whether this is an intermediate image
+ */
+ isIntermediate: boolean,
formData: Body_upload_image,
- imageCategory?: ImageCategory,
+ /**
+ * The session ID associated with this upload, if any
+ */
+ sessionId?: string,
}): CancelablePromise {
return __request(OpenAPI, {
method: 'POST',
url: '/api/v1/images/',
query: {
- 'image_type': imageType,
'image_category': imageCategory,
+ 'is_intermediate': isIntermediate,
+ 'session_id': sessionId,
},
formData: formData,
mediaType: 'multipart/form-data',
@@ -96,13 +115,13 @@ export class ImagesService {
* @throws ApiError
*/
public static getImageFull({
- imageType,
+ imageOrigin,
imageName,
}: {
/**
* The type of full-resolution image file to get
*/
- imageType: ImageType,
+ imageOrigin: ResourceOrigin,
/**
* The name of full-resolution image file to get
*/
@@ -110,9 +129,9 @@ export class ImagesService {
}): CancelablePromise {
return __request(OpenAPI, {
method: 'GET',
- url: '/api/v1/images/{image_type}/{image_name}',
+ url: '/api/v1/images/{image_origin}/{image_name}',
path: {
- 'image_type': imageType,
+ 'image_origin': imageOrigin,
'image_name': imageName,
},
errors: {
@@ -129,10 +148,13 @@ export class ImagesService {
* @throws ApiError
*/
public static deleteImage({
- imageType,
+ imageOrigin,
imageName,
}: {
- imageType: ImageType,
+ /**
+ * The origin of image to delete
+ */
+ imageOrigin: ResourceOrigin,
/**
* The name of the image to delete
*/
@@ -140,9 +162,9 @@ export class ImagesService {
}): CancelablePromise {
return __request(OpenAPI, {
method: 'DELETE',
- url: '/api/v1/images/{image_type}/{image_name}',
+ url: '/api/v1/images/{image_origin}/{image_name}',
path: {
- 'image_type': imageType,
+ 'image_origin': imageOrigin,
'image_name': imageName,
},
errors: {
@@ -151,6 +173,42 @@ export class ImagesService {
});
}
+ /**
+ * Update Image
+ * Updates an image
+ * @returns ImageDTO Successful Response
+ * @throws ApiError
+ */
+ public static updateImage({
+ imageOrigin,
+ imageName,
+ requestBody,
+ }: {
+ /**
+ * The origin of image to update
+ */
+ imageOrigin: ResourceOrigin,
+ /**
+ * The name of the image to update
+ */
+ imageName: string,
+ requestBody: ImageRecordChanges,
+ }): CancelablePromise {
+ return __request(OpenAPI, {
+ method: 'PATCH',
+ url: '/api/v1/images/{image_origin}/{image_name}',
+ path: {
+ 'image_origin': imageOrigin,
+ 'image_name': imageName,
+ },
+ body: requestBody,
+ mediaType: 'application/json',
+ errors: {
+ 422: `Validation Error`,
+ },
+ });
+ }
+
/**
* Get Image Metadata
* Gets an image's metadata
@@ -158,13 +216,13 @@ export class ImagesService {
* @throws ApiError
*/
public static getImageMetadata({
- imageType,
+ imageOrigin,
imageName,
}: {
/**
- * The type of image to get
+ * The origin of image to get
*/
- imageType: ImageType,
+ imageOrigin: ResourceOrigin,
/**
* The name of image to get
*/
@@ -172,9 +230,9 @@ export class ImagesService {
}): CancelablePromise {
return __request(OpenAPI, {
method: 'GET',
- url: '/api/v1/images/{image_type}/{image_name}/metadata',
+ url: '/api/v1/images/{image_origin}/{image_name}/metadata',
path: {
- 'image_type': imageType,
+ 'image_origin': imageOrigin,
'image_name': imageName,
},
errors: {
@@ -190,13 +248,13 @@ export class ImagesService {
* @throws ApiError
*/
public static getImageThumbnail({
- imageType,
+ imageOrigin,
imageName,
}: {
/**
- * The type of thumbnail image file to get
+ * The origin of thumbnail image file to get
*/
- imageType: ImageType,
+ imageOrigin: ResourceOrigin,
/**
* The name of thumbnail image file to get
*/
@@ -204,9 +262,9 @@ export class ImagesService {
}): CancelablePromise {
return __request(OpenAPI, {
method: 'GET',
- url: '/api/v1/images/{image_type}/{image_name}/thumbnail',
+ url: '/api/v1/images/{image_origin}/{image_name}/thumbnail',
path: {
- 'image_type': imageType,
+ 'image_origin': imageOrigin,
'image_name': imageName,
},
errors: {
@@ -223,13 +281,13 @@ export class ImagesService {
* @throws ApiError
*/
public static getImageUrls({
- imageType,
+ imageOrigin,
imageName,
}: {
/**
- * The type of the image whose URL to get
+ * The origin of the image whose URL to get
*/
- imageType: ImageType,
+ imageOrigin: ResourceOrigin,
/**
* The name of the image whose URL to get
*/
@@ -237,9 +295,9 @@ export class ImagesService {
}): CancelablePromise {
return __request(OpenAPI, {
method: 'GET',
- url: '/api/v1/images/{image_type}/{image_name}/urls',
+ url: '/api/v1/images/{image_origin}/{image_name}/urls',
path: {
- 'image_type': imageType,
+ 'image_origin': imageOrigin,
'image_name': imageName,
},
errors: {
diff --git a/invokeai/frontend/web/src/services/api/services/SessionsService.ts b/invokeai/frontend/web/src/services/api/services/SessionsService.ts
index 23597c9e9e..6ae6783313 100644
--- a/invokeai/frontend/web/src/services/api/services/SessionsService.ts
+++ b/invokeai/frontend/web/src/services/api/services/SessionsService.ts
@@ -2,34 +2,53 @@
/* tslint:disable */
/* eslint-disable */
import type { AddInvocation } from '../models/AddInvocation';
-import type { BlurInvocation } from '../models/BlurInvocation';
+import type { CannyImageProcessorInvocation } from '../models/CannyImageProcessorInvocation';
import type { CollectInvocation } from '../models/CollectInvocation';
import type { CompelInvocation } from '../models/CompelInvocation';
-import type { CropImageInvocation } from '../models/CropImageInvocation';
+import type { ContentShuffleImageProcessorInvocation } from '../models/ContentShuffleImageProcessorInvocation';
+import type { ControlNetInvocation } from '../models/ControlNetInvocation';
import type { CvInpaintInvocation } from '../models/CvInpaintInvocation';
import type { DivideInvocation } from '../models/DivideInvocation';
import type { Edge } from '../models/Edge';
import type { Graph } from '../models/Graph';
import type { GraphExecutionState } from '../models/GraphExecutionState';
import type { GraphInvocation } from '../models/GraphInvocation';
+import type { HedImageprocessorInvocation } from '../models/HedImageprocessorInvocation';
+import type { ImageBlurInvocation } from '../models/ImageBlurInvocation';
+import type { ImageChannelInvocation } from '../models/ImageChannelInvocation';
+import type { ImageConvertInvocation } from '../models/ImageConvertInvocation';
+import type { ImageCropInvocation } from '../models/ImageCropInvocation';
+import type { ImageInverseLerpInvocation } from '../models/ImageInverseLerpInvocation';
+import type { ImageLerpInvocation } from '../models/ImageLerpInvocation';
+import type { ImageMultiplyInvocation } from '../models/ImageMultiplyInvocation';
+import type { ImagePasteInvocation } from '../models/ImagePasteInvocation';
+import type { ImageProcessorInvocation } from '../models/ImageProcessorInvocation';
+import type { ImageResizeInvocation } from '../models/ImageResizeInvocation';
+import type { ImageScaleInvocation } from '../models/ImageScaleInvocation';
import type { ImageToImageInvocation } from '../models/ImageToImageInvocation';
import type { ImageToLatentsInvocation } from '../models/ImageToLatentsInvocation';
import type { InfillColorInvocation } from '../models/InfillColorInvocation';
import type { InfillPatchMatchInvocation } from '../models/InfillPatchMatchInvocation';
import type { InfillTileInvocation } from '../models/InfillTileInvocation';
import type { InpaintInvocation } from '../models/InpaintInvocation';
-import type { InverseLerpInvocation } from '../models/InverseLerpInvocation';
import type { IterateInvocation } from '../models/IterateInvocation';
import type { LatentsToImageInvocation } from '../models/LatentsToImageInvocation';
import type { LatentsToLatentsInvocation } from '../models/LatentsToLatentsInvocation';
-import type { LerpInvocation } from '../models/LerpInvocation';
+import type { LineartAnimeImageProcessorInvocation } from '../models/LineartAnimeImageProcessorInvocation';
+import type { LineartImageProcessorInvocation } from '../models/LineartImageProcessorInvocation';
import type { LoadImageInvocation } from '../models/LoadImageInvocation';
import type { MaskFromAlphaInvocation } from '../models/MaskFromAlphaInvocation';
+import type { MediapipeFaceProcessorInvocation } from '../models/MediapipeFaceProcessorInvocation';
+import type { MidasDepthImageProcessorInvocation } from '../models/MidasDepthImageProcessorInvocation';
+import type { MlsdImageProcessorInvocation } from '../models/MlsdImageProcessorInvocation';
import type { MultiplyInvocation } from '../models/MultiplyInvocation';
import type { NoiseInvocation } from '../models/NoiseInvocation';
+import type { NormalbaeImageProcessorInvocation } from '../models/NormalbaeImageProcessorInvocation';
+import type { OpenposeImageProcessorInvocation } from '../models/OpenposeImageProcessorInvocation';
import type { PaginatedResults_GraphExecutionState_ } from '../models/PaginatedResults_GraphExecutionState_';
+import type { ParamFloatInvocation } from '../models/ParamFloatInvocation';
import type { ParamIntInvocation } from '../models/ParamIntInvocation';
-import type { PasteImageInvocation } from '../models/PasteImageInvocation';
+import type { PidiImageProcessorInvocation } from '../models/PidiImageProcessorInvocation';
import type { RandomIntInvocation } from '../models/RandomIntInvocation';
import type { RandomRangeInvocation } from '../models/RandomRangeInvocation';
import type { RangeInvocation } from '../models/RangeInvocation';
@@ -42,6 +61,7 @@ import type { SubtractInvocation } from '../models/SubtractInvocation';
import type { TextToImageInvocation } from '../models/TextToImageInvocation';
import type { TextToLatentsInvocation } from '../models/TextToLatentsInvocation';
import type { UpscaleInvocation } from '../models/UpscaleInvocation';
+import type { ZoeDepthImageProcessorInvocation } from '../models/ZoeDepthImageProcessorInvocation';
import type { CancelablePromise } from '../core/CancelablePromise';
import { OpenAPI } from '../core/OpenAPI';
@@ -151,7 +171,7 @@ export class SessionsService {
* The id of the session
*/
sessionId: string,
- requestBody: (LoadImageInvocation | ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
+ requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageprocessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
}): CancelablePromise {
return __request(OpenAPI, {
method: 'POST',
@@ -188,7 +208,7 @@ export class SessionsService {
* The path to the node in the graph
*/
nodePath: string,
- requestBody: (LoadImageInvocation | ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
+ requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageprocessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
}): CancelablePromise {
return __request(OpenAPI, {
method: 'PUT',
diff --git a/invokeai/frontend/web/src/services/events/actions.ts b/invokeai/frontend/web/src/services/events/actions.ts
index 76bffeaa49..5832cb24b1 100644
--- a/invokeai/frontend/web/src/services/events/actions.ts
+++ b/invokeai/frontend/web/src/services/events/actions.ts
@@ -12,46 +12,153 @@ type BaseSocketPayload = {
timestamp: string;
};
-// Create actions for each socket event
+// Create actions for each socket
// Middleware and redux can then respond to them as needed
+/**
+ * Socket.IO Connected
+ *
+ * Do not use. Only for use in middleware.
+ */
export const socketConnected = createAction(
'socket/socketConnected'
);
+/**
+ * App-level Socket.IO Connected
+ */
+export const appSocketConnected = createAction(
+ 'socket/appSocketConnected'
+);
+
+/**
+ * Socket.IO Disconnect
+ *
+ * Do not use. Only for use in middleware.
+ */
export const socketDisconnected = createAction(
'socket/socketDisconnected'
);
+/**
+ * App-level Socket.IO Disconnected
+ */
+export const appSocketDisconnected = createAction(
+ 'socket/appSocketDisconnected'
+);
+
+/**
+ * Socket.IO Subscribed
+ *
+ * Do not use. Only for use in middleware.
+ */
export const socketSubscribed = createAction<
BaseSocketPayload & { sessionId: string }
>('socket/socketSubscribed');
+/**
+ * App-level Socket.IO Subscribed
+ */
+export const appSocketSubscribed = createAction<
+ BaseSocketPayload & { sessionId: string }
+>('socket/appSocketSubscribed');
+
+/**
+ * Socket.IO Unsubscribed
+ *
+ * Do not use. Only for use in middleware.
+ */
export const socketUnsubscribed = createAction<
BaseSocketPayload & { sessionId: string }
>('socket/socketUnsubscribed');
-export const invocationStarted = createAction<
- BaseSocketPayload & { data: InvocationStartedEvent }
->('socket/invocationStarted');
+/**
+ * App-level Socket.IO Unsubscribed
+ */
+export const appSocketUnsubscribed = createAction<
+ BaseSocketPayload & { sessionId: string }
+>('socket/appSocketUnsubscribed');
-export const invocationComplete = createAction<
+/**
+ * Socket.IO Invocation Started
+ *
+ * Do not use. Only for use in middleware.
+ */
+export const socketInvocationStarted = createAction<
+ BaseSocketPayload & { data: InvocationStartedEvent }
+>('socket/socketInvocationStarted');
+
+/**
+ * App-level Socket.IO Invocation Started
+ */
+export const appSocketInvocationStarted = createAction<
+ BaseSocketPayload & { data: InvocationStartedEvent }
+>('socket/appSocketInvocationStarted');
+
+/**
+ * Socket.IO Invocation Complete
+ *
+ * Do not use. Only for use in middleware.
+ */
+export const socketInvocationComplete = createAction<
BaseSocketPayload & {
data: InvocationCompleteEvent;
}
->('socket/invocationComplete');
+>('socket/socketInvocationComplete');
-export const invocationError = createAction<
+/**
+ * App-level Socket.IO Invocation Complete
+ */
+export const appSocketInvocationComplete = createAction<
+ BaseSocketPayload & {
+ data: InvocationCompleteEvent;
+ }
+>('socket/appSocketInvocationComplete');
+
+/**
+ * Socket.IO Invocation Error
+ *
+ * Do not use. Only for use in middleware.
+ */
+export const socketInvocationError = createAction<
BaseSocketPayload & { data: InvocationErrorEvent }
->('socket/invocationError');
+>('socket/socketInvocationError');
-export const graphExecutionStateComplete = createAction<
+/**
+ * App-level Socket.IO Invocation Error
+ */
+export const appSocketInvocationError = createAction<
+ BaseSocketPayload & { data: InvocationErrorEvent }
+>('socket/appSocketInvocationError');
+
+/**
+ * Socket.IO Graph Execution State Complete
+ *
+ * Do not use. Only for use in middleware.
+ */
+export const socketGraphExecutionStateComplete = createAction<
BaseSocketPayload & { data: GraphExecutionStateCompleteEvent }
->('socket/graphExecutionStateComplete');
+>('socket/socketGraphExecutionStateComplete');
-export const generatorProgress = createAction<
+/**
+ * App-level Socket.IO Graph Execution State Complete
+ */
+export const appSocketGraphExecutionStateComplete = createAction<
+ BaseSocketPayload & { data: GraphExecutionStateCompleteEvent }
+>('socket/appSocketGraphExecutionStateComplete');
+
+/**
+ * Socket.IO Generator Progress
+ *
+ * Do not use. Only for use in middleware.
+ */
+export const socketGeneratorProgress = createAction<
BaseSocketPayload & { data: GeneratorProgressEvent }
->('socket/generatorProgress');
+>('socket/socketGeneratorProgress');
-// dispatch this when we need to fully reset the socket connection
-export const socketReset = createAction('socket/socketReset');
+/**
+ * App-level Socket.IO Generator Progress
+ */
+export const appSocketGeneratorProgress = createAction<
+ BaseSocketPayload & { data: GeneratorProgressEvent }
+>('socket/appSocketGeneratorProgress');
diff --git a/invokeai/frontend/web/src/services/events/middleware.ts b/invokeai/frontend/web/src/services/events/middleware.ts
index bd1d60099a..f1eb844f2c 100644
--- a/invokeai/frontend/web/src/services/events/middleware.ts
+++ b/invokeai/frontend/web/src/services/events/middleware.ts
@@ -8,7 +8,7 @@ import {
import { socketSubscribed, socketUnsubscribed } from './actions';
import { AppThunkDispatch, RootState } from 'app/store/store';
import { getTimestamp } from 'common/util/getTimestamp';
-import { sessionInvoked, sessionCreated } from 'services/thunks/session';
+import { sessionCreated } from 'services/thunks/session';
import { OpenAPI } from 'services/api';
import { setEventListeners } from 'services/events/util/setEventListeners';
import { log } from 'app/logging/useLogger';
@@ -64,15 +64,9 @@ export const socketMiddleware = () => {
if (sessionCreated.fulfilled.match(action)) {
const sessionId = action.payload.id;
- const sessionLog = socketioLog.child({ sessionId });
const oldSessionId = getState().system.sessionId;
if (oldSessionId) {
- sessionLog.debug(
- { oldSessionId },
- `Unsubscribed from old session (${oldSessionId})`
- );
-
socket.emit('unsubscribe', {
session: oldSessionId,
});
@@ -85,8 +79,6 @@ export const socketMiddleware = () => {
);
}
- sessionLog.debug(`Subscribe to new session (${sessionId})`);
-
socket.emit('subscribe', { session: sessionId });
dispatch(
@@ -95,9 +87,6 @@ export const socketMiddleware = () => {
timestamp: getTimestamp(),
})
);
-
- // Finally we actually invoke the session, starting processing
- dispatch(sessionInvoked({ sessionId }));
}
next(action);
diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts
index 4431a9fd8b..2c4cba510a 100644
--- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts
+++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts
@@ -1,14 +1,13 @@
import { MiddlewareAPI } from '@reduxjs/toolkit';
import { AppDispatch, RootState } from 'app/store/store';
import { getTimestamp } from 'common/util/getTimestamp';
-import { sessionCanceled } from 'services/thunks/session';
import { Socket } from 'socket.io-client';
import {
- generatorProgress,
- graphExecutionStateComplete,
- invocationComplete,
- invocationError,
- invocationStarted,
+ socketGeneratorProgress,
+ socketGraphExecutionStateComplete,
+ socketInvocationComplete,
+ socketInvocationError,
+ socketInvocationStarted,
socketConnected,
socketDisconnected,
socketSubscribed,
@@ -16,12 +15,6 @@ import {
import { ClientToServerEvents, ServerToClientEvents } from '../types';
import { Logger } from 'roarr';
import { JsonObject } from 'roarr/dist/types';
-import {
- receivedResultImagesPage,
- receivedUploadImagesPage,
-} from 'services/thunks/gallery';
-import { receivedModels } from 'services/thunks/model';
-import { receivedOpenAPISchema } from 'services/thunks/schema';
import { makeToast } from '../../../app/components/Toaster';
import { addToast } from '../../../features/system/store/systemSlice';
@@ -43,37 +36,13 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
dispatch(socketConnected({ timestamp: getTimestamp() }));
- const { results, uploads, models, nodes, config, system } = getState();
+ const { sessionId } = getState().system;
- const { disabledTabs } = config;
-
- // These thunks need to be dispatch in middleware; cannot handle in a reducer
- if (!results.ids.length) {
- dispatch(receivedResultImagesPage());
- }
-
- if (!uploads.ids.length) {
- dispatch(receivedUploadImagesPage());
- }
-
- if (!models.ids.length) {
- dispatch(receivedModels());
- }
-
- if (!nodes.schema && !disabledTabs.includes('nodes')) {
- dispatch(receivedOpenAPISchema());
- }
-
- if (system.sessionId) {
- log.debug(
- { sessionId: system.sessionId },
- `Subscribed to existing session (${system.sessionId})`
- );
-
- socket.emit('subscribe', { session: system.sessionId });
+ if (sessionId) {
+ socket.emit('subscribe', { session: sessionId });
dispatch(
socketSubscribed({
- sessionId: system.sessionId,
+ sessionId,
timestamp: getTimestamp(),
})
);
@@ -101,7 +70,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
* Disconnect
*/
socket.on('disconnect', () => {
- log.debug('Disconnected');
dispatch(socketDisconnected({ timestamp: getTimestamp() }));
});
@@ -109,70 +77,29 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
* Invocation started
*/
socket.on('invocation_started', (data) => {
- if (getState().system.canceledSession === data.graph_execution_state_id) {
- log.trace(
- { data, sessionId: data.graph_execution_state_id },
- `Ignored invocation started (${data.node.type}) for canceled session (${data.graph_execution_state_id})`
- );
- return;
- }
-
- log.info(
- { data, sessionId: data.graph_execution_state_id },
- `Invocation started (${data.node.type})`
- );
- dispatch(invocationStarted({ data, timestamp: getTimestamp() }));
+ dispatch(socketInvocationStarted({ data, timestamp: getTimestamp() }));
});
/**
* Generator progress
*/
socket.on('generator_progress', (data) => {
- if (getState().system.canceledSession === data.graph_execution_state_id) {
- log.trace(
- { data, sessionId: data.graph_execution_state_id },
- `Ignored generator progress (${data.node.type}) for canceled session (${data.graph_execution_state_id})`
- );
- return;
- }
-
- log.trace(
- { data, sessionId: data.graph_execution_state_id },
- `Generator progress (${data.node.type})`
- );
- dispatch(generatorProgress({ data, timestamp: getTimestamp() }));
+ dispatch(socketGeneratorProgress({ data, timestamp: getTimestamp() }));
});
/**
* Invocation error
*/
socket.on('invocation_error', (data) => {
- log.error(
- { data, sessionId: data.graph_execution_state_id },
- `Invocation error (${data.node.type})`
- );
- dispatch(invocationError({ data, timestamp: getTimestamp() }));
+ dispatch(socketInvocationError({ data, timestamp: getTimestamp() }));
});
/**
* Invocation complete
*/
socket.on('invocation_complete', (data) => {
- log.info(
- { data, sessionId: data.graph_execution_state_id },
- `Invocation complete (${data.node.type})`
- );
- const sessionId = data.graph_execution_state_id;
-
- const { cancelType, isCancelScheduled } = getState().system;
-
- // Handle scheduled cancelation
- if (cancelType === 'scheduled' && isCancelScheduled) {
- dispatch(sessionCanceled({ sessionId }));
- }
-
dispatch(
- invocationComplete({
+ socketInvocationComplete({
data,
timestamp: getTimestamp(),
})
@@ -183,10 +110,11 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
* Graph complete
*/
socket.on('graph_execution_state_complete', (data) => {
- log.info(
- { data, sessionId: data.graph_execution_state_id },
- `Graph execution state complete (${data.graph_execution_state_id})`
+ dispatch(
+ socketGraphExecutionStateComplete({
+ data,
+ timestamp: getTimestamp(),
+ })
);
- dispatch(graphExecutionStateComplete({ data, timestamp: getTimestamp() }));
});
};
diff --git a/invokeai/frontend/web/src/services/thunks/gallery.ts b/invokeai/frontend/web/src/services/thunks/gallery.ts
deleted file mode 100644
index 01e8a986b2..0000000000
--- a/invokeai/frontend/web/src/services/thunks/gallery.ts
+++ /dev/null
@@ -1,45 +0,0 @@
-import { log } from 'app/logging/useLogger';
-import { createAppAsyncThunk } from 'app/store/storeUtils';
-import { ImagesService } from 'services/api';
-
-export const IMAGES_PER_PAGE = 20;
-
-const galleryLog = log.child({ namespace: 'gallery' });
-
-export const receivedResultImagesPage = createAppAsyncThunk(
- 'results/receivedResultImagesPage',
- async (_arg, { getState, rejectWithValue }) => {
- const { page, pages, nextPage } = getState().results;
-
- if (nextPage === page) {
- rejectWithValue([]);
- }
-
- const response = await ImagesService.listImagesWithMetadata({
- imageType: 'results',
- imageCategory: 'general',
- page: getState().results.nextPage,
- perPage: IMAGES_PER_PAGE,
- });
-
- galleryLog.info({ response }, `Received ${response.items.length} results`);
-
- return response;
- }
-);
-
-export const receivedUploadImagesPage = createAppAsyncThunk(
- 'uploads/receivedUploadImagesPage',
- async (_arg, { getState }) => {
- const response = await ImagesService.listImagesWithMetadata({
- imageType: 'uploads',
- imageCategory: 'general',
- page: getState().uploads.nextPage,
- perPage: IMAGES_PER_PAGE,
- });
-
- galleryLog.info({ response }, `Received ${response.items.length} uploads`);
-
- return response;
- }
-);
diff --git a/invokeai/frontend/web/src/services/thunks/image.ts b/invokeai/frontend/web/src/services/thunks/image.ts
index 6831eb647d..87832c6b1e 100644
--- a/invokeai/frontend/web/src/services/thunks/image.ts
+++ b/invokeai/frontend/web/src/services/thunks/image.ts
@@ -1,10 +1,6 @@
-import { log } from 'app/logging/useLogger';
import { createAppAsyncThunk } from 'app/store/storeUtils';
-import { InvokeTabName } from 'features/ui/store/tabMap';
+import { selectImagesAll } from 'features/gallery/store/imagesSlice';
import { ImagesService } from 'services/api';
-import { getHeaders } from 'services/util/getHeaders';
-
-const imagesLog = log.child({ namespace: 'image' });
type imageUrlsReceivedArg = Parameters<
(typeof ImagesService)['getImageUrls']
@@ -17,7 +13,6 @@ export const imageUrlsReceived = createAppAsyncThunk(
'api/imageUrlsReceived',
async (arg: imageUrlsReceivedArg) => {
const response = await ImagesService.getImageUrls(arg);
- imagesLog.info({ arg, response }, 'Received image urls');
return response;
}
);
@@ -33,16 +28,11 @@ export const imageMetadataReceived = createAppAsyncThunk(
'api/imageMetadataReceived',
async (arg: imageMetadataReceivedArg) => {
const response = await ImagesService.getImageMetadata(arg);
- imagesLog.info({ arg, response }, 'Received image record');
return response;
}
);
-type ImageUploadedArg = Parameters<(typeof ImagesService)['uploadImage']>[0] & {
- // extra arg to determine post-upload actions - we check for this when the image is uploaded
- // to determine if we should set the init image
- activeTabName?: InvokeTabName;
-};
+type ImageUploadedArg = Parameters<(typeof ImagesService)['uploadImage']>[0];
/**
* `ImagesService.uploadImage()` thunk
@@ -51,13 +41,8 @@ export const imageUploaded = createAppAsyncThunk(
'api/imageUploaded',
async (arg: ImageUploadedArg) => {
// strip out `activeTabName` from arg - the route does not need it
- const { activeTabName, ...rest } = arg;
- const response = await ImagesService.uploadImage(rest);
- const { location } = getHeaders(response);
-
- imagesLog.debug({ arg: '', response, location }, 'Image uploaded');
-
- return { response, location };
+ const response = await ImagesService.uploadImage(arg);
+ return response;
}
);
@@ -70,9 +55,48 @@ export const imageDeleted = createAppAsyncThunk(
'api/imageDeleted',
async (arg: ImageDeletedArg) => {
const response = await ImagesService.deleteImage(arg);
-
- imagesLog.debug({ arg, response }, 'Image deleted');
-
+ return response;
+ }
+);
+
+type ImageUpdatedArg = Parameters<(typeof ImagesService)['updateImage']>[0];
+
+/**
+ * `ImagesService.updateImage()` thunk
+ */
+export const imageUpdated = createAppAsyncThunk(
+ 'api/imageUpdated',
+ async (arg: ImageUpdatedArg) => {
+ const response = await ImagesService.updateImage(arg);
+ return response;
+ }
+);
+
+type ImagesListedArg = Parameters<
+ (typeof ImagesService)['listImagesWithMetadata']
+>[0];
+
+export const IMAGES_PER_PAGE = 20;
+
+/**
+ * `ImagesService.listImagesWithMetadata()` thunk
+ */
+export const receivedPageOfImages = createAppAsyncThunk(
+ 'api/receivedPageOfImages',
+ async (_, { getState }) => {
+ const state = getState();
+ const { categories } = state.images;
+
+ const totalImagesInFilter = selectImagesAll(state).filter((i) =>
+ categories.includes(i.image_category)
+ ).length;
+
+ const response = await ImagesService.listImagesWithMetadata({
+ categories,
+ isIntermediate: false,
+ offset: totalImagesInFilter,
+ limit: IMAGES_PER_PAGE,
+ });
return response;
}
);
diff --git a/invokeai/frontend/web/src/services/thunks/session.ts b/invokeai/frontend/web/src/services/thunks/session.ts
index dca4134886..cf87fb30f5 100644
--- a/invokeai/frontend/web/src/services/thunks/session.ts
+++ b/invokeai/frontend/web/src/services/thunks/session.ts
@@ -1,7 +1,7 @@
import { createAppAsyncThunk } from 'app/store/storeUtils';
-import { SessionsService } from 'services/api';
+import { GraphExecutionState, SessionsService } from 'services/api';
import { log } from 'app/logging/useLogger';
-import { serializeError } from 'serialize-error';
+import { isObject } from 'lodash-es';
const sessionLog = log.child({ namespace: 'session' });
@@ -11,99 +11,89 @@ type SessionCreatedArg = {
>[0]['requestBody'];
};
+type SessionCreatedThunkConfig = {
+ rejectValue: { arg: SessionCreatedArg; error: unknown };
+};
+
/**
* `SessionsService.createSession()` thunk
*/
-export const sessionCreated = createAppAsyncThunk(
- 'api/sessionCreated',
- async (arg: SessionCreatedArg, { rejectWithValue }) => {
- try {
- const response = await SessionsService.createSession({
- requestBody: arg.graph,
- });
- sessionLog.info({ arg, response }, `Session created (${response.id})`);
- return response;
- } catch (err: any) {
- sessionLog.error(
- {
- error: serializeError(err),
- },
- 'Problem creating session'
- );
- return rejectWithValue(err.message);
- }
- }
-);
-
-type NodeAddedArg = Parameters<(typeof SessionsService)['addNode']>[0];
-
-/**
- * `SessionsService.addNode()` thunk
- */
-export const nodeAdded = createAppAsyncThunk(
- 'api/nodeAdded',
- async (
- arg: { node: NodeAddedArg['requestBody']; sessionId: string },
- _thunkApi
- ) => {
- const response = await SessionsService.addNode({
- requestBody: arg.node,
- sessionId: arg.sessionId,
+export const sessionCreated = createAppAsyncThunk<
+ GraphExecutionState,
+ SessionCreatedArg,
+ SessionCreatedThunkConfig
+>('api/sessionCreated', async (arg, { rejectWithValue }) => {
+ try {
+ const response = await SessionsService.createSession({
+ requestBody: arg.graph,
});
-
- sessionLog.info({ arg, response }, `Node added (${response})`);
-
return response;
+ } catch (error) {
+ return rejectWithValue({ arg, error });
}
-);
+});
+
+type SessionInvokedArg = { sessionId: string };
+
+type SessionInvokedThunkConfig = {
+ rejectValue: {
+ arg: SessionInvokedArg;
+ error: unknown;
+ };
+};
+
+const isErrorWithStatus = (error: unknown): error is { status: number } =>
+ isObject(error) && 'status' in error;
/**
* `SessionsService.invokeSession()` thunk
*/
-export const sessionInvoked = createAppAsyncThunk(
- 'api/sessionInvoked',
- async (arg: { sessionId: string }, { rejectWithValue }) => {
- const { sessionId } = arg;
+export const sessionInvoked = createAppAsyncThunk<
+ void,
+ SessionInvokedArg,
+ SessionInvokedThunkConfig
+>('api/sessionInvoked', async (arg, { rejectWithValue }) => {
+ const { sessionId } = arg;
- try {
- const response = await SessionsService.invokeSession({
- sessionId,
- all: true,
- });
- sessionLog.info({ arg, response }, `Session invoked (${sessionId})`);
-
- return response;
- } catch (error) {
- const err = error as any;
- if (err.status === 403) {
- return rejectWithValue(err.body.detail);
- }
- throw error;
+ try {
+ const response = await SessionsService.invokeSession({
+ sessionId,
+ all: true,
+ });
+ return response;
+ } catch (error) {
+ if (isErrorWithStatus(error) && error.status === 403) {
+ return rejectWithValue({ arg, error: (error as any).body.detail });
}
+ return rejectWithValue({ arg, error });
}
-);
+});
type SessionCanceledArg = Parameters<
(typeof SessionsService)['cancelSessionInvoke']
>[0];
-
+type SessionCanceledThunkConfig = {
+ rejectValue: {
+ arg: SessionCanceledArg;
+ error: unknown;
+ };
+};
/**
* `SessionsService.cancelSession()` thunk
*/
-export const sessionCanceled = createAppAsyncThunk(
- 'api/sessionCanceled',
- async (arg: SessionCanceledArg, _thunkApi) => {
- const { sessionId } = arg;
+export const sessionCanceled = createAppAsyncThunk<
+ void,
+ SessionCanceledArg,
+ SessionCanceledThunkConfig
+>('api/sessionCanceled', async (arg: SessionCanceledArg, _thunkApi) => {
+ const { sessionId } = arg;
- const response = await SessionsService.cancelSessionInvoke({
- sessionId,
- });
+ const response = await SessionsService.cancelSessionInvoke({
+ sessionId,
+ });
- sessionLog.info({ arg, response }, `Session canceled (${sessionId})`);
-
- return response;
- }
-);
+ return response;
+});
type SessionsListedArg = Parameters<
(typeof SessionsService)['listSessions']
diff --git a/invokeai/frontend/web/src/services/types/guards.ts b/invokeai/frontend/web/src/services/types/guards.ts
index 266e991f4d..4d33cfa246 100644
--- a/invokeai/frontend/web/src/services/types/guards.ts
+++ b/invokeai/frontend/web/src/services/types/guards.ts
@@ -1,5 +1,3 @@
-import { ResultsImageDTO } from 'features/gallery/store/resultsSlice';
-import { UploadsImageDTO } from 'features/gallery/store/uploadsSlice';
import { get, isObject, isString } from 'lodash-es';
import {
GraphExecutionState,
@@ -9,18 +7,11 @@ import {
PromptOutput,
IterateInvocationOutput,
CollectInvocationOutput,
- ImageType,
ImageField,
LatentsOutput,
- ImageDTO,
+ ResourceOrigin,
} from 'services/api';
-export const isUploadsImageDTO = (image: ImageDTO): image is UploadsImageDTO =>
- image.image_type === 'uploads';
-
-export const isResultsImageDTO = (image: ImageDTO): image is ResultsImageDTO =>
- image.image_type === 'results';
-
export const isImageOutput = (
output: GraphExecutionState['results'][string]
): output is ImageOutput => output.type === 'image_output';
@@ -49,10 +40,10 @@ export const isCollectOutput = (
output: GraphExecutionState['results'][string]
): output is CollectInvocationOutput => output.type === 'collect_output';
-export const isImageType = (t: unknown): t is ImageType =>
- isString(t) && ['results', 'uploads', 'intermediates'].includes(t);
+export const isResourceOrigin = (t: unknown): t is ResourceOrigin =>
+ isString(t) && ['internal', 'external'].includes(t);
export const isImageField = (imageField: unknown): imageField is ImageField =>
isObject(imageField) &&
isString(get(imageField, 'image_name')) &&
- isImageType(get(imageField, 'image_type'));
+ isResourceOrigin(get(imageField, 'image_origin'));
diff --git a/invokeai/frontend/web/stats.html b/invokeai/frontend/web/stats.html
index dba25766e4..e9c4381206 100644
--- a/invokeai/frontend/web/stats.html
+++ b/invokeai/frontend/web/stats.html
@@ -6157,7 +6157,7 @@ var drawChart = (function (exports) {