diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 046e8f0c57..ed3ab0375c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,7 +2,7 @@ /.github/workflows/ @lstein @blessedcoolant # documentation -/docs/ @lstein @tildebyte @blessedcoolant +/docs/ @lstein @blessedcoolant @hipsterusername /mkdocs.yml @lstein @blessedcoolant # nodes @@ -18,17 +18,17 @@ /invokeai/version @lstein @blessedcoolant # web ui -/invokeai/frontend @blessedcoolant @psychedelicious @lstein -/invokeai/backend @blessedcoolant @psychedelicious @lstein +/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp +/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp # generation, model management, postprocessing -/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 +/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 @StAlKeR7779 # front ends /invokeai/frontend/CLI @lstein /invokeai/frontend/install @lstein @ebr -/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername -/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername -/invokeai/frontend/web @psychedelicious @blessedcoolant +/invokeai/frontend/merge @lstein @blessedcoolant +/invokeai/frontend/training @lstein @blessedcoolant +/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 1ea365eb50..d2fcf0a364 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -5,6 +5,7 @@ import os from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.images import ImageService from invokeai.app.services.metadata import CoreMetadataService +from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.urls import LocalUrlService from invokeai.backend.util.logging import InvokeAILogger @@ -67,7 +68,7 @@ class ApiDependencies: metadata = CoreMetadataService() image_record_storage = SqliteImageRecordStorage(db_location) image_file_storage = DiskImageFileStorage(f"{output_folder}/images") - + names = SimpleNameService() latents = ForwardCacheLatentsStorage( DiskLatentsStorage(f"{output_folder}/latents") ) @@ -78,6 +79,7 @@ class ApiDependencies: metadata=metadata, url=urls, logger=logger, + names=names, graph_execution_manager=graph_execution_manager, ) diff --git a/invokeai/app/api/models/images.py b/invokeai/app/api/models/images.py deleted file mode 100644 index fa04702326..0000000000 --- a/invokeai/app/api/models/images.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import Optional -from pydantic import BaseModel, Field - -from invokeai.app.models.image import ImageType - - -class ImageResponseMetadata(BaseModel): - """An image's metadata. Used only in HTTP responses.""" - - created: int = Field(description="The creation timestamp of the image") - width: int = Field(description="The width of the image in pixels") - height: int = Field(description="The height of the image in pixels") - # invokeai: Optional[InvokeAIMetadata] = Field( - # description="The image's InvokeAI-specific metadata" - # ) - - -class ImageResponse(BaseModel): - """The response type for images""" - - image_type: ImageType = Field(description="The type of the image") - image_name: str = Field(description="The name of the image") - image_url: str = Field(description="The url of the image") - thumbnail_url: str = Field(description="The url of the image's thumbnail") - metadata: ImageResponseMetadata = Field(description="The image's metadata") - - -class ProgressImage(BaseModel): - """The progress image sent intermittently during processing""" - - width: int = Field(description="The effective width of the image in pixels") - height: int = Field(description="The effective height of the image in pixels") - dataURL: str = Field(description="The image data as a b64 data URL") - - -class SavedImage(BaseModel): - image_name: str = Field(description="The name of the saved image") - thumbnail_name: str = Field(description="The name of the saved thumbnail") - created: int = Field(description="The created timestamp of the saved image") diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 0615ff187e..ae10cce140 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -1,13 +1,19 @@ import io -from fastapi import HTTPException, Path, Query, Request, Response, UploadFile +from typing import Optional +from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi.routing import APIRouter from fastapi.responses import FileResponse from PIL import Image from invokeai.app.models.image import ( ImageCategory, - ImageType, + ResourceOrigin, +) +from invokeai.app.services.image_record_storage import OffsetPaginatedResults +from invokeai.app.services.models.image_record import ( + ImageDTO, + ImageRecordChanges, + ImageUrlsDTO, ) -from invokeai.app.services.models.image_record import ImageDTO, ImageUrlsDTO from invokeai.app.services.item_storage import PaginatedResults from ..dependencies import ApiDependencies @@ -27,10 +33,13 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"]) ) async def upload_image( file: UploadFile, - image_type: ImageType, request: Request, response: Response, - image_category: ImageCategory = ImageCategory.GENERAL, + image_category: ImageCategory = Query(description="The category of the image"), + is_intermediate: bool = Query(description="Whether this is an intermediate image"), + session_id: Optional[str] = Query( + default=None, description="The session ID associated with this upload, if any" + ), ) -> ImageDTO: """Uploads an image""" if not file.content_type.startswith("image"): @@ -46,9 +55,11 @@ async def upload_image( try: image_dto = ApiDependencies.invoker.services.images.create( - pil_image, - image_type, - image_category, + image=pil_image, + image_origin=ResourceOrigin.EXTERNAL, + image_category=image_category, + session_id=session_id, + is_intermediate=is_intermediate, ) response.status_code = 201 @@ -59,41 +70,61 @@ async def upload_image( raise HTTPException(status_code=500, detail="Failed to create image") -@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image") +@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image") async def delete_image( - image_type: ImageType = Query(description="The type of image to delete"), + image_origin: ResourceOrigin = Path(description="The origin of image to delete"), image_name: str = Path(description="The name of the image to delete"), ) -> None: """Deletes an image""" try: - ApiDependencies.invoker.services.images.delete(image_type, image_name) + ApiDependencies.invoker.services.images.delete(image_origin, image_name) except Exception as e: # TODO: Does this need any exception handling at all? pass +@images_router.patch( + "/{image_origin}/{image_name}", + operation_id="update_image", + response_model=ImageDTO, +) +async def update_image( + image_origin: ResourceOrigin = Path(description="The origin of image to update"), + image_name: str = Path(description="The name of the image to update"), + image_changes: ImageRecordChanges = Body( + description="The changes to apply to the image" + ), +) -> ImageDTO: + """Updates an image""" + + try: + return ApiDependencies.invoker.services.images.update( + image_origin, image_name, image_changes + ) + except Exception as e: + raise HTTPException(status_code=400, detail="Failed to update image") + + @images_router.get( - "/{image_type}/{image_name}/metadata", + "/{image_origin}/{image_name}/metadata", operation_id="get_image_metadata", response_model=ImageDTO, ) async def get_image_metadata( - image_type: ImageType = Path(description="The type of image to get"), + image_origin: ResourceOrigin = Path(description="The origin of image to get"), image_name: str = Path(description="The name of image to get"), ) -> ImageDTO: """Gets an image's metadata""" try: - return ApiDependencies.invoker.services.images.get_dto( - image_type, image_name - ) + return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name) except Exception as e: raise HTTPException(status_code=404) @images_router.get( - "/{image_type}/{image_name}", + "/{image_origin}/{image_name}", operation_id="get_image_full", response_class=Response, responses={ @@ -105,7 +136,7 @@ async def get_image_metadata( }, ) async def get_image_full( - image_type: ImageType = Path( + image_origin: ResourceOrigin = Path( description="The type of full-resolution image file to get" ), image_name: str = Path(description="The name of full-resolution image file to get"), @@ -113,9 +144,7 @@ async def get_image_full( """Gets a full-resolution image file""" try: - path = ApiDependencies.invoker.services.images.get_path( - image_type, image_name - ) + path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name) if not ApiDependencies.invoker.services.images.validate_path(path): raise HTTPException(status_code=404) @@ -131,7 +160,7 @@ async def get_image_full( @images_router.get( - "/{image_type}/{image_name}/thumbnail", + "/{image_origin}/{image_name}/thumbnail", operation_id="get_image_thumbnail", response_class=Response, responses={ @@ -143,14 +172,14 @@ async def get_image_full( }, ) async def get_image_thumbnail( - image_type: ImageType = Path(description="The type of thumbnail image file to get"), + image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"), image_name: str = Path(description="The name of thumbnail image file to get"), ) -> FileResponse: """Gets a thumbnail image file""" try: path = ApiDependencies.invoker.services.images.get_path( - image_type, image_name, thumbnail=True + image_origin, image_name, thumbnail=True ) if not ApiDependencies.invoker.services.images.validate_path(path): raise HTTPException(status_code=404) @@ -163,25 +192,25 @@ async def get_image_thumbnail( @images_router.get( - "/{image_type}/{image_name}/urls", + "/{image_origin}/{image_name}/urls", operation_id="get_image_urls", response_model=ImageUrlsDTO, ) async def get_image_urls( - image_type: ImageType = Path(description="The type of the image whose URL to get"), + image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"), image_name: str = Path(description="The name of the image whose URL to get"), ) -> ImageUrlsDTO: """Gets an image and thumbnail URL""" try: image_url = ApiDependencies.invoker.services.images.get_url( - image_type, image_name + image_origin, image_name ) thumbnail_url = ApiDependencies.invoker.services.images.get_url( - image_type, image_name, thumbnail=True + image_origin, image_name, thumbnail=True ) return ImageUrlsDTO( - image_type=image_type, + image_origin=image_origin, image_name=image_name, image_url=image_url, thumbnail_url=thumbnail_url, @@ -193,23 +222,29 @@ async def get_image_urls( @images_router.get( "/", operation_id="list_images_with_metadata", - response_model=PaginatedResults[ImageDTO], + response_model=OffsetPaginatedResults[ImageDTO], ) async def list_images_with_metadata( - image_type: ImageType = Query(description="The type of images to list"), - image_category: ImageCategory = Query(description="The kind of images to list"), - page: int = Query(default=0, description="The page of image metadata to get"), - per_page: int = Query( - default=10, description="The number of image metadata per page" + image_origin: Optional[ResourceOrigin] = Query( + default=None, description="The origin of images to list" ), -) -> PaginatedResults[ImageDTO]: - """Gets a list of images with metadata""" + categories: Optional[list[ImageCategory]] = Query( + default=None, description="The categories of image to include" + ), + is_intermediate: Optional[bool] = Query( + default=None, description="Whether to list intermediate images" + ), + offset: int = Query(default=0, description="The page offset"), + limit: int = Query(default=10, description="The number of images per page"), +) -> OffsetPaginatedResults[ImageDTO]: + """Gets a list of images""" image_dtos = ApiDependencies.invoker.services.images.get_many( - image_type, - image_category, - page, - per_page, + offset, + limit, + image_origin, + categories, + is_intermediate, ) return image_dtos diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index c6267a6871..15afe65a8a 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -12,11 +12,10 @@ from pydantic import BaseModel, ValidationError from pydantic.fields import Field import invokeai.backend.util.logging as logger -import invokeai.version from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.images import ImageService -from invokeai.app.services.metadata import (CoreMetadataService, - PngMetadataService) +from invokeai.app.services.metadata import CoreMetadataService +from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.urls import LocalUrlService from .cli.commands import (BaseCommand, CliContext, ExitCli, @@ -232,6 +231,7 @@ def invoke_cli(): metadata = CoreMetadataService() image_record_storage = SqliteImageRecordStorage(db_location) image_file_storage = DiskImageFileStorage(f"{output_folder}/images") + names = SimpleNameService() images = ImageService( image_record_storage=image_record_storage, @@ -239,6 +239,7 @@ def invoke_cli(): metadata=metadata, url=urls, logger=logger, + names=names, graph_execution_manager=graph_execution_manager, ) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index da61641105..4ce3e839b6 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -78,6 +78,7 @@ class BaseInvocation(ABC, BaseModel): #fmt: off id: str = Field(description="The id of this node. Must be unique among all nodes.") + is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.") #fmt: on @@ -95,6 +96,7 @@ class UIConfig(TypedDict, total=False): "image", "latents", "model", + "control", ], ] tags: List[str] diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index 475b6028a9..891f217317 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -22,6 +22,14 @@ class IntCollectionOutput(BaseInvocationOutput): # Outputs collection: list[int] = Field(default=[], description="The int collection") +class FloatCollectionOutput(BaseInvocationOutput): + """A collection of floats""" + + type: Literal["float_collection"] = "float_collection" + + # Outputs + collection: list[float] = Field(default=[], description="The float collection") + class RangeInvocation(BaseInvocation): """Creates a range of numbers from start to stop with step""" diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py new file mode 100644 index 0000000000..7d5160a491 --- /dev/null +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -0,0 +1,428 @@ +# InvokeAI nodes for ControlNet image preprocessors +# initial implementation by Gregg Helt, 2023 +# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux + +import numpy as np +from typing import Literal, Optional, Union, List +from PIL import Image, ImageFilter, ImageOps +from pydantic import BaseModel, Field + +from ..models.image import ImageField, ImageCategory, ResourceOrigin +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + InvocationContext, + InvocationConfig, +) +from controlnet_aux import ( + CannyDetector, + HEDdetector, + LineartDetector, + LineartAnimeDetector, + MidasDetector, + MLSDdetector, + NormalBaeDetector, + OpenposeDetector, + PidiNetDetector, + ContentShuffleDetector, + ZoeDetector, + MediapipeFaceDetector, +) + +from .image import ImageOutput, PILInvocationConfig + +CONTROLNET_DEFAULT_MODELS = [ + ########################################### + # lllyasviel sd v1.5, ControlNet v1.0 models + ############################################## + "lllyasviel/sd-controlnet-canny", + "lllyasviel/sd-controlnet-depth", + "lllyasviel/sd-controlnet-hed", + "lllyasviel/sd-controlnet-seg", + "lllyasviel/sd-controlnet-openpose", + "lllyasviel/sd-controlnet-scribble", + "lllyasviel/sd-controlnet-normal", + "lllyasviel/sd-controlnet-mlsd", + + ############################################# + # lllyasviel sd v1.5, ControlNet v1.1 models + ############################################# + "lllyasviel/control_v11p_sd15_canny", + "lllyasviel/control_v11p_sd15_openpose", + "lllyasviel/control_v11p_sd15_seg", + # "lllyasviel/control_v11p_sd15_depth", # broken + "lllyasviel/control_v11f1p_sd15_depth", + "lllyasviel/control_v11p_sd15_normalbae", + "lllyasviel/control_v11p_sd15_scribble", + "lllyasviel/control_v11p_sd15_mlsd", + "lllyasviel/control_v11p_sd15_softedge", + "lllyasviel/control_v11p_sd15s2_lineart_anime", + "lllyasviel/control_v11p_sd15_lineart", + "lllyasviel/control_v11p_sd15_inpaint", + # "lllyasviel/control_v11u_sd15_tile", + # problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile", + # so for now replace "lllyasviel/control_v11f1e_sd15_tile", + "lllyasviel/control_v11e_sd15_shuffle", + "lllyasviel/control_v11e_sd15_ip2p", + "lllyasviel/control_v11f1e_sd15_tile", + + ################################################# + # thibaud sd v2.1 models (ControlNet v1.0? or v1.1? + ################################################## + "thibaud/controlnet-sd21-openpose-diffusers", + "thibaud/controlnet-sd21-canny-diffusers", + "thibaud/controlnet-sd21-depth-diffusers", + "thibaud/controlnet-sd21-scribble-diffusers", + "thibaud/controlnet-sd21-hed-diffusers", + "thibaud/controlnet-sd21-zoedepth-diffusers", + "thibaud/controlnet-sd21-color-diffusers", + "thibaud/controlnet-sd21-openposev2-diffusers", + "thibaud/controlnet-sd21-lineart-diffusers", + "thibaud/controlnet-sd21-normalbae-diffusers", + "thibaud/controlnet-sd21-ade20k-diffusers", + + ############################################## + # ControlNetMediaPipeface, ControlNet v1.1 + ############################################## + # ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5 + # diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg + # hacked t2l to split to model & subfolder if format is "model,subfolder" + "CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5 + "CrucibleAI/ControlNetMediaPipeFace", # SD 2.1? +] + +CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] + +class ControlField(BaseModel): + image: ImageField = Field(default=None, description="processed image") + control_model: Optional[str] = Field(default=None, description="control model used") + control_weight: Optional[float] = Field(default=1, description="weight given to controlnet") + begin_step_percent: float = Field(default=0, ge=0, le=1, + description="% of total steps at which controlnet is first applied") + end_step_percent: float = Field(default=1, ge=0, le=1, + description="% of total steps at which controlnet is last applied") + + class Config: + schema_extra = { + "required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"] + } + + +class ControlOutput(BaseInvocationOutput): + """node output for ControlNet info""" + # fmt: off + type: Literal["control_output"] = "control_output" + control: ControlField = Field(default=None, description="The control info dict") + # fmt: on + + +class ControlNetInvocation(BaseInvocation): + """Collects ControlNet info to pass to other nodes""" + # fmt: off + type: Literal["controlnet"] = "controlnet" + # Inputs + image: ImageField = Field(default=None, description="image to process") + control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny", + description="control model used") + control_weight: float = Field(default=1.0, ge=0, le=1, description="weight given to controlnet") + # TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode + begin_step_percent: float = Field(default=0, ge=0, le=1, + description="% of total steps at which controlnet is first applied") + end_step_percent: float = Field(default=1, ge=0, le=1, + description="% of total steps at which controlnet is last applied") + # fmt: on + + + def invoke(self, context: InvocationContext) -> ControlOutput: + + return ControlOutput( + control=ControlField( + image=self.image, + control_model=self.control_model, + control_weight=self.control_weight, + begin_step_percent=self.begin_step_percent, + end_step_percent=self.end_step_percent, + ), + ) + +# TODO: move image processors to separate file (image_analysis.py +class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): + """Base class for invocations that preprocess images for ControlNet""" + + # fmt: off + type: Literal["image_processor"] = "image_processor" + # Inputs + image: ImageField = Field(default=None, description="image to process") + # fmt: on + + + def run_processor(self, image): + # superclass just passes through image without processing + return image + + def invoke(self, context: InvocationContext) -> ImageOutput: + + raw_image = context.services.images.get_pil_image( + self.image.image_origin, self.image.image_name + ) + # image type should be PIL.PngImagePlugin.PngImageFile ? + processed_image = self.run_processor(raw_image) + + # FIXME: what happened to image metadata? + # metadata = context.services.metadata.build_metadata( + # session_id=context.graph_execution_state_id, node=self + # ) + + # currently can't see processed image in node UI without a showImage node, + # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery + image_dto = context.services.images.create( + image=processed_image, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.CONTROL, + session_id=context.graph_execution_state_id, + node_id=self.id, + is_intermediate=self.is_intermediate + ) + + """Builds an ImageOutput and its ImageField""" + processed_image_field = ImageField( + image_name=image_dto.image_name, + image_origin=image_dto.image_origin, + ) + return ImageOutput( + image=processed_image_field, + # width=processed_image.width, + width = image_dto.width, + # height=processed_image.height, + height = image_dto.height, + # mode=processed_image.mode, + ) + + +class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Canny edge detection for ControlNet""" + # fmt: off + type: Literal["canny_image_processor"] = "canny_image_processor" + # Input + low_threshold: float = Field(default=100, ge=0, description="low threshold of Canny pixel gradient") + high_threshold: float = Field(default=200, ge=0, description="high threshold of Canny pixel gradient") + # fmt: on + + def run_processor(self, image): + canny_processor = CannyDetector() + processed_image = canny_processor(image, self.low_threshold, self.high_threshold) + return processed_image + + +class HedImageprocessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies HED edge detection to image""" + # fmt: off + type: Literal["hed_image_processor"] = "hed_image_processor" + # Inputs + detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") + image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") + # safe not supported in controlnet_aux v0.0.3 + # safe: bool = Field(default=False, description="whether to use safe mode") + scribble: bool = Field(default=False, description="whether to use scribble mode") + # fmt: on + + def run_processor(self, image): + hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators") + processed_image = hed_processor(image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + # safe not supported in controlnet_aux v0.0.3 + # safe=self.safe, + scribble=self.scribble, + ) + return processed_image + + +class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies line art processing to image""" + # fmt: off + type: Literal["lineart_image_processor"] = "lineart_image_processor" + # Inputs + detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") + image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") + coarse: bool = Field(default=False, description="whether to use coarse mode") + # fmt: on + + def run_processor(self, image): + lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators") + processed_image = lineart_processor(image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + coarse=self.coarse) + return processed_image + + +class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies line art anime processing to image""" + # fmt: off + type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor" + # Inputs + detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") + image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") + # fmt: on + + def run_processor(self, image): + processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") + processed_image = processor(image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + ) + return processed_image + + +class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies Openpose processing to image""" + # fmt: off + type: Literal["openpose_image_processor"] = "openpose_image_processor" + # Inputs + hand_and_face: bool = Field(default=False, description="whether to use hands and face mode") + detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") + image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") + # fmt: on + + def run_processor(self, image): + openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators") + processed_image = openpose_processor(image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + hand_and_face=self.hand_and_face, + ) + return processed_image + + +class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies Midas depth processing to image""" + # fmt: off + type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor" + # Inputs + a_mult: float = Field(default=2.0, ge=0, description="Midas parameter a = amult * PI") + bg_th: float = Field(default=0.1, ge=0, description="Midas parameter bg_th") + # depth_and_normal not supported in controlnet_aux v0.0.3 + # depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode") + # fmt: on + + def run_processor(self, image): + midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") + processed_image = midas_processor(image, + a=np.pi * self.a_mult, + bg_th=self.bg_th, + # dept_and_normal not supported in controlnet_aux v0.0.3 + # depth_and_normal=self.depth_and_normal, + ) + return processed_image + + +class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies NormalBae processing to image""" + # fmt: off + type: Literal["normalbae_image_processor"] = "normalbae_image_processor" + # Inputs + detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") + image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") + # fmt: on + + def run_processor(self, image): + normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") + processed_image = normalbae_processor(image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution) + return processed_image + + +class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies MLSD processing to image""" + # fmt: off + type: Literal["mlsd_image_processor"] = "mlsd_image_processor" + # Inputs + detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") + image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") + thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter thr_v") + thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter thr_d") + # fmt: on + + def run_processor(self, image): + mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") + processed_image = mlsd_processor(image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + thr_v=self.thr_v, + thr_d=self.thr_d) + return processed_image + + +class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies PIDI processing to image""" + # fmt: off + type: Literal["pidi_image_processor"] = "pidi_image_processor" + # Inputs + detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") + image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") + safe: bool = Field(default=False, description="whether to use safe mode") + scribble: bool = Field(default=False, description="whether to use scribble mode") + # fmt: on + + def run_processor(self, image): + pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") + processed_image = pidi_processor(image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + safe=self.safe, + scribble=self.scribble) + return processed_image + + +class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies content shuffle processing to image""" + # fmt: off + type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor" + # Inputs + detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection") + image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image") + h: Union[int | None] = Field(default=512, ge=0, description="content shuffle h parameter") + w: Union[int | None] = Field(default=512, ge=0, description="content shuffle w parameter") + f: Union[int | None] = Field(default=256, ge=0, description="cont") + # fmt: on + + def run_processor(self, image): + content_shuffle_processor = ContentShuffleDetector() + processed_image = content_shuffle_processor(image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + h=self.h, + w=self.w, + f=self.f + ) + return processed_image + + +# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13 +class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies Zoe depth processing to image""" + # fmt: off + type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor" + # fmt: on + + def run_processor(self, image): + zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") + processed_image = zoe_depth_processor(image) + return processed_image + + +class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): + """Applies mediapipe face processing to image""" + # fmt: off + type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor" + # Inputs + max_faces: int = Field(default=1, ge=1, description="maximum number of faces to detect") + min_confidence: float = Field(default=0.5, ge=0, le=1, description="minimum confidence for face detection") + # fmt: on + + def run_processor(self, image): + mediapipe_face_processor = MediapipeFaceDetector() + processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence) + return processed_image diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 26e06a2af8..5275116a2a 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -7,7 +7,7 @@ import numpy from PIL import Image, ImageOps from pydantic import BaseModel, Field -from invokeai.app.models.image import ImageCategory, ImageField, ImageType +from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .image import ImageOutput @@ -37,10 +37,10 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) mask = context.services.images.get_pil_image( - self.mask.image_type, self.mask.image_name + self.mask.image_origin, self.mask.image_name ) # Convert to cv image/mask @@ -57,16 +57,17 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): image_dto = context.services.images.create( image=image_inpainted, - image_type=ImageType.INTERMEDIATE, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 1805362416..53d4d16330 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -3,16 +3,20 @@ from functools import partial from typing import Literal, Optional, Union, get_args +import torch +from diffusers import ControlNetModel from pydantic import BaseModel, Field -from invokeai.app.models.image import ImageCategory, ImageType, ColorField, ImageField +from invokeai.app.models.image import (ColorField, ImageCategory, ImageField, + ResourceOrigin) from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.backend.generator.inpaint import infill_methods -from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig -from .image import ImageOutput -from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator + +from ...backend.generator import Img2Img, Inpaint, InvokeAIGenerator, Txt2Img from ...backend.stable_diffusion import PipelineIntermediateState from ..util.step_callback import stable_diffusion_step_callback +from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext +from .image import ImageOutput SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())] INFILL_METHODS = Literal[tuple(infill_methods())] @@ -53,6 +57,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) model: str = Field(default="", description="The model to use (currently ignored)") + progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", ) + control_model: Optional[str] = Field(default=None, description="The control model to use") + control_image: Optional[ImageField] = Field(default=None, description="The processed control image") # fmt: on # TODO: pass this an emitter method or something? or a session for dispatching? @@ -73,17 +80,35 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): # Handle invalid model parameter model = context.services.model_manager.get_model(self.model,node=self,context=context) + # loading controlnet image (currently requires pre-processed image) + control_image = ( + None if self.control_image is None + else context.services.images.get_pil_image( + self.control_image.image_origin, self.control_image.image_name + ) + ) + # loading controlnet model + if (self.control_model is None or self.control_model==''): + control_model = None + else: + # FIXME: change this to dropdown menu? + # FIXME: generalize so don't have to hardcode torch_dtype and device + control_model = ControlNetModel.from_pretrained(self.control_model, + torch_dtype=torch.float16).to("cuda") + # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get( context.graph_execution_state_id ) source_node_id = graph_execution_state.prepared_source_mapping[self.id] - outputs = Txt2Img(model).generate( + txt2img = Txt2Img(model, control_model=control_model) + outputs = txt2img.generate( prompt=self.prompt, step_callback=partial(self.dispatch_progress, context, source_node_id), + control_image=control_image, **self.dict( - exclude={"prompt"} + exclude={"prompt", "control_image" } ), # Shorthand for passing all of the parameters above manually ) # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object @@ -92,16 +117,17 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): image_dto = context.services.images.create( image=generate_output.image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, session_id=context.graph_execution_state_id, node_id=self.id, + is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -141,7 +167,7 @@ class ImageToImageInvocation(TextToImageInvocation): None if self.image is None else context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) ) @@ -172,16 +198,17 @@ class ImageToImageInvocation(TextToImageInvocation): image_dto = context.services.images.create( image=generator_output.image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, session_id=context.graph_execution_state_id, node_id=self.id, + is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -253,13 +280,13 @@ class InpaintInvocation(ImageToImageInvocation): None if self.image is None else context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) ) mask = ( None if self.mask is None - else context.services.images.get_pil_image(self.mask.image_type, self.mask.image_name) + else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name) ) # Handle invalid model parameter @@ -287,16 +314,17 @@ class InpaintInvocation(ImageToImageInvocation): image_dto = context.services.images.create( image=generator_output.image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, session_id=context.graph_execution_state_id, node_id=self.id, + is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 21dfb4c1cd..d048410468 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -7,7 +7,7 @@ import numpy from PIL import Image, ImageFilter, ImageOps, ImageChops from pydantic import BaseModel, Field -from ..models.image import ImageCategory, ImageField, ImageType +from ..models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -72,12 +72,12 @@ class LoadImageInvocation(BaseInvocation): ) # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_type, self.image.image_name) + image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name) return ImageOutput( image=ImageField( image_name=self.image.image_name, - image_type=self.image.image_type, + image_origin=self.image.image_origin, ), width=image.width, height=image.height, @@ -96,7 +96,7 @@ class ShowImageInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) if image: image.show() @@ -106,7 +106,7 @@ class ShowImageInvocation(BaseInvocation): return ImageOutput( image=ImageField( image_name=self.image.image_name, - image_type=self.image.image_type, + image_origin=self.image.image_origin, ), width=image.width, height=image.height, @@ -129,7 +129,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) image_crop = Image.new( @@ -139,16 +139,17 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=image_crop, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -171,17 +172,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: base_image = context.services.images.get_pil_image( - self.base_image.image_type, self.base_image.image_name + self.base_image.image_origin, self.base_image.image_name ) image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) mask = ( None if self.mask is None else ImageOps.invert( context.services.images.get_pil_image( - self.mask.image_type, self.mask.image_name + self.mask.image_origin, self.mask.image_name ) ) ) @@ -200,16 +201,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=new_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -229,7 +231,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> MaskOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) image_mask = image.split()[-1] @@ -238,15 +240,16 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=image_mask, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.MASK, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return MaskOutput( mask=ImageField( - image_type=image_dto.image_type, image_name=image_dto.image_name + image_origin=image_dto.image_origin, image_name=image_dto.image_name ), width=image_dto.width, height=image_dto.height, @@ -266,25 +269,26 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image1 = context.services.images.get_pil_image( - self.image1.image_type, self.image1.image_name + self.image1.image_origin, self.image1.image_name ) image2 = context.services.images.get_pil_image( - self.image2.image_type, self.image2.image_name + self.image2.image_origin, self.image2.image_name ) multiply_image = ImageChops.multiply(image1, image2) image_dto = context.services.images.create( image=multiply_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField( - image_type=image_dto.image_type, image_name=image_dto.image_name + image_origin=image_dto.image_origin, image_name=image_dto.image_name ), width=image_dto.width, height=image_dto.height, @@ -307,22 +311,23 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) channel_image = image.getchannel(self.channel) image_dto = context.services.images.create( image=channel_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField( - image_type=image_dto.image_type, image_name=image_dto.image_name + image_origin=image_dto.image_origin, image_name=image_dto.image_name ), width=image_dto.width, height=image_dto.height, @@ -345,22 +350,23 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) converted_image = image.convert(self.mode) image_dto = context.services.images.create( image=converted_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField( - image_type=image_dto.image_type, image_name=image_dto.image_name + image_origin=image_dto.image_origin, image_name=image_dto.image_name ), width=image_dto.width, height=image_dto.height, @@ -381,7 +387,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) blur = ( @@ -393,16 +399,126 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=blur_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, + ), + width=image_dto.width, + height=image_dto.height, + ) + + +PIL_RESAMPLING_MODES = Literal[ + "nearest", + "box", + "bilinear", + "hamming", + "bicubic", + "lanczos", +] + + +PIL_RESAMPLING_MAP = { + "nearest": Image.Resampling.NEAREST, + "box": Image.Resampling.BOX, + "bilinear": Image.Resampling.BILINEAR, + "hamming": Image.Resampling.HAMMING, + "bicubic": Image.Resampling.BICUBIC, + "lanczos": Image.Resampling.LANCZOS, +} + + +class ImageResizeInvocation(BaseInvocation, PILInvocationConfig): + """Resizes an image to specific dimensions""" + + # fmt: off + type: Literal["img_resize"] = "img_resize" + + # Inputs + image: Union[ImageField, None] = Field(default=None, description="The image to resize") + width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") + height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") + resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") + # fmt: on + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get_pil_image( + self.image.image_origin, self.image.image_name + ) + + resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] + + resize_image = image.resize( + (self.width, self.height), + resample=resample_mode, + ) + + image_dto = context.services.images.create( + image=resize_image, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, + ) + + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_origin=image_dto.image_origin, + ), + width=image_dto.width, + height=image_dto.height, + ) + + +class ImageScaleInvocation(BaseInvocation, PILInvocationConfig): + """Scales an image by a factor""" + + # fmt: off + type: Literal["img_scale"] = "img_scale" + + # Inputs + image: Union[ImageField, None] = Field(default=None, description="The image to scale") + scale_factor: float = Field(gt=0, description="The factor by which to scale the image") + resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") + # fmt: on + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get_pil_image( + self.image.image_origin, self.image.image_name + ) + + resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] + width = int(image.width * self.scale_factor) + height = int(image.height * self.scale_factor) + + resize_image = image.resize( + (width, height), + resample=resample_mode, + ) + + image_dto = context.services.images.create( + image=resize_image, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, + ) + + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -423,7 +539,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 @@ -433,16 +549,17 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=lerp_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -463,7 +580,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) image_arr = numpy.asarray(image, dtype=numpy.float32) @@ -478,16 +595,17 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=ilerp_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index 17a43dbdac..a06780c1f5 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -11,7 +11,7 @@ from invokeai.app.invocations.image import ImageOutput from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.backend.image_util.patchmatch import PatchMatch -from ..models.image import ColorField, ImageCategory, ImageField, ImageType +from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin from .baseinvocation import ( BaseInvocation, InvocationContext, @@ -135,7 +135,7 @@ class InfillColorInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) solid_bg = Image.new("RGBA", image.size, self.color.tuple()) @@ -145,16 +145,17 @@ class InfillColorInvocation(BaseInvocation): image_dto = context.services.images.create( image=infilled, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -179,7 +180,7 @@ class InfillTileInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) infilled = tile_fill_missing( @@ -189,16 +190,17 @@ class InfillTileInvocation(BaseInvocation): image_dto = context.services.images.create( image=infilled, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -216,7 +218,7 @@ class InfillPatchMatchInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) if PatchMatch.patchmatch_available(): @@ -226,16 +228,17 @@ class InfillPatchMatchInvocation(BaseInvocation): image_dto = context.services.images.create( image=infilled, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index abfe92f828..6dc491611e 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1,37 +1,36 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) -from typing import Literal, Optional, Union +from contextlib import ExitStack +from typing import List, Literal, Optional, Union import einops import torch +from diffusers import ControlNetModel from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import SchedulerMixin as Scheduler from pydantic import BaseModel, Field, validator -from contextlib import ExitStack from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.step_callback import stable_diffusion_step_callback +from ..models.image import ImageCategory, ImageField, ResourceOrigin from ...backend.image_util.seamless import configure_model_padding from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( - ConditioningData, StableDiffusionGeneratorPipeline, + ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor) from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \ PostprocessingSettings from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.util.devices import choose_torch_device, torch_dtype -from ..services.image_file_storage import ImageType -from ..services.model_manager_service import ModelManagerService +from ...backend.model_management.lora import ModelPatcher from .baseinvocation import (BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext) from .compel import ConditioningField -from .image import ImageCategory, ImageField, ImageOutput +from .controlnet_image_processors import ControlField +from .image import ImageOutput from .model import ModelInfo, UNetField, VaeField -from ...backend.model_management.lora import LoRAHelper - - class LatentsField(BaseModel): """A latents field used for passing latents between invocations""" @@ -93,10 +92,12 @@ def get_scheduler( orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict()) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config + if "_backup" in scheduler_config: scheduler_config = scheduler_config["_backup"] scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config} scheduler = scheduler_class.from_config(scheduler_config) + # hack copied over from generate.py if not hasattr(scheduler, 'uses_inpainting_model'): scheduler.uses_inpainting_model = lambda: False @@ -171,12 +172,13 @@ class TextToLatentsInvocation(BaseInvocation): negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation") noise: Optional[LatentsField] = Field(description="The noise to use") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") - cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) + cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") unet: UNetField = Field(default=None, description="UNet submodel") + control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") # fmt: on # Schema customisation @@ -184,6 +186,10 @@ class TextToLatentsInvocation(BaseInvocation): schema_extra = { "ui": { "tags": ["latents", "image"], + "type_hints": { + "model": "model", + "control": "control", + } }, } @@ -243,6 +249,82 @@ class TextToLatentsInvocation(BaseInvocation): precision="float16" if unet.dtype == torch.float16 else "float32", #precision="float16", # TODO: ) + + def prep_control_data(self, + context: InvocationContext, + model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device + control_input: List[ControlField], + latents_shape: List[int], + do_classifier_free_guidance: bool = True, + ) -> List[ControlNetData]: + # assuming fixed dimensional scaling of 8:1 for image:latents + control_height_resize = latents_shape[2] * 8 + control_width_resize = latents_shape[3] * 8 + if control_input is None: + # print("control input is None") + control_list = None + elif isinstance(control_input, list) and len(control_input) == 0: + # print("control input is empty list") + control_list = None + elif isinstance(control_input, ControlField): + # print("control input is ControlField") + control_list = [control_input] + elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField): + # print("control input is list[ControlField]") + control_list = control_input + else: + # print("input control is unrecognized:", type(self.control)) + control_list = None + if (control_list is None): + control_data = None + # from above handling, any control that is not None should now be of type list[ControlField] + else: + # FIXME: add checks to skip entry if model or image is None + # and if weight is None, populate with default 1.0? + control_data = [] + control_models = [] + for control_info in control_list: + # handle control models + if ("," in control_info.control_model): + control_model_split = control_info.control_model.split(",") + control_name = control_model_split[0] + control_subfolder = control_model_split[1] + print("Using HF model subfolders") + print(" control_name: ", control_name) + print(" control_subfolder: ", control_subfolder) + control_model = ControlNetModel.from_pretrained(control_name, + subfolder=control_subfolder, + torch_dtype=model.unet.dtype).to(model.device) + else: + control_model = ControlNetModel.from_pretrained(control_info.control_model, + torch_dtype=model.unet.dtype).to(model.device) + control_models.append(control_model) + control_image_field = control_info.image + input_image = context.services.images.get_pil_image(control_image_field.image_origin, + control_image_field.image_name) + # self.image.image_type, self.image.image_name + # FIXME: still need to test with different widths, heights, devices, dtypes + # and add in batch_size, num_images_per_prompt? + # and do real check for classifier_free_guidance? + # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width) + control_image = model.prepare_control_image( + image=input_image, + do_classifier_free_guidance=do_classifier_free_guidance, + width=control_width_resize, + height=control_height_resize, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + device=control_model.device, + dtype=control_model.dtype, + ) + control_item = ControlNetData(model=control_model, + image_tensor=control_image, + weight=control_info.control_weight, + begin_step_percent=control_info.begin_step_percent, + end_step_percent=control_info.end_step_percent) + control_data.append(control_item) + # MultiControlNetModel has been refactored out, just need list[ControlNetData] + return control_data def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) @@ -269,13 +351,19 @@ class TextToLatentsInvocation(BaseInvocation): loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] - with LoRAHelper.apply_lora_unet(pipeline.unet, loras): + print("type of control input: ", type(self.control)) + control_data = self.prep_control_data(model=pipeline, context=context, control_input=self.control, + latents_shape=noise.shape, + do_classifier_free_guidance=(self.cfg_scale >= 1.0)) + + with ModelPatcher.apply_lora_unet(pipeline.unet, loras): # TODO: Verify the noise is the right size result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), noise=noise, num_inference_steps=self.steps, conditioning_data=conditioning_data, + control_data=control_data, # list[ControlNetData] callback=step_callback ) @@ -286,7 +374,6 @@ class TextToLatentsInvocation(BaseInvocation): context.services.latents.save(name, result_latents) return build_latents_output(latents_name=name, latents=result_latents) - class LatentsToLatentsInvocation(TextToLatentsInvocation): """Generates latents using latents as base image.""" @@ -294,13 +381,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): # Inputs latents: Optional[LatentsField] = Field(description="The latents to use as a base image") - strength: float = Field(default=0.5, description="The strength of the latents to use") + strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use") # Schema customisation class Config(InvocationConfig): schema_extra = { "ui": { "tags": ["latents"], + "type_hints": { + "model": "model", + "control": "control", + } }, } @@ -315,7 +406,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - #unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) unet_info = context.services.model_manager.get_model( **self.unet.unet.dict(), ) @@ -345,7 +435,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] - with LoRAHelper.apply_lora_unet(pipeline.unet, loras): + with ModelPatcher.apply_lora_unet(pipeline.unet, loras): result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( latents=initial_latents, timesteps=timesteps, @@ -413,7 +503,7 @@ class LatentsToImageInvocation(BaseInvocation): image_dto = context.services.images.create( image=image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -459,6 +549,7 @@ class ResizeLatentsInvocation(BaseInvocation): torch.cuda.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" + # context.services.latents.set(name, resized_latents) context.services.latents.save(name, resized_latents) return build_latents_output(latents_name=name, latents=resized_latents) @@ -489,6 +580,7 @@ class ScaleLatentsInvocation(BaseInvocation): torch.cuda.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" + # context.services.latents.set(name, resized_latents) context.services.latents.save(name, resized_latents) return build_latents_output(latents_name=name, latents=resized_latents) @@ -513,8 +605,11 @@ class ImageToLatentsInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: + # image = context.services.images.get( + # self.image.image_type, self.image.image_name + # ) image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) @@ -543,6 +638,6 @@ class ImageToLatentsInvocation(BaseInvocation): latents = 0.18215 * latents name = f"{context.graph_execution_state_id}__{self.id}" + # context.services.latents.set(name, latents) context.services.latents.save(name, latents) return build_latents_output(latents_name=name, latents=latents) - diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index 2ce58c016b..113b630200 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -34,6 +34,15 @@ class IntOutput(BaseInvocationOutput): # fmt: on +class FloatOutput(BaseInvocationOutput): + """A float output""" + + # fmt: off + type: Literal["float_output"] = "float_output" + param: float = Field(default=None, description="The output float") + # fmt: on + + class AddInvocation(BaseInvocation, MathInvocationConfig): """Adds two numbers""" diff --git a/invokeai/app/invocations/params.py b/invokeai/app/invocations/params.py index fcc7f1737a..1c6297665b 100644 --- a/invokeai/app/invocations/params.py +++ b/invokeai/app/invocations/params.py @@ -3,7 +3,7 @@ from typing import Literal from pydantic import Field from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext -from .math import IntOutput +from .math import IntOutput, FloatOutput # Pass-through parameter nodes - used by subgraphs @@ -16,3 +16,13 @@ class ParamIntInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=self.a) + +class ParamFloatInvocation(BaseInvocation): + """A float parameter""" + #fmt: off + type: Literal["param_float"] = "param_float" + param: float = Field(default=0.0, description="The float value") + #fmt: on + + def invoke(self, context: InvocationContext) -> FloatOutput: + return FloatOutput(param=self.param) diff --git a/invokeai/app/invocations/reconstruct.py b/invokeai/app/invocations/reconstruct.py index 024134cd46..5313411400 100644 --- a/invokeai/app/invocations/reconstruct.py +++ b/invokeai/app/invocations/reconstruct.py @@ -2,7 +2,7 @@ from typing import Literal, Union from pydantic import Field -from invokeai.app.models.image import ImageCategory, ImageField, ImageType +from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .image import ImageOutput @@ -29,7 +29,7 @@ class RestoreFaceInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) results = context.services.restoration.upscale_and_reconstruct( image_list=[[image, 0]], @@ -43,16 +43,17 @@ class RestoreFaceInvocation(BaseInvocation): # TODO: can this return multiple results? image_dto = context.services.images.create( image=results[0][0], - image_type=ImageType.INTERMEDIATE, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 75aeec784f..80e1567047 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -4,7 +4,7 @@ from typing import Literal, Union from pydantic import Field -from invokeai.app.models.image import ImageCategory, ImageField, ImageType +from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .image import ImageOutput @@ -31,7 +31,7 @@ class UpscaleInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) results = context.services.restoration.upscale_and_reconstruct( image_list=[[image, 0]], @@ -45,16 +45,17 @@ class UpscaleInvocation(BaseInvocation): # TODO: can this return multiple results? image_dto = context.services.images.create( image=results[0][0], - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/models/image.py b/invokeai/app/models/image.py index 544951ea34..6d48f2dbb1 100644 --- a/invokeai/app/models/image.py +++ b/invokeai/app/models/image.py @@ -5,31 +5,52 @@ from pydantic import BaseModel, Field from invokeai.app.util.metaenum import MetaEnum -class ImageType(str, Enum, metaclass=MetaEnum): - """The type of an image.""" +class ResourceOrigin(str, Enum, metaclass=MetaEnum): + """The origin of a resource (eg image). - RESULT = "results" - UPLOAD = "uploads" - INTERMEDIATE = "intermediates" + - INTERNAL: The resource was created by the application. + - EXTERNAL: The resource was not created by the application. + This may be a user-initiated upload, or an internal application upload (eg Canvas init image). + """ + + INTERNAL = "internal" + """The resource was created by the application.""" + EXTERNAL = "external" + """The resource was not created by the application. + This may be a user-initiated upload, or an internal application upload (eg Canvas init image). + """ -class InvalidImageTypeException(ValueError): - """Raised when a provided value is not a valid ImageType. +class InvalidOriginException(ValueError): + """Raised when a provided value is not a valid ResourceOrigin. Subclasses `ValueError`. """ - def __init__(self, message="Invalid image type."): + def __init__(self, message="Invalid resource origin."): super().__init__(message) class ImageCategory(str, Enum, metaclass=MetaEnum): - """The category of an image. Use ImageCategory.OTHER for non-default categories.""" + """The category of an image. + + - GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose. + - MASK: The image is a mask image. + - CONTROL: The image is a ControlNet control image. + - USER: The image is a user-provide image. + - OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes. + """ GENERAL = "general" - CONTROL = "control" + """GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.""" MASK = "mask" + """MASK: The image is a mask image.""" + CONTROL = "control" + """CONTROL: The image is a ControlNet control image.""" + USER = "user" + """USER: The image is a user-provide image.""" OTHER = "other" + """OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.""" class InvalidImageCategoryException(ValueError): @@ -45,13 +66,13 @@ class InvalidImageCategoryException(ValueError): class ImageField(BaseModel): """An image field used for passing image objects between invocations""" - image_type: ImageType = Field( - default=ImageType.RESULT, description="The type of the image" + image_origin: ResourceOrigin = Field( + default=ResourceOrigin.INTERNAL, description="The type of the image" ) image_name: Optional[str] = Field(default=None, description="The name of the image") class Config: - schema_extra = {"required": ["image_type", "image_name"]} + schema_extra = {"required": ["image_origin", "image_name"]} class ColorField(BaseModel): @@ -62,3 +83,11 @@ class ColorField(BaseModel): def tuple(self) -> Tuple[int, int, int, int]: return (self.r, self.g, self.b, self.a) + + +class ProgressImage(BaseModel): + """The progress image sent intermittently during processing""" + + width: int = Field(description="The effective width of the image in pixels") + height: int = Field(description="The effective height of the image in pixels") + dataURL: str = Field(description="The image data as a b64 data URL") diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index 7e0b5bdb07..75fc2c74d3 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -1,7 +1,7 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from typing import Any, Optional -from invokeai.app.api.models.images import ProgressImage +from typing import Any +from invokeai.app.models.image import ProgressImage from invokeai.app.util.misc import get_timestamp from invokeai.app.services.model_manager_service import SDModelType, SDModelInfo from invokeai.app.models.exceptions import CanceledException diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 44688ada0a..60e196faa1 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -60,6 +60,35 @@ def get_input_field(node: BaseInvocation, field: str) -> Any: node_input_field = node_inputs.get(field) or None return node_input_field +from typing import Optional, Union, List, get_args + +def is_union_subtype(t1, t2): + t1_args = get_args(t1) + t2_args = get_args(t2) + + if not t1_args: + # t1 is a single type + return t1 in t2_args + else: + # t1 is a Union, check that all of its types are in t2_args + return all(arg in t2_args for arg in t1_args) + +def is_list_or_contains_list(t): + t_args = get_args(t) + + # If the type is a List + if get_origin(t) is list: + return True + + # If the type is a Union + elif t_args: + # Check if any of the types in the Union is a List + for arg in t_args: + if get_origin(arg) is list: + return True + + return False + def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: if not from_type: @@ -85,7 +114,8 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: if to_type in get_args(from_type): return True - if not issubclass(from_type, to_type): + # if not issubclass(from_type, to_type): + if not is_union_subtype(from_type, to_type): return False else: return False @@ -694,7 +724,11 @@ class Graph(BaseModel): input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore # Verify that all outputs are lists - if not all((get_origin(f) == list for f in output_fields)): + # if not all((get_origin(f) == list for f in output_fields)): + # return False + + # Verify that all outputs are lists + if not all(is_list_or_contains_list(f) for f in output_fields): return False # Verify that all outputs match the input type (are a base class or the same class) diff --git a/invokeai/app/services/image_file_storage.py b/invokeai/app/services/image_file_storage.py index 46070b3bf2..68a994ea75 100644 --- a/invokeai/app/services/image_file_storage.py +++ b/invokeai/app/services/image_file_storage.py @@ -9,7 +9,7 @@ from PIL.Image import Image as PILImageType from PIL import Image, PngImagePlugin from send2trash import send2trash -from invokeai.app.models.image import ImageType +from invokeai.app.models.image import ResourceOrigin from invokeai.app.models.metadata import ImageMetadata from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail @@ -40,13 +40,13 @@ class ImageFileStorageBase(ABC): """Low-level service responsible for storing and retrieving image files.""" @abstractmethod - def get(self, image_type: ImageType, image_name: str) -> PILImageType: + def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: """Retrieves an image as PIL Image.""" pass @abstractmethod def get_path( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: """Gets the internal path to an image or thumbnail.""" pass @@ -62,7 +62,7 @@ class ImageFileStorageBase(ABC): def save( self, image: PILImageType, - image_type: ImageType, + image_origin: ResourceOrigin, image_name: str, metadata: Optional[ImageMetadata] = None, thumbnail_size: int = 256, @@ -71,7 +71,7 @@ class ImageFileStorageBase(ABC): pass @abstractmethod - def delete(self, image_type: ImageType, image_name: str) -> None: + def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: """Deletes an image and its thumbnail (if one exists).""" pass @@ -93,17 +93,17 @@ class DiskImageFileStorage(ImageFileStorageBase): Path(output_folder).mkdir(parents=True, exist_ok=True) # TODO: don't hard-code. get/save/delete should maybe take subpath? - for image_type in ImageType: - Path(os.path.join(output_folder, image_type)).mkdir( + for image_origin in ResourceOrigin: + Path(os.path.join(output_folder, image_origin)).mkdir( parents=True, exist_ok=True ) - Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir( + Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir( parents=True, exist_ok=True ) - def get(self, image_type: ImageType, image_name: str) -> PILImageType: + def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: try: - image_path = self.get_path(image_type, image_name) + image_path = self.get_path(image_origin, image_name) cache_item = self.__get_cache(image_path) if cache_item: return cache_item @@ -117,13 +117,13 @@ class DiskImageFileStorage(ImageFileStorageBase): def save( self, image: PILImageType, - image_type: ImageType, + image_origin: ResourceOrigin, image_name: str, metadata: Optional[ImageMetadata] = None, thumbnail_size: int = 256, ) -> None: try: - image_path = self.get_path(image_type, image_name) + image_path = self.get_path(image_origin, image_name) if metadata is not None: pnginfo = PngImagePlugin.PngInfo() @@ -133,7 +133,7 @@ class DiskImageFileStorage(ImageFileStorageBase): image.save(image_path, "PNG") thumbnail_name = get_thumbnail_name(image_name) - thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True) + thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True) thumbnail_image = make_thumbnail(image, thumbnail_size) thumbnail_image.save(thumbnail_path) @@ -142,10 +142,10 @@ class DiskImageFileStorage(ImageFileStorageBase): except Exception as e: raise ImageFileSaveException from e - def delete(self, image_type: ImageType, image_name: str) -> None: + def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: try: basename = os.path.basename(image_name) - image_path = self.get_path(image_type, basename) + image_path = self.get_path(image_origin, basename) if os.path.exists(image_path): send2trash(image_path) @@ -153,7 +153,7 @@ class DiskImageFileStorage(ImageFileStorageBase): del self.__cache[image_path] thumbnail_name = get_thumbnail_name(image_name) - thumbnail_path = self.get_path(image_type, thumbnail_name, True) + thumbnail_path = self.get_path(image_origin, thumbnail_name, True) if os.path.exists(thumbnail_path): send2trash(thumbnail_path) @@ -164,7 +164,7 @@ class DiskImageFileStorage(ImageFileStorageBase): # TODO: make this a bit more flexible for e.g. cloud storage def get_path( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: # strip out any relative path shenanigans basename = os.path.basename(image_name) @@ -172,10 +172,10 @@ class DiskImageFileStorage(ImageFileStorageBase): if thumbnail: thumbnail_name = get_thumbnail_name(basename) path = os.path.join( - self.__output_folder, image_type, "thumbnails", thumbnail_name + self.__output_folder, image_origin, "thumbnails", thumbnail_name ) else: - path = os.path.join(self.__output_folder, image_type, basename) + path = os.path.join(self.__output_folder, image_origin, basename) abspath = os.path.abspath(path) diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 4e1f73978b..6907ac3952 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -1,20 +1,35 @@ from abc import ABC, abstractmethod from datetime import datetime -from typing import Optional, cast +from typing import Generic, Optional, TypeVar, cast import sqlite3 import threading from typing import Optional, Union +from pydantic import BaseModel, Field +from pydantic.generics import GenericModel + from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.image import ( ImageCategory, - ImageType, + ResourceOrigin, ) from invokeai.app.services.models.image_record import ( ImageRecord, + ImageRecordChanges, deserialize_image_record, ) -from invokeai.app.services.item_storage import PaginatedResults + +T = TypeVar("T", bound=BaseModel) + +class OffsetPaginatedResults(GenericModel, Generic[T]): + """Offset-paginated results""" + + # fmt: off + items: list[T] = Field(description="Items") + offset: int = Field(description="Offset from which to retrieve items") + limit: int = Field(description="Limit of items to get") + total: int = Field(description="Total number of items in result") + # fmt: on # TODO: Should these excpetions subclass existing python exceptions? @@ -45,25 +60,36 @@ class ImageRecordStorageBase(ABC): # TODO: Implement an `update()` method @abstractmethod - def get(self, image_type: ImageType, image_name: str) -> ImageRecord: + def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: """Gets an image record.""" pass + @abstractmethod + def update( + self, + image_name: str, + image_origin: ResourceOrigin, + changes: ImageRecordChanges, + ) -> None: + """Updates an image record.""" + pass + @abstractmethod def get_many( self, - image_type: ImageType, - image_category: ImageCategory, - page: int = 0, - per_page: int = 10, - ) -> PaginatedResults[ImageRecord]: + offset: int = 0, + limit: int = 10, + image_origin: Optional[ResourceOrigin] = None, + categories: Optional[list[ImageCategory]] = None, + is_intermediate: Optional[bool] = None, + ) -> OffsetPaginatedResults[ImageRecord]: """Gets a page of image records.""" pass # TODO: The database has a nullable `deleted_at` column, currently unused. # Should we implement soft deletes? Would need coordination with ImageFileStorage. @abstractmethod - def delete(self, image_type: ImageType, image_name: str) -> None: + def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: """Deletes an image record.""" pass @@ -71,13 +97,14 @@ class ImageRecordStorageBase(ABC): def save( self, image_name: str, - image_type: ImageType, + image_origin: ResourceOrigin, image_category: ImageCategory, width: int, height: int, session_id: Optional[str], node_id: Optional[str], metadata: Optional[ImageMetadata], + is_intermediate: bool = False, ) -> datetime: """Saves an image record.""" pass @@ -91,7 +118,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): def __init__(self, filename: str) -> None: super().__init__() - self._filename = filename self._conn = sqlite3.connect(filename, check_same_thread=False) # Enable row factory to get rows as dictionaries (must be done before making the cursor!) @@ -117,7 +143,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): CREATE TABLE IF NOT EXISTS images ( image_name TEXT NOT NULL PRIMARY KEY, -- This is an enum in python, unrestricted string here for flexibility - image_type TEXT NOT NULL, + image_origin TEXT NOT NULL, -- This is an enum in python, unrestricted string here for flexibility image_category TEXT NOT NULL, width INTEGER NOT NULL, @@ -125,9 +151,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): session_id TEXT, node_id TEXT, metadata TEXT, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + is_intermediate BOOLEAN DEFAULT FALSE, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- Updated via trigger - updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- Soft delete, currently unused deleted_at DATETIME ); @@ -142,7 +169,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): ) self._cursor.execute( """--sql - CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type); + CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin); """ ) self._cursor.execute( @@ -169,7 +196,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """ ) - def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]: + def get( + self, image_origin: ResourceOrigin, image_name: str + ) -> Union[ImageRecord, None]: try: self._lock.acquire() @@ -193,38 +222,110 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): return deserialize_image_record(dict(result)) + def update( + self, + image_name: str, + image_origin: ResourceOrigin, + changes: ImageRecordChanges, + ) -> None: + try: + self._lock.acquire() + # Change the category of the image + if changes.image_category is not None: + self._cursor.execute( + f"""--sql + UPDATE images + SET image_category = ? + WHERE image_name = ?; + """, + (changes.image_category, image_name), + ) + + # Change the session associated with the image + if changes.session_id is not None: + self._cursor.execute( + f"""--sql + UPDATE images + SET session_id = ? + WHERE image_name = ?; + """, + (changes.session_id, image_name), + ) + + # Change the image's `is_intermediate`` flag + if changes.is_intermediate is not None: + self._cursor.execute( + f"""--sql + UPDATE images + SET is_intermediate = ? + WHERE image_name = ?; + """, + (changes.is_intermediate, image_name), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise ImageRecordSaveException from e + finally: + self._lock.release() + def get_many( self, - image_type: ImageType, - image_category: ImageCategory, - page: int = 0, - per_page: int = 10, - ) -> PaginatedResults[ImageRecord]: + offset: int = 0, + limit: int = 10, + image_origin: Optional[ResourceOrigin] = None, + categories: Optional[list[ImageCategory]] = None, + is_intermediate: Optional[bool] = None, + ) -> OffsetPaginatedResults[ImageRecord]: try: self._lock.acquire() - self._cursor.execute( - f"""--sql - SELECT * FROM images - WHERE image_type = ? AND image_category = ? - ORDER BY created_at DESC - LIMIT ? OFFSET ?; - """, - (image_type.value, image_category.value, per_page, page * per_page), - ) + # Manually build two queries - one for the count, one for the records + count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n""" + images_query = f"""SELECT * FROM images WHERE 1=1\n""" + + query_conditions = "" + query_params = [] + + if image_origin is not None: + query_conditions += f"""AND image_origin = ?\n""" + query_params.append(image_origin.value) + + if categories is not None: + ## Convert the enum values to unique list of strings + category_strings = list( + map(lambda c: c.value, set(categories)) + ) + # Create the correct length of placeholders + placeholders = ",".join("?" * len(category_strings)) + query_conditions += f"AND image_category IN ( {placeholders} )\n" + + # Unpack the included categories into the query params + for c in category_strings: + query_params.append(c) + + if is_intermediate is not None: + query_conditions += f"""AND is_intermediate = ?\n""" + query_params.append(is_intermediate) + + query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n""" + + # Final images query with pagination + images_query += query_conditions + query_pagination + ";" + # Add all the parameters + images_params = query_params.copy() + images_params.append(limit) + images_params.append(offset) + # Build the list of images, deserializing each row + self._cursor.execute(images_query, images_params) result = cast(list[sqlite3.Row], self._cursor.fetchall()) - images = list(map(lambda r: deserialize_image_record(dict(r)), result)) - self._cursor.execute( - """--sql - SELECT count(*) FROM images - WHERE image_type = ? AND image_category = ? - """, - (image_type.value, image_category.value), - ) - + # Set up and execute the count query, without pagination + count_query += query_conditions + ";" + count_params = query_params.copy() + self._cursor.execute(count_query, count_params) count = self._cursor.fetchone()[0] except sqlite3.Error as e: self._conn.rollback() @@ -232,13 +333,11 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): finally: self._lock.release() - pageCount = int(count / per_page) + 1 - - return PaginatedResults( - items=images, page=page, pages=pageCount, per_page=per_page, total=count + return OffsetPaginatedResults( + items=images, offset=offset, limit=limit, total=count ) - def delete(self, image_type: ImageType, image_name: str) -> None: + def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: try: self._lock.acquire() self._cursor.execute( @@ -258,13 +357,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): def save( self, image_name: str, - image_type: ImageType, + image_origin: ResourceOrigin, image_category: ImageCategory, session_id: Optional[str], width: int, height: int, node_id: Optional[str], metadata: Optional[ImageMetadata], + is_intermediate: bool = False, ) -> datetime: try: metadata_json = ( @@ -275,25 +375,27 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """--sql INSERT OR IGNORE INTO images ( image_name, - image_type, + image_origin, image_category, width, height, node_id, session_id, - metadata + metadata, + is_intermediate ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?); + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); """, ( image_name, - image_type.value, + image_origin.value, image_category.value, width, height, node_id, session_id, metadata_json, + is_intermediate, ), ) self._conn.commit() diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 914dd3b6d3..2618a9763e 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -1,14 +1,13 @@ from abc import ABC, abstractmethod from logging import Logger from typing import Optional, TYPE_CHECKING, Union -import uuid from PIL.Image import Image as PILImageType from invokeai.app.models.image import ( ImageCategory, - ImageType, + ResourceOrigin, InvalidImageCategoryException, - InvalidImageTypeException, + InvalidOriginException, ) from invokeai.app.models.metadata import ImageMetadata from invokeai.app.services.image_record_storage import ( @@ -16,10 +15,12 @@ from invokeai.app.services.image_record_storage import ( ImageRecordNotFoundException, ImageRecordSaveException, ImageRecordStorageBase, + OffsetPaginatedResults, ) from invokeai.app.services.models.image_record import ( ImageRecord, ImageDTO, + ImageRecordChanges, image_record_to_dto, ) from invokeai.app.services.image_file_storage import ( @@ -30,8 +31,8 @@ from invokeai.app.services.image_file_storage import ( ) from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults from invokeai.app.services.metadata import MetadataServiceBase +from invokeai.app.services.resource_name import NameServiceBase from invokeai.app.services.urls import UrlServiceBase -from invokeai.app.util.misc import get_iso_timestamp if TYPE_CHECKING: from invokeai.app.services.graph import GraphExecutionState @@ -44,32 +45,42 @@ class ImageServiceABC(ABC): def create( self, image: PILImageType, - image_type: ImageType, + image_origin: ResourceOrigin, image_category: ImageCategory, node_id: Optional[str] = None, session_id: Optional[str] = None, - metadata: Optional[ImageMetadata] = None, + intermediate: bool = False, ) -> ImageDTO: """Creates an image, storing the file and its metadata.""" pass @abstractmethod - def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: + def update( + self, + image_origin: ResourceOrigin, + image_name: str, + changes: ImageRecordChanges, + ) -> ImageDTO: + """Updates an image.""" + pass + + @abstractmethod + def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: """Gets an image as a PIL image.""" pass @abstractmethod - def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: + def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: """Gets an image record.""" pass @abstractmethod - def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO: + def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO: """Gets an image DTO.""" pass @abstractmethod - def get_path(self, image_type: ImageType, image_name: str) -> str: + def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str: """Gets an image's path.""" pass @@ -80,7 +91,7 @@ class ImageServiceABC(ABC): @abstractmethod def get_url( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: """Gets an image's or thumbnail's URL.""" pass @@ -88,16 +99,17 @@ class ImageServiceABC(ABC): @abstractmethod def get_many( self, - image_type: ImageType, - image_category: ImageCategory, - page: int = 0, - per_page: int = 10, - ) -> PaginatedResults[ImageDTO]: + offset: int = 0, + limit: int = 10, + image_origin: Optional[ResourceOrigin] = None, + categories: Optional[list[ImageCategory]] = None, + is_intermediate: Optional[bool] = None, + ) -> OffsetPaginatedResults[ImageDTO]: """Gets a paginated list of image DTOs.""" pass @abstractmethod - def delete(self, image_type: ImageType, image_name: str): + def delete(self, image_origin: ResourceOrigin, image_name: str): """Deletes an image.""" pass @@ -110,6 +122,7 @@ class ImageServiceDependencies: metadata: MetadataServiceBase urls: UrlServiceBase logger: Logger + names: NameServiceBase graph_execution_manager: ItemStorageABC["GraphExecutionState"] def __init__( @@ -119,6 +132,7 @@ class ImageServiceDependencies: metadata: MetadataServiceBase, url: UrlServiceBase, logger: Logger, + names: NameServiceBase, graph_execution_manager: ItemStorageABC["GraphExecutionState"], ): self.records = image_record_storage @@ -126,6 +140,7 @@ class ImageServiceDependencies: self.metadata = metadata self.urls = url self.logger = logger + self.names = names self.graph_execution_manager = graph_execution_manager @@ -139,6 +154,7 @@ class ImageService(ImageServiceABC): metadata: MetadataServiceBase, url: UrlServiceBase, logger: Logger, + names: NameServiceBase, graph_execution_manager: ItemStorageABC["GraphExecutionState"], ): self._services = ImageServiceDependencies( @@ -147,29 +163,26 @@ class ImageService(ImageServiceABC): metadata=metadata, url=url, logger=logger, + names=names, graph_execution_manager=graph_execution_manager, ) def create( self, image: PILImageType, - image_type: ImageType, + image_origin: ResourceOrigin, image_category: ImageCategory, node_id: Optional[str] = None, session_id: Optional[str] = None, + is_intermediate: bool = False, ) -> ImageDTO: - if image_type not in ImageType: - raise InvalidImageTypeException + if image_origin not in ResourceOrigin: + raise InvalidOriginException if image_category not in ImageCategory: raise InvalidImageCategoryException - image_name = self._create_image_name( - image_type=image_type, - image_category=image_category, - node_id=node_id, - session_id=session_id, - ) + image_name = self._services.names.create_image_name() metadata = self._get_metadata(session_id, node_id) @@ -180,10 +193,12 @@ class ImageService(ImageServiceABC): created_at = self._services.records.save( # Non-nullable fields image_name=image_name, - image_type=image_type, + image_origin=image_origin, image_category=image_category, width=width, height=height, + # Meta fields + is_intermediate=is_intermediate, # Nullable fields node_id=node_id, session_id=session_id, @@ -191,21 +206,21 @@ class ImageService(ImageServiceABC): ) self._services.files.save( - image_type=image_type, + image_origin=image_origin, image_name=image_name, image=image, metadata=metadata, ) - image_url = self._services.urls.get_image_url(image_type, image_name) + image_url = self._services.urls.get_image_url(image_origin, image_name) thumbnail_url = self._services.urls.get_image_url( - image_type, image_name, True + image_origin, image_name, True ) return ImageDTO( # Non-nullable fields image_name=image_name, - image_type=image_type, + image_origin=image_origin, image_category=image_category, width=width, height=height, @@ -217,6 +232,7 @@ class ImageService(ImageServiceABC): created_at=created_at, updated_at=created_at, # this is always the same as the created_at at this time deleted_at=None, + is_intermediate=is_intermediate, # Extra non-nullable fields for DTO image_url=image_url, thumbnail_url=thumbnail_url, @@ -231,9 +247,25 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem saving image record and file") raise e - def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: + def update( + self, + image_origin: ResourceOrigin, + image_name: str, + changes: ImageRecordChanges, + ) -> ImageDTO: try: - return self._services.files.get(image_type, image_name) + self._services.records.update(image_name, image_origin, changes) + return self.get_dto(image_origin, image_name) + except ImageRecordSaveException: + self._services.logger.error("Failed to update image record") + raise + except Exception as e: + self._services.logger.error("Problem updating image record") + raise e + + def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: + try: + return self._services.files.get(image_origin, image_name) except ImageFileNotFoundException: self._services.logger.error("Failed to get image file") raise @@ -241,9 +273,9 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem getting image file") raise e - def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: + def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: try: - return self._services.records.get(image_type, image_name) + return self._services.records.get(image_origin, image_name) except ImageRecordNotFoundException: self._services.logger.error("Image record not found") raise @@ -251,14 +283,14 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem getting image record") raise e - def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO: + def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO: try: - image_record = self._services.records.get(image_type, image_name) + image_record = self._services.records.get(image_origin, image_name) image_dto = image_record_to_dto( image_record, - self._services.urls.get_image_url(image_type, image_name), - self._services.urls.get_image_url(image_type, image_name, True), + self._services.urls.get_image_url(image_origin, image_name), + self._services.urls.get_image_url(image_origin, image_name, True), ) return image_dto @@ -270,10 +302,10 @@ class ImageService(ImageServiceABC): raise e def get_path( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: try: - return self._services.files.get_path(image_type, image_name, thumbnail) + return self._services.files.get_path(image_origin, image_name, thumbnail) except Exception as e: self._services.logger.error("Problem getting image path") raise e @@ -286,57 +318,58 @@ class ImageService(ImageServiceABC): raise e def get_url( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: try: - return self._services.urls.get_image_url(image_type, image_name, thumbnail) + return self._services.urls.get_image_url(image_origin, image_name, thumbnail) except Exception as e: self._services.logger.error("Problem getting image path") raise e def get_many( self, - image_type: ImageType, - image_category: ImageCategory, - page: int = 0, - per_page: int = 10, - ) -> PaginatedResults[ImageDTO]: + offset: int = 0, + limit: int = 10, + image_origin: Optional[ResourceOrigin] = None, + categories: Optional[list[ImageCategory]] = None, + is_intermediate: Optional[bool] = None, + ) -> OffsetPaginatedResults[ImageDTO]: try: results = self._services.records.get_many( - image_type, - image_category, - page, - per_page, + offset, + limit, + image_origin, + categories, + is_intermediate, ) image_dtos = list( map( lambda r: image_record_to_dto( r, - self._services.urls.get_image_url(image_type, r.image_name), + self._services.urls.get_image_url(r.image_origin, r.image_name), self._services.urls.get_image_url( - image_type, r.image_name, True + r.image_origin, r.image_name, True ), ), results.items, ) ) - return PaginatedResults[ImageDTO]( + return OffsetPaginatedResults[ImageDTO]( items=image_dtos, - page=results.page, - pages=results.pages, - per_page=results.per_page, + offset=results.offset, + limit=results.limit, total=results.total, ) except Exception as e: self._services.logger.error("Problem getting paginated image DTOs") raise e - def delete(self, image_type: ImageType, image_name: str): + def delete(self, image_origin: ResourceOrigin, image_name: str): try: - self._services.files.delete(image_type, image_name) - self._services.records.delete(image_type, image_name) + self._services.files.delete(image_origin, image_name) + self._services.records.delete(image_origin, image_name) except ImageRecordDeleteException: self._services.logger.error(f"Failed to delete image record") raise @@ -347,21 +380,6 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem deleting image record and file") raise e - def _create_image_name( - self, - image_type: ImageType, - image_category: ImageCategory, - node_id: Optional[str] = None, - session_id: Optional[str] = None, - ) -> str: - """Create a unique image name.""" - uuid_str = str(uuid.uuid4()) - - if node_id is not None and session_id is not None: - return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png" - - return f"{image_type.value}_{image_category.value}_{uuid_str}.png" - def _get_metadata( self, session_id: Optional[str] = None, node_id: Optional[str] = None ) -> Union[ImageMetadata, None]: diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index c1155ff73e..051236b12b 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -1,7 +1,7 @@ import datetime from typing import Optional, Union -from pydantic import BaseModel, Field -from invokeai.app.models.image import ImageCategory, ImageType +from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr +from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.metadata import ImageMetadata from invokeai.app.util.misc import get_iso_timestamp @@ -11,8 +11,8 @@ class ImageRecord(BaseModel): image_name: str = Field(description="The unique name of the image.") """The unique name of the image.""" - image_type: ImageType = Field(description="The type of the image.") - """The type of the image.""" + image_origin: ResourceOrigin = Field(description="The type of the image.") + """The origin of the image.""" image_category: ImageCategory = Field(description="The category of the image.") """The category of the image.""" width: int = Field(description="The width of the image in px.") @@ -31,6 +31,8 @@ class ImageRecord(BaseModel): description="The deleted timestamp of the image." ) """The deleted timestamp of the image.""" + is_intermediate: bool = Field(description="Whether this is an intermediate image.") + """Whether this is an intermediate image.""" session_id: Optional[str] = Field( default=None, description="The session ID that generated this image, if it is a generated image.", @@ -48,13 +50,37 @@ class ImageRecord(BaseModel): """A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.""" +class ImageRecordChanges(BaseModel, extra=Extra.forbid): + """A set of changes to apply to an image record. + + Only limited changes are valid: + - `image_category`: change the category of an image + - `session_id`: change the session associated with an image + - `is_intermediate`: change the image's `is_intermediate` flag + """ + + image_category: Optional[ImageCategory] = Field( + description="The image's new category." + ) + """The image's new category.""" + session_id: Optional[StrictStr] = Field( + default=None, + description="The image's new session ID.", + ) + """The image's new session ID.""" + is_intermediate: Optional[StrictBool] = Field( + default=None, description="The image's new `is_intermediate` flag." + ) + """The image's new `is_intermediate` flag.""" + + class ImageUrlsDTO(BaseModel): """The URLs for an image and its thumbnail.""" image_name: str = Field(description="The unique name of the image.") """The unique name of the image.""" - image_type: ImageType = Field(description="The type of the image.") - """The type of the image.""" + image_origin: ResourceOrigin = Field(description="The type of the image.") + """The origin of the image.""" image_url: str = Field(description="The URL of the image.") """The URL of the image.""" thumbnail_url: str = Field(description="The URL of the image's thumbnail.") @@ -84,7 +110,9 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: # Retrieve all the values, setting "reasonable" defaults if they are not present. image_name = image_dict.get("image_name", "unknown") - image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value)) + image_origin = ResourceOrigin( + image_dict.get("image_origin", ResourceOrigin.INTERNAL.value) + ) image_category = ImageCategory( image_dict.get("image_category", ImageCategory.GENERAL.value) ) @@ -95,6 +123,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: created_at = image_dict.get("created_at", get_iso_timestamp()) updated_at = image_dict.get("updated_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) + is_intermediate = image_dict.get("is_intermediate", False) raw_metadata = image_dict.get("metadata") @@ -105,7 +134,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: return ImageRecord( image_name=image_name, - image_type=image_type, + image_origin=image_origin, image_category=image_category, width=width, height=height, @@ -115,4 +144,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: created_at=created_at, updated_at=updated_at, deleted_at=deleted_at, + is_intermediate=is_intermediate, ) diff --git a/invokeai/app/services/resource_name.py b/invokeai/app/services/resource_name.py new file mode 100644 index 0000000000..dd5a76cfc0 --- /dev/null +++ b/invokeai/app/services/resource_name.py @@ -0,0 +1,30 @@ +from abc import ABC, abstractmethod +from enum import Enum, EnumMeta +import uuid + + +class ResourceType(str, Enum, metaclass=EnumMeta): + """Enum for resource types.""" + + IMAGE = "image" + LATENT = "latent" + + +class NameServiceBase(ABC): + """Low-level service responsible for naming resources (images, latents, etc).""" + + # TODO: Add customizable naming schemes + @abstractmethod + def create_image_name(self) -> str: + """Creates a name for an image.""" + pass + + +class SimpleNameService(NameServiceBase): + """Creates image names from UUIDs.""" + + # TODO: Add customizable naming schemes + def create_image_name(self) -> str: + uuid_str = str(uuid.uuid4()) + filename = f"{uuid_str}.png" + return filename diff --git a/invokeai/app/services/urls.py b/invokeai/app/services/urls.py index 2716da60ad..4c8354c899 100644 --- a/invokeai/app/services/urls.py +++ b/invokeai/app/services/urls.py @@ -1,7 +1,7 @@ import os from abc import ABC, abstractmethod -from invokeai.app.models.image import ImageType +from invokeai.app.models.image import ResourceOrigin from invokeai.app.util.thumbnails import get_thumbnail_name @@ -10,7 +10,7 @@ class UrlServiceBase(ABC): @abstractmethod def get_image_url( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: """Gets the URL for an image or thumbnail.""" pass @@ -21,14 +21,14 @@ class LocalUrlService(UrlServiceBase): self._base_url = base_url def get_image_url( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: image_basename = os.path.basename(image_name) # These paths are determined by the routes in invokeai/app/api/routers/images.py if thumbnail: return ( - f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail" + f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail" ) - return f"{self._base_url}/images/{image_type.value}/{image_basename}" + return f"{self._base_url}/images/{image_origin.value}/{image_basename}" diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 963e770406..b4b9a25909 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -1,5 +1,5 @@ -from invokeai.app.api.models.images import ProgressImage from invokeai.app.models.exceptions import CanceledException +from invokeai.app.models.image import ProgressImage from ..invocations.baseinvocation import InvocationContext from ...backend.util.util import image_to_dataURL from ...backend.generator.base import Generator diff --git a/invokeai/backend/generator/base.py b/invokeai/backend/generator/base.py index bf976e7290..fb293ab5b2 100644 --- a/invokeai/backend/generator/base.py +++ b/invokeai/backend/generator/base.py @@ -75,9 +75,11 @@ class InvokeAIGenerator(metaclass=ABCMeta): def __init__(self, model_info: dict, params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(), + **kwargs, ): self.model_info=model_info self.params=params + self.kwargs = kwargs def generate(self, prompt: str='', @@ -120,7 +122,7 @@ class InvokeAIGenerator(metaclass=ABCMeta): ) uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model) gen_class = self._generator_class() - generator = gen_class(model, self.params.precision) + generator = gen_class(model, self.params.precision, **self.kwargs) if self.params.variation_amount > 0: generator.set_variation(generator_args.get('seed'), generator_args.get('variation_amount'), @@ -275,7 +277,7 @@ class Generator: precision: str model: DiffusionPipeline - def __init__(self, model: DiffusionPipeline, precision: str): + def __init__(self, model: DiffusionPipeline, precision: str, **kwargs): self.model = model self.precision = precision self.seed = None diff --git a/invokeai/backend/generator/txt2img.py b/invokeai/backend/generator/txt2img.py index e5a96212f0..189fe65710 100644 --- a/invokeai/backend/generator/txt2img.py +++ b/invokeai/backend/generator/txt2img.py @@ -4,6 +4,10 @@ invokeai.backend.generator.txt2img inherits from invokeai.backend.generator import PIL.Image import torch +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from diffusers.models.controlnet import ControlNetModel, ControlNetOutput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel + from ..stable_diffusion import ( ConditioningData, PostprocessingSettings, @@ -13,8 +17,13 @@ from .base import Generator class Txt2Img(Generator): - def __init__(self, model, precision): - super().__init__(model, precision) + def __init__(self, model, precision, + control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None, + **kwargs): + self.control_model = control_model + if isinstance(self.control_model, list): + self.control_model = MultiControlNetModel(self.control_model) + super().__init__(model, precision, **kwargs) @torch.no_grad() def get_make_image( @@ -42,9 +51,12 @@ class Txt2Img(Generator): kwargs are 'width' and 'height' """ self.perlin = perlin + control_image = kwargs.get("control_image", None) + do_classifier_free_guidance = cfg_scale > 1.0 # noinspection PyTypeChecker pipeline: StableDiffusionGeneratorPipeline = self.model + pipeline.control_model = self.control_model pipeline.scheduler = sampler uc, c, extra_conditioning_info = conditioning @@ -61,6 +73,37 @@ class Txt2Img(Generator): ), ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta) + # FIXME: still need to test with different widths, heights, devices, dtypes + # and add in batch_size, num_images_per_prompt? + if control_image is not None: + if isinstance(self.control_model, ControlNetModel): + control_image = pipeline.prepare_control_image( + image=control_image, + do_classifier_free_guidance=do_classifier_free_guidance, + width=width, + height=height, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + device=self.control_model.device, + dtype=self.control_model.dtype, + ) + elif isinstance(self.control_model, MultiControlNetModel): + images = [] + for image_ in control_image: + image_ = pipeline.prepare_control_image( + image=image_, + do_classifier_free_guidance=do_classifier_free_guidance, + width=width, + height=height, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + device=self.control_model.device, + dtype=self.control_model.dtype, + ) + images.append(image_) + control_image = images + kwargs["control_image"] = control_image + def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image: pipeline_output = pipeline.image_from_embeddings( latents=torch.zeros_like(x_T, dtype=self.torch_dtype()), @@ -68,6 +111,7 @@ class Txt2Img(Generator): num_inference_steps=steps, conditioning_data=conditioning_data, callback=step_callback, + **kwargs, ) if ( diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 4ca2a5cb30..c07b41e6c1 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -2,23 +2,29 @@ from __future__ import annotations import dataclasses import inspect +import math import secrets from collections.abc import Sequence from dataclasses import dataclass, field from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union +from pydantic import BaseModel, Field import einops import PIL.Image +import numpy as np from accelerate.utils import set_seed import psutil import torch import torchvision.transforms as T from compel import EmbeddingsProvider from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.controlnet import ControlNetModel, ControlNetOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( StableDiffusionPipeline, ) +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel + from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import ( StableDiffusionImg2ImgPipeline, ) @@ -27,6 +33,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import ( ) from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput +from diffusers.utils import PIL_INTERPOLATION from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.outputs import BaseOutput from torchvision.transforms.functional import resize as tv_resize @@ -68,10 +75,10 @@ class AddsMaskLatents: initial_image_latents: torch.Tensor def __call__( - self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor + self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor, **kwargs, ) -> torch.Tensor: model_input = self.add_mask_channels(latents) - return self.forward(model_input, t, text_embeddings) + return self.forward(model_input, t, text_embeddings, **kwargs) def add_mask_channels(self, latents): batch_size = latents.size(0) @@ -207,6 +214,13 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]): raise AssertionError("why was that an empty generator?") return result +@dataclass +class ControlNetData: + model: ControlNetModel = Field(default=None) + image_tensor: torch.Tensor= Field(default=None) + weight: float = Field(default=1.0) + begin_step_percent: float = Field(default=0.0) + end_step_percent: float = Field(default=1.0) @dataclass(frozen=True) class ConditioningData: @@ -302,6 +316,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): feature_extractor: Optional[CLIPFeatureExtractor], requires_safety_checker: bool = False, precision: str = "float32", + control_model: ControlNetModel = None, ): super().__init__( vae, @@ -322,6 +337,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + # FIXME: can't currently register control module + # control_model=control_model, ) self.invokeai_diffuser = InvokeAIDiffuserComponent( self.unet, self._unet_forward, is_running_diffusers=True @@ -341,6 +358,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): self._model_group = FullyLoadedModelGroup(self.unet.device) self._model_group.install(*self._submodels) + self.control_model = control_model def _adjust_memory_efficient_attention(self, latents: torch.Tensor): """ @@ -463,6 +481,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise: torch.Tensor, callback: Callable[[PipelineIntermediateState], None] = None, run_id=None, + **kwargs, ) -> InvokeAIStableDiffusionPipelineOutput: r""" Function invoked when calling the pipeline for generation. @@ -483,6 +502,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise=noise, run_id=run_id, callback=callback, + **kwargs, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -507,6 +527,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance: List[Callable] = None, run_id=None, callback: Callable[[PipelineIntermediateState], None] = None, + control_data: List[ControlNetData] = None, + **kwargs, ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: if self.scheduler.config.get("cpu_only", False): scheduler_device = torch.device('cpu') @@ -527,6 +549,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance=additional_guidance, run_id=run_id, callback=callback, + control_data=control_data, + **kwargs, ) return result.latents, result.attention_map_saver @@ -539,6 +563,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise: torch.Tensor, run_id: str = None, additional_guidance: List[Callable] = None, + control_data: List[ControlNetData] = None, + **kwargs, ): self._adjust_memory_efficient_attention(latents) if run_id is None: @@ -568,7 +594,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): latents = self.scheduler.add_noise(latents, noise, batched_t) attention_map_saver: Optional[AttentionMapSaver] = None - + # print("timesteps:", timesteps) for i, t in enumerate(self.progress_bar(timesteps)): batched_t.fill_(t) step_output = self.step( @@ -578,6 +604,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): step_index=i, total_step_count=len(timesteps), additional_guidance=additional_guidance, + control_data=control_data, + **kwargs, ) latents = step_output.prev_sample @@ -618,10 +646,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): step_index: int, total_step_count: int, additional_guidance: List[Callable] = None, + control_data: List[ControlNetData] = None, + **kwargs, ): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] - if additional_guidance is None: additional_guidance = [] @@ -629,6 +658,48 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # i.e. before or after passing it to InvokeAIDiffuserComponent latent_model_input = self.scheduler.scale_model_input(latents, timestep) + # default is no controlnet, so set controlnet processing output to None + down_block_res_samples, mid_block_res_sample = None, None + + if control_data is not None: + if conditioning_data.guidance_scale > 1.0: + # expand the latents input to control model if doing classifier free guidance + # (which I think for now is always true, there is conditional elsewhere that stops execution if + # classifier_free_guidance is <= 1.0 ?) + latent_control_input = torch.cat([latent_model_input] * 2) + else: + latent_control_input = latent_model_input + # control_data should be type List[ControlNetData] + # this loop covers both ControlNet (one ControlNetData in list) + # and MultiControlNet (multiple ControlNetData in list) + for i, control_datum in enumerate(control_data): + # print("controlnet", i, "==>", type(control_datum)) + first_control_step = math.floor(control_datum.begin_step_percent * total_step_count) + last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) + # only apply controlnet if current step is within the controlnet's begin/end step range + if step_index >= first_control_step and step_index <= last_control_step: + # print("running controlnet", i, "for step", step_index) + down_samples, mid_sample = control_datum.model( + sample=latent_control_input, + timestep=timestep, + encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, + conditioning_data.text_embeddings]), + controlnet_cond=control_datum.image_tensor, + conditioning_scale=control_datum.weight, + # cross_attention_kwargs, + guess_mode=False, + return_dict=False, + ) + if down_block_res_samples is None and mid_block_res_sample is None: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + # add controlnet outputs together if have multiple controlnets + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + # predict the noise residual noise_pred = self.invokeai_diffuser.do_diffusion_step( latent_model_input, @@ -638,6 +709,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): conditioning_data.guidance_scale, step_index=step_index, total_step_count=total_step_count, + down_block_additional_residuals=down_block_res_samples, # from controlnet(s) + mid_block_additional_residual=mid_block_res_sample, # from controlnet(s) ) # compute the previous noisy sample x_t -> x_t-1 @@ -659,6 +732,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): t, text_embeddings, cross_attention_kwargs: Optional[dict[str, Any]] = None, + **kwargs, ): """predict the noise residual""" if is_inpainting_model(self.unet) and latents.size(1) == 4: @@ -678,7 +752,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # First three args should be positional, not keywords, so torch hooks can see them. return self.unet( - latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs + latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs, + **kwargs, ).sample def img2img_from_embeddings( @@ -728,7 +803,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): noise: torch.Tensor, run_id=None, callback=None, - ) -> InvokeAIStableDiffusionPipelineOutput: + ) -> InvokeAIStableDiffusionPipelineOutput: timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength) result_latents, result_attention_maps = self.latents_from_embeddings( latents=initial_latents if strength < 1.0 else torch.zeros_like( @@ -940,3 +1015,51 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): debug_image( img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True ) + + # Copied from diffusers pipeline_stable_diffusion_controlnet.py + # Returns torch.Tensor of shape (batch_size, 3, height, width) + @staticmethod + def prepare_control_image( + image, + # FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions? + # latents, + width=512, # should be 8 * latent.shape[3] + height=512, # should be 8 * latent height[2] + batch_size=1, + num_images_per_prompt=1, + device="cuda", + dtype=torch.float16, + do_classifier_free_guidance=True, + ): + + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + image = images + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + return image diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 4131837b41..d05565c506 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -181,6 +181,7 @@ class InvokeAIDiffuserComponent: unconditional_guidance_scale: float, step_index: Optional[int] = None, total_step_count: Optional[int] = None, + **kwargs, ): """ :param x: current latents @@ -209,7 +210,7 @@ class InvokeAIDiffuserComponent: if wants_hybrid_conditioning: unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning( - x, sigma, unconditioning, conditioning + x, sigma, unconditioning, conditioning, **kwargs, ) elif wants_cross_attention_control: ( @@ -221,13 +222,14 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ) elif self.sequential_guidance: ( unconditioned_next_x, conditioned_next_x, ) = self._apply_standard_conditioning_sequentially( - x, sigma, unconditioning, conditioning + x, sigma, unconditioning, conditioning, **kwargs, ) else: @@ -235,7 +237,7 @@ class InvokeAIDiffuserComponent: unconditioned_next_x, conditioned_next_x, ) = self._apply_standard_conditioning( - x, sigma, unconditioning, conditioning + x, sigma, unconditioning, conditioning, **kwargs, ) combined_next_x = self._combine( @@ -282,13 +284,13 @@ class InvokeAIDiffuserComponent: # methods below are called from do_diffusion_step and should be considered private to this class. - def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning): + def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs): # fast batched path x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) both_conditionings = torch.cat([unconditioning, conditioning]) both_results = self.model_forward_callback( - x_twice, sigma_twice, both_conditionings + x_twice, sigma_twice, both_conditionings, **kwargs, ) unconditioned_next_x, conditioned_next_x = both_results.chunk(2) if conditioned_next_x.device.type == "mps": @@ -302,16 +304,17 @@ class InvokeAIDiffuserComponent: sigma, unconditioning: torch.Tensor, conditioning: torch.Tensor, + **kwargs, ): # low-memory sequential path - unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) - conditioned_next_x = self.model_forward_callback(x, sigma, conditioning) + unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs) + conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs) if conditioned_next_x.device.type == "mps": # prevent a result filled with zeros. seems to be a torch bug. conditioned_next_x = conditioned_next_x.clone() return unconditioned_next_x, conditioned_next_x - def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning): + def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs): assert isinstance(conditioning, dict) assert isinstance(unconditioning, dict) x_twice = torch.cat([x] * 2) @@ -326,7 +329,7 @@ class InvokeAIDiffuserComponent: else: both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]]) unconditioned_next_x, conditioned_next_x = self.model_forward_callback( - x_twice, sigma_twice, both_conditionings + x_twice, sigma_twice, both_conditionings, **kwargs, ).chunk(2) return unconditioned_next_x, conditioned_next_x @@ -337,6 +340,7 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ): if self.is_running_diffusers: return self._apply_cross_attention_controlled_conditioning__diffusers( @@ -345,6 +349,7 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ) else: return self._apply_cross_attention_controlled_conditioning__compvis( @@ -353,6 +358,7 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ) def _apply_cross_attention_controlled_conditioning__diffusers( @@ -362,6 +368,7 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ): context: Context = self.cross_attention_control_context @@ -377,6 +384,7 @@ class InvokeAIDiffuserComponent: sigma, unconditioning, {"swap_cross_attn_context": cross_attn_processor_context}, + **kwargs, ) # do requested cross attention types for conditioning (positive prompt) @@ -388,6 +396,7 @@ class InvokeAIDiffuserComponent: sigma, conditioning, {"swap_cross_attn_context": cross_attn_processor_context}, + **kwargs, ) return unconditioned_next_x, conditioned_next_x @@ -398,6 +407,7 @@ class InvokeAIDiffuserComponent: unconditioning, conditioning, cross_attention_control_types_to_do, + **kwargs, ): # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # slower non-batched path (20% slower on mac MPS) @@ -411,13 +421,13 @@ class InvokeAIDiffuserComponent: context: Context = self.cross_attention_control_context try: - unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) + unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs) # process x using the original prompt, saving the attention maps # print("saving attention maps for", cross_attention_control_types_to_do) for ca_type in cross_attention_control_types_to_do: context.request_save_attention_maps(ca_type) - _ = self.model_forward_callback(x, sigma, conditioning) + _ = self.model_forward_callback(x, sigma, conditioning, **kwargs,) context.clear_requests(cleanup=False) # process x again, using the saved attention maps to control where self.edited_conditioning will be applied @@ -428,7 +438,7 @@ class InvokeAIDiffuserComponent: self.conditioning.cross_attention_control_args.edited_conditioning ) conditioned_next_x = self.model_forward_callback( - x, sigma, edited_conditioning + x, sigma, edited_conditioning, **kwargs, ) context.clear_requests(cleanup=True) diff --git a/invokeai/frontend/web/dist/index.html b/invokeai/frontend/web/dist/index.html index 63618e60be..8a982a7268 100644 --- a/invokeai/frontend/web/dist/index.html +++ b/invokeai/frontend/web/dist/index.html @@ -12,7 +12,7 @@ margin: 0; } - + diff --git a/invokeai/frontend/web/dist/locales/en.json b/invokeai/frontend/web/dist/locales/en.json index 94dff3934a..bf14dd5510 100644 --- a/invokeai/frontend/web/dist/locales/en.json +++ b/invokeai/frontend/web/dist/locales/en.json @@ -122,7 +122,9 @@ "noImagesInGallery": "No Images In Gallery", "deleteImage": "Delete Image", "deleteImageBin": "Deleted images will be sent to your operating system's Bin.", - "deleteImagePermanent": "Deleted images cannot be restored." + "deleteImagePermanent": "Deleted images cannot be restored.", + "images": "Images", + "assets": "Assets" }, "hotkeys": { "keyboardShortcuts": "Keyboard Shortcuts", @@ -452,6 +454,8 @@ "height": "Height", "scheduler": "Scheduler", "seed": "Seed", + "boundingBoxWidth": "Bounding Box Width", + "boundingBoxHeight": "Bounding Box Height", "imageToImage": "Image to Image", "randomizeSeed": "Randomize Seed", "shuffle": "Shuffle Seed", @@ -524,7 +528,7 @@ }, "settings": { "models": "Models", - "displayInProgress": "Display In-Progress Images", + "displayInProgress": "Display Progress Images", "saveSteps": "Save images every n steps", "confirmOnDelete": "Confirm On Delete", "displayHelpIcons": "Display Help Icons", @@ -564,6 +568,8 @@ "canvasMerged": "Canvas Merged", "sentToImageToImage": "Sent To Image To Image", "sentToUnifiedCanvas": "Sent to Unified Canvas", + "parameterSet": "Parameter set", + "parameterNotSet": "Parameter not set", "parametersSet": "Parameters Set", "parametersNotSet": "Parameters Not Set", "parametersNotSetDesc": "No metadata found for this image.", diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index 13b8d78bf7..dd1c87effb 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -101,7 +101,8 @@ "serialize-error": "^11.0.0", "socket.io-client": "^4.6.0", "use-image": "^1.1.0", - "uuid": "^9.0.0" + "uuid": "^9.0.0", + "zod": "^3.21.4" }, "peerDependencies": { "@chakra-ui/cli": "^2.4.0", diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 94dff3934a..bf14dd5510 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -122,7 +122,9 @@ "noImagesInGallery": "No Images In Gallery", "deleteImage": "Delete Image", "deleteImageBin": "Deleted images will be sent to your operating system's Bin.", - "deleteImagePermanent": "Deleted images cannot be restored." + "deleteImagePermanent": "Deleted images cannot be restored.", + "images": "Images", + "assets": "Assets" }, "hotkeys": { "keyboardShortcuts": "Keyboard Shortcuts", @@ -452,6 +454,8 @@ "height": "Height", "scheduler": "Scheduler", "seed": "Seed", + "boundingBoxWidth": "Bounding Box Width", + "boundingBoxHeight": "Bounding Box Height", "imageToImage": "Image to Image", "randomizeSeed": "Randomize Seed", "shuffle": "Shuffle Seed", @@ -524,7 +528,7 @@ }, "settings": { "models": "Models", - "displayInProgress": "Display In-Progress Images", + "displayInProgress": "Display Progress Images", "saveSteps": "Save images every n steps", "confirmOnDelete": "Confirm On Delete", "displayHelpIcons": "Display Help Icons", @@ -564,6 +568,8 @@ "canvasMerged": "Canvas Merged", "sentToImageToImage": "Sent To Image To Image", "sentToUnifiedCanvas": "Sent to Unified Canvas", + "parameterSet": "Parameter set", + "parameterNotSet": "Parameter not set", "parametersSet": "Parameters Set", "parametersNotSet": "Parameters Not Set", "parametersNotSetDesc": "No metadata found for this image.", diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts index d312d725ba..6700a732b3 100644 --- a/invokeai/frontend/web/src/app/constants.ts +++ b/invokeai/frontend/web/src/app/constants.ts @@ -21,25 +21,11 @@ export const SCHEDULERS = [ export type Scheduler = (typeof SCHEDULERS)[number]; -export const isScheduler = (x: string): x is Scheduler => - SCHEDULERS.includes(x as Scheduler); - -// Valid image widths -export const WIDTHS: Array = Array.from(Array(64)).map( - (_x, i) => (i + 1) * 64 -); - -// Valid image heights -export const HEIGHTS: Array = Array.from(Array(64)).map( - (_x, i) => (i + 1) * 64 -); - // Valid upscaling levels export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [ { key: '2x', value: 2 }, { key: '4x', value: 4 }, ]; - export const NUMPY_RAND_MIN = 0; export const NUMPY_RAND_MAX = 2147483647; diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts index 52995e0da3..9fb4ceae32 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts @@ -1,7 +1,5 @@ import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist'; import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist'; -import { resultsPersistDenylist } from 'features/gallery/store/resultsPersistDenylist'; -import { uploadsPersistDenylist } from 'features/gallery/store/uploadsPersistDenylist'; import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist'; import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist'; import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist'; @@ -22,11 +20,9 @@ const serializationDenylist: { models: modelsPersistDenylist, nodes: nodesPersistDenylist, postprocessing: postprocessingPersistDenylist, - results: resultsPersistDenylist, system: systemPersistDenylist, // config: configPersistDenyList, ui: uiPersistDenylist, - uploads: uploadsPersistDenylist, // hotkeys: hotkeysPersistDenylist, }; diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts index 155a7786b3..c6ae4946f2 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts @@ -1,7 +1,6 @@ import { initialCanvasState } from 'features/canvas/store/canvasSlice'; import { initialGalleryState } from 'features/gallery/store/gallerySlice'; -import { initialResultsState } from 'features/gallery/store/resultsSlice'; -import { initialUploadsState } from 'features/gallery/store/uploadsSlice'; +import { initialImagesState } from 'features/gallery/store/imagesSlice'; import { initialLightboxState } from 'features/lightbox/store/lightboxSlice'; import { initialNodesState } from 'features/nodes/store/nodesSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice'; @@ -24,12 +23,11 @@ const initialStates: { models: initialModelsState, nodes: initialNodesState, postprocessing: initialPostprocessingState, - results: initialResultsState, system: initialSystemState, config: initialConfigState, ui: initialUIState, - uploads: initialUploadsState, hotkeys: initialHotkeysState, + images: initialImagesState, }; export const unserialize: UnserializeFunction = (data, key) => { diff --git a/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts b/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts index 743537d7ea..eb54868735 100644 --- a/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts +++ b/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts @@ -7,5 +7,6 @@ export const actionsDenylist = [ 'canvas/setBoundingBoxDimensions', 'canvas/setIsDrawing', 'canvas/addPointToCurrentLine', - 'socket/generatorProgress', + 'socket/socketGeneratorProgress', + 'socket/appSocketGeneratorProgress', ]; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index f23e83a191..ba16e56371 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -8,9 +8,16 @@ import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit'; import type { RootState, AppDispatch } from '../../store'; import { addInitialImageSelectedListener } from './listeners/initialImageSelected'; -import { addImageResultReceivedListener } from './listeners/invocationComplete'; -import { addImageUploadedListener } from './listeners/imageUploaded'; -import { addRequestedImageDeletionListener } from './listeners/imageDeleted'; +import { + addImageUploadedFulfilledListener, + addImageUploadedRejectedListener, +} from './listeners/imageUploaded'; +import { + addImageDeletedFulfilledListener, + addImageDeletedPendingListener, + addImageDeletedRejectedListener, + addRequestedImageDeletionListener, +} from './listeners/imageDeleted'; import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; @@ -19,6 +26,50 @@ import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGaller import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage'; import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard'; import { addCanvasMergedListener } from './listeners/canvasMerged'; +import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress'; +import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete'; +import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete'; +import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError'; +import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted'; +import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected'; +import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected'; +import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed'; +import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed'; +import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke'; +import { + addImageMetadataReceivedFulfilledListener, + addImageMetadataReceivedRejectedListener, +} from './listeners/imageMetadataReceived'; +import { + addImageUrlsReceivedFulfilledListener, + addImageUrlsReceivedRejectedListener, +} from './listeners/imageUrlsReceived'; +import { + addSessionCreatedFulfilledListener, + addSessionCreatedPendingListener, + addSessionCreatedRejectedListener, +} from './listeners/sessionCreated'; +import { + addSessionInvokedFulfilledListener, + addSessionInvokedPendingListener, + addSessionInvokedRejectedListener, +} from './listeners/sessionInvoked'; +import { + addSessionCanceledFulfilledListener, + addSessionCanceledPendingListener, + addSessionCanceledRejectedListener, +} from './listeners/sessionCanceled'; +import { + addImageUpdatedFulfilledListener, + addImageUpdatedRejectedListener, +} from './listeners/imageUpdated'; +import { + addReceivedPageOfImagesFulfilledListener, + addReceivedPageOfImagesRejectedListener, +} from './listeners/receivedPageOfImages'; +import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved'; +import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener'; +import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged'; export const listenerMiddleware = createListenerMiddleware(); @@ -38,17 +89,87 @@ export type AppListenerEffect = ListenerEffect< AppDispatch >; -addImageUploadedListener(); -addInitialImageSelectedListener(); -addImageResultReceivedListener(); -addRequestedImageDeletionListener(); +// Image uploaded +addImageUploadedFulfilledListener(); +addImageUploadedRejectedListener(); +// Image updated +addImageUpdatedFulfilledListener(); +addImageUpdatedRejectedListener(); + +// Image selected +addInitialImageSelectedListener(); + +// Image deleted +addRequestedImageDeletionListener(); +addImageDeletedPendingListener(); +addImageDeletedFulfilledListener(); +addImageDeletedRejectedListener(); + +// Image metadata +addImageMetadataReceivedFulfilledListener(); +addImageMetadataReceivedRejectedListener(); + +// Image URLs +addImageUrlsReceivedFulfilledListener(); +addImageUrlsReceivedRejectedListener(); + +// User Invoked addUserInvokedCanvasListener(); addUserInvokedNodesListener(); addUserInvokedTextToImageListener(); addUserInvokedImageToImageListener(); +addSessionReadyToInvokeListener(); +// Canvas actions addCanvasSavedToGalleryListener(); addCanvasDownloadedAsImageListener(); addCanvasCopiedToClipboardListener(); addCanvasMergedListener(); +addStagingAreaImageSavedListener(); +addCommitStagingAreaImageListener(); + +/** + * Socket.IO Events - these handle SIO events directly and pass on internal application actions. + * We don't handle SIO events in slices via `extraReducers` because some of these events shouldn't + * actually be handled at all. + * + * For example, we don't want to respond to progress events for canceled sessions. To avoid + * duplicating the logic to determine if an event should be responded to, we handle all of that + * "is this session canceled?" logic in these listeners. + * + * The `socketGeneratorProgress` listener will then only dispatch the `appSocketGeneratorProgress` + * action if it should be handled by the rest of the application. It is this `appSocketGeneratorProgress` + * action that is handled by reducers in slices. + */ +addGeneratorProgressListener(); +addGraphExecutionStateCompleteListener(); +addInvocationCompleteListener(); +addInvocationErrorListener(); +addInvocationStartedListener(); +addSocketConnectedListener(); +addSocketDisconnectedListener(); +addSocketSubscribedListener(); +addSocketUnsubscribedListener(); + +// Session Created +addSessionCreatedPendingListener(); +addSessionCreatedFulfilledListener(); +addSessionCreatedRejectedListener(); + +// Session Invoked +addSessionInvokedPendingListener(); +addSessionInvokedFulfilledListener(); +addSessionInvokedRejectedListener(); + +// Session Canceled +addSessionCanceledPendingListener(); +addSessionCanceledFulfilledListener(); +addSessionCanceledRejectedListener(); + +// Fetching images +addReceivedPageOfImagesFulfilledListener(); +addReceivedPageOfImagesRejectedListener(); + +// Gallery +addImageCategoriesChangedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts new file mode 100644 index 0000000000..90f71879a1 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener.ts @@ -0,0 +1,42 @@ +import { startAppListening } from '..'; +import { log } from 'app/logging/useLogger'; +import { commitStagingAreaImage } from 'features/canvas/store/canvasSlice'; +import { sessionCanceled } from 'services/thunks/session'; + +const moduleLog = log.child({ namespace: 'canvas' }); + +export const addCommitStagingAreaImageListener = () => { + startAppListening({ + actionCreator: commitStagingAreaImage, + effect: async (action, { dispatch, getState }) => { + const state = getState(); + const { sessionId, isProcessing } = state.system; + const canvasSessionId = action.payload; + + if (!isProcessing) { + // Only need to cancel if we are processing + return; + } + + if (!canvasSessionId) { + moduleLog.debug('No canvas session, skipping cancel'); + return; + } + + if (canvasSessionId !== sessionId) { + moduleLog.debug( + { + data: { + canvasSessionId, + sessionId, + }, + }, + 'Canvas session does not match global session, skipping cancel' + ); + return; + } + + dispatch(sessionCanceled({ sessionId })); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts index 1e2d99541c..80865f3126 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts @@ -52,10 +52,11 @@ export const addCanvasMergedListener = () => { dispatch( imageUploaded({ - imageType: 'intermediates', formData: { file: new File([blob], filename, { type: 'image/png' }), }, + imageCategory: 'general', + isIntermediate: true, }) ); @@ -65,7 +66,7 @@ export const addCanvasMergedListener = () => { action.meta.arg.formData.file.name === filename ); - const mergedCanvasImage = payload.response; + const mergedCanvasImage = payload; dispatch( setMergedCanvas({ diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts index d8237d1d5c..b89620775b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts @@ -4,16 +4,18 @@ import { log } from 'app/logging/useLogger'; import { imageUploaded } from 'services/thunks/image'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { addToast } from 'features/system/store/systemSlice'; +import { v4 as uuidv4 } from 'uuid'; +import { imageUpserted } from 'features/gallery/store/imagesSlice'; const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' }); export const addCanvasSavedToGalleryListener = () => { startAppListening({ actionCreator: canvasSavedToGallery, - effect: async (action, { dispatch, getState }) => { + effect: async (action, { dispatch, getState, take }) => { const state = getState(); - const blob = await getBaseLayerBlob(state); + const blob = await getBaseLayerBlob(state, true); if (!blob) { moduleLog.error('Problem getting base layer blob'); @@ -27,14 +29,25 @@ export const addCanvasSavedToGalleryListener = () => { return; } + const filename = `mergedCanvas_${uuidv4()}.png`; + dispatch( imageUploaded({ - imageType: 'results', formData: { - file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }), + file: new File([blob], filename, { type: 'image/png' }), }, + imageCategory: 'general', + isIntermediate: false, }) ); + + const [{ payload: uploadedImageDTO }] = await take( + (action): action is ReturnType => + imageUploaded.fulfilled.match(action) && + action.meta.arg.formData.file.name === filename + ); + + dispatch(imageUpserted(uploadedImageDTO)); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts new file mode 100644 index 0000000000..85d56d3913 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts @@ -0,0 +1,24 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { receivedPageOfImages } from 'services/thunks/image'; +import { + imageCategoriesChanged, + selectFilteredImagesAsArray, +} from 'features/gallery/store/imagesSlice'; + +const moduleLog = log.child({ namespace: 'gallery' }); + +export const addImageCategoriesChangedListener = () => { + startAppListening({ + actionCreator: imageCategoriesChanged, + effect: (action, { getState, dispatch }) => { + const filteredImagesCount = selectFilteredImagesAsArray( + getState() + ).length; + + if (!filteredImagesCount) { + dispatch(receivedPageOfImages()); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts index 42a62b3d80..bf7ca4020c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts @@ -4,9 +4,18 @@ import { imageDeleted } from 'services/thunks/image'; import { log } from 'app/logging/useLogger'; import { clamp } from 'lodash-es'; import { imageSelected } from 'features/gallery/store/gallerySlice'; +import { + imageRemoved, + imagesAdapter, + selectImagesEntities, + selectImagesIds, +} from 'features/gallery/store/imagesSlice'; const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' }); +/** + * Called when the user requests an image deletion + */ export const addRequestedImageDeletionListener = () => { startAppListening({ actionCreator: requestedImageDeletion, @@ -17,24 +26,20 @@ export const addRequestedImageDeletionListener = () => { return; } - const { image_name, image_type } = image; + const { image_name, image_origin } = image; - if (image_type !== 'uploads' && image_type !== 'results') { - moduleLog.warn({ data: image }, `Invalid image type ${image_type}`); - return; - } + const state = getState(); + const selectedImage = state.gallery.selectedImage; - const selectedImageName = getState().gallery.selectedImage?.image_name; + if (selectedImage && selectedImage.image_name === image_name) { + const ids = selectImagesIds(state); + const entities = selectImagesEntities(state); - if (selectedImageName === image_name) { - const allIds = getState()[image_type].ids; - const allEntities = getState()[image_type].entities; - - const deletedImageIndex = allIds.findIndex( + const deletedImageIndex = ids.findIndex( (result) => result.toString() === image_name ); - const filteredIds = allIds.filter((id) => id.toString() !== image_name); + const filteredIds = ids.filter((id) => id.toString() !== image_name); const newSelectedImageIndex = clamp( deletedImageIndex, @@ -44,7 +49,7 @@ export const addRequestedImageDeletionListener = () => { const newSelectedImageId = filteredIds[newSelectedImageIndex]; - const newSelectedImage = allEntities[newSelectedImageId]; + const newSelectedImage = entities[newSelectedImageId]; if (newSelectedImageId) { dispatch(imageSelected(newSelectedImage)); @@ -53,7 +58,52 @@ export const addRequestedImageDeletionListener = () => { } } - dispatch(imageDeleted({ imageName: image_name, imageType: image_type })); + dispatch(imageRemoved(image_name)); + + dispatch( + imageDeleted({ imageName: image_name, imageOrigin: image_origin }) + ); + }, + }); +}; + +/** + * Called when the actual delete request is sent to the server + */ +export const addImageDeletedPendingListener = () => { + startAppListening({ + actionCreator: imageDeleted.pending, + effect: (action, { dispatch, getState }) => { + const { imageName, imageOrigin } = action.meta.arg; + // Preemptively remove the image from the gallery + imagesAdapter.removeOne(getState().images, imageName); + }, + }); +}; + +/** + * Called on successful delete + */ +export const addImageDeletedFulfilledListener = () => { + startAppListening({ + actionCreator: imageDeleted.fulfilled, + effect: (action, { dispatch, getState }) => { + moduleLog.debug({ data: { image: action.meta.arg } }, 'Image deleted'); + }, + }); +}; + +/** + * Called on failed delete + */ +export const addImageDeletedRejectedListener = () => { + startAppListening({ + actionCreator: imageDeleted.rejected, + effect: (action, { dispatch, getState }) => { + moduleLog.debug( + { data: { image: action.meta.arg } }, + 'Unable to delete image' + ); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts new file mode 100644 index 0000000000..63aeecb95e --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts @@ -0,0 +1,33 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { imageMetadataReceived } from 'services/thunks/image'; +import { imageUpserted } from 'features/gallery/store/imagesSlice'; + +const moduleLog = log.child({ namespace: 'image' }); + +export const addImageMetadataReceivedFulfilledListener = () => { + startAppListening({ + actionCreator: imageMetadataReceived.fulfilled, + effect: (action, { getState, dispatch }) => { + const image = action.payload; + if (image.is_intermediate) { + // No further actions needed for intermediate images + return; + } + moduleLog.debug({ data: { image } }, 'Image metadata received'); + dispatch(imageUpserted(image)); + }, + }); +}; + +export const addImageMetadataReceivedRejectedListener = () => { + startAppListening({ + actionCreator: imageMetadataReceived.rejected, + effect: (action, { getState, dispatch }) => { + moduleLog.debug( + { data: { image: action.meta.arg } }, + 'Problem receiving image metadata' + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUpdated.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUpdated.ts new file mode 100644 index 0000000000..6f8b46ec23 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUpdated.ts @@ -0,0 +1,26 @@ +import { startAppListening } from '..'; +import { imageUpdated } from 'services/thunks/image'; +import { log } from 'app/logging/useLogger'; + +const moduleLog = log.child({ namespace: 'image' }); + +export const addImageUpdatedFulfilledListener = () => { + startAppListening({ + actionCreator: imageUpdated.fulfilled, + effect: (action, { dispatch, getState }) => { + moduleLog.debug( + { oldImage: action.meta.arg, updatedImage: action.payload }, + 'Image updated' + ); + }, + }); +}; + +export const addImageUpdatedRejectedListener = () => { + startAppListening({ + actionCreator: imageUpdated.rejected, + effect: (action, { dispatch }) => { + moduleLog.debug({ oldImage: action.meta.arg }, 'Image update failed'); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index 1d66166c12..6d84431f80 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -1,44 +1,46 @@ import { startAppListening } from '..'; -import { uploadAdded } from 'features/gallery/store/uploadsSlice'; -import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageUploaded } from 'services/thunks/image'; import { addToast } from 'features/system/store/systemSlice'; -import { initialImageSelected } from 'features/parameters/store/actions'; -import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; -import { resultAdded } from 'features/gallery/store/resultsSlice'; -import { isResultsImageDTO, isUploadsImageDTO } from 'services/types/guards'; +import { log } from 'app/logging/useLogger'; +import { imageUpserted } from 'features/gallery/store/imagesSlice'; -export const addImageUploadedListener = () => { +const moduleLog = log.child({ namespace: 'image' }); + +export const addImageUploadedFulfilledListener = () => { startAppListening({ - predicate: (action): action is ReturnType => - imageUploaded.fulfilled.match(action) && - action.payload.response.image_type !== 'intermediates', + actionCreator: imageUploaded.fulfilled, effect: (action, { dispatch, getState }) => { - const { response: image } = action.payload; + const image = action.payload; + + moduleLog.debug({ arg: '', image }, 'Image uploaded'); + + if (action.payload.is_intermediate) { + // No further actions needed for intermediate images + return; + } const state = getState(); - if (isUploadsImageDTO(image)) { - dispatch(uploadAdded(image)); - - dispatch(addToast({ title: 'Image Uploaded', status: 'success' })); - - if (state.gallery.shouldAutoSwitchToNewImages) { - dispatch(imageSelected(image)); - } - - if (action.meta.arg.activeTabName === 'img2img') { - dispatch(initialImageSelected(image)); - } - - if (action.meta.arg.activeTabName === 'unifiedCanvas') { - dispatch(setInitialCanvasImage(image)); - } - } - - if (isResultsImageDTO(image)) { - dispatch(resultAdded(image)); - } + dispatch(imageUpserted(image)); + dispatch(addToast({ title: 'Image Uploaded', status: 'success' })); + }, + }); +}; + +export const addImageUploadedRejectedListener = () => { + startAppListening({ + actionCreator: imageUploaded.rejected, + effect: (action, { dispatch }) => { + const { formData, ...rest } = action.meta.arg; + const sanitizedData = { arg: { ...rest, formData: { file: '' } } }; + moduleLog.error({ data: sanitizedData }, 'Image upload failed'); + dispatch( + addToast({ + title: 'Image Upload Failed', + description: action.error.message, + status: 'error', + }) + ); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts new file mode 100644 index 0000000000..fd0461f893 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUrlsReceived.ts @@ -0,0 +1,38 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { imageUrlsReceived } from 'services/thunks/image'; +import { imagesAdapter } from 'features/gallery/store/imagesSlice'; + +const moduleLog = log.child({ namespace: 'image' }); + +export const addImageUrlsReceivedFulfilledListener = () => { + startAppListening({ + actionCreator: imageUrlsReceived.fulfilled, + effect: (action, { getState, dispatch }) => { + const image = action.payload; + moduleLog.debug({ data: { image } }, 'Image URLs received'); + + const { image_name, image_url, thumbnail_url } = image; + + imagesAdapter.updateOne(getState().images, { + id: image_name, + changes: { + image_url, + thumbnail_url, + }, + }); + }, + }); +}; + +export const addImageUrlsReceivedRejectedListener = () => { + startAppListening({ + actionCreator: imageUrlsReceived.rejected, + effect: (action, { getState, dispatch }) => { + moduleLog.debug( + { data: { image: action.meta.arg } }, + 'Problem getting image URLs' + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts index d6cfc260f3..940cc84c1e 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts @@ -1,6 +1,4 @@ import { initialImageChanged } from 'features/parameters/store/generationSlice'; -import { selectResultsById } from 'features/gallery/store/resultsSlice'; -import { selectUploadsById } from 'features/gallery/store/uploadsSlice'; import { t } from 'i18next'; import { addToast } from 'features/system/store/systemSlice'; import { startAppListening } from '..'; @@ -9,7 +7,7 @@ import { isImageDTO, } from 'features/parameters/store/actions'; import { makeToast } from 'app/components/Toaster'; -import { ImageDTO } from 'services/api'; +import { selectImagesById } from 'features/gallery/store/imagesSlice'; export const addInitialImageSelectedListener = () => { startAppListening({ @@ -30,16 +28,8 @@ export const addInitialImageSelectedListener = () => { return; } - const { image_name, image_type } = action.payload; - - let image: ImageDTO | undefined; - const state = getState(); - - if (image_type === 'results') { - image = selectResultsById(state, image_name); - } else if (image_type === 'uploads') { - image = selectUploadsById(state, image_name); - } + const imageName = action.payload; + const image = selectImagesById(getState(), imageName); if (!image) { dispatch( diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/invocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/invocationComplete.ts deleted file mode 100644 index 0222eea93c..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/invocationComplete.ts +++ /dev/null @@ -1,62 +0,0 @@ -import { invocationComplete } from 'services/events/actions'; -import { isImageOutput } from 'services/types/guards'; -import { - imageMetadataReceived, - imageUrlsReceived, -} from 'services/thunks/image'; -import { startAppListening } from '..'; -import { addImageToStagingArea } from 'features/canvas/store/canvasSlice'; - -const nodeDenylist = ['dataURL_image']; - -export const addImageResultReceivedListener = () => { - startAppListening({ - predicate: (action) => { - if ( - invocationComplete.match(action) && - isImageOutput(action.payload.data.result) - ) { - return true; - } - return false; - }, - effect: async (action, { getState, dispatch, take }) => { - if (!invocationComplete.match(action)) { - return; - } - - const { data } = action.payload; - const { result, node, graph_execution_state_id } = data; - - if (isImageOutput(result) && !nodeDenylist.includes(node.type)) { - const { image_name, image_type } = result.image; - - dispatch( - imageUrlsReceived({ imageName: image_name, imageType: image_type }) - ); - - dispatch( - imageMetadataReceived({ - imageName: image_name, - imageType: image_type, - }) - ); - - // Handle canvas image - if ( - graph_execution_state_id === - getState().canvas.layerState.stagingArea.sessionId - ) { - const [{ payload: image }] = await take( - ( - action - ): action is ReturnType => - imageMetadataReceived.fulfilled.match(action) && - action.payload.image_name === image_name - ); - dispatch(addImageToStagingArea(image)); - } - } - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedPageOfImages.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedPageOfImages.ts new file mode 100644 index 0000000000..cde7e22e3d --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedPageOfImages.ts @@ -0,0 +1,33 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { serializeError } from 'serialize-error'; +import { receivedPageOfImages } from 'services/thunks/image'; + +const moduleLog = log.child({ namespace: 'gallery' }); + +export const addReceivedPageOfImagesFulfilledListener = () => { + startAppListening({ + actionCreator: receivedPageOfImages.fulfilled, + effect: (action, { getState, dispatch }) => { + const page = action.payload; + moduleLog.debug( + { data: { payload: action.payload } }, + `Received ${page.items.length} images` + ); + }, + }); +}; + +export const addReceivedPageOfImagesRejectedListener = () => { + startAppListening({ + actionCreator: receivedPageOfImages.rejected, + effect: (action, { getState, dispatch }) => { + if (action.payload) { + moduleLog.debug( + { data: { error: serializeError(action.payload) } }, + 'Problem receiving images' + ); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCanceled.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCanceled.ts new file mode 100644 index 0000000000..6274ad4dc8 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCanceled.ts @@ -0,0 +1,48 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { sessionCanceled } from 'services/thunks/session'; +import { serializeError } from 'serialize-error'; + +const moduleLog = log.child({ namespace: 'session' }); + +export const addSessionCanceledPendingListener = () => { + startAppListening({ + actionCreator: sessionCanceled.pending, + effect: (action, { getState, dispatch }) => { + // + }, + }); +}; + +export const addSessionCanceledFulfilledListener = () => { + startAppListening({ + actionCreator: sessionCanceled.fulfilled, + effect: (action, { getState, dispatch }) => { + const { sessionId } = action.meta.arg; + moduleLog.debug( + { data: { sessionId } }, + `Session canceled (${sessionId})` + ); + }, + }); +}; + +export const addSessionCanceledRejectedListener = () => { + startAppListening({ + actionCreator: sessionCanceled.rejected, + effect: (action, { getState, dispatch }) => { + if (action.payload) { + const { arg, error } = action.payload; + moduleLog.error( + { + data: { + arg, + error: serializeError(error), + }, + }, + `Problem canceling session` + ); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts new file mode 100644 index 0000000000..fb8a64d2e3 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts @@ -0,0 +1,45 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { sessionCreated } from 'services/thunks/session'; +import { serializeError } from 'serialize-error'; + +const moduleLog = log.child({ namespace: 'session' }); + +export const addSessionCreatedPendingListener = () => { + startAppListening({ + actionCreator: sessionCreated.pending, + effect: (action, { getState, dispatch }) => { + // + }, + }); +}; + +export const addSessionCreatedFulfilledListener = () => { + startAppListening({ + actionCreator: sessionCreated.fulfilled, + effect: (action, { getState, dispatch }) => { + const session = action.payload; + moduleLog.debug({ data: { session } }, `Session created (${session.id})`); + }, + }); +}; + +export const addSessionCreatedRejectedListener = () => { + startAppListening({ + actionCreator: sessionCreated.rejected, + effect: (action, { getState, dispatch }) => { + if (action.payload) { + const { arg, error } = action.payload; + moduleLog.error( + { + data: { + arg, + error: serializeError(error), + }, + }, + `Problem creating session` + ); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts new file mode 100644 index 0000000000..272d1d9e1d --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts @@ -0,0 +1,48 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { sessionInvoked } from 'services/thunks/session'; +import { serializeError } from 'serialize-error'; + +const moduleLog = log.child({ namespace: 'session' }); + +export const addSessionInvokedPendingListener = () => { + startAppListening({ + actionCreator: sessionInvoked.pending, + effect: (action, { getState, dispatch }) => { + // + }, + }); +}; + +export const addSessionInvokedFulfilledListener = () => { + startAppListening({ + actionCreator: sessionInvoked.fulfilled, + effect: (action, { getState, dispatch }) => { + const { sessionId } = action.meta.arg; + moduleLog.debug( + { data: { sessionId } }, + `Session invoked (${sessionId})` + ); + }, + }); +}; + +export const addSessionInvokedRejectedListener = () => { + startAppListening({ + actionCreator: sessionInvoked.rejected, + effect: (action, { getState, dispatch }) => { + if (action.payload) { + const { arg, error } = action.payload; + moduleLog.error( + { + data: { + arg, + error: serializeError(error), + }, + }, + `Problem invoking session` + ); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionReadyToInvoke.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionReadyToInvoke.ts new file mode 100644 index 0000000000..8d4262e7da --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionReadyToInvoke.ts @@ -0,0 +1,22 @@ +import { startAppListening } from '..'; +import { sessionInvoked } from 'services/thunks/session'; +import { log } from 'app/logging/useLogger'; +import { sessionReadyToInvoke } from 'features/system/store/actions'; + +const moduleLog = log.child({ namespace: 'session' }); + +export const addSessionReadyToInvokeListener = () => { + startAppListening({ + actionCreator: sessionReadyToInvoke, + effect: (action, { getState, dispatch }) => { + const { sessionId } = getState().system; + if (sessionId) { + moduleLog.debug( + { sessionId }, + `Session ready to invoke (${sessionId})})` + ); + dispatch(sessionInvoked({ sessionId })); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts new file mode 100644 index 0000000000..3049d2c933 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -0,0 +1,38 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { appSocketConnected, socketConnected } from 'services/events/actions'; +import { receivedPageOfImages } from 'services/thunks/image'; +import { receivedModels } from 'services/thunks/model'; +import { receivedOpenAPISchema } from 'services/thunks/schema'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addSocketConnectedEventListener = () => { + startAppListening({ + actionCreator: socketConnected, + effect: (action, { dispatch, getState }) => { + const { timestamp } = action.payload; + + moduleLog.debug({ timestamp }, 'Connected'); + + const { models, nodes, config, images } = getState(); + + const { disabledTabs } = config; + + if (!images.ids.length) { + dispatch(receivedPageOfImages()); + } + + if (!models.ids.length) { + dispatch(receivedModels()); + } + + if (!nodes.schema && !disabledTabs.includes('nodes')) { + dispatch(receivedOpenAPISchema()); + } + + // pass along the socket event as an application action + dispatch(appSocketConnected(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected.ts new file mode 100644 index 0000000000..d5e8914cef --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected.ts @@ -0,0 +1,19 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { + socketDisconnected, + appSocketDisconnected, +} from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addSocketDisconnectedEventListener = () => { + startAppListening({ + actionCreator: socketDisconnected, + effect: (action, { dispatch, getState }) => { + moduleLog.debug(action.payload, 'Disconnected'); + // pass along the socket event as an application action + dispatch(appSocketDisconnected(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts new file mode 100644 index 0000000000..756444d644 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts @@ -0,0 +1,34 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { + appSocketGeneratorProgress, + socketGeneratorProgress, +} from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addGeneratorProgressEventListener = () => { + startAppListening({ + actionCreator: socketGeneratorProgress, + effect: (action, { dispatch, getState }) => { + if ( + getState().system.canceledSession === + action.payload.data.graph_execution_state_id + ) { + moduleLog.trace( + action.payload, + 'Ignored generator progress for canceled session' + ); + return; + } + + moduleLog.trace( + action.payload, + `Generator progress (${action.payload.data.node.type})` + ); + + // pass along the socket event as an application action + dispatch(appSocketGeneratorProgress(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete.ts new file mode 100644 index 0000000000..7297825e32 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete.ts @@ -0,0 +1,22 @@ +import { log } from 'app/logging/useLogger'; +import { + appSocketGraphExecutionStateComplete, + socketGraphExecutionStateComplete, +} from 'services/events/actions'; +import { startAppListening } from '../..'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addGraphExecutionStateCompleteEventListener = () => { + startAppListening({ + actionCreator: socketGraphExecutionStateComplete, + effect: (action, { dispatch, getState }) => { + moduleLog.debug( + action.payload, + `Session invocation complete (${action.payload.data.graph_execution_state_id})` + ); + // pass along the socket event as an application action + dispatch(appSocketGraphExecutionStateComplete(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts new file mode 100644 index 0000000000..0b47f7a1be --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -0,0 +1,67 @@ +import { addImageToStagingArea } from 'features/canvas/store/canvasSlice'; +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { + appSocketInvocationComplete, + socketInvocationComplete, +} from 'services/events/actions'; +import { imageMetadataReceived } from 'services/thunks/image'; +import { sessionCanceled } from 'services/thunks/session'; +import { isImageOutput } from 'services/types/guards'; +import { progressImageSet } from 'features/system/store/systemSlice'; + +const moduleLog = log.child({ namespace: 'socketio' }); +const nodeDenylist = ['dataURL_image']; + +export const addInvocationCompleteEventListener = () => { + startAppListening({ + actionCreator: socketInvocationComplete, + effect: async (action, { dispatch, getState, take }) => { + moduleLog.debug( + action.payload, + `Invocation complete (${action.payload.data.node.type})` + ); + + const sessionId = action.payload.data.graph_execution_state_id; + + const { cancelType, isCancelScheduled } = getState().system; + + // Handle scheduled cancelation + if (cancelType === 'scheduled' && isCancelScheduled) { + dispatch(sessionCanceled({ sessionId })); + } + + const { data } = action.payload; + const { result, node, graph_execution_state_id } = data; + + // This complete event has an associated image output + if (isImageOutput(result) && !nodeDenylist.includes(node.type)) { + const { image_name, image_origin } = result.image; + + // Get its metadata + dispatch( + imageMetadataReceived({ + imageName: image_name, + imageOrigin: image_origin, + }) + ); + + const [{ payload: imageDTO }] = await take( + imageMetadataReceived.fulfilled.match + ); + + // Handle canvas image + if ( + graph_execution_state_id === + getState().canvas.layerState.stagingArea.sessionId + ) { + dispatch(addImageToStagingArea(imageDTO)); + } + + dispatch(progressImageSet(null)); + } + // pass along the socket event as an application action + dispatch(appSocketInvocationComplete(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts new file mode 100644 index 0000000000..51480bbad4 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts @@ -0,0 +1,21 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { + appSocketInvocationError, + socketInvocationError, +} from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addInvocationErrorEventListener = () => { + startAppListening({ + actionCreator: socketInvocationError, + effect: (action, { dispatch, getState }) => { + moduleLog.error( + action.payload, + `Invocation error (${action.payload.data.node.type})` + ); + dispatch(appSocketInvocationError(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts new file mode 100644 index 0000000000..978be2fef5 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts @@ -0,0 +1,32 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { + appSocketInvocationStarted, + socketInvocationStarted, +} from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addInvocationStartedEventListener = () => { + startAppListening({ + actionCreator: socketInvocationStarted, + effect: (action, { dispatch, getState }) => { + if ( + getState().system.canceledSession === + action.payload.data.graph_execution_state_id + ) { + moduleLog.trace( + action.payload, + 'Ignored invocation started for canceled session' + ); + return; + } + + moduleLog.debug( + action.payload, + `Invocation started (${action.payload.data.node.type})` + ); + dispatch(appSocketInvocationStarted(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts new file mode 100644 index 0000000000..871222981d --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed.ts @@ -0,0 +1,18 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { appSocketSubscribed, socketSubscribed } from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addSocketSubscribedEventListener = () => { + startAppListening({ + actionCreator: socketSubscribed, + effect: (action, { dispatch, getState }) => { + moduleLog.debug( + action.payload, + `Subscribed (${action.payload.sessionId}))` + ); + dispatch(appSocketSubscribed(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts new file mode 100644 index 0000000000..ff85379907 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed.ts @@ -0,0 +1,21 @@ +import { startAppListening } from '../..'; +import { log } from 'app/logging/useLogger'; +import { + appSocketUnsubscribed, + socketUnsubscribed, +} from 'services/events/actions'; + +const moduleLog = log.child({ namespace: 'socketio' }); + +export const addSocketUnsubscribedEventListener = () => { + startAppListening({ + actionCreator: socketUnsubscribed, + effect: (action, { dispatch, getState }) => { + moduleLog.debug( + action.payload, + `Unsubscribed (${action.payload.sessionId})` + ); + dispatch(appSocketUnsubscribed(action.payload)); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts new file mode 100644 index 0000000000..9bd3cd6dd2 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts @@ -0,0 +1,54 @@ +import { stagingAreaImageSaved } from 'features/canvas/store/actions'; +import { startAppListening } from '..'; +import { log } from 'app/logging/useLogger'; +import { imageUpdated } from 'services/thunks/image'; +import { imageUpserted } from 'features/gallery/store/imagesSlice'; +import { addToast } from 'features/system/store/systemSlice'; + +const moduleLog = log.child({ namespace: 'canvas' }); + +export const addStagingAreaImageSavedListener = () => { + startAppListening({ + actionCreator: stagingAreaImageSaved, + effect: async (action, { dispatch, getState, take }) => { + const { image_name, image_origin } = action.payload; + + dispatch( + imageUpdated({ + imageName: image_name, + imageOrigin: image_origin, + requestBody: { + is_intermediate: false, + }, + }) + ); + + const [imageUpdatedAction] = await take( + (action) => + (imageUpdated.fulfilled.match(action) || + imageUpdated.rejected.match(action)) && + action.meta.arg.imageName === image_name + ); + + if (imageUpdated.rejected.match(imageUpdatedAction)) { + moduleLog.error( + { data: { arg: imageUpdatedAction.meta.arg } }, + 'Image saving failed' + ); + dispatch( + addToast({ + title: 'Image Saving Failed', + description: imageUpdatedAction.error.message, + status: 'error', + }) + ); + return; + } + + if (imageUpdated.fulfilled.match(imageUpdatedAction)) { + dispatch(imageUpserted(imageUpdatedAction.payload)); + dispatch(addToast({ title: 'Image Saved', status: 'success' })); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts index 2ebd3684e9..0ee3016bdb 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts @@ -1,9 +1,9 @@ import { startAppListening } from '..'; -import { sessionCreated, sessionInvoked } from 'services/thunks/session'; +import { sessionCreated } from 'services/thunks/session'; import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph'; import { log } from 'app/logging/useLogger'; import { canvasGraphBuilt } from 'features/nodes/store/actions'; -import { imageUploaded } from 'services/thunks/image'; +import { imageUpdated, imageUploaded } from 'services/thunks/image'; import { v4 as uuidv4 } from 'uuid'; import { Graph } from 'services/api'; import { @@ -15,12 +15,22 @@ import { getCanvasData } from 'features/canvas/util/getCanvasData'; import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode'; import { blobToDataURL } from 'features/canvas/util/blobToDataURL'; import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; +import { sessionReadyToInvoke } from 'features/system/store/actions'; const moduleLog = log.child({ namespace: 'invoke' }); /** - * This listener is responsible for building the canvas graph and blobs when the user invokes the canvas. - * It is also responsible for uploading the base and mask layers to the server. + * This listener is responsible invoking the canvas. This involves a number of steps: + * + * 1. Generate image blobs from the canvas layers + * 2. Determine the generation mode from the layers (txt2img, img2img, inpaint) + * 3. Build the canvas graph + * 4. Create the session with the graph + * 5. Upload the init image if necessary + * 6. Upload the mask image if necessary + * 7. Update the init and mask images with the session ID + * 8. Initialize the staging area if not yet initialized + * 9. Dispatch the sessionReadyToInvoke action to invoke the session */ export const addUserInvokedCanvasListener = () => { startAppListening({ @@ -70,63 +80,7 @@ export const addUserInvokedCanvasListener = () => { const { rangeNode, iterateNode, baseNode, edges } = graphComponents; - // Upload the base layer, to be used as init image - const baseFilename = `${uuidv4()}.png`; - - dispatch( - imageUploaded({ - imageType: 'intermediates', - formData: { - file: new File([baseBlob], baseFilename, { type: 'image/png' }), - }, - }) - ); - - if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') { - const [{ payload: basePayload }] = await take( - (action): action is ReturnType => - imageUploaded.fulfilled.match(action) && - action.meta.arg.formData.file.name === baseFilename - ); - - const { image_name: baseName, image_type: baseType } = - basePayload.response; - - baseNode.image = { - image_name: baseName, - image_type: baseType, - }; - } - - // Upload the mask layer image - const maskFilename = `${uuidv4()}.png`; - - if (baseNode.type === 'inpaint') { - dispatch( - imageUploaded({ - imageType: 'intermediates', - formData: { - file: new File([maskBlob], maskFilename, { type: 'image/png' }), - }, - }) - ); - - const [{ payload: maskPayload }] = await take( - (action): action is ReturnType => - imageUploaded.fulfilled.match(action) && - action.meta.arg.formData.file.name === maskFilename - ); - - const { image_name: maskName, image_type: maskType } = - maskPayload.response; - - baseNode.mask = { - image_name: maskName, - image_type: maskType, - }; - } - - // Assemble! + // Assemble! Note that this graph *does not have the init or mask image set yet!* const nodes: Graph['nodes'] = { [rangeNode.id]: rangeNode, [iterateNode.id]: iterateNode, @@ -136,15 +90,92 @@ export const addUserInvokedCanvasListener = () => { const graph = { nodes, edges }; dispatch(canvasGraphBuilt(graph)); - moduleLog({ data: graph }, 'Canvas graph built'); - // Actually create the session + moduleLog.debug({ data: graph }, 'Canvas graph built'); + + // If we are generating img2img or inpaint, we need to upload the init images + if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') { + const baseFilename = `${uuidv4()}.png`; + dispatch( + imageUploaded({ + formData: { + file: new File([baseBlob], baseFilename, { type: 'image/png' }), + }, + imageCategory: 'general', + isIntermediate: true, + }) + ); + + // Wait for the image to be uploaded + const [{ payload: baseImageDTO }] = await take( + (action): action is ReturnType => + imageUploaded.fulfilled.match(action) && + action.meta.arg.formData.file.name === baseFilename + ); + + // Update the base node with the image name and type + baseNode.image = { + image_name: baseImageDTO.image_name, + image_origin: baseImageDTO.image_origin, + }; + } + + // For inpaint, we also need to upload the mask layer + if (baseNode.type === 'inpaint') { + const maskFilename = `${uuidv4()}.png`; + dispatch( + imageUploaded({ + formData: { + file: new File([maskBlob], maskFilename, { type: 'image/png' }), + }, + imageCategory: 'mask', + isIntermediate: true, + }) + ); + + // Wait for the mask to be uploaded + const [{ payload: maskImageDTO }] = await take( + (action): action is ReturnType => + imageUploaded.fulfilled.match(action) && + action.meta.arg.formData.file.name === maskFilename + ); + + // Update the base node with the image name and type + baseNode.mask = { + image_name: maskImageDTO.image_name, + image_origin: maskImageDTO.image_origin, + }; + } + + // Create the session and wait for response dispatch(sessionCreated({ graph })); + const [sessionCreatedAction] = await take(sessionCreated.fulfilled.match); + const sessionId = sessionCreatedAction.payload.id; - // Wait for the session to be invoked (this is just the HTTP request to start processing) - const [{ meta }] = await take(sessionInvoked.fulfilled.match); + // Associate the init image with the session, now that we have the session ID + if ( + (baseNode.type === 'img2img' || baseNode.type === 'inpaint') && + baseNode.image + ) { + dispatch( + imageUpdated({ + imageName: baseNode.image.image_name, + imageOrigin: baseNode.image.image_origin, + requestBody: { session_id: sessionId }, + }) + ); + } - const { sessionId } = meta.arg; + // Associate the mask image with the session, now that we have the session ID + if (baseNode.type === 'inpaint' && baseNode.mask) { + dispatch( + imageUpdated({ + imageName: baseNode.mask.image_name, + imageOrigin: baseNode.mask.image_origin, + requestBody: { session_id: sessionId }, + }) + ); + } if (!state.canvas.layerState.stagingArea.boundingBox) { dispatch( @@ -158,7 +189,11 @@ export const addUserInvokedCanvasListener = () => { ); } + // Flag the session with the canvas session ID dispatch(canvasSessionIdChanged(sessionId)); + + // We are ready to invoke the session! + dispatch(sessionReadyToInvoke()); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts index e747aefa08..7dcbe8a41d 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedImageToImage.ts @@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session'; import { log } from 'app/logging/useLogger'; import { imageToImageGraphBuilt } from 'features/nodes/store/actions'; import { userInvoked } from 'app/store/actions'; +import { sessionReadyToInvoke } from 'features/system/store/actions'; const moduleLog = log.child({ namespace: 'invoke' }); @@ -11,14 +12,18 @@ export const addUserInvokedImageToImageListener = () => { startAppListening({ predicate: (action): action is ReturnType => userInvoked.match(action) && action.payload === 'img2img', - effect: (action, { getState, dispatch }) => { + effect: async (action, { getState, dispatch, take }) => { const state = getState(); const graph = buildImageToImageGraph(state); dispatch(imageToImageGraphBuilt(graph)); - moduleLog({ data: graph }, 'Image to Image graph built'); + moduleLog.debug({ data: graph }, 'Image to Image graph built'); dispatch(sessionCreated({ graph })); + + await take(sessionCreated.fulfilled.match); + + dispatch(sessionReadyToInvoke()); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts index 01e532d5ff..6fda3db0d6 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts @@ -4,6 +4,7 @@ import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGra import { log } from 'app/logging/useLogger'; import { nodesGraphBuilt } from 'features/nodes/store/actions'; import { userInvoked } from 'app/store/actions'; +import { sessionReadyToInvoke } from 'features/system/store/actions'; const moduleLog = log.child({ namespace: 'invoke' }); @@ -11,14 +12,18 @@ export const addUserInvokedNodesListener = () => { startAppListening({ predicate: (action): action is ReturnType => userInvoked.match(action) && action.payload === 'nodes', - effect: (action, { getState, dispatch }) => { + effect: async (action, { getState, dispatch, take }) => { const state = getState(); const graph = buildNodesGraph(state); dispatch(nodesGraphBuilt(graph)); - moduleLog({ data: graph }, 'Nodes graph built'); + moduleLog.debug({ data: graph }, 'Nodes graph built'); dispatch(sessionCreated({ graph })); + + await take(sessionCreated.fulfilled.match); + + dispatch(sessionReadyToInvoke()); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts index e3eb5d0b38..6042d86cb7 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedTextToImage.ts @@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session'; import { log } from 'app/logging/useLogger'; import { textToImageGraphBuilt } from 'features/nodes/store/actions'; import { userInvoked } from 'app/store/actions'; +import { sessionReadyToInvoke } from 'features/system/store/actions'; const moduleLog = log.child({ namespace: 'invoke' }); @@ -11,14 +12,20 @@ export const addUserInvokedTextToImageListener = () => { startAppListening({ predicate: (action): action is ReturnType => userInvoked.match(action) && action.payload === 'txt2img', - effect: (action, { getState, dispatch }) => { + effect: async (action, { getState, dispatch, take }) => { const state = getState(); const graph = buildTextToImageGraph(state); + dispatch(textToImageGraphBuilt(graph)); - moduleLog({ data: graph }, 'Text to Image graph built'); + + moduleLog.debug({ data: graph }, 'Text to Image graph built'); dispatch(sessionCreated({ graph })); + + await take(sessionCreated.fulfilled.match); + + dispatch(sessionReadyToInvoke()); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index b89615b2c0..521610adcc 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -10,12 +10,12 @@ import dynamicMiddlewares from 'redux-dynamic-middlewares'; import canvasReducer from 'features/canvas/store/canvasSlice'; import galleryReducer from 'features/gallery/store/gallerySlice'; -import resultsReducer from 'features/gallery/store/resultsSlice'; -import uploadsReducer from 'features/gallery/store/uploadsSlice'; +import imagesReducer from 'features/gallery/store/imagesSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice'; import generationReducer from 'features/parameters/store/generationSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import systemReducer from 'features/system/store/systemSlice'; +// import sessionReducer from 'features/system/store/sessionSlice'; import configReducer from 'features/system/store/configSlice'; import uiReducer from 'features/ui/store/uiSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice'; @@ -40,12 +40,12 @@ const allReducers = { models: modelsReducer, nodes: nodesReducer, postprocessing: postprocessingReducer, - results: resultsReducer, system: systemReducer, config: configReducer, ui: uiReducer, - uploads: uploadsReducer, hotkeys: hotkeysReducer, + images: imagesReducer, + // session: sessionReducer, }; const rootReducer = combineReducers(allReducers); @@ -63,8 +63,6 @@ const rememberedKeys: (keyof typeof allReducers)[] = [ 'system', 'ui', // 'hotkeys', - // 'results', - // 'uploads', // 'config', ]; diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index 68f7568779..8081ffa491 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -1,316 +1,82 @@ -/** - * Types for images, the things they are made of, and the things - * they make up. - * - * Generated images are txt2img and img2img images. They may have - * had additional postprocessing done on them when they were first - * generated. - * - * Postprocessed images are images which were not generated here - * but only postprocessed by the app. They only get postprocessing - * metadata and have a different image type, e.g. 'esrgan' or - * 'gfpgan'. - */ - -import { SelectedImage } from 'features/parameters/store/actions'; import { InvokeTabName } from 'features/ui/store/tabMap'; -import { IRect } from 'konva/lib/types'; -import { ImageResponseMetadata, ImageType } from 'services/api'; import { O } from 'ts-toolbelt'; -/** - * TODO: - * Once an image has been generated, if it is postprocessed again, - * additional postprocessing steps are added to its postprocessing - * array. - * - * TODO: Better documentation of types. - */ +// These are old types from the model management UI -export type PromptItem = { - prompt: string; - weight: number; -}; +// export type ModelStatus = 'active' | 'cached' | 'not loaded'; -// TECHDEBT: We need to retain compatibility with plain prompt strings and the structure Prompt type -export type Prompt = Array | string; - -export type SeedWeightPair = { - seed: number; - weight: number; -}; - -export type SeedWeights = Array; - -// All generated images contain these metadata. -export type CommonGeneratedImageMetadata = { - postprocessing: null | Array; - sampler: - | 'ddim' - | 'ddpm' - | 'deis' - | 'lms' - | 'pndm' - | 'heun' - | 'heun_k' - | 'euler' - | 'euler_k' - | 'euler_a' - | 'kdpm_2' - | 'kdpm_2_a' - | 'dpmpp_2s' - | 'dpmpp_2m' - | 'dpmpp_2m_k' - | 'unipc'; - prompt: Prompt; - seed: number; - variations: SeedWeights; - steps: number; - cfg_scale: number; - width: number; - height: number; - seamless: boolean; - hires_fix: boolean; - extra: null | Record; // Pending development of RFC #266 -}; - -// txt2img and img2img images have some unique attributes. -export type Txt2ImgMetadata = CommonGeneratedImageMetadata & { - type: 'txt2img'; -}; - -export type Img2ImgMetadata = CommonGeneratedImageMetadata & { - type: 'img2img'; - orig_hash: string; - strength: number; - fit: boolean; - init_image_path: string; - mask_image_path?: string; -}; - -// Superset of generated image metadata types. -export type GeneratedImageMetadata = Txt2ImgMetadata | Img2ImgMetadata; - -// All post processed images contain these metadata. -export type CommonPostProcessedImageMetadata = { - orig_path: string; - orig_hash: string; -}; - -// esrgan and gfpgan images have some unique attributes. -export type ESRGANMetadata = CommonPostProcessedImageMetadata & { - type: 'esrgan'; - scale: 2 | 4; - strength: number; - denoise_str: number; -}; - -export type FacetoolMetadata = CommonPostProcessedImageMetadata & { - type: 'gfpgan' | 'codeformer'; - strength: number; - fidelity?: number; -}; - -// Superset of all postprocessed image metadata types.. -export type PostProcessedImageMetadata = ESRGANMetadata | FacetoolMetadata; - -// Metadata includes the system config and image metadata. -// export type Metadata = SystemGenerationMetadata & { -// image: GeneratedImageMetadata | PostProcessedImageMetadata; +// export type Model = { +// status: ModelStatus; +// description: string; +// weights: string; +// config?: string; +// vae?: string; +// width?: number; +// height?: number; +// default?: boolean; +// format?: string; // }; -/** - * ResultImage - */ -// export ty`pe Image = { +// export type DiffusersModel = { +// status: ModelStatus; +// description: string; +// repo_id?: string; +// path?: string; +// vae?: { +// repo_id?: string; +// path?: string; +// }; +// format?: string; +// default?: boolean; +// }; + +// export type ModelList = Record; + +// export type FoundModel = { // name: string; -// type: ImageType; -// url: string; -// thumbnail: string; -// metadata: ImageResponseMetadata; +// location: string; // }; -// export const isInvokeAIImage = (obj: Image | SelectedImage): obj is Image => { -// if ('url' in obj && 'thumbnail' in obj) { -// return true; -// } - -// return false; +// export type InvokeModelConfigProps = { +// name: string | undefined; +// description: string | undefined; +// config: string | undefined; +// weights: string | undefined; +// vae: string | undefined; +// width: number | undefined; +// height: number | undefined; +// default: boolean | undefined; +// format: string | undefined; // }; -/** - * Types related to the system status. - */ - -// // This represents the processing status of the backend. -// export type SystemStatus = { -// isProcessing: boolean; -// currentStep: number; -// totalSteps: number; -// currentIteration: number; -// totalIterations: number; -// currentStatus: string; -// currentStatusHasSteps: boolean; -// hasError: boolean; +// export type InvokeDiffusersModelConfigProps = { +// name: string | undefined; +// description: string | undefined; +// repo_id: string | undefined; +// path: string | undefined; +// default: boolean | undefined; +// format: string | undefined; +// vae: { +// repo_id: string | undefined; +// path: string | undefined; +// }; // }; -// export type SystemGenerationMetadata = { -// model: string; -// model_weights?: string; -// model_id?: string; -// model_hash: string; -// app_id: string; -// app_version: string; +// export type InvokeModelConversionProps = { +// model_name: string; +// save_location: string; +// custom_location: string | null; // }; -// export type SystemConfig = SystemGenerationMetadata & { -// model_list: ModelList; -// infill_methods: string[]; +// export type InvokeModelMergingProps = { +// models_to_merge: string[]; +// alpha: number; +// interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'; +// force: boolean; +// merged_model_name: string; +// model_merge_save_path: string | null; // }; -export type ModelStatus = 'active' | 'cached' | 'not loaded'; - -export type Model = { - status: ModelStatus; - description: string; - weights: string; - config?: string; - vae?: string; - width?: number; - height?: number; - default?: boolean; - format?: string; -}; - -export type DiffusersModel = { - status: ModelStatus; - description: string; - repo_id?: string; - path?: string; - vae?: { - repo_id?: string; - path?: string; - }; - format?: string; - default?: boolean; -}; - -export type ModelList = Record; - -export type FoundModel = { - name: string; - location: string; -}; - -export type InvokeModelConfigProps = { - name: string | undefined; - description: string | undefined; - config: string | undefined; - weights: string | undefined; - vae: string | undefined; - width: number | undefined; - height: number | undefined; - default: boolean | undefined; - format: string | undefined; -}; - -export type InvokeDiffusersModelConfigProps = { - name: string | undefined; - description: string | undefined; - repo_id: string | undefined; - path: string | undefined; - default: boolean | undefined; - format: string | undefined; - vae: { - repo_id: string | undefined; - path: string | undefined; - }; -}; - -export type InvokeModelConversionProps = { - model_name: string; - save_location: string; - custom_location: string | null; -}; - -export type InvokeModelMergingProps = { - models_to_merge: string[]; - alpha: number; - interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'; - force: boolean; - merged_model_name: string; - model_merge_save_path: string | null; -}; - -/** - * These types type data received from the server via socketio. - */ - -export type ModelChangeResponse = { - model_name: string; - model_list: ModelList; -}; - -export type ModelConvertedResponse = { - converted_model_name: string; - model_list: ModelList; -}; - -export type ModelsMergedResponse = { - merged_models: string[]; - merged_model_name: string; - model_list: ModelList; -}; - -export type ModelAddedResponse = { - new_model_name: string; - model_list: ModelList; - update: boolean; -}; - -export type ModelDeletedResponse = { - deleted_model_name: string; - model_list: ModelList; -}; - -export type FoundModelResponse = { - search_folder: string; - found_models: FoundModel[]; -}; - -// export type SystemStatusResponse = SystemStatus; - -// export type SystemConfigResponse = SystemConfig; - -export type ImageResultResponse = Omit & { - boundingBox?: IRect; - generationMode: InvokeTabName; -}; - -export type ImageUploadResponse = { - // image: Omit; - url: string; - mtime: number; - width: number; - height: number; - thumbnail: string; - // bbox: [number, number, number, number]; -}; - -export type ErrorResponse = { - message: string; - additionalData?: string; -}; - -export type ImageUrlResponse = { - url: string; -}; - -export type UploadOutpaintingMergeImagePayload = { - dataURL: string; - name: string; -}; - /** * A disable-able application feature */ @@ -322,7 +88,8 @@ export type AppFeature = | 'githubLink' | 'discordLink' | 'bugLink' - | 'localization'; + | 'localization' + | 'consoleLogging'; /** * A disable-able Stable Diffusion feature @@ -351,6 +118,7 @@ export type AppConfig = { disabledSDFeatures: SDFeature[]; canRestoreDeletedImagesFromBin: boolean; sd: { + defaultModel?: string; iterations: { initial: number; min: number; diff --git a/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx b/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx index d9610346ec..6d6cdbadf5 100644 --- a/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx @@ -21,9 +21,12 @@ import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { memo } from 'react'; +export type ItemTooltips = { [key: string]: string }; + type IAICustomSelectProps = { label?: string; items: string[]; + itemTooltips?: ItemTooltips; selectedItem: string; setSelectedItem: (v: string | null | undefined) => void; withCheckIcon?: boolean; @@ -37,6 +40,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => { const { label, items, + itemTooltips, setSelectedItem, selectedItem, withCheckIcon, @@ -118,48 +122,56 @@ const IAICustomSelect = (props: IAICustomSelectProps) => { > {items.map((item, index) => ( - - {withCheckIcon ? ( - - - {selectedItem === item && } - - - - {item} - - - - ) : ( - - {item} - - )} - + + {withCheckIcon ? ( + + + {selectedItem === item && } + + + + {item} + + + + ) : ( + + {item} + + )} + + ))} diff --git a/invokeai/frontend/web/src/common/components/ImageUploadOverlay.tsx b/invokeai/frontend/web/src/common/components/ImageUploadOverlay.tsx index 28d9d32a71..862d806eb1 100644 --- a/invokeai/frontend/web/src/common/components/ImageUploadOverlay.tsx +++ b/invokeai/frontend/web/src/common/components/ImageUploadOverlay.tsx @@ -4,7 +4,6 @@ import { useHotkeys } from 'react-hotkeys-hook'; type ImageUploadOverlayProps = { isDragAccept: boolean; isDragReject: boolean; - overlaySecondaryText: string; setIsHandlingUpload: (isHandlingUpload: boolean) => void; }; @@ -12,7 +11,6 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => { const { isDragAccept, isDragReject: _isDragAccept, - overlaySecondaryText, setIsHandlingUpload, } = props; @@ -48,7 +46,7 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => { }} > {isDragAccept ? ( - Upload Image{overlaySecondaryText} + Drop to Upload ) : ( <> Invalid Upload diff --git a/invokeai/frontend/web/src/common/components/ImageUploader.tsx b/invokeai/frontend/web/src/common/components/ImageUploader.tsx index db6b9ee517..17f6d68633 100644 --- a/invokeai/frontend/web/src/common/components/ImageUploader.tsx +++ b/invokeai/frontend/web/src/common/components/ImageUploader.tsx @@ -68,13 +68,13 @@ const ImageUploader = (props: ImageUploaderProps) => { async (file: File) => { dispatch( imageUploaded({ - imageType: 'uploads', formData: { file }, - activeTabName, + imageCategory: 'user', + isIntermediate: false, }) ); }, - [dispatch, activeTabName] + [dispatch] ); const onDrop = useCallback( @@ -145,14 +145,6 @@ const ImageUploader = (props: ImageUploaderProps) => { }; }, [inputRef, open, setOpenUploaderFunction]); - const overlaySecondaryText = useMemo(() => { - if (['img2img', 'unifiedCanvas'].includes(activeTabName)) { - return ` to ${String(t(`common.${activeTabName}` as ResourceKey))}`; - } - - return ''; - }, [t, activeTabName]); - return ( { )} diff --git a/invokeai/frontend/web/src/common/util/_parseMetadataZod.ts b/invokeai/frontend/web/src/common/util/_parseMetadataZod.ts deleted file mode 100644 index 584399233f..0000000000 --- a/invokeai/frontend/web/src/common/util/_parseMetadataZod.ts +++ /dev/null @@ -1,119 +0,0 @@ -/** - * PARTIAL ZOD IMPLEMENTATION - * - * doesn't work well bc like most validators, zod is not built to skip invalid values. - * it mostly works but just seems clearer and simpler to manually parse for now. - * - * in the future it would be really nice if we could use zod for some things: - * - zodios (axios + zod): https://github.com/ecyrbe/zodios - * - openapi to zodios: https://github.com/astahmer/openapi-zod-client - */ - -// import { z } from 'zod'; - -// const zMetadataStringField = z.string(); -// export type MetadataStringField = z.infer; - -// const zMetadataIntegerField = z.number().int(); -// export type MetadataIntegerField = z.infer; - -// const zMetadataFloatField = z.number(); -// export type MetadataFloatField = z.infer; - -// const zMetadataBooleanField = z.boolean(); -// export type MetadataBooleanField = z.infer; - -// const zMetadataImageField = z.object({ -// image_type: z.union([ -// z.literal('results'), -// z.literal('uploads'), -// z.literal('intermediates'), -// ]), -// image_name: z.string().min(1), -// }); -// export type MetadataImageField = z.infer; - -// const zMetadataLatentsField = z.object({ -// latents_name: z.string().min(1), -// }); -// export type MetadataLatentsField = z.infer; - -// /** -// * zod Schema for any node field. Use a `transform()` to manually parse, skipping invalid values. -// */ -// const zAnyMetadataField = z.any().transform((val, ctx) => { -// // Grab the field name from the path -// const fieldName = String(ctx.path[ctx.path.length - 1]); - -// // `id` and `type` must be strings if they exist -// if (['id', 'type'].includes(fieldName)) { -// const reservedStringPropertyResult = zMetadataStringField.safeParse(val); -// if (reservedStringPropertyResult.success) { -// return reservedStringPropertyResult.data; -// } - -// return; -// } - -// // Parse the rest of the fields, only returning the data if the parsing is successful - -// const stringFieldResult = zMetadataStringField.safeParse(val); -// if (stringFieldResult.success) { -// return stringFieldResult.data; -// } - -// const integerFieldResult = zMetadataIntegerField.safeParse(val); -// if (integerFieldResult.success) { -// return integerFieldResult.data; -// } - -// const floatFieldResult = zMetadataFloatField.safeParse(val); -// if (floatFieldResult.success) { -// return floatFieldResult.data; -// } - -// const booleanFieldResult = zMetadataBooleanField.safeParse(val); -// if (booleanFieldResult.success) { -// return booleanFieldResult.data; -// } - -// const imageFieldResult = zMetadataImageField.safeParse(val); -// if (imageFieldResult.success) { -// return imageFieldResult.data; -// } - -// const latentsFieldResult = zMetadataImageField.safeParse(val); -// if (latentsFieldResult.success) { -// return latentsFieldResult.data; -// } -// }); - -// /** -// * The node metadata schema. -// */ -// const zNodeMetadata = z.object({ -// session_id: z.string().min(1).optional(), -// node: z.record(z.string().min(1), zAnyMetadataField).optional(), -// }); - -// export type NodeMetadata = z.infer; - -// const zMetadata = z.object({ -// invokeai: zNodeMetadata.optional(), -// 'sd-metadata': z.record(z.string().min(1), z.any()).optional(), -// }); -// export type Metadata = z.infer; - -// export const parseMetadata = ( -// metadata: Record -// ): Metadata | undefined => { -// const result = zMetadata.safeParse(metadata); -// if (!result.success) { -// console.log(result.error.issues); -// return; -// } - -// return result.data; -// }; - -export default {}; diff --git a/invokeai/frontend/web/src/common/util/parseMetadata.ts b/invokeai/frontend/web/src/common/util/parseMetadata.ts deleted file mode 100644 index f1828d95e7..0000000000 --- a/invokeai/frontend/web/src/common/util/parseMetadata.ts +++ /dev/null @@ -1,334 +0,0 @@ -import { forEach, size } from 'lodash-es'; -import { - ImageField, - LatentsField, - ConditioningField, - UNetField, - ClipField, - VaeField, -} from 'services/api'; - -const OBJECT_TYPESTRING = '[object Object]'; -const STRING_TYPESTRING = '[object String]'; -const NUMBER_TYPESTRING = '[object Number]'; -const BOOLEAN_TYPESTRING = '[object Boolean]'; -const ARRAY_TYPESTRING = '[object Array]'; - -const isObject = (obj: unknown): obj is Record => - Object.prototype.toString.call(obj) === OBJECT_TYPESTRING; - -const isString = (obj: unknown): obj is string => - Object.prototype.toString.call(obj) === STRING_TYPESTRING; - -const isNumber = (obj: unknown): obj is number => - Object.prototype.toString.call(obj) === NUMBER_TYPESTRING; - -const isBoolean = (obj: unknown): obj is boolean => - Object.prototype.toString.call(obj) === BOOLEAN_TYPESTRING; - -const isArray = (obj: unknown): obj is Array => - Object.prototype.toString.call(obj) === ARRAY_TYPESTRING; - -const parseImageField = (imageField: unknown): ImageField | undefined => { - // Must be an object - if (!isObject(imageField)) { - return; - } - - // An ImageField must have both `image_name` and `image_type` - if (!('image_name' in imageField && 'image_type' in imageField)) { - return; - } - - // An ImageField's `image_type` must be one of the allowed values - if ( - !['results', 'uploads', 'intermediates'].includes(imageField.image_type) - ) { - return; - } - - // An ImageField's `image_name` must be a string - if (typeof imageField.image_name !== 'string') { - return; - } - - // Build a valid ImageField - return { - image_type: imageField.image_type, - image_name: imageField.image_name, - }; -}; - -const parseLatentsField = (latentsField: unknown): LatentsField | undefined => { - // Must be an object - if (!isObject(latentsField)) { - return; - } - - // A LatentsField must have a `latents_name` - if (!('latents_name' in latentsField)) { - return; - } - - // A LatentsField's `latents_name` must be a string - if (typeof latentsField.latents_name !== 'string') { - return; - } - - // Build a valid LatentsField - return { - latents_name: latentsField.latents_name, - }; -}; - -const parseConditioningField = ( - conditioningField: unknown -): ConditioningField | undefined => { - // Must be an object - if (!isObject(conditioningField)) { - return; - } - - // A ConditioningField must have a `conditioning_name` - if (!('conditioning_name' in conditioningField)) { - return; - } - - // A ConditioningField's `conditioning_name` must be a string - if (typeof conditioningField.conditioning_name !== 'string') { - return; - } - - // Build a valid ConditioningField - return { - conditioning_name: conditioningField.conditioning_name, - }; -}; - -const _parseModelInfo = (modelInfo: unknown): ModelInfo | undefined => { - // Must be an object - if (!isObject(modelInfo)) { - return; - } - - if (!('model_name' in modelInfo && typeof modelInfo.model_name == 'string')) { - return; - } - - if (!('model_type' in modelInfo && typeof modelInfo.model_type == 'string')) { - return; - } - - if (!('submodel' in modelInfo && typeof modelInfo.submodel == 'string')) { - return; - } - - return { - model_name: modelInfo.model_name, - model_type: modelInfo.model_type, - submodel: modelInfo.submodel, - }; -}; - -const parseUNetField = (unetField: unknown): UNetField | undefined => { - // Must be an object - if (!isObject(unetField)) { - return; - } - - if (!('unet' in unetField && 'scheduler' in unetField)) { - return; - } - - const unet = _parseModelInfo(unetField.unet); - const scheduler = _parseModelInfo(unetField.scheduler); - - if (!(unet && scheduler)) { - return; - } - - // Build a valid UNetField - return { - unet: unet, - scheduler: scheduler, - }; -}; - -const parseClipField = (clipField: unknown): ClipField | undefined => { - // Must be an object - if (!isObject(clipField)) { - return; - } - - if (!('tokenizer' in clipField && 'text_encoder' in clipField)) { - return; - } - - const tokenizer = _parseModelInfo(clipField.tokenizer); - const text_encoder = _parseModelInfo(clipField.text_encoder); - - if (!(tokenizer && text_encoder)) { - return; - } - - // Build a valid ClipField - return { - tokenizer: tokenizer, - text_encoder: text_encoder, - }; -}; - -const parseVaeField = (vaeField: unknown): VaeField | undefined => { - // Must be an object - if (!isObject(vaeField)) { - return; - } - - if (!('vae' in vaeField)) { - return; - } - - const vae = _parseModelInfo(vaeField.vae); - - if (!vae) { - return; - } - - // Build a valid VaeField - return { - vae: vae, - }; -}; - -type NodeMetadata = { - [key: string]: - | string - | number - | boolean - | ImageField - | LatentsField - | ConditioningField - | UNetField - | ClipField - | VaeField; -}; - -type InvokeAIMetadata = { - session_id?: string; - node?: NodeMetadata; -}; - -export const parseNodeMetadata = ( - nodeMetadata: Record -): NodeMetadata | undefined => { - if (!isObject(nodeMetadata)) { - return; - } - - const parsed: NodeMetadata = {}; - - forEach(nodeMetadata, (nodeItem, nodeKey) => { - // `id` and `type` must be strings if they are present - if (['id', 'type'].includes(nodeKey)) { - if (isString(nodeItem)) { - parsed[nodeKey] = nodeItem; - } - return; - } - - // valid object types are: - // ImageField, LatentsField ConditioningField, UNetField, ClipField, VaeField - if (isObject(nodeItem)) { - if ('image_name' in nodeItem || 'image_type' in nodeItem) { - const imageField = parseImageField(nodeItem); - if (imageField) { - parsed[nodeKey] = imageField; - } - return; - } - - if ('latents_name' in nodeItem) { - const latentsField = parseLatentsField(nodeItem); - if (latentsField) { - parsed[nodeKey] = latentsField; - } - return; - } - - if ('conditioning_name' in nodeItem) { - const conditioningField = parseConditioningField(nodeItem); - if (conditioningField) { - parsed[nodeKey] = conditioningField; - } - return; - } - - if ('unet' in nodeItem && 'scheduler' in nodeItem) { - const unetField = parseUNetField(nodeItem); - if (unetField) { - parsed[nodeKey] = unetField; - } - } - - if ('tokenizer' in nodeItem && 'text_encoder' in nodeItem) { - const clipField = parseClipField(nodeItem); - if (clipField) { - parsed[nodeKey] = clipField; - } - } - - if ('vae' in nodeItem) { - const vaeField = parseVaeField(nodeItem); - if (vaeField) { - parsed[nodeKey] = vaeField; - } - } - } - - // otherwise we accept any string, number or boolean - if (isString(nodeItem) || isNumber(nodeItem) || isBoolean(nodeItem)) { - parsed[nodeKey] = nodeItem; - return; - } - }); - - if (size(parsed) === 0) { - return; - } - - return parsed; -}; - -export const parseInvokeAIMetadata = ( - metadata: Record | undefined -): InvokeAIMetadata | undefined => { - if (metadata === undefined) { - return; - } - - if (!isObject(metadata)) { - return; - } - - const parsed: InvokeAIMetadata = {}; - - forEach(metadata, (item, key) => { - if (key === 'session_id' && isString(item)) { - parsed['session_id'] = item; - } - - if (key === 'node' && isObject(item)) { - const nodeMetadata = parseNodeMetadata(item); - - if (nodeMetadata) { - parsed['node'] = nodeMetadata; - } - } - }); - - if (size(parsed) === 0) { - return; - } - - return parsed; -}; diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx index 745825a975..ea5e9a6486 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasIntermediateImage.tsx @@ -1,18 +1,24 @@ import { createSelector } from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; -import { useGetUrl } from 'common/util/getUrl'; -import { GalleryState } from 'features/gallery/store/gallerySlice'; +import { systemSelector } from 'features/system/store/systemSelectors'; import { ImageConfig } from 'konva/lib/shapes/Image'; import { isEqual } from 'lodash-es'; import { useEffect, useState } from 'react'; import { Image as KonvaImage } from 'react-konva'; +import { canvasSelector } from '../store/canvasSelectors'; const selector = createSelector( - [(state: RootState) => state.gallery], - (gallery: GalleryState) => { - return gallery.intermediateImage ? gallery.intermediateImage : null; + [systemSelector, canvasSelector], + (system, canvas) => { + const { progressImage, sessionId } = system; + const { sessionId: canvasSessionId, boundingBox } = + canvas.layerState.stagingArea; + + return { + boundingBox, + progressImage: sessionId === canvasSessionId ? progressImage : undefined, + }; }, { memoizeOptions: { @@ -25,33 +31,34 @@ type Props = Omit; const IAICanvasIntermediateImage = (props: Props) => { const { ...rest } = props; - const intermediateImage = useAppSelector(selector); - const { getUrl } = useGetUrl(); + const { progressImage, boundingBox } = useAppSelector(selector); const [loadedImageElement, setLoadedImageElement] = useState(null); useEffect(() => { - if (!intermediateImage) return; + if (!progressImage) { + return; + } + const tempImage = new Image(); tempImage.onload = () => { setLoadedImageElement(tempImage); }; - tempImage.src = getUrl(intermediateImage.url); - }, [intermediateImage, getUrl]); - if (!intermediateImage?.boundingBox) return null; + tempImage.src = progressImage.dataURL; + }, [progressImage]); - const { - boundingBox: { x, y, width, height }, - } = intermediateImage; + if (!(progressImage && boundingBox)) { + return null; + } return loadedImageElement ? ( { {shouldShowStagingImage && currentStagingAreaImage && ( diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx index 64c752fce0..76ffdcf082 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx @@ -1,6 +1,5 @@ import { ButtonGroup, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; -// import { saveStagingAreaImageToGallery } from 'app/socketio/actions'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIIconButton from 'common/components/IAIIconButton'; import { canvasSelector } from 'features/canvas/store/canvasSelectors'; @@ -26,13 +25,14 @@ import { FaPlus, FaSave, } from 'react-icons/fa'; +import { stagingAreaImageSaved } from '../store/actions'; const selector = createSelector( [canvasSelector], (canvas) => { const { layerState: { - stagingArea: { images, selectedImageIndex }, + stagingArea: { images, selectedImageIndex, sessionId }, }, shouldShowStagingOutline, shouldShowStagingImage, @@ -45,6 +45,7 @@ const selector = createSelector( isOnLastImage: selectedImageIndex === images.length - 1, shouldShowStagingImage, shouldShowStagingOutline, + sessionId, }; }, { @@ -61,6 +62,7 @@ const IAICanvasStagingAreaToolbar = () => { isOnLastImage, currentStagingAreaImage, shouldShowStagingImage, + sessionId, } = useAppSelector(selector); const { t } = useTranslation(); @@ -106,9 +108,20 @@ const IAICanvasStagingAreaToolbar = () => { } ); - const handlePrevImage = () => dispatch(prevStagingAreaImage()); - const handleNextImage = () => dispatch(nextStagingAreaImage()); - const handleAccept = () => dispatch(commitStagingAreaImage()); + const handlePrevImage = useCallback( + () => dispatch(prevStagingAreaImage()), + [dispatch] + ); + + const handleNextImage = useCallback( + () => dispatch(nextStagingAreaImage()), + [dispatch] + ); + + const handleAccept = useCallback( + () => dispatch(commitStagingAreaImage(sessionId)), + [dispatch, sessionId] + ); if (!currentStagingAreaImage) return null; @@ -157,19 +170,15 @@ const IAICanvasStagingAreaToolbar = () => { } colorScheme="accent" /> - {/* } onClick={() => - dispatch( - saveStagingAreaImageToGallery( - currentStagingAreaImage.image.image_url - ) - ) + dispatch(stagingAreaImageSaved(currentStagingAreaImage.image)) } colorScheme="accent" - /> */} + /> ( + 'canvas/stagingAreaImageSaved' +); diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts index 0ebe5b264c..7f41066ba1 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts @@ -29,6 +29,7 @@ import { isCanvasMaskLine, } from './canvasTypes'; import { ImageDTO } from 'services/api'; +import { sessionCanceled } from 'services/thunks/session'; export const initialLayerState: CanvasLayerState = { objects: [], @@ -696,7 +697,10 @@ export const canvasSlice = createSlice({ 0 ); }, - commitStagingAreaImage: (state) => { + commitStagingAreaImage: ( + state, + action: PayloadAction + ) => { if (!state.layerState.stagingArea.images.length) { return; } @@ -841,6 +845,13 @@ export const canvasSlice = createSlice({ state.isTransformingBoundingBox = false; }, }, + extraReducers: (builder) => { + builder.addCase(sessionCanceled.pending, (state) => { + if (!state.layerState.stagingArea.images.length) { + state.layerState.stagingArea = initialLayerState.stagingArea; + } + }); + }, }); export const { diff --git a/invokeai/frontend/web/src/features/canvas/util/createMaskStage.ts b/invokeai/frontend/web/src/features/canvas/util/createMaskStage.ts index 96ac592711..b417b3a786 100644 --- a/invokeai/frontend/web/src/features/canvas/util/createMaskStage.ts +++ b/invokeai/frontend/web/src/features/canvas/util/createMaskStage.ts @@ -9,7 +9,8 @@ import { IRect } from 'konva/lib/types'; */ const createMaskStage = async ( lines: CanvasMaskLine[], - boundingBox: IRect + boundingBox: IRect, + shouldInvertMask: boolean ): Promise => { // create an offscreen canvas and add the mask to it const { width, height } = boundingBox; @@ -29,7 +30,7 @@ const createMaskStage = async ( baseLayer.add( new Konva.Rect({ ...boundingBox, - fill: 'white', + fill: shouldInvertMask ? 'black' : 'white', }) ); @@ -37,7 +38,7 @@ const createMaskStage = async ( maskLayer.add( new Konva.Line({ points: line.points, - stroke: 'black', + stroke: shouldInvertMask ? 'white' : 'black', strokeWidth: line.strokeWidth * 2, tension: 0, lineCap: 'round', diff --git a/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts b/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts index 21a33aa349..d0190878e2 100644 --- a/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts +++ b/invokeai/frontend/web/src/features/canvas/util/getCanvasData.ts @@ -25,6 +25,7 @@ export const getCanvasData = async (state: RootState) => { boundingBoxCoordinates, boundingBoxDimensions, isMaskEnabled, + shouldPreserveMaskedArea, } = state.canvas; const boundingBox = { @@ -58,7 +59,8 @@ export const getCanvasData = async (state: RootState) => { // For the mask layer, use the normal boundingBox const maskStage = await createMaskStage( isMaskEnabled ? objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled - boundingBox + boundingBox, + shouldPreserveMaskedArea ); const maskBlob = await konvaNodeToBlob(maskStage, boundingBox); const maskImageData = await konvaNodeToImageData(maskStage, boundingBox); diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx index c19a404a37..91bd1a0425 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx @@ -1,5 +1,5 @@ import { createSelector } from '@reduxjs/toolkit'; -import { isEqual, isString } from 'lodash-es'; +import { isEqual } from 'lodash-es'; import { ButtonGroup, @@ -25,8 +25,8 @@ import { } from 'features/ui/store/uiSelectors'; import { setActiveTab, - setShouldHidePreview, setShouldShowImageDetails, + setShouldShowProgressInViewer, } from 'features/ui/store/uiSlice'; import { useHotkeys } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; @@ -37,23 +37,19 @@ import { FaDownload, FaExpand, FaExpandArrowsAlt, - FaEye, - FaEyeSlash, FaGrinStars, + FaHourglassHalf, FaQuoteRight, FaSeedling, FaShare, FaShareAlt, - FaTrash, - FaWrench, } from 'react-icons/fa'; import { gallerySelector } from '../store/gallerySelectors'; -import DeleteImageModal from './DeleteImageModal'; import { useCallback } from 'react'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { useGetUrl } from 'common/util/getUrl'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; -import { useParameters } from 'features/parameters/hooks/useParameters'; +import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { initialImageSelected } from 'features/parameters/store/actions'; import { requestedImageDeletion, @@ -62,7 +58,6 @@ import { } from '../store/actions'; import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceRestore/FaceRestoreSettings'; import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings'; -import { allParametersSet } from 'features/parameters/store/generationSlice'; import DeleteImageButton from './ImageActionButtons/DeleteImageButton'; import { useAppToaster } from 'app/components/Toaster'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; @@ -90,7 +85,11 @@ const currentImageButtonsSelector = createSelector( const { isLightboxOpen } = lightbox; - const { shouldShowImageDetails, shouldHidePreview } = ui; + const { + shouldShowImageDetails, + shouldHidePreview, + shouldShowProgressInViewer, + } = ui; const { selectedImage } = gallery; @@ -112,6 +111,7 @@ const currentImageButtonsSelector = createSelector( seed: selectedImage?.metadata?.seed, prompt: selectedImage?.metadata?.positive_conditioning, negativePrompt: selectedImage?.metadata?.negative_conditioning, + shouldShowProgressInViewer, }; }, { @@ -145,6 +145,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { image, canDeleteImage, shouldConfirmOnDelete, + shouldShowProgressInViewer, } = useAppSelector(currentImageButtonsSelector); const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled; @@ -163,7 +164,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { const toaster = useAppToaster(); const { t } = useTranslation(); - const { recallPrompt, recallSeed, recallAllParameters } = useParameters(); + const { recallBothPrompts, recallSeed, recallAllParameters } = + useRecallParameters(); // const handleCopyImage = useCallback(async () => { // if (!image?.url) { @@ -229,10 +231,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { }); }, [toaster, shouldTransformUrls, getUrl, t, image]); - const handlePreviewVisibility = useCallback(() => { - dispatch(setShouldHidePreview(!shouldHidePreview)); - }, [dispatch, shouldHidePreview]); - const handleClickUseAllParameters = useCallback(() => { recallAllParameters(image); }, [image, recallAllParameters]); @@ -252,11 +250,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { useHotkeys('s', handleUseSeed, [image]); const handleUsePrompt = useCallback(() => { - recallPrompt( + recallBothPrompts( image?.metadata?.positive_conditioning, image?.metadata?.negative_conditioning ); - }, [image, recallPrompt]); + }, [image, recallBothPrompts]); useHotkeys('p', handleUsePrompt, [image]); @@ -386,6 +384,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { } }, [shouldConfirmOnDelete, onDeleteDialogOpen, handleDelete]); + const handleClickProgressImagesToggle = useCallback(() => { + dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer)); + }, [dispatch, shouldShowProgressInViewer]); + useHotkeys('delete', handleInitiateDelete, [ image, shouldConfirmOnDelete, @@ -412,8 +414,9 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { } /> } @@ -458,28 +461,17 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { {t('parameters.copyImageToLink')} - + } size="sm" w="100%"> {t('parameters.downloadImage')} - {/* : } - tooltip={ - !shouldHidePreview - ? t('parameters.hidePreview') - : t('parameters.showPreview') - } - aria-label={ - !shouldHidePreview - ? t('parameters.hidePreview') - : t('parameters.showPreview') - } - isChecked={shouldHidePreview} - onClick={handlePreviewVisibility} - /> */} {isLightboxEnabled && ( } @@ -604,6 +596,16 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => { /> + + } + isChecked={shouldShowProgressInViewer} + onClick={handleClickProgressImagesToggle} + /> + + diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx index 4562e3458d..280d859b87 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx @@ -62,7 +62,6 @@ const CurrentImagePreview = () => { return; } e.dataTransfer.setData('invokeai/imageName', image.image_name); - e.dataTransfer.setData('invokeai/imageType', image.image_type); e.dataTransfer.effectAllowed = 'move'; }, [image] diff --git a/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx b/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx index ed427f4984..f652cebda2 100644 --- a/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/HoverableImage.tsx @@ -30,7 +30,7 @@ import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { isEqual } from 'lodash-es'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; -import { useParameters } from 'features/parameters/hooks/useParameters'; +import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { initialImageSelected } from 'features/parameters/store/actions'; import { requestedImageDeletion, @@ -114,8 +114,8 @@ const HoverableImage = memo((props: HoverableImageProps) => { const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled; const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled; - const { recallSeed, recallPrompt, recallInitialImage, recallAllParameters } = - useParameters(); + const { recallBothPrompts, recallSeed, recallAllParameters } = + useRecallParameters(); const handleMouseOver = () => setIsHovered(true); const handleMouseOut = () => setIsHovered(false); @@ -147,7 +147,6 @@ const HoverableImage = memo((props: HoverableImageProps) => { const handleDragStart = useCallback( (e: DragEvent) => { e.dataTransfer.setData('invokeai/imageName', image.image_name); - e.dataTransfer.setData('invokeai/imageType', image.image_type); e.dataTransfer.effectAllowed = 'move'; }, [image] @@ -155,11 +154,15 @@ const HoverableImage = memo((props: HoverableImageProps) => { // Recall parameters handlers const handleRecallPrompt = useCallback(() => { - recallPrompt( + recallBothPrompts( image.metadata?.positive_conditioning, image.metadata?.negative_conditioning ); - }, [image, recallPrompt]); + }, [ + image.metadata?.negative_conditioning, + image.metadata?.positive_conditioning, + recallBothPrompts, + ]); const handleRecallSeed = useCallback(() => { recallSeed(image.metadata?.seed); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx index 468dfd694f..77f42a11a6 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx @@ -16,7 +16,6 @@ import IAIPopover from 'common/components/IAIPopover'; import IAISlider from 'common/components/IAISlider'; import { gallerySelector } from 'features/gallery/store/gallerySelectors'; import { - setCurrentCategory, setGalleryImageMinimumWidth, setGalleryImageObjectFit, setShouldAutoSwitchToNewImages, @@ -31,59 +30,46 @@ import { memo, useCallback, useEffect, + useMemo, useRef, useState, } from 'react'; import { useTranslation } from 'react-i18next'; import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs'; -import { FaImage, FaUser, FaWrench } from 'react-icons/fa'; +import { FaImage, FaServer, FaWrench } from 'react-icons/fa'; import { MdPhotoLibrary } from 'react-icons/md'; import HoverableImage from './HoverableImage'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; -import { resultsAdapter } from '../store/resultsSlice'; -import { - receivedResultImagesPage, - receivedUploadImagesPage, -} from 'services/thunks/gallery'; -import { uploadsAdapter } from '../store/uploadsSlice'; import { createSelector } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; import { Virtuoso, VirtuosoGrid } from 'react-virtuoso'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import GalleryProgressImage from './GalleryProgressImage'; import { uiSelector } from 'features/ui/store/uiSelectors'; -import { ImageDTO } from 'services/api'; - -const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290; -const PROGRESS_IMAGE_PLACEHOLDER = 'PROGRESS_IMAGE_PLACEHOLDER'; +import { + ASSETS_CATEGORIES, + IMAGE_CATEGORIES, + imageCategoriesChanged, + selectImagesAll, +} from '../store/imagesSlice'; +import { receivedPageOfImages } from 'services/thunks/image'; const categorySelector = createSelector( [(state: RootState) => state], (state) => { - const { results, uploads, system, gallery } = state; - const { currentCategory } = gallery; + const { images } = state; + const { categories } = images; - if (currentCategory === 'results') { - const tempImages: (ImageDTO | typeof PROGRESS_IMAGE_PLACEHOLDER)[] = []; - - if (system.progressImage) { - tempImages.push(PROGRESS_IMAGE_PLACEHOLDER); - } - - return { - images: tempImages.concat( - resultsAdapter.getSelectors().selectAll(results) - ), - isLoading: results.isLoading, - areMoreImagesAvailable: results.page < results.pages - 1, - }; - } + const allImages = selectImagesAll(state); + const filteredImages = allImages.filter((i) => + categories.includes(i.image_category) + ); return { - images: uploadsAdapter.getSelectors().selectAll(uploads), - isLoading: uploads.isLoading, - areMoreImagesAvailable: uploads.page < uploads.pages - 1, + images: filteredImages, + isLoading: images.isLoading, + areMoreImagesAvailable: filteredImages.length < images.total, + categories: images.categories, }; }, defaultSelectorOptions @@ -93,7 +79,6 @@ const mainSelector = createSelector( [gallerySelector, uiSelector], (gallery, ui) => { const { - currentCategory, galleryImageMinimumWidth, galleryImageObjectFit, shouldAutoSwitchToNewImages, @@ -104,7 +89,6 @@ const mainSelector = createSelector( const { shouldPinGallery } = ui; return { - currentCategory, shouldPinGallery, galleryImageMinimumWidth, galleryImageObjectFit, @@ -120,7 +104,6 @@ const ImageGalleryContent = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); const resizeObserverRef = useRef(null); - const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true); const rootRef = useRef(null); const [scroller, setScroller] = useState(null); const [initialize, osInstance] = useOverlayScrollbars({ @@ -137,7 +120,6 @@ const ImageGalleryContent = () => { }); const { - currentCategory, shouldPinGallery, galleryImageMinimumWidth, galleryImageObjectFit, @@ -146,18 +128,19 @@ const ImageGalleryContent = () => { selectedImage, } = useAppSelector(mainSelector); - const { images, areMoreImagesAvailable, isLoading } = + const { images, areMoreImagesAvailable, isLoading, categories } = useAppSelector(categorySelector); - const handleClickLoadMore = () => { - if (currentCategory === 'results') { - dispatch(receivedResultImagesPage()); - } + const handleLoadMoreImages = useCallback(() => { + dispatch(receivedPageOfImages()); + }, [dispatch]); - if (currentCategory === 'uploads') { - dispatch(receivedUploadImagesPage()); + const handleEndReached = useMemo(() => { + if (areMoreImagesAvailable && !isLoading) { + return handleLoadMoreImages; } - }; + return undefined; + }, [areMoreImagesAvailable, handleLoadMoreImages, isLoading]); const handleChangeGalleryImageMinimumWidth = (v: number) => { dispatch(setGalleryImageMinimumWidth(v)); @@ -168,28 +151,6 @@ const ImageGalleryContent = () => { dispatch(requestCanvasRescale()); }; - useEffect(() => { - if (!resizeObserverRef.current) { - return; - } - const resizeObserver = new ResizeObserver(() => { - if (!resizeObserverRef.current) { - return; - } - - if ( - resizeObserverRef.current.clientWidth < GALLERY_SHOW_BUTTONS_MIN_WIDTH - ) { - setShouldShouldIconButtons(true); - return; - } - - setShouldShouldIconButtons(false); - }); - resizeObserver.observe(resizeObserverRef.current); - return () => resizeObserver.disconnect(); // clean up - }, []); - useEffect(() => { const { current: root } = rootRef; if (scroller && root) { @@ -209,13 +170,13 @@ const ImageGalleryContent = () => { } }, []); - const handleEndReached = useCallback(() => { - if (currentCategory === 'results') { - dispatch(receivedResultImagesPage()); - } else if (currentCategory === 'uploads') { - dispatch(receivedUploadImagesPage()); - } - }, [dispatch, currentCategory]); + const handleClickImagesCategory = useCallback(() => { + dispatch(imageCategoriesChanged(IMAGE_CATEGORIES)); + }, [dispatch]); + + const handleClickAssetsCategory = useCallback(() => { + dispatch(imageCategoriesChanged(ASSETS_CATEGORIES)); + }, [dispatch]); return ( { alignItems="center" justifyContent="space-between" > - - {shouldShouldIconButtons ? ( - <> - } - onClick={() => dispatch(setCurrentCategory('results'))} - /> - } - onClick={() => dispatch(setCurrentCategory('uploads'))} - /> - - ) : ( - <> - dispatch(setCurrentCategory('results'))} - flexGrow={1} - > - {t('gallery.generations')} - - dispatch(setCurrentCategory('uploads'))} - flexGrow={1} - > - {t('gallery.uploads')} - - - )} + + } + /> + } + /> - } /> } @@ -347,28 +280,17 @@ const ImageGalleryContent = () => { data={images} endReached={handleEndReached} scrollerRef={(ref) => setScrollerRef(ref)} - itemContent={(index, image) => { - const isSelected = - image === PROGRESS_IMAGE_PLACEHOLDER - ? false - : selectedImage?.image_name === image?.image_name; - - return ( - - {image === PROGRESS_IMAGE_PLACEHOLDER ? ( - - ) : ( - - )} - - ); - }} + itemContent={(index, image) => ( + + + + )} /> ) : ( { List: ListContainer, }} scrollerRef={setScroller} - itemContent={(index, image) => { - const isSelected = - image === PROGRESS_IMAGE_PLACEHOLDER - ? false - : selectedImage?.image_name === image?.image_name; - - return image === PROGRESS_IMAGE_PLACEHOLDER ? ( - - ) : ( - - ); - }} + itemContent={(index, image) => ( + + )} /> )} { const { t } = useTranslation(); + + if (!value) { + return null; + } + return ( {onClick && ( @@ -115,6 +121,21 @@ const memoEqualityCheck = ( */ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => { const dispatch = useAppDispatch(); + const { + recallBothPrompts, + recallPositivePrompt, + recallNegativePrompt, + recallSeed, + recallInitialImage, + recallCfgScale, + recallModel, + recallScheduler, + recallSteps, + recallWidth, + recallHeight, + recallStrength, + recallAllParameters, + } = useRecallParameters(); useHotkeys('esc', () => { dispatch(setShouldShowImageDetails(false)); @@ -161,52 +182,53 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => { {metadata.type && ( )} - {metadata.width && ( - dispatch(setWidth(Number(metadata.width)))} - /> - )} - {metadata.height && ( - dispatch(setHeight(Number(metadata.height)))} - /> - )} - {metadata.model && ( - - )} + {sessionId && } {metadata.positive_conditioning && ( + recallPositivePrompt(metadata.positive_conditioning) } - onClick={() => setPositivePrompt(metadata.positive_conditioning!)} /> )} {metadata.negative_conditioning && ( + recallNegativePrompt(metadata.negative_conditioning) } - onClick={() => setNegativePrompt(metadata.negative_conditioning!)} /> )} {metadata.seed !== undefined && ( dispatch(setSeed(Number(metadata.seed)))} + onClick={() => recallSeed(metadata.seed)} + /> + )} + {metadata.model !== undefined && ( + recallModel(metadata.model)} + /> + )} + {metadata.width && ( + recallWidth(metadata.width)} + /> + )} + {metadata.height && ( + recallHeight(metadata.height)} /> )} {/* {metadata.threshold !== undefined && ( @@ -227,23 +249,21 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => { - dispatch(setScheduler(metadata.scheduler as Scheduler)) - } + onClick={() => recallScheduler(metadata.scheduler)} /> )} {metadata.steps && ( dispatch(setSteps(Number(metadata.steps)))} + onClick={() => recallSteps(metadata.steps)} /> )} {metadata.cfg_scale !== undefined && ( dispatch(setCfgScale(Number(metadata.cfg_scale)))} + onClick={() => recallCfgScale(metadata.cfg_scale)} /> )} {/* {metadata.variations && metadata.variations.length > 0 && ( @@ -284,9 +304,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => { - dispatch(setImg2imgStrength(Number(metadata.strength))) - } + onClick={() => recallStrength(metadata.strength)} /> )} {/* {metadata.fit && ( diff --git a/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx index fcf8359187..82e7a0d623 100644 --- a/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx @@ -9,6 +9,10 @@ import { gallerySelector } from '../store/gallerySelectors'; import { RootState } from 'app/store/store'; import { imageSelected } from '../store/gallerySlice'; import { useHotkeys } from 'react-hotkeys-hook'; +import { + selectFilteredImagesAsObject, + selectFilteredImagesIds, +} from '../store/imagesSlice'; const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = { height: '100%', @@ -21,9 +25,14 @@ const nextPrevButtonStyles: ChakraProps['sx'] = { }; export const nextPrevImageButtonsSelector = createSelector( - [(state: RootState) => state, gallerySelector], - (state, gallery) => { - const { selectedImage, currentCategory } = gallery; + [ + (state: RootState) => state, + gallerySelector, + selectFilteredImagesAsObject, + selectFilteredImagesIds, + ], + (state, gallery, filteredImagesAsObject, filteredImageIds) => { + const { selectedImage } = gallery; if (!selectedImage) { return { @@ -32,29 +41,29 @@ export const nextPrevImageButtonsSelector = createSelector( }; } - const currentImageIndex = state[currentCategory].ids.findIndex( + const currentImageIndex = filteredImageIds.findIndex( (i) => i === selectedImage.image_name ); const nextImageIndex = clamp( currentImageIndex + 1, 0, - state[currentCategory].ids.length - 1 + filteredImageIds.length - 1 ); const prevImageIndex = clamp( currentImageIndex - 1, 0, - state[currentCategory].ids.length - 1 + filteredImageIds.length - 1 ); - const nextImageId = state[currentCategory].ids[nextImageIndex]; - const prevImageId = state[currentCategory].ids[prevImageIndex]; + const nextImageId = filteredImageIds[nextImageIndex]; + const prevImageId = filteredImageIds[prevImageIndex]; - const nextImage = state[currentCategory].entities[nextImageId]; - const prevImage = state[currentCategory].entities[prevImageId]; + const nextImage = filteredImagesAsObject[nextImageId]; + const prevImage = filteredImagesAsObject[prevImageId]; - const imagesLength = state[currentCategory].ids.length; + const imagesLength = filteredImageIds.length; return { isOnFirstImage: currentImageIndex === 0, diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts b/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts index ad0870e7a4..89709b322a 100644 --- a/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts +++ b/invokeai/frontend/web/src/features/gallery/hooks/useGetImageByName.ts @@ -1,33 +1,18 @@ -import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; -import { ImageType } from 'services/api'; -import { selectResultsEntities } from '../store/resultsSlice'; -import { selectUploadsEntities } from '../store/uploadsSlice'; +import { selectImagesEntities } from '../store/imagesSlice'; +import { useCallback } from 'react'; -const useGetImageByNameSelector = createSelector( - [selectResultsEntities, selectUploadsEntities], - (allResults, allUploads) => { - return { allResults, allUploads }; - } -); - -const useGetImageByNameAndType = () => { - const { allResults, allUploads } = useAppSelector(useGetImageByNameSelector); - return (name: string, type: ImageType) => { - if (type === 'results') { - const resultImagesResult = allResults[name]; - if (resultImagesResult) { - return resultImagesResult; +const useGetImageByName = () => { + const images = useAppSelector(selectImagesEntities); + return useCallback( + (name: string | undefined) => { + if (!name) { + return; } - } - - if (type === 'uploads') { - const userImagesResult = allUploads[name]; - if (userImagesResult) { - return userImagesResult; - } - } - }; + return images[name]; + }, + [images] + ); }; -export default useGetImageByNameAndType; +export default useGetImageByName; diff --git a/invokeai/frontend/web/src/features/gallery/store/actions.ts b/invokeai/frontend/web/src/features/gallery/store/actions.ts index 7e071f279d..7c00201da9 100644 --- a/invokeai/frontend/web/src/features/gallery/store/actions.ts +++ b/invokeai/frontend/web/src/features/gallery/store/actions.ts @@ -1,9 +1,9 @@ import { createAction } from '@reduxjs/toolkit'; -import { ImageNameAndType } from 'features/parameters/store/actions'; +import { ImageNameAndOrigin } from 'features/parameters/store/actions'; import { ImageDTO } from 'services/api'; export const requestedImageDeletion = createAction< - ImageDTO | ImageNameAndType | undefined + ImageDTO | ImageNameAndOrigin | undefined >('gallery/requestedImageDeletion'); export const sentImageToCanvas = createAction('gallery/sentImageToCanvas'); diff --git a/invokeai/frontend/web/src/features/gallery/store/galleryPersistDenylist.ts b/invokeai/frontend/web/src/features/gallery/store/galleryPersistDenylist.ts index 49f51d5a80..44e03f9f71 100644 --- a/invokeai/frontend/web/src/features/gallery/store/galleryPersistDenylist.ts +++ b/invokeai/frontend/web/src/features/gallery/store/galleryPersistDenylist.ts @@ -4,6 +4,5 @@ import { GalleryState } from './gallerySlice'; * Gallery slice persist denylist */ export const galleryPersistDenylist: (keyof GalleryState)[] = [ - 'currentCategory', 'shouldAutoSwitchToNewImages', ]; diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index 9d6f5ece60..ab62646c0f 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -1,10 +1,7 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; -import { - receivedResultImagesPage, - receivedUploadImagesPage, -} from '../../../services/thunks/gallery'; import { ImageDTO } from 'services/api'; +import { imageUpserted } from './imagesSlice'; type GalleryImageObjectFitType = 'contain' | 'cover'; @@ -14,7 +11,6 @@ export interface GalleryState { galleryImageObjectFit: GalleryImageObjectFitType; shouldAutoSwitchToNewImages: boolean; shouldUseSingleGalleryColumn: boolean; - currentCategory: 'results' | 'uploads'; } export const initialGalleryState: GalleryState = { @@ -22,7 +18,6 @@ export const initialGalleryState: GalleryState = { galleryImageObjectFit: 'cover', shouldAutoSwitchToNewImages: true, shouldUseSingleGalleryColumn: false, - currentCategory: 'results', }; export const gallerySlice = createSlice({ @@ -46,12 +41,6 @@ export const gallerySlice = createSlice({ setShouldAutoSwitchToNewImages: (state, action: PayloadAction) => { state.shouldAutoSwitchToNewImages = action.payload; }, - setCurrentCategory: ( - state, - action: PayloadAction<'results' | 'uploads'> - ) => { - state.currentCategory = action.payload; - }, setShouldUseSingleGalleryColumn: ( state, action: PayloadAction @@ -59,37 +48,10 @@ export const gallerySlice = createSlice({ state.shouldUseSingleGalleryColumn = action.payload; }, }, - extraReducers(builder) { - builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => { - // rehydrate selectedImage URL when results list comes in - // solves case when outdated URL is in local storage - const selectedImage = state.selectedImage; - if (selectedImage) { - const selectedImageInResults = action.payload.items.find( - (image) => image.image_name === selectedImage.image_name - ); - - if (selectedImageInResults) { - selectedImage.image_url = selectedImageInResults.image_url; - selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url; - state.selectedImage = selectedImage; - } - } - }); - builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => { - // rehydrate selectedImage URL when results list comes in - // solves case when outdated URL is in local storage - const selectedImage = state.selectedImage; - if (selectedImage) { - const selectedImageInResults = action.payload.items.find( - (image) => image.image_name === selectedImage.image_name - ); - - if (selectedImageInResults) { - selectedImage.image_url = selectedImageInResults.image_url; - selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url; - state.selectedImage = selectedImage; - } + extraReducers: (builder) => { + builder.addCase(imageUpserted, (state, action) => { + if (state.shouldAutoSwitchToNewImages) { + state.selectedImage = action.payload; } }); }, @@ -101,7 +63,6 @@ export const { setGalleryImageObjectFit, setShouldAutoSwitchToNewImages, setShouldUseSingleGalleryColumn, - setCurrentCategory, } = gallerySlice.actions; export default gallerySlice.reducer; diff --git a/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts b/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts new file mode 100644 index 0000000000..cb6469aeb4 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts @@ -0,0 +1,135 @@ +import { + PayloadAction, + createEntityAdapter, + createSelector, + createSlice, +} from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { ImageCategory, ImageDTO } from 'services/api'; +import { dateComparator } from 'common/util/dateComparator'; +import { isString, keyBy } from 'lodash-es'; +import { receivedPageOfImages } from 'services/thunks/image'; + +export const imagesAdapter = createEntityAdapter({ + selectId: (image) => image.image_name, + sortComparer: (a, b) => dateComparator(b.created_at, a.created_at), +}); + +export const IMAGE_CATEGORIES: ImageCategory[] = ['general']; +export const ASSETS_CATEGORIES: ImageCategory[] = [ + 'control', + 'mask', + 'user', + 'other', +]; + +type AdditionaImagesState = { + offset: number; + limit: number; + total: number; + isLoading: boolean; + categories: ImageCategory[]; +}; + +export const initialImagesState = + imagesAdapter.getInitialState({ + offset: 0, + limit: 0, + total: 0, + isLoading: false, + categories: IMAGE_CATEGORIES, + }); + +export type ImagesState = typeof initialImagesState; + +const imagesSlice = createSlice({ + name: 'images', + initialState: initialImagesState, + reducers: { + imageUpserted: (state, action: PayloadAction) => { + imagesAdapter.upsertOne(state, action.payload); + }, + imageRemoved: (state, action: PayloadAction) => { + if (isString(action.payload)) { + imagesAdapter.removeOne(state, action.payload); + return; + } + + imagesAdapter.removeOne(state, action.payload.image_name); + }, + imageCategoriesChanged: (state, action: PayloadAction) => { + state.categories = action.payload; + }, + }, + extraReducers: (builder) => { + builder.addCase(receivedPageOfImages.pending, (state) => { + state.isLoading = true; + }); + builder.addCase(receivedPageOfImages.rejected, (state) => { + state.isLoading = false; + }); + builder.addCase(receivedPageOfImages.fulfilled, (state, action) => { + state.isLoading = false; + const { items, offset, limit, total } = action.payload; + state.offset = offset; + state.limit = limit; + state.total = total; + imagesAdapter.upsertMany(state, items); + }); + }, +}); + +export const { + selectAll: selectImagesAll, + selectById: selectImagesById, + selectEntities: selectImagesEntities, + selectIds: selectImagesIds, + selectTotal: selectImagesTotal, +} = imagesAdapter.getSelectors((state) => state.images); + +export const { imageUpserted, imageRemoved, imageCategoriesChanged } = + imagesSlice.actions; + +export default imagesSlice.reducer; + +export const selectFilteredImagesAsArray = createSelector( + (state: RootState) => state, + (state) => { + const { + images: { categories }, + } = state; + + return selectImagesAll(state).filter((i) => + categories.includes(i.image_category) + ); + } +); + +export const selectFilteredImagesAsObject = createSelector( + (state: RootState) => state, + (state) => { + const { + images: { categories }, + } = state; + + return keyBy( + selectImagesAll(state).filter((i) => + categories.includes(i.image_category) + ), + 'image_name' + ); + } +); + +export const selectFilteredImagesIds = createSelector( + (state: RootState) => state, + (state) => { + const { + images: { categories }, + } = state; + + return selectImagesAll(state) + .filter((i) => categories.includes(i.image_category)) + .map((i) => i.image_name); + } +); diff --git a/invokeai/frontend/web/src/features/gallery/store/resultsPersistDenylist.ts b/invokeai/frontend/web/src/features/gallery/store/resultsPersistDenylist.ts deleted file mode 100644 index 1c3d8aaaec..0000000000 --- a/invokeai/frontend/web/src/features/gallery/store/resultsPersistDenylist.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { ResultsState } from './resultsSlice'; - -/** - * Results slice persist denylist - * - * Currently denylisting results slice entirely, see `serialize.ts` - */ -export const resultsPersistDenylist: (keyof ResultsState)[] = []; diff --git a/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts b/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts deleted file mode 100644 index 125f4ff5d5..0000000000 --- a/invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts +++ /dev/null @@ -1,125 +0,0 @@ -import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; -import { - receivedResultImagesPage, - IMAGES_PER_PAGE, -} from 'services/thunks/gallery'; -import { - imageDeleted, - imageMetadataReceived, - imageUrlsReceived, -} from 'services/thunks/image'; -import { ImageDTO } from 'services/api'; -import { dateComparator } from 'common/util/dateComparator'; - -export type ResultsImageDTO = Omit & { - image_type: 'results'; -}; - -export const resultsAdapter = createEntityAdapter({ - selectId: (image) => image.image_name, - sortComparer: (a, b) => dateComparator(b.created_at, a.created_at), -}); - -type AdditionalResultsState = { - page: number; - pages: number; - isLoading: boolean; - nextPage: number; -}; - -export const initialResultsState = - resultsAdapter.getInitialState({ - page: 0, - pages: 0, - isLoading: false, - nextPage: 0, - }); - -export type ResultsState = typeof initialResultsState; - -const resultsSlice = createSlice({ - name: 'results', - initialState: initialResultsState, - reducers: { - resultAdded: resultsAdapter.upsertOne, - }, - extraReducers: (builder) => { - /** - * Received Result Images Page - PENDING - */ - builder.addCase(receivedResultImagesPage.pending, (state) => { - state.isLoading = true; - }); - - /** - * Received Result Images Page - FULFILLED - */ - builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => { - const { page, pages } = action.payload; - - // We know these will all be of the results type, but it's not represented in the API types - const items = action.payload.items as ResultsImageDTO[]; - - resultsAdapter.setMany(state, items); - - state.page = page; - state.pages = pages; - state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1; - state.isLoading = false; - }); - - /** - * Image Metadata Received - FULFILLED - */ - builder.addCase(imageMetadataReceived.fulfilled, (state, action) => { - const { image_type } = action.payload; - - if (image_type === 'results') { - resultsAdapter.upsertOne(state, action.payload as ResultsImageDTO); - } - }); - - /** - * Image URLs Received - FULFILLED - */ - builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - const { image_name, image_type, image_url, thumbnail_url } = - action.payload; - - if (image_type === 'results') { - resultsAdapter.updateOne(state, { - id: image_name, - changes: { - image_url: image_url, - thumbnail_url: thumbnail_url, - }, - }); - } - }); - - /** - * Delete Image - PENDING - * Pre-emptively remove the image from the gallery - */ - builder.addCase(imageDeleted.pending, (state, action) => { - const { imageType, imageName } = action.meta.arg; - - if (imageType === 'results') { - resultsAdapter.removeOne(state, imageName); - } - }); - }, -}); - -export const { - selectAll: selectResultsAll, - selectById: selectResultsById, - selectEntities: selectResultsEntities, - selectIds: selectResultsIds, - selectTotal: selectResultsTotal, -} = resultsAdapter.getSelectors((state) => state.results); - -export const { resultAdded } = resultsSlice.actions; - -export default resultsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/gallery/store/uploadsPersistDenylist.ts b/invokeai/frontend/web/src/features/gallery/store/uploadsPersistDenylist.ts deleted file mode 100644 index 296e8b2057..0000000000 --- a/invokeai/frontend/web/src/features/gallery/store/uploadsPersistDenylist.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { UploadsState } from './uploadsSlice'; - -/** - * Uploads slice persist denylist - * - * Currently denylisting uploads slice entirely, see `serialize.ts` - */ -export const uploadsPersistDenylist: (keyof UploadsState)[] = []; diff --git a/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts b/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts deleted file mode 100644 index 5e458503ec..0000000000 --- a/invokeai/frontend/web/src/features/gallery/store/uploadsSlice.ts +++ /dev/null @@ -1,111 +0,0 @@ -import { createEntityAdapter, createSlice } from '@reduxjs/toolkit'; - -import { RootState } from 'app/store/store'; -import { - receivedUploadImagesPage, - IMAGES_PER_PAGE, -} from 'services/thunks/gallery'; -import { imageDeleted, imageUrlsReceived } from 'services/thunks/image'; -import { ImageDTO } from 'services/api'; -import { dateComparator } from 'common/util/dateComparator'; - -export type UploadsImageDTO = Omit & { - image_type: 'uploads'; -}; - -export const uploadsAdapter = createEntityAdapter({ - selectId: (image) => image.image_name, - sortComparer: (a, b) => dateComparator(b.created_at, a.created_at), -}); - -type AdditionalUploadsState = { - page: number; - pages: number; - isLoading: boolean; - nextPage: number; -}; - -export const initialUploadsState = - uploadsAdapter.getInitialState({ - page: 0, - pages: 0, - nextPage: 0, - isLoading: false, - }); - -export type UploadsState = typeof initialUploadsState; - -const uploadsSlice = createSlice({ - name: 'uploads', - initialState: initialUploadsState, - reducers: { - uploadAdded: uploadsAdapter.upsertOne, - }, - extraReducers: (builder) => { - /** - * Received Upload Images Page - PENDING - */ - builder.addCase(receivedUploadImagesPage.pending, (state) => { - state.isLoading = true; - }); - - /** - * Received Upload Images Page - FULFILLED - */ - builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => { - const { page, pages } = action.payload; - - // We know these will all be of the uploads type, but it's not represented in the API types - const items = action.payload.items as UploadsImageDTO[]; - - uploadsAdapter.setMany(state, items); - - state.page = page; - state.pages = pages; - state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1; - state.isLoading = false; - }); - - /** - * Image URLs Received - FULFILLED - */ - builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - const { image_name, image_type, image_url, thumbnail_url } = - action.payload; - - if (image_type === 'uploads') { - uploadsAdapter.updateOne(state, { - id: image_name, - changes: { - image_url: image_url, - thumbnail_url: thumbnail_url, - }, - }); - } - }); - - /** - * Delete Image - pending - * Pre-emptively remove the image from the gallery - */ - builder.addCase(imageDeleted.pending, (state, action) => { - const { imageType, imageName } = action.meta.arg; - - if (imageType === 'uploads') { - uploadsAdapter.removeOne(state, imageName); - } - }); - }, -}); - -export const { - selectAll: selectUploadsAll, - selectById: selectUploadsById, - selectEntities: selectUploadsEntities, - selectIds: selectUploadsIds, - selectTotal: selectUploadsTotal, -} = uploadsAdapter.getSelectors((state) => state.uploads); - -export const { uploadAdded } = uploadsSlice.actions; - -export default uploadsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx index 341ca19fa9..65b7cfa560 100644 --- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx @@ -10,6 +10,7 @@ import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComp import UNetInputFieldComponent from './fields/UNetInputFieldComponent'; import ClipInputFieldComponent from './fields/ClipInputFieldComponent'; import VaeInputFieldComponent from './fields/VaeInputFieldComponent'; +import ControlInputFieldComponent from './fields/ControlInputFieldComponent'; import ModelInputFieldComponent from './fields/ModelInputFieldComponent'; import NumberInputFieldComponent from './fields/NumberInputFieldComponent'; import StringInputFieldComponent from './fields/StringInputFieldComponent'; @@ -130,6 +131,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => { ); } + if (type === 'control' && template.type === 'control') { + return ( + + ); + } + if (type === 'model' && template.type === 'model') { return ( +) => { + const { nodeId, field } = props; + + return null; +}; + +export default memo(ControlInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx index 18be021625..57cefb0a9c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx @@ -2,7 +2,7 @@ import { Box, Image } from '@chakra-ui/react'; import { useAppDispatch } from 'app/store/storeHooks'; import SelectImagePlaceholder from 'common/components/SelectImagePlaceholder'; import { useGetUrl } from 'common/util/getUrl'; -import useGetImageByNameAndType from 'features/gallery/hooks/useGetImageByName'; +import useGetImageByName from 'features/gallery/hooks/useGetImageByName'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { @@ -11,7 +11,6 @@ import { } from 'features/nodes/types/types'; import { DragEvent, memo, useCallback, useState } from 'react'; -import { ImageType } from 'services/api'; import { FieldComponentProps } from './types'; const ImageInputFieldComponent = ( @@ -19,7 +18,7 @@ const ImageInputFieldComponent = ( ) => { const { nodeId, field } = props; - const getImageByNameAndType = useGetImageByNameAndType(); + const getImageByName = useGetImageByName(); const dispatch = useAppDispatch(); const [url, setUrl] = useState(field.value?.image_url); const { getUrl } = useGetUrl(); @@ -27,13 +26,7 @@ const ImageInputFieldComponent = ( const handleDrop = useCallback( (e: DragEvent) => { const name = e.dataTransfer.getData('invokeai/imageName'); - const type = e.dataTransfer.getData('invokeai/imageType') as ImageType; - - if (!name || !type) { - return; - } - - const image = getImageByNameAndType(name, type); + const image = getImageByName(name); if (!image) { return; @@ -49,7 +42,7 @@ const ImageInputFieldComponent = ( }) ); }, - [getImageByNameAndType, dispatch, field.name, nodeId] + [getImageByName, dispatch, field.name, nodeId] ); return ( diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 36c3514eeb..b8cd7efaa4 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -4,6 +4,7 @@ export const HANDLE_TOOLTIP_OPEN_DELAY = 500; export const FIELD_TYPE_MAP: Record = { integer: 'integer', + float: 'float', number: 'float', string: 'string', boolean: 'boolean', @@ -18,6 +19,8 @@ export const FIELD_TYPE_MAP: Record = { array: 'array', item: 'item', ColorField: 'color', + ControlField: 'control', + control: 'control', }; const COLOR_TOKEN_VALUE = 500; @@ -25,6 +28,9 @@ const COLOR_TOKEN_VALUE = 500; const getColorTokenCssVariable = (color: string) => `var(--invokeai-colors-${color}-${COLOR_TOKEN_VALUE})`; +// @ts-ignore +// @ts-ignore +// @ts-ignore export const FIELDS: Record = { integer: { color: 'red', @@ -92,6 +98,12 @@ export const FIELDS: Record = { title: 'Vae', description: 'Vae submodel.', }, + control: { + color: 'cyan', + colorCssVar: getColorTokenCssVariable('cyan'), // TODO: no free color left + title: 'Control', + description: 'Control info passed between nodes.', + }, model: { color: 'teal', colorCssVar: getColorTokenCssVariable('teal'), diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 930b7dad57..5e140b6eef 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -64,6 +64,7 @@ export type FieldType = | 'unet' | 'clip' | 'vae' + | 'control' | 'model' | 'array' | 'item' @@ -88,6 +89,7 @@ export type InputFieldValue = | UNetInputFieldValue | ClipInputFieldValue | VaeInputFieldValue + | ControlInputFieldValue | EnumInputFieldValue | ModelInputFieldValue | ArrayInputFieldValue @@ -111,6 +113,7 @@ export type InputFieldTemplate = | UNetInputFieldTemplate | ClipInputFieldTemplate | VaeInputFieldTemplate + | ControlInputFieldTemplate | EnumInputFieldTemplate | ModelInputFieldTemplate | ArrayInputFieldTemplate @@ -186,6 +189,11 @@ export type LatentsInputFieldValue = FieldValueBase & { export type ConditioningInputFieldValue = FieldValueBase & { type: 'conditioning'; + value?: string; +}; + +export type ControlInputFieldValue = FieldValueBase & { + type: 'control'; value?: undefined; }; @@ -286,6 +294,11 @@ export type ConditioningInputFieldTemplate = InputFieldTemplateBase & { type: 'conditioning'; }; +export type ControlInputFieldTemplate = InputFieldTemplateBase & { + default: undefined; + type: 'control'; +}; + export type EnumInputFieldTemplate = InputFieldTemplateBase & { default: string | number; type: 'enum'; diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index b275c84248..f1ad731d32 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -13,6 +13,7 @@ import { UNetInputFieldTemplate, ClipInputFieldTemplate, VaeInputFieldTemplate, + ControlInputFieldTemplate, StringInputFieldTemplate, ModelInputFieldTemplate, ArrayInputFieldTemplate, @@ -263,6 +264,21 @@ const buildVaeInputFieldTemplate = ({ return template; }; +const buildControlInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): ControlInputFieldTemplate => { + const template: ControlInputFieldTemplate = { + ...baseField, + type: 'control', + inputRequirement: 'always', + inputKind: 'connection', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildEnumInputFieldTemplate = ({ schemaObject, baseField, @@ -334,9 +350,20 @@ export const getFieldType = ( if (typeHints && name in typeHints) { rawFieldType = typeHints[name]; } else if (!schemaObject.type) { - rawFieldType = refObjectToFieldType( - schemaObject.allOf![0] as OpenAPIV3.ReferenceObject - ); + // if schemaObject has no type, then it should have one of allOf, anyOf, oneOf + if (schemaObject.allOf) { + rawFieldType = refObjectToFieldType( + schemaObject.allOf![0] as OpenAPIV3.ReferenceObject + ); + } else if (schemaObject.anyOf) { + rawFieldType = refObjectToFieldType( + schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject + ); + } else if (schemaObject.oneOf) { + rawFieldType = refObjectToFieldType( + schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject + ); + } } else if (schemaObject.enum) { rawFieldType = 'enum'; } else if (schemaObject.type) { @@ -388,6 +415,9 @@ export const buildInputFieldTemplate = ( if (['vae'].includes(fieldType)) { return buildVaeInputFieldTemplate({ schemaObject, baseField }); } + if (['control'].includes(fieldType)) { + return buildControlInputFieldTemplate({ schemaObject, baseField }); + } if (['model'].includes(fieldType)) { return buildModelInputFieldTemplate({ schemaObject, baseField }); } diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts index c0c19708c7..1703c45331 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts @@ -64,6 +64,10 @@ export const buildInputFieldValue = ( fieldValue.value = undefined; } + if (template.type === 'control') { + fieldValue.value = undefined; + } + if (template.type === 'model') { fieldValue.value = undefined; } diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts index 3615f7d298..2d23b882ea 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts @@ -16,7 +16,7 @@ import { buildEdges } from '../edgeBuilders/buildEdges'; import { log } from 'app/logging/useLogger'; import { buildInpaintNode } from '../nodeBuilders/buildInpaintNode'; -const moduleLog = log.child({ namespace: 'buildCanvasGraph' }); +const moduleLog = log.child({ namespace: 'nodes' }); const buildBaseNode = ( nodeType: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint', @@ -26,18 +26,21 @@ const buildBaseNode = ( | ImageToImageInvocation | InpaintInvocation | undefined => { - const dimensionsOverride = state.canvas.boundingBoxDimensions; + const overrides = { + ...state.canvas.boundingBoxDimensions, + is_intermediate: true, + }; if (nodeType === 'txt2img') { - return buildTxt2ImgNode(state, dimensionsOverride); + return buildTxt2ImgNode(state, overrides); } if (nodeType === 'img2img') { - return buildImg2ImgNode(state, dimensionsOverride); + return buildImg2ImgNode(state, overrides); } if (nodeType === 'inpaint' || nodeType === 'outpaint') { - return buildInpaintNode(state, dimensionsOverride); + return buildInpaintNode(state, overrides); } }; @@ -77,18 +80,23 @@ export const buildCanvasGraphComponents = async ( infillMethod, } = state.generation; - // generationParameters.invert_mask = shouldPreserveMaskedArea; - // if (boundingBoxScale !== 'none') { - // generationParameters.inpaint_width = scaledBoundingBoxDimensions.width; - // generationParameters.inpaint_height = scaledBoundingBoxDimensions.height; - // } + const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = + state.canvas; + + if (boundingBoxScaleMethod !== 'none') { + baseNode.inpaint_width = scaledBoundingBoxDimensions.width; + baseNode.inpaint_height = scaledBoundingBoxDimensions.height; + } + baseNode.seam_size = seamSize; baseNode.seam_blur = seamBlur; baseNode.seam_strength = seamStrength; baseNode.seam_steps = seamSteps; - baseNode.tile_size = tileSize; baseNode.infill_method = infillMethod as InpaintInvocation['infill_method']; - // baseNode.force_outpaint = false; + + if (infillMethod === 'tile') { + baseNode.tile_size = tileSize; + } } // We always range and iterate nodes, no matter the iteration count diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts index d9eb80d654..fe4f6c63b5 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts @@ -2,21 +2,31 @@ import { RootState } from 'app/store/store'; import { CompelInvocation, Graph, + ImageResizeInvocation, ImageToLatentsInvocation, + IterateInvocation, LatentsToImageInvocation, LatentsToLatentsInvocation, + NoiseInvocation, + RandomIntInvocation, + RangeOfSizeInvocation, } from 'services/api'; import { NonNullableGraph } from 'features/nodes/types/types'; -import { addNoiseNodes } from '../nodeBuilders/addNoiseNodes'; import { log } from 'app/logging/useLogger'; +import { set } from 'lodash-es'; -const moduleLog = log.child({ namespace: 'buildImageToImageGraph' }); +const moduleLog = log.child({ namespace: 'nodes' }); const POSITIVE_CONDITIONING = 'positive_conditioning'; const NEGATIVE_CONDITIONING = 'negative_conditioning'; const IMAGE_TO_LATENTS = 'image_to_latents'; const LATENTS_TO_LATENTS = 'latents_to_latents'; const LATENTS_TO_IMAGE = 'latents_to_image'; +const RESIZE = 'resize_image'; +const NOISE = 'noise'; +const RANDOM_INT = 'rand_int'; +const RANGE_OF_SIZE = 'range_of_size'; +const ITERATE = 'iterate'; /** * Builds the Image to Image tab graph. @@ -31,6 +41,12 @@ export const buildImageToImageGraph = (state: RootState): Graph => { steps, initialImage, img2imgStrength: strength, + shouldFitToWidthHeight, + width, + height, + iterations, + seed, + shouldRandomizeSeed, } = state.generation; if (!initialImage) { @@ -38,12 +54,12 @@ export const buildImageToImageGraph = (state: RootState): Graph => { throw new Error('No initial image found in state'); } - let graph: NonNullableGraph = { + const graph: NonNullableGraph = { nodes: {}, edges: [], }; - // Create the conditioning, t2l and l2i nodes + // Create the positive conditioning (prompt) node const positiveConditioningNode: CompelInvocation = { id: POSITIVE_CONDITIONING, type: 'compel', @@ -51,6 +67,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => { model, }; + // Negative conditioning const negativeConditioningNode: CompelInvocation = { id: NEGATIVE_CONDITIONING, type: 'compel', @@ -58,16 +75,15 @@ export const buildImageToImageGraph = (state: RootState): Graph => { model, }; + // This will encode the raster image to latents - but it may get its `image` from a resize node, + // so we do not set its `image` property yet const imageToLatentsNode: ImageToLatentsInvocation = { id: IMAGE_TO_LATENTS, type: 'i2l', model, - image: { - image_name: initialImage?.image_name, - image_type: initialImage?.image_type, - }, }; + // This does the actual img2img inference const latentsToLatentsNode: LatentsToLatentsInvocation = { id: LATENTS_TO_LATENTS, type: 'l2l', @@ -78,20 +94,21 @@ export const buildImageToImageGraph = (state: RootState): Graph => { strength, }; + // Finally we decode the latents back to an image const latentsToImageNode: LatentsToImageInvocation = { id: LATENTS_TO_IMAGE, type: 'l2i', model, }; - // Add to the graph + // Add all those nodes to the graph graph.nodes[POSITIVE_CONDITIONING] = positiveConditioningNode; graph.nodes[NEGATIVE_CONDITIONING] = negativeConditioningNode; graph.nodes[IMAGE_TO_LATENTS] = imageToLatentsNode; graph.nodes[LATENTS_TO_LATENTS] = latentsToLatentsNode; graph.nodes[LATENTS_TO_IMAGE] = latentsToImageNode; - // Connect them + // Connect the prompt nodes to the imageToLatents node graph.edges.push({ source: { node_id: POSITIVE_CONDITIONING, field: 'conditioning' }, destination: { @@ -99,7 +116,6 @@ export const buildImageToImageGraph = (state: RootState): Graph => { field: 'positive_conditioning', }, }); - graph.edges.push({ source: { node_id: NEGATIVE_CONDITIONING, field: 'conditioning' }, destination: { @@ -108,6 +124,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => { }, }); + // Connect the image-encoding node graph.edges.push({ source: { node_id: IMAGE_TO_LATENTS, field: 'latents' }, destination: { @@ -116,6 +133,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => { }, }); + // Connect the image-decoding node graph.edges.push({ source: { node_id: LATENTS_TO_LATENTS, field: 'latents' }, destination: { @@ -124,8 +142,271 @@ export const buildImageToImageGraph = (state: RootState): Graph => { }, }); - // Create and add the noise nodes - graph = addNoiseNodes(graph, latentsToLatentsNode.id, state); + /** + * Now we need to handle iterations and random seeds. There are four possible scenarios: + * - Single iteration, explicit seed + * - Single iteration, random seed + * - Multiple iterations, explicit seed + * - Multiple iterations, random seed + * + * They all have different graphs and connections. + */ + + // Single iteration, explicit seed + if (!shouldRandomizeSeed && iterations === 1) { + // Noise node using the explicit seed + const noiseNode: NoiseInvocation = { + id: NOISE, + type: 'noise', + seed: seed, + }; + + graph.nodes[NOISE] = noiseNode; + + // Connect noise to l2l + graph.edges.push({ + source: { node_id: NOISE, field: 'noise' }, + destination: { + node_id: LATENTS_TO_LATENTS, + field: 'noise', + }, + }); + } + + // Single iteration, random seed + if (shouldRandomizeSeed && iterations === 1) { + // Random int node to generate the seed + const randomIntNode: RandomIntInvocation = { + id: RANDOM_INT, + type: 'rand_int', + }; + + // Noise node without any seed + const noiseNode: NoiseInvocation = { + id: NOISE, + type: 'noise', + }; + + graph.nodes[RANDOM_INT] = randomIntNode; + graph.nodes[NOISE] = noiseNode; + + // Connect random int to the seed of the noise node + graph.edges.push({ + source: { node_id: RANDOM_INT, field: 'a' }, + destination: { + node_id: NOISE, + field: 'seed', + }, + }); + + // Connect noise to l2l + graph.edges.push({ + source: { node_id: NOISE, field: 'noise' }, + destination: { + node_id: LATENTS_TO_LATENTS, + field: 'noise', + }, + }); + } + + // Multiple iterations, explicit seed + if (!shouldRandomizeSeed && iterations > 1) { + // Range of size node to generate `iterations` count of seeds - range of size generates a collection + // of ints from `start` to `start + size`. The `start` is the seed, and the `size` is the number of + // iterations. + const rangeOfSizeNode: RangeOfSizeInvocation = { + id: RANGE_OF_SIZE, + type: 'range_of_size', + start: seed, + size: iterations, + }; + + // Iterate node to iterate over the seeds generated by the range of size node + const iterateNode: IterateInvocation = { + id: ITERATE, + type: 'iterate', + }; + + // Noise node without any seed + const noiseNode: NoiseInvocation = { + id: NOISE, + type: 'noise', + }; + + // Adding to the graph + graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode; + graph.nodes[ITERATE] = iterateNode; + graph.nodes[NOISE] = noiseNode; + + // Connect range of size to iterate + graph.edges.push({ + source: { node_id: RANGE_OF_SIZE, field: 'collection' }, + destination: { + node_id: ITERATE, + field: 'collection', + }, + }); + + // Connect iterate to noise + graph.edges.push({ + source: { + node_id: ITERATE, + field: 'item', + }, + destination: { + node_id: NOISE, + field: 'seed', + }, + }); + + // Connect noise to l2l + graph.edges.push({ + source: { node_id: NOISE, field: 'noise' }, + destination: { + node_id: LATENTS_TO_LATENTS, + field: 'noise', + }, + }); + } + + // Multiple iterations, random seed + if (shouldRandomizeSeed && iterations > 1) { + // Random int node to generate the seed + const randomIntNode: RandomIntInvocation = { + id: RANDOM_INT, + type: 'rand_int', + }; + + // Range of size node to generate `iterations` count of seeds - range of size generates a collection + const rangeOfSizeNode: RangeOfSizeInvocation = { + id: RANGE_OF_SIZE, + type: 'range_of_size', + size: iterations, + }; + + // Iterate node to iterate over the seeds generated by the range of size node + const iterateNode: IterateInvocation = { + id: ITERATE, + type: 'iterate', + }; + + // Noise node without any seed + const noiseNode: NoiseInvocation = { + id: NOISE, + type: 'noise', + width, + height, + }; + + // Adding to the graph + graph.nodes[RANDOM_INT] = randomIntNode; + graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode; + graph.nodes[ITERATE] = iterateNode; + graph.nodes[NOISE] = noiseNode; + + // Connect random int to the start of the range of size so the range starts on the random first seed + graph.edges.push({ + source: { node_id: RANDOM_INT, field: 'a' }, + destination: { node_id: RANGE_OF_SIZE, field: 'start' }, + }); + + // Connect range of size to iterate + graph.edges.push({ + source: { node_id: RANGE_OF_SIZE, field: 'collection' }, + destination: { + node_id: ITERATE, + field: 'collection', + }, + }); + + // Connect iterate to noise + graph.edges.push({ + source: { + node_id: ITERATE, + field: 'item', + }, + destination: { + node_id: NOISE, + field: 'seed', + }, + }); + + // Connect noise to l2l + graph.edges.push({ + source: { node_id: NOISE, field: 'noise' }, + destination: { + node_id: LATENTS_TO_LATENTS, + field: 'noise', + }, + }); + } + + if (shouldFitToWidthHeight) { + // The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS` + + // Create a resize node, explicitly setting its image + const resizeNode: ImageResizeInvocation = { + id: RESIZE, + type: 'img_resize', + image: { + image_name: initialImage.image_name, + image_origin: initialImage.image_origin, + }, + is_intermediate: true, + height, + width, + }; + + graph.nodes[RESIZE] = resizeNode; + + // The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS` + graph.edges.push({ + source: { node_id: RESIZE, field: 'image' }, + destination: { + node_id: IMAGE_TO_LATENTS, + field: 'image', + }, + }); + + // The `RESIZE` node also passes its width and height to `NOISE` + graph.edges.push({ + source: { node_id: RESIZE, field: 'width' }, + destination: { + node_id: NOISE, + field: 'width', + }, + }); + + graph.edges.push({ + source: { node_id: RESIZE, field: 'height' }, + destination: { + node_id: NOISE, + field: 'height', + }, + }); + } else { + // We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly + set(graph.nodes[IMAGE_TO_LATENTS], 'image', { + image_name: initialImage.image_name, + image_origin: initialImage.image_origin, + }); + + // Pass the image's dimensions to the `NOISE` node + graph.edges.push({ + source: { node_id: IMAGE_TO_LATENTS, field: 'width' }, + destination: { + node_id: NOISE, + field: 'width', + }, + }); + graph.edges.push({ + source: { node_id: IMAGE_TO_LATENTS, field: 'height' }, + destination: { + node_id: NOISE, + field: 'height', + }, + }); + } return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts index eef7379624..6a700d4813 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts @@ -1,8 +1,9 @@ import { Graph } from 'services/api'; import { v4 as uuidv4 } from 'uuid'; -import { cloneDeep, reduce } from 'lodash-es'; +import { cloneDeep, forEach, omit, reduce, values } from 'lodash-es'; import { RootState } from 'app/store/store'; import { InputFieldValue } from 'features/nodes/types/types'; +import { AnyInvocation } from 'services/events/types'; /** * We need to do special handling for some fields @@ -89,6 +90,24 @@ export const buildNodesGraph = (state: RootState): Graph => { [] ); + /** + * Omit all inputs that have edges connected. + * + * Fixes edge case where the user has connected an input, but also provided an invalid explicit, + * value. + * + * In this edge case, pydantic will invalidate the node based on the invalid explicit value, + * even though the actual value that will be used comes from the connection. + */ + parsedEdges.forEach((edge) => { + const destination_node = parsedNodes[edge.destination.node_id]; + const field = edge.destination.field; + parsedNodes[edge.destination.node_id] = omit( + destination_node, + field + ) as AnyInvocation; + }); + // Assemble! const graph = { id: uuidv4(), diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts index 51f89e8f74..753ccccff8 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildTextToImageGraph.ts @@ -2,16 +2,23 @@ import { RootState } from 'app/store/store'; import { CompelInvocation, Graph, + IterateInvocation, LatentsToImageInvocation, + NoiseInvocation, + RandomIntInvocation, + RangeOfSizeInvocation, TextToLatentsInvocation, } from 'services/api'; import { NonNullableGraph } from 'features/nodes/types/types'; -import { addNoiseNodes } from '../nodeBuilders/addNoiseNodes'; const POSITIVE_CONDITIONING = 'positive_conditioning'; const NEGATIVE_CONDITIONING = 'negative_conditioning'; const TEXT_TO_LATENTS = 'text_to_latents'; const LATENTS_TO_IMAGE = 'latents_to_image'; +const NOISE = 'noise'; +const RANDOM_INT = 'rand_int'; +const RANGE_OF_SIZE = 'range_of_size'; +const ITERATE = 'iterate'; /** * Builds the Text to Image tab graph. @@ -24,9 +31,14 @@ export const buildTextToImageGraph = (state: RootState): Graph => { cfgScale: cfg_scale, scheduler, steps, + width, + height, + iterations, + seed, + shouldRandomizeSeed, } = state.generation; - let graph: NonNullableGraph = { + const graph: NonNullableGraph = { nodes: {}, edges: [], }; @@ -92,8 +104,209 @@ export const buildTextToImageGraph = (state: RootState): Graph => { }, }); - // Create and add the noise nodes - graph = addNoiseNodes(graph, TEXT_TO_LATENTS, state); + /** + * Now we need to handle iterations and random seeds. There are four possible scenarios: + * - Single iteration, explicit seed + * - Single iteration, random seed + * - Multiple iterations, explicit seed + * - Multiple iterations, random seed + * + * They all have different graphs and connections. + */ + // Single iteration, explicit seed + if (!shouldRandomizeSeed && iterations === 1) { + // Noise node using the explicit seed + const noiseNode: NoiseInvocation = { + id: NOISE, + type: 'noise', + seed: seed, + width, + height, + }; + + graph.nodes[NOISE] = noiseNode; + + // Connect noise to l2l + graph.edges.push({ + source: { node_id: NOISE, field: 'noise' }, + destination: { + node_id: TEXT_TO_LATENTS, + field: 'noise', + }, + }); + } + + // Single iteration, random seed + if (shouldRandomizeSeed && iterations === 1) { + // Random int node to generate the seed + const randomIntNode: RandomIntInvocation = { + id: RANDOM_INT, + type: 'rand_int', + }; + + // Noise node without any seed + const noiseNode: NoiseInvocation = { + id: NOISE, + type: 'noise', + width, + height, + }; + + graph.nodes[RANDOM_INT] = randomIntNode; + graph.nodes[NOISE] = noiseNode; + + // Connect random int to the seed of the noise node + graph.edges.push({ + source: { node_id: RANDOM_INT, field: 'a' }, + destination: { + node_id: NOISE, + field: 'seed', + }, + }); + + // Connect noise to t2l + graph.edges.push({ + source: { node_id: NOISE, field: 'noise' }, + destination: { + node_id: TEXT_TO_LATENTS, + field: 'noise', + }, + }); + } + + // Multiple iterations, explicit seed + if (!shouldRandomizeSeed && iterations > 1) { + // Range of size node to generate `iterations` count of seeds - range of size generates a collection + // of ints from `start` to `start + size`. The `start` is the seed, and the `size` is the number of + // iterations. + const rangeOfSizeNode: RangeOfSizeInvocation = { + id: RANGE_OF_SIZE, + type: 'range_of_size', + start: seed, + size: iterations, + }; + + // Iterate node to iterate over the seeds generated by the range of size node + const iterateNode: IterateInvocation = { + id: ITERATE, + type: 'iterate', + }; + + // Noise node without any seed + const noiseNode: NoiseInvocation = { + id: NOISE, + type: 'noise', + width, + height, + }; + + // Adding to the graph + graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode; + graph.nodes[ITERATE] = iterateNode; + graph.nodes[NOISE] = noiseNode; + + // Connect range of size to iterate + graph.edges.push({ + source: { node_id: RANGE_OF_SIZE, field: 'collection' }, + destination: { + node_id: ITERATE, + field: 'collection', + }, + }); + + // Connect iterate to noise + graph.edges.push({ + source: { + node_id: ITERATE, + field: 'item', + }, + destination: { + node_id: NOISE, + field: 'seed', + }, + }); + + // Connect noise to t2l + graph.edges.push({ + source: { node_id: NOISE, field: 'noise' }, + destination: { + node_id: TEXT_TO_LATENTS, + field: 'noise', + }, + }); + } + + // Multiple iterations, random seed + if (shouldRandomizeSeed && iterations > 1) { + // Random int node to generate the seed + const randomIntNode: RandomIntInvocation = { + id: RANDOM_INT, + type: 'rand_int', + }; + + // Range of size node to generate `iterations` count of seeds - range of size generates a collection + const rangeOfSizeNode: RangeOfSizeInvocation = { + id: RANGE_OF_SIZE, + type: 'range_of_size', + size: iterations, + }; + + // Iterate node to iterate over the seeds generated by the range of size node + const iterateNode: IterateInvocation = { + id: ITERATE, + type: 'iterate', + }; + + // Noise node without any seed + const noiseNode: NoiseInvocation = { + id: NOISE, + type: 'noise', + width, + height, + }; + + // Adding to the graph + graph.nodes[RANDOM_INT] = randomIntNode; + graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode; + graph.nodes[ITERATE] = iterateNode; + graph.nodes[NOISE] = noiseNode; + + // Connect random int to the start of the range of size so the range starts on the random first seed + graph.edges.push({ + source: { node_id: RANDOM_INT, field: 'a' }, + destination: { node_id: RANGE_OF_SIZE, field: 'start' }, + }); + + // Connect range of size to iterate + graph.edges.push({ + source: { node_id: RANGE_OF_SIZE, field: 'collection' }, + destination: { + node_id: ITERATE, + field: 'collection', + }, + }); + + // Connect iterate to noise + graph.edges.push({ + source: { + node_id: ITERATE, + field: 'item', + }, + destination: { + node_id: NOISE, + field: 'seed', + }, + }); + + // Connect noise to t2l + graph.edges.push({ + source: { node_id: NOISE, field: 'noise' }, + destination: { + node_id: TEXT_TO_LATENTS, + field: 'noise', + }, + }); + } return graph; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/addNoiseNodes.ts b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/addNoiseNodes.ts deleted file mode 100644 index ba3d4d8168..0000000000 --- a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/addNoiseNodes.ts +++ /dev/null @@ -1,208 +0,0 @@ -import { RootState } from 'app/store/store'; -import { - IterateInvocation, - NoiseInvocation, - RandomIntInvocation, - RangeOfSizeInvocation, -} from 'services/api'; -import { NonNullableGraph } from 'features/nodes/types/types'; -import { cloneDeep } from 'lodash-es'; - -const NOISE = 'noise'; -const RANDOM_INT = 'rand_int'; -const RANGE_OF_SIZE = 'range_of_size'; -const ITERATE = 'iterate'; -/** - * Adds the appropriate noise nodes to a linear UI t2l or l2l graph. - * - * @param graph The graph to add the noise nodes to. - * @param baseNodeId The id of the base node to connect the noise nodes to. - * @param state The app state.. - */ -export const addNoiseNodes = ( - graph: NonNullableGraph, - baseNodeId: string, - state: RootState -): NonNullableGraph => { - const graphClone = cloneDeep(graph); - - // Create and add the noise nodes - const { width, height, seed, iterations, shouldRandomizeSeed } = - state.generation; - - // Single iteration, explicit seed - if (!shouldRandomizeSeed && iterations === 1) { - const noiseNode: NoiseInvocation = { - id: NOISE, - type: 'noise', - seed: seed, - width, - height, - }; - - graphClone.nodes[NOISE] = noiseNode; - - // Connect them - graphClone.edges.push({ - source: { node_id: NOISE, field: 'noise' }, - destination: { - node_id: baseNodeId, - field: 'noise', - }, - }); - } - - // Single iteration, random seed - if (shouldRandomizeSeed && iterations === 1) { - // TODO: This assumes the `high` value is the max seed value - const randomIntNode: RandomIntInvocation = { - id: RANDOM_INT, - type: 'rand_int', - }; - - const noiseNode: NoiseInvocation = { - id: NOISE, - type: 'noise', - width, - height, - }; - - graphClone.nodes[RANDOM_INT] = randomIntNode; - graphClone.nodes[NOISE] = noiseNode; - - graphClone.edges.push({ - source: { node_id: RANDOM_INT, field: 'a' }, - destination: { - node_id: NOISE, - field: 'seed', - }, - }); - - graphClone.edges.push({ - source: { node_id: NOISE, field: 'noise' }, - destination: { - node_id: baseNodeId, - field: 'noise', - }, - }); - } - - // Multiple iterations, explicit seed - if (!shouldRandomizeSeed && iterations > 1) { - const rangeOfSizeNode: RangeOfSizeInvocation = { - id: RANGE_OF_SIZE, - type: 'range_of_size', - start: seed, - size: iterations, - }; - - const iterateNode: IterateInvocation = { - id: ITERATE, - type: 'iterate', - }; - - const noiseNode: NoiseInvocation = { - id: NOISE, - type: 'noise', - width, - height, - }; - - graphClone.nodes[RANGE_OF_SIZE] = rangeOfSizeNode; - graphClone.nodes[ITERATE] = iterateNode; - graphClone.nodes[NOISE] = noiseNode; - - graphClone.edges.push({ - source: { node_id: RANGE_OF_SIZE, field: 'collection' }, - destination: { - node_id: ITERATE, - field: 'collection', - }, - }); - - graphClone.edges.push({ - source: { - node_id: ITERATE, - field: 'item', - }, - destination: { - node_id: NOISE, - field: 'seed', - }, - }); - - graphClone.edges.push({ - source: { node_id: NOISE, field: 'noise' }, - destination: { - node_id: baseNodeId, - field: 'noise', - }, - }); - } - - // Multiple iterations, random seed - if (shouldRandomizeSeed && iterations > 1) { - // TODO: This assumes the `high` value is the max seed value - const randomIntNode: RandomIntInvocation = { - id: RANDOM_INT, - type: 'rand_int', - }; - - const rangeOfSizeNode: RangeOfSizeInvocation = { - id: RANGE_OF_SIZE, - type: 'range_of_size', - size: iterations, - }; - - const iterateNode: IterateInvocation = { - id: ITERATE, - type: 'iterate', - }; - - const noiseNode: NoiseInvocation = { - id: NOISE, - type: 'noise', - width, - height, - }; - - graphClone.nodes[RANDOM_INT] = randomIntNode; - graphClone.nodes[RANGE_OF_SIZE] = rangeOfSizeNode; - graphClone.nodes[ITERATE] = iterateNode; - graphClone.nodes[NOISE] = noiseNode; - - graphClone.edges.push({ - source: { node_id: RANDOM_INT, field: 'a' }, - destination: { node_id: RANGE_OF_SIZE, field: 'start' }, - }); - - graphClone.edges.push({ - source: { node_id: RANGE_OF_SIZE, field: 'collection' }, - destination: { - node_id: ITERATE, - field: 'collection', - }, - }); - - graphClone.edges.push({ - source: { - node_id: ITERATE, - field: 'item', - }, - destination: { - node_id: NOISE, - field: 'seed', - }, - }); - - graphClone.edges.push({ - source: { node_id: NOISE, field: 'noise' }, - destination: { - node_id: baseNodeId, - field: 'noise', - }, - }); - } - - return graphClone; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts index 5f00d12a23..558f937837 100644 --- a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts +++ b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts @@ -58,7 +58,7 @@ export const buildImg2ImgNode = ( imageToImageNode.image = { image_name: initialImage.name, - image_type: initialImage.type, + image_origin: initialImage.type, }; } diff --git a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildInpaintNode.ts b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildInpaintNode.ts index b3f6cca933..593658e536 100644 --- a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildInpaintNode.ts +++ b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildInpaintNode.ts @@ -2,15 +2,12 @@ import { v4 as uuidv4 } from 'uuid'; import { RootState } from 'app/store/store'; import { InpaintInvocation } from 'services/api'; import { O } from 'ts-toolbelt'; -import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; export const buildInpaintNode = ( state: RootState, overrides: O.Partial = {} ): InpaintInvocation => { const nodeId = uuidv4(); - const { generation } = state; - const activeTabName = activeTabNameSelector(state); const { positivePrompt: prompt, @@ -25,8 +22,7 @@ export const buildInpaintNode = ( img2imgStrength: strength, shouldFitToWidthHeight: fit, shouldRandomizeSeed, - initialImage, - } = generation; + } = state.generation; const inpaintNode: InpaintInvocation = { id: nodeId, @@ -42,19 +38,6 @@ export const buildInpaintNode = ( fit, }; - // on Canvas tab, we do not manually specific init image - if (activeTabName !== 'unifiedCanvas') { - if (!initialImage) { - // TODO: handle this more better - throw 'no initial image'; - } - - inpaintNode.image = { - image_name: initialImage.name, - image_type: initialImage.type, - }; - } - if (!shouldRandomizeSeed) { inpaintNode.seed = seed; } diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts index ddd19b8749..c77fdeca5e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts @@ -13,7 +13,9 @@ import { buildOutputFieldTemplates, } from './fieldTemplateBuilders'; -const invocationDenylist = ['Graph']; +const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate']; + +const invocationDenylist = ['Graph', 'InvocationMeta']; export const parseSchema = (openAPI: OpenAPIV3.Document) => { // filter out non-invocation schemas, plus some tricky invocations for now @@ -73,7 +75,7 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => { (inputsAccumulator, property, propertyName) => { if ( // `type` and `id` are not valid inputs/outputs - !['type', 'id'].includes(propertyName) && + !RESERVED_FIELD_NAMES.includes(propertyName) && isSchemaObject(property) ) { const field: InputFieldTemplate | undefined = diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight.tsx index 75ec70f257..dc83ba8907 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight.tsx @@ -2,18 +2,22 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISlider from 'common/components/IAISlider'; -import { canvasSelector } from 'features/canvas/store/canvasSelectors'; +import { + canvasSelector, + isStagingSelector, +} from 'features/canvas/store/canvasSelectors'; import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; const selector = createSelector( - canvasSelector, - (canvas) => { + [canvasSelector, isStagingSelector], + (canvas, isStaging) => { const { boundingBoxDimensions } = canvas; return { boundingBoxDimensions, + isStaging, }; }, defaultSelectorOptions @@ -21,7 +25,7 @@ const selector = createSelector( const ParamBoundingBoxWidth = () => { const dispatch = useAppDispatch(); - const { boundingBoxDimensions } = useAppSelector(selector); + const { boundingBoxDimensions, isStaging } = useAppSelector(selector); const { t } = useTranslation(); @@ -45,12 +49,13 @@ const ParamBoundingBoxWidth = () => { return ( { + [canvasSelector, isStagingSelector], + (canvas, isStaging) => { const { boundingBoxDimensions } = canvas; return { boundingBoxDimensions, + isStaging, }; }, defaultSelectorOptions @@ -21,7 +25,7 @@ const selector = createSelector( const ParamBoundingBoxWidth = () => { const dispatch = useAppDispatch(); - const { boundingBoxDimensions } = useAppSelector(selector); + const { boundingBoxDimensions, isStaging } = useAppSelector(selector); const { t } = useTranslation(); @@ -45,12 +49,13 @@ const ParamBoundingBoxWidth = () => { return ( { return ( diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx index 2b5db18d93..f4413c4cf6 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx @@ -1,24 +1,33 @@ +import { createSelector } from '@reduxjs/toolkit'; import { Scheduler } from 'app/constants'; -import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICustomSelect from 'common/components/IAICustomSelect'; +import { generationSelector } from 'features/parameters/store/generationSelectors'; import { setScheduler } from 'features/parameters/store/generationSlice'; -import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; +import { uiSelector } from 'features/ui/store/uiSelectors'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; +const selector = createSelector( + [uiSelector, generationSelector], + (ui, generation) => { + // TODO: DPMSolverSinglestepScheduler is fixed in https://github.com/huggingface/diffusers/pull/3413 + // but we need to wait for the next release before removing this special handling. + const allSchedulers = ui.schedulers.filter((scheduler) => { + return !['dpmpp_2s'].includes(scheduler); + }); + + return { + scheduler: generation.scheduler, + allSchedulers, + }; + }, + defaultSelectorOptions +); + const ParamScheduler = () => { - const scheduler = useAppSelector( - (state: RootState) => state.generation.scheduler - ); - - const activeTabName = useAppSelector(activeTabNameSelector); - - const schedulers = useAppSelector((state: RootState) => state.ui.schedulers); - - const img2imgSchedulers = schedulers.filter((scheduler) => { - return !['dpmpp_2s'].includes(scheduler); - }); + const { allSchedulers, scheduler } = useAppSelector(selector); const dispatch = useAppDispatch(); const { t } = useTranslation(); @@ -38,11 +47,7 @@ const ParamScheduler = () => { label={t('parameters.scheduler')} selectedItem={scheduler} setSelectedItem={handleChange} - items={ - ['img2img', 'unifiedCanvas'].includes(activeTabName) - ? img2imgSchedulers - : schedulers - } + items={allSchedulers} withCheckIcon /> ); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx index be40f548e6..cfe1513420 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx @@ -5,7 +5,6 @@ import { useGetUrl } from 'common/util/getUrl'; import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { DragEvent, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { ImageType } from 'services/api'; import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { initialImageSelected } from 'features/parameters/store/actions'; @@ -55,9 +54,7 @@ const InitialImagePreview = () => { const handleDrop = useCallback( (e: DragEvent) => { const name = e.dataTransfer.getData('invokeai/imageName'); - const type = e.dataTransfer.getData('invokeai/imageType') as ImageType; - - dispatch(initialImageSelected({ image_name: name, image_type: type })); + dispatch(initialImageSelected(name)); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useParameters.ts deleted file mode 100644 index 27ae63e5dd..0000000000 --- a/invokeai/frontend/web/src/features/parameters/hooks/useParameters.ts +++ /dev/null @@ -1,151 +0,0 @@ -import { useAppDispatch } from 'app/store/storeHooks'; -import { isFinite, isString } from 'lodash-es'; -import { useCallback } from 'react'; -import { useTranslation } from 'react-i18next'; -import useSetBothPrompts from './usePrompt'; -import { allParametersSet, setSeed } from '../store/generationSlice'; -import { isImageField } from 'services/types/guards'; -import { NUMPY_RAND_MAX } from 'app/constants'; -import { initialImageSelected } from '../store/actions'; -import { setActiveTab } from 'features/ui/store/uiSlice'; -import { useAppToaster } from 'app/components/Toaster'; -import { ImageDTO } from 'services/api'; - -export const useParameters = () => { - const dispatch = useAppDispatch(); - const toaster = useAppToaster(); - const { t } = useTranslation(); - const setBothPrompts = useSetBothPrompts(); - - /** - * Sets prompt with toast - */ - const recallPrompt = useCallback( - (prompt: unknown, negativePrompt?: unknown) => { - if (!isString(prompt) || !isString(negativePrompt)) { - toaster({ - title: t('toast.promptNotSet'), - description: t('toast.promptNotSetDesc'), - status: 'warning', - duration: 2500, - isClosable: true, - }); - return; - } - - setBothPrompts(prompt, negativePrompt); - toaster({ - title: t('toast.promptSet'), - status: 'info', - duration: 2500, - isClosable: true, - }); - }, - [t, toaster, setBothPrompts] - ); - - /** - * Sets seed with toast - */ - const recallSeed = useCallback( - (seed: unknown) => { - const s = Number(seed); - if (!isFinite(s) || (isFinite(s) && !(s >= 0 && s <= NUMPY_RAND_MAX))) { - toaster({ - title: t('toast.seedNotSet'), - description: t('toast.seedNotSetDesc'), - status: 'warning', - duration: 2500, - isClosable: true, - }); - return; - } - - dispatch(setSeed(s)); - toaster({ - title: t('toast.seedSet'), - status: 'info', - duration: 2500, - isClosable: true, - }); - }, - [t, toaster, dispatch] - ); - - /** - * Sets initial image with toast - */ - const recallInitialImage = useCallback( - async (image: unknown) => { - if (!isImageField(image)) { - toaster({ - title: t('toast.initialImageNotSet'), - description: t('toast.initialImageNotSetDesc'), - status: 'warning', - duration: 2500, - isClosable: true, - }); - return; - } - - dispatch(initialImageSelected(image)); - toaster({ - title: t('toast.initialImageSet'), - status: 'info', - duration: 2500, - isClosable: true, - }); - }, - [t, toaster, dispatch] - ); - - /** - * Sets image as initial image with toast - */ - const sendToImageToImage = useCallback( - (image: ImageDTO) => { - dispatch(initialImageSelected(image)); - }, - [dispatch] - ); - - const recallAllParameters = useCallback( - (image: ImageDTO | undefined) => { - const type = image?.metadata?.type; - // not sure what this list should be - if (['t2l', 'l2l', 'inpaint'].includes(String(type))) { - dispatch(allParametersSet(image)); - - if (image?.metadata?.type === 'l2l') { - dispatch(setActiveTab('img2img')); - } else if (image?.metadata?.type === 't2l') { - dispatch(setActiveTab('txt2img')); - } - - toaster({ - title: t('toast.parametersSet'), - status: 'success', - duration: 2500, - isClosable: true, - }); - } else { - toaster({ - title: t('toast.parametersNotSet'), - description: t('toast.parametersNotSetDesc'), - status: 'error', - duration: 2500, - isClosable: true, - }); - } - }, - [t, toaster, dispatch] - ); - - return { - recallPrompt, - recallSeed, - recallInitialImage, - sendToImageToImage, - recallAllParameters, - }; -}; diff --git a/invokeai/frontend/web/src/features/parameters/hooks/usePrompt.ts b/invokeai/frontend/web/src/features/parameters/hooks/usePrompt.ts deleted file mode 100644 index 3fee0bcdd8..0000000000 --- a/invokeai/frontend/web/src/features/parameters/hooks/usePrompt.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { getPromptAndNegative } from 'common/util/getPromptAndNegative'; - -import * as InvokeAI from 'app/types/invokeai'; -import promptToString from 'common/util/promptToString'; -import { useAppDispatch } from 'app/store/storeHooks'; -import { setNegativePrompt, setPositivePrompt } from '../store/generationSlice'; -import { useCallback } from 'react'; - -// TECHDEBT: We have two metadata prompt formats and need to handle recalling either of them. -// This hook provides a function to do that. -const useSetBothPrompts = () => { - const dispatch = useAppDispatch(); - - return useCallback( - (inputPrompt: InvokeAI.Prompt, negativePrompt: InvokeAI.Prompt) => { - dispatch(setPositivePrompt(inputPrompt)); - dispatch(setNegativePrompt(negativePrompt)); - }, - [dispatch] - ); -}; - -export default useSetBothPrompts; diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts new file mode 100644 index 0000000000..7b7a405867 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -0,0 +1,348 @@ +import { useAppDispatch } from 'app/store/storeHooks'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + modelSelected, + setCfgScale, + setHeight, + setImg2imgStrength, + setNegativePrompt, + setPositivePrompt, + setScheduler, + setSeed, + setSteps, + setWidth, +} from '../store/generationSlice'; +import { isImageField } from 'services/types/guards'; +import { initialImageSelected } from '../store/actions'; +import { useAppToaster } from 'app/components/Toaster'; +import { ImageDTO } from 'services/api'; +import { + isValidCfgScale, + isValidHeight, + isValidModel, + isValidNegativePrompt, + isValidPositivePrompt, + isValidScheduler, + isValidSeed, + isValidSteps, + isValidStrength, + isValidWidth, +} from '../store/parameterZodSchemas'; + +export const useRecallParameters = () => { + const dispatch = useAppDispatch(); + const toaster = useAppToaster(); + const { t } = useTranslation(); + + const parameterSetToast = useCallback(() => { + toaster({ + title: t('toast.parameterSet'), + status: 'info', + duration: 2500, + isClosable: true, + }); + }, [t, toaster]); + + const parameterNotSetToast = useCallback(() => { + toaster({ + title: t('toast.parameterNotSet'), + status: 'warning', + duration: 2500, + isClosable: true, + }); + }, [t, toaster]); + + const allParameterSetToast = useCallback(() => { + toaster({ + title: t('toast.parametersSet'), + status: 'info', + duration: 2500, + isClosable: true, + }); + }, [t, toaster]); + + const allParameterNotSetToast = useCallback(() => { + toaster({ + title: t('toast.parametersNotSet'), + status: 'warning', + duration: 2500, + isClosable: true, + }); + }, [t, toaster]); + + /** + * Recall both prompts with toast + */ + const recallBothPrompts = useCallback( + (positivePrompt: unknown, negativePrompt: unknown) => { + if ( + isValidPositivePrompt(positivePrompt) || + isValidNegativePrompt(negativePrompt) + ) { + if (isValidPositivePrompt(positivePrompt)) { + dispatch(setPositivePrompt(positivePrompt)); + } + if (isValidNegativePrompt(negativePrompt)) { + dispatch(setNegativePrompt(negativePrompt)); + } + parameterSetToast(); + return; + } + parameterNotSetToast(); + }, + [dispatch, parameterSetToast, parameterNotSetToast] + ); + + /** + * Recall positive prompt with toast + */ + const recallPositivePrompt = useCallback( + (positivePrompt: unknown) => { + if (!isValidPositivePrompt(positivePrompt)) { + parameterNotSetToast(); + return; + } + dispatch(setPositivePrompt(positivePrompt)); + parameterSetToast(); + }, + [dispatch, parameterSetToast, parameterNotSetToast] + ); + + /** + * Recall negative prompt with toast + */ + const recallNegativePrompt = useCallback( + (negativePrompt: unknown) => { + if (!isValidNegativePrompt(negativePrompt)) { + parameterNotSetToast(); + return; + } + dispatch(setNegativePrompt(negativePrompt)); + parameterSetToast(); + }, + [dispatch, parameterSetToast, parameterNotSetToast] + ); + + /** + * Recall seed with toast + */ + const recallSeed = useCallback( + (seed: unknown) => { + if (!isValidSeed(seed)) { + parameterNotSetToast(); + return; + } + dispatch(setSeed(seed)); + parameterSetToast(); + }, + [dispatch, parameterSetToast, parameterNotSetToast] + ); + + /** + * Recall CFG scale with toast + */ + const recallCfgScale = useCallback( + (cfgScale: unknown) => { + if (!isValidCfgScale(cfgScale)) { + parameterNotSetToast(); + return; + } + dispatch(setCfgScale(cfgScale)); + parameterSetToast(); + }, + [dispatch, parameterSetToast, parameterNotSetToast] + ); + + /** + * Recall model with toast + */ + const recallModel = useCallback( + (model: unknown) => { + if (!isValidModel(model)) { + parameterNotSetToast(); + return; + } + dispatch(modelSelected(model)); + parameterSetToast(); + }, + [dispatch, parameterSetToast, parameterNotSetToast] + ); + + /** + * Recall scheduler with toast + */ + const recallScheduler = useCallback( + (scheduler: unknown) => { + if (!isValidScheduler(scheduler)) { + parameterNotSetToast(); + return; + } + dispatch(setScheduler(scheduler)); + parameterSetToast(); + }, + [dispatch, parameterSetToast, parameterNotSetToast] + ); + + /** + * Recall steps with toast + */ + const recallSteps = useCallback( + (steps: unknown) => { + if (!isValidSteps(steps)) { + parameterNotSetToast(); + return; + } + dispatch(setSteps(steps)); + parameterSetToast(); + }, + [dispatch, parameterSetToast, parameterNotSetToast] + ); + + /** + * Recall width with toast + */ + const recallWidth = useCallback( + (width: unknown) => { + if (!isValidWidth(width)) { + parameterNotSetToast(); + return; + } + dispatch(setWidth(width)); + parameterSetToast(); + }, + [dispatch, parameterSetToast, parameterNotSetToast] + ); + + /** + * Recall height with toast + */ + const recallHeight = useCallback( + (height: unknown) => { + if (!isValidHeight(height)) { + parameterNotSetToast(); + return; + } + dispatch(setHeight(height)); + parameterSetToast(); + }, + [dispatch, parameterSetToast, parameterNotSetToast] + ); + + /** + * Recall strength with toast + */ + const recallStrength = useCallback( + (strength: unknown) => { + if (!isValidStrength(strength)) { + parameterNotSetToast(); + return; + } + dispatch(setImg2imgStrength(strength)); + parameterSetToast(); + }, + [dispatch, parameterSetToast, parameterNotSetToast] + ); + + /** + * Sets initial image with toast + */ + const recallInitialImage = useCallback( + async (image: unknown) => { + if (!isImageField(image)) { + parameterNotSetToast(); + return; + } + dispatch(initialImageSelected(image.image_name)); + parameterSetToast(); + }, + [dispatch, parameterSetToast, parameterNotSetToast] + ); + + /** + * Sets image as initial image with toast + */ + const sendToImageToImage = useCallback( + (image: ImageDTO) => { + dispatch(initialImageSelected(image)); + }, + [dispatch] + ); + + const recallAllParameters = useCallback( + (image: ImageDTO | undefined) => { + if (!image || !image.metadata) { + allParameterNotSetToast(); + return; + } + const { + cfg_scale, + height, + model, + positive_conditioning, + negative_conditioning, + scheduler, + seed, + steps, + width, + strength, + clip, + extra, + latents, + unet, + vae, + } = image.metadata; + + if (isValidCfgScale(cfg_scale)) { + dispatch(setCfgScale(cfg_scale)); + } + if (isValidModel(model)) { + dispatch(modelSelected(model)); + } + if (isValidPositivePrompt(positive_conditioning)) { + dispatch(setPositivePrompt(positive_conditioning)); + } + if (isValidNegativePrompt(negative_conditioning)) { + dispatch(setNegativePrompt(negative_conditioning)); + } + if (isValidScheduler(scheduler)) { + dispatch(setScheduler(scheduler)); + } + if (isValidSeed(seed)) { + dispatch(setSeed(seed)); + } + if (isValidSteps(steps)) { + dispatch(setSteps(steps)); + } + if (isValidWidth(width)) { + dispatch(setWidth(width)); + } + if (isValidHeight(height)) { + dispatch(setHeight(height)); + } + if (isValidStrength(strength)) { + dispatch(setImg2imgStrength(strength)); + } + + allParameterSetToast(); + }, + [allParameterNotSetToast, allParameterSetToast, dispatch] + ); + + return { + recallBothPrompts, + recallPositivePrompt, + recallNegativePrompt, + recallSeed, + recallInitialImage, + recallCfgScale, + recallModel, + recallScheduler, + recallSteps, + recallWidth, + recallHeight, + recallStrength, + recallAllParameters, + sendToImageToImage, + }; +}; diff --git a/invokeai/frontend/web/src/features/parameters/store/actions.ts b/invokeai/frontend/web/src/features/parameters/store/actions.ts index 853597c809..e9b90134e1 100644 --- a/invokeai/frontend/web/src/features/parameters/store/actions.ts +++ b/invokeai/frontend/web/src/features/parameters/store/actions.ts @@ -1,10 +1,10 @@ import { createAction } from '@reduxjs/toolkit'; import { isObject } from 'lodash-es'; -import { ImageDTO, ImageType } from 'services/api'; +import { ImageDTO, ResourceOrigin } from 'services/api'; -export type ImageNameAndType = { +export type ImageNameAndOrigin = { image_name: string; - image_type: ImageType; + image_origin: ResourceOrigin; }; export const isImageDTO = (image: any): image is ImageDTO => { @@ -13,8 +13,8 @@ export const isImageDTO = (image: any): image is ImageDTO => { isObject(image) && 'image_name' in image && image?.image_name !== undefined && - 'image_type' in image && - image?.image_type !== undefined && + 'image_origin' in image && + image?.image_origin !== undefined && 'image_url' in image && image?.image_url !== undefined && 'thumbnail_url' in image && @@ -26,6 +26,6 @@ export const isImageDTO = (image: any): image is ImageDTO => { ); }; -export const initialImageSelected = createAction< - ImageDTO | ImageNameAndType | undefined ->('generation/initialImageSelected'); +export const initialImageSelected = createAction( + 'generation/initialImageSelected' +); diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSelectors.ts b/invokeai/frontend/web/src/features/parameters/store/generationSelectors.ts index dbf5eec791..b7322740ef 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSelectors.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSelectors.ts @@ -1,34 +1,3 @@ -import { createSelector } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; -import { selectResultsById } from 'features/gallery/store/resultsSlice'; -import { selectUploadsById } from 'features/gallery/store/uploadsSlice'; -import { isEqual } from 'lodash-es'; export const generationSelector = (state: RootState) => state.generation; - -export const mayGenerateMultipleImagesSelector = createSelector( - generationSelector, - ({ shouldRandomizeSeed, shouldGenerateVariations }) => { - return shouldRandomizeSeed || shouldGenerateVariations; - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); - -export const initialImageSelector = createSelector( - [(state: RootState) => state, generationSelector], - (state, generation) => { - const { initialImage } = generation; - - if (initialImage?.type === 'results') { - return selectResultsById(state, initialImage.name); - } - - if (initialImage?.type === 'uploads') { - return selectUploadsById(state, initialImage.name); - } - } -); diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 849f848ff3..6420950e4a 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -1,43 +1,53 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; -import * as InvokeAI from 'app/types/invokeai'; -import promptToString from 'common/util/promptToString'; -import { clamp, sample } from 'lodash-es'; -import { setAllParametersReducer } from './setAllParametersReducer'; +import { clamp, sortBy } from 'lodash-es'; import { receivedModels } from 'services/thunks/model'; import { Scheduler } from 'app/constants'; import { ImageDTO } from 'services/api'; +import { configChanged } from 'features/system/store/configSlice'; +import { + CfgScaleParam, + HeightParam, + ModelParam, + NegativePromptParam, + PositivePromptParam, + SchedulerParam, + SeedParam, + StepsParam, + StrengthParam, + WidthParam, +} from './parameterZodSchemas'; export interface GenerationState { - cfgScale: number; - height: number; - img2imgStrength: number; + cfgScale: CfgScaleParam; + height: HeightParam; + img2imgStrength: StrengthParam; infillMethod: string; initialImage?: ImageDTO; iterations: number; perlin: number; - positivePrompt: string; - negativePrompt: string; - scheduler: Scheduler; + positivePrompt: PositivePromptParam; + negativePrompt: NegativePromptParam; + scheduler: SchedulerParam; seamBlur: number; seamSize: number; seamSteps: number; seamStrength: number; - seed: number; + seed: SeedParam; seedWeights: string; shouldFitToWidthHeight: boolean; shouldGenerateVariations: boolean; shouldRandomizeSeed: boolean; shouldUseNoiseSettings: boolean; - steps: number; + steps: StepsParam; threshold: number; tileSize: number; variationAmount: number; - width: number; + width: WidthParam; shouldUseSymmetry: boolean; horizontalSymmetrySteps: number; verticalSymmetrySteps: number; - model: string; + model: ModelParam; shouldUseSeamless: boolean; seamlessXAxis: boolean; seamlessYAxis: boolean; @@ -83,27 +93,11 @@ export const generationSlice = createSlice({ name: 'generation', initialState, reducers: { - setPositivePrompt: ( - state, - action: PayloadAction - ) => { - const newPrompt = action.payload; - if (typeof newPrompt === 'string') { - state.positivePrompt = newPrompt; - } else { - state.positivePrompt = promptToString(newPrompt); - } + setPositivePrompt: (state, action: PayloadAction) => { + state.positivePrompt = action.payload; }, - setNegativePrompt: ( - state, - action: PayloadAction - ) => { - const newPrompt = action.payload; - if (typeof newPrompt === 'string') { - state.negativePrompt = newPrompt; - } else { - state.negativePrompt = promptToString(newPrompt); - } + setNegativePrompt: (state, action: PayloadAction) => { + state.negativePrompt = action.payload; }, setIterations: (state, action: PayloadAction) => { state.iterations = action.payload; @@ -174,7 +168,6 @@ export const generationSlice = createSlice({ state.shouldGenerateVariations = true; state.variationAmount = 0; }, - allParametersSet: setAllParametersReducer, resetParametersState: (state) => { return { ...state, @@ -227,10 +220,15 @@ export const generationSlice = createSlice({ extraReducers: (builder) => { builder.addCase(receivedModels.fulfilled, (state, action) => { if (!state.model) { - const randomModel = sample(action.payload); - if (randomModel) { - state.model = randomModel.name; - } + const firstModel = sortBy(action.payload, 'name')[0]; + state.model = firstModel.name; + } + }); + + builder.addCase(configChanged, (state, action) => { + const defaultModel = action.payload.sd?.defaultModel; + if (defaultModel && !state.model) { + state.model = defaultModel; } }); }, @@ -273,7 +271,6 @@ export const { setSeamless, setSeamlessXAxis, setSeamlessYAxis, - allParametersSet, } = generationSlice.actions; export default generationSlice.reducer; diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts new file mode 100644 index 0000000000..b99e57bfbb --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts @@ -0,0 +1,156 @@ +import { NUMPY_RAND_MAX, SCHEDULERS } from 'app/constants'; +import { z } from 'zod'; + +/** + * These zod schemas should match the pydantic node schemas. + * + * Parameters only need schemas if we want to recall them from metadata. + * + * Each parameter needs: + * - a zod schema + * - a type alias, inferred from the zod schema + * - a combo validation/type guard function, which returns true if the value is valid + */ + +/** + * Zod schema for positive prompt parameter + */ +export const zPositivePrompt = z.string(); +/** + * Type alias for positive prompt parameter, inferred from its zod schema + */ +export type PositivePromptParam = z.infer; +/** + * Validates/type-guards a value as a positive prompt parameter + */ +export const isValidPositivePrompt = ( + val: unknown +): val is PositivePromptParam => zPositivePrompt.safeParse(val).success; + +/** + * Zod schema for negative prompt parameter + */ +export const zNegativePrompt = z.string(); +/** + * Type alias for negative prompt parameter, inferred from its zod schema + */ +export type NegativePromptParam = z.infer; +/** + * Validates/type-guards a value as a negative prompt parameter + */ +export const isValidNegativePrompt = ( + val: unknown +): val is NegativePromptParam => zNegativePrompt.safeParse(val).success; + +/** + * Zod schema for steps parameter + */ +export const zSteps = z.number().int().min(1); +/** + * Type alias for steps parameter, inferred from its zod schema + */ +export type StepsParam = z.infer; +/** + * Validates/type-guards a value as a steps parameter + */ +export const isValidSteps = (val: unknown): val is StepsParam => + zSteps.safeParse(val).success; + +/** + * Zod schema for CFG scale parameter + */ +export const zCfgScale = z.number().min(1); +/** + * Type alias for CFG scale parameter, inferred from its zod schema + */ +export type CfgScaleParam = z.infer; +/** + * Validates/type-guards a value as a CFG scale parameter + */ +export const isValidCfgScale = (val: unknown): val is CfgScaleParam => + zCfgScale.safeParse(val).success; + +/** + * Zod schema for scheduler parameter + */ +export const zScheduler = z.enum(SCHEDULERS); +/** + * Type alias for scheduler parameter, inferred from its zod schema + */ +export type SchedulerParam = z.infer; +/** + * Validates/type-guards a value as a scheduler parameter + */ +export const isValidScheduler = (val: unknown): val is SchedulerParam => + zScheduler.safeParse(val).success; + +/** + * Zod schema for seed parameter + */ +export const zSeed = z.number().int().min(0).max(NUMPY_RAND_MAX); +/** + * Type alias for seed parameter, inferred from its zod schema + */ +export type SeedParam = z.infer; +/** + * Validates/type-guards a value as a seed parameter + */ +export const isValidSeed = (val: unknown): val is SeedParam => + zSeed.safeParse(val).success; + +/** + * Zod schema for width parameter + */ +export const zWidth = z.number().multipleOf(8).min(64); +/** + * Type alias for width parameter, inferred from its zod schema + */ +export type WidthParam = z.infer; +/** + * Validates/type-guards a value as a width parameter + */ +export const isValidWidth = (val: unknown): val is WidthParam => + zWidth.safeParse(val).success; + +/** + * Zod schema for height parameter + */ +export const zHeight = z.number().multipleOf(8).min(64); +/** + * Type alias for height parameter, inferred from its zod schema + */ +export type HeightParam = z.infer; +/** + * Validates/type-guards a value as a height parameter + */ +export const isValidHeight = (val: unknown): val is HeightParam => + zHeight.safeParse(val).success; + +/** + * Zod schema for model parameter + * TODO: Make this a dynamically generated enum? + */ +export const zModel = z.string(); +/** + * Type alias for model parameter, inferred from its zod schema + */ +export type ModelParam = z.infer; +/** + * Validates/type-guards a value as a model parameter + */ +export const isValidModel = (val: unknown): val is ModelParam => + zModel.safeParse(val).success; + +/** + * Zod schema for l2l strength parameter + */ +export const zStrength = z.number().min(0).max(1); +/** + * Type alias for l2l strength parameter, inferred from its zod schema + */ +export type StrengthParam = z.infer; +/** + * Validates/type-guards a value as a l2l strength parameter + */ +export const isValidStrength = (val: unknown): val is StrengthParam => + zStrength.safeParse(val).success; diff --git a/invokeai/frontend/web/src/features/parameters/store/setAllParametersReducer.ts b/invokeai/frontend/web/src/features/parameters/store/setAllParametersReducer.ts deleted file mode 100644 index 8f06c7d0ef..0000000000 --- a/invokeai/frontend/web/src/features/parameters/store/setAllParametersReducer.ts +++ /dev/null @@ -1,77 +0,0 @@ -import { Draft, PayloadAction } from '@reduxjs/toolkit'; -import { GenerationState } from './generationSlice'; -import { ImageDTO, ImageToImageInvocation } from 'services/api'; -import { isScheduler } from 'app/constants'; - -export const setAllParametersReducer = ( - state: Draft, - action: PayloadAction -) => { - const metadata = action.payload?.metadata; - - if (!metadata) { - return; - } - - // not sure what this list should be - if ( - metadata.type === 't2l' || - metadata.type === 'l2l' || - metadata.type === 'inpaint' - ) { - const { - cfg_scale, - height, - model, - positive_conditioning, - negative_conditioning, - scheduler, - seed, - steps, - width, - } = metadata; - - if (cfg_scale !== undefined) { - state.cfgScale = Number(cfg_scale); - } - if (height !== undefined) { - state.height = Number(height); - } - if (model !== undefined) { - state.model = String(model); - } - if (positive_conditioning !== undefined) { - state.positivePrompt = String(positive_conditioning); - } - if (negative_conditioning !== undefined) { - state.negativePrompt = String(negative_conditioning); - } - if (scheduler !== undefined) { - const schedulerString = String(scheduler); - if (isScheduler(schedulerString)) { - state.scheduler = schedulerString; - } - } - if (seed !== undefined) { - state.seed = Number(seed); - state.shouldRandomizeSeed = false; - } - if (steps !== undefined) { - state.steps = Number(steps); - } - if (width !== undefined) { - state.width = Number(width); - } - } - - if (metadata.type === 'l2l') { - const { fit, image } = metadata as ImageToImageInvocation; - - if (fit !== undefined) { - state.shouldFitToWidthHeight = Boolean(fit); - } - // if (image !== undefined) { - // state.initialImage = image; - // } - } -}; diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index 520e30b60a..be4be8ceaa 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -4,19 +4,33 @@ import { isEqual } from 'lodash-es'; import { useTranslation } from 'react-i18next'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { selectModelsById, selectModelsIds } from '../store/modelSlice'; +import { + selectModelsAll, + selectModelsById, + selectModelsIds, +} from '../store/modelSlice'; import { RootState } from 'app/store/store'; import { modelSelected } from 'features/parameters/store/generationSlice'; import { generationSelector } from 'features/parameters/store/generationSelectors'; -import IAICustomSelect from 'common/components/IAICustomSelect'; +import IAICustomSelect, { + ItemTooltips, +} from 'common/components/IAICustomSelect'; const selector = createSelector( [(state: RootState) => state, generationSelector], (state, generation) => { const selectedModel = selectModelsById(state, generation.model); const allModelNames = selectModelsIds(state).map((id) => String(id)); + const allModelTooltips = selectModelsAll(state).reduce( + (allModelTooltips, model) => { + allModelTooltips[model.name] = model.description ?? ''; + return allModelTooltips; + }, + {} as ItemTooltips + ); return { allModelNames, + allModelTooltips, selectedModel, }; }, @@ -30,7 +44,8 @@ const selector = createSelector( const ModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { allModelNames, selectedModel } = useAppSelector(selector); + const { allModelNames, allModelTooltips, selectedModel } = + useAppSelector(selector); const handleChangeModel = useCallback( (v: string | null | undefined) => { if (!v) { @@ -46,6 +61,7 @@ const ModelSelect = () => { label={t('modelManager.model')} tooltip={selectedModel?.description} items={allModelNames} + itemTooltips={allModelTooltips} selectedItem={selectedModel?.name ?? ''} setSelectedItem={handleChangeModel} withCheckIcon={true} diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx index 54556124c9..4cfe35081b 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx @@ -35,7 +35,13 @@ import { } from 'features/ui/store/uiSlice'; import { UIState } from 'features/ui/store/uiTypes'; import { isEqual } from 'lodash-es'; -import { ChangeEvent, cloneElement, ReactElement, useCallback } from 'react'; +import { + ChangeEvent, + cloneElement, + ReactElement, + useCallback, + useEffect, +} from 'react'; import { useTranslation } from 'react-i18next'; import { VALID_LOG_LEVELS } from 'app/logging/useLogger'; import { LogLevelName } from 'roarr'; @@ -85,15 +91,33 @@ const modalSectionStyles: ChakraProps['sx'] = { borderRadius: 'base', }; +type ConfigOptions = { + shouldShowDeveloperSettings: boolean; + shouldShowResetWebUiText: boolean; + shouldShowBetaLayout: boolean; +}; + type SettingsModalProps = { /* The button to open the Settings Modal */ children: ReactElement; + config?: ConfigOptions; }; -const SettingsModal = ({ children }: SettingsModalProps) => { +const SettingsModal = ({ children, config }: SettingsModalProps) => { const dispatch = useAppDispatch(); const { t } = useTranslation(); + const shouldShowBetaLayout = config?.shouldShowBetaLayout ?? true; + const shouldShowDeveloperSettings = + config?.shouldShowDeveloperSettings ?? true; + const shouldShowResetWebUiText = config?.shouldShowResetWebUiText ?? true; + + useEffect(() => { + if (!shouldShowDeveloperSettings) { + dispatch(shouldLogToConsoleChanged(false)); + } + }, [shouldShowDeveloperSettings, dispatch]); + const { isOpen: isSettingsModalOpen, onOpen: onSettingsModalOpen, @@ -189,13 +213,15 @@ const SettingsModal = ({ children }: SettingsModalProps) => { dispatch(setShouldDisplayGuides(e.target.checked)) } /> - ) => - dispatch(setShouldUseCanvasBetaLayout(e.target.checked)) - } - /> + {shouldShowBetaLayout && ( + ) => + dispatch(setShouldUseCanvasBetaLayout(e.target.checked)) + } + /> + )} { /> - - {t('settings.developer')} - - - ) => - dispatch(setEnableImageDebugging(e.target.checked)) - } - /> - + {shouldShowDeveloperSettings && ( + + {t('settings.developer')} + + + ) => + dispatch(setEnableImageDebugging(e.target.checked)) + } + /> + + )} {t('settings.resetWebUI')} {t('settings.resetWebUI')} - {t('settings.resetWebUIDesc1')} - {t('settings.resetWebUIDesc2')} + {shouldShowResetWebUiText && ( + <> + {t('settings.resetWebUIDesc1')} + {t('settings.resetWebUIDesc2')} + + )} diff --git a/invokeai/frontend/web/src/features/system/store/actions.ts b/invokeai/frontend/web/src/features/system/store/actions.ts new file mode 100644 index 0000000000..66181bc803 --- /dev/null +++ b/invokeai/frontend/web/src/features/system/store/actions.ts @@ -0,0 +1,3 @@ +import { createAction } from '@reduxjs/toolkit'; + +export const sessionReadyToInvoke = createAction('system/sessionReadyToInvoke'); diff --git a/invokeai/frontend/web/src/features/system/store/sessionSlice.ts b/invokeai/frontend/web/src/features/system/store/sessionSlice.ts new file mode 100644 index 0000000000..40d59c7baa --- /dev/null +++ b/invokeai/frontend/web/src/features/system/store/sessionSlice.ts @@ -0,0 +1,62 @@ +// TODO: split system slice inot this + +// import type { PayloadAction } from '@reduxjs/toolkit'; +// import { createSlice } from '@reduxjs/toolkit'; +// import { socketSubscribed, socketUnsubscribed } from 'services/events/actions'; + +// export type SessionState = { +// /** +// * The current socket session id +// */ +// sessionId: string; +// /** +// * Whether the current session is a canvas session. Needed to manage the staging area. +// */ +// isCanvasSession: boolean; +// /** +// * When a session is canceled, its ID is stored here until a new session is created. +// */ +// canceledSessionId: string; +// }; + +// export const initialSessionState: SessionState = { +// sessionId: '', +// isCanvasSession: false, +// canceledSessionId: '', +// }; + +// export const sessionSlice = createSlice({ +// name: 'session', +// initialState: initialSessionState, +// reducers: { +// sessionIdChanged: (state, action: PayloadAction) => { +// state.sessionId = action.payload; +// }, +// isCanvasSessionChanged: (state, action: PayloadAction) => { +// state.isCanvasSession = action.payload; +// }, +// }, +// extraReducers: (builder) => { +// /** +// * Socket Subscribed +// */ +// builder.addCase(socketSubscribed, (state, action) => { +// state.sessionId = action.payload.sessionId; +// state.canceledSessionId = ''; +// }); + +// /** +// * Socket Unsubscribed +// */ +// builder.addCase(socketUnsubscribed, (state) => { +// state.sessionId = ''; +// }); +// }, +// }); + +// export const { sessionIdChanged, isCanvasSessionChanged } = +// sessionSlice.actions; + +// export default sessionSlice.reducer; + +export default {}; diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index 7331fcdba9..6bc8d7106a 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -1,22 +1,15 @@ import { UseToastOptions } from '@chakra-ui/react'; -import type { PayloadAction } from '@reduxjs/toolkit'; +import { PayloadAction, isAnyOf } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import * as InvokeAI from 'app/types/invokeai'; -import { - generatorProgress, - graphExecutionStateComplete, - invocationComplete, - invocationError, - invocationStarted, - socketConnected, - socketDisconnected, - socketSubscribed, - socketUnsubscribed, -} from 'services/events/actions'; import { ProgressImage } from 'services/events/types'; import { makeToast } from '../../../app/components/Toaster'; -import { sessionCanceled, sessionInvoked } from 'services/thunks/session'; +import { + sessionCanceled, + sessionCreated, + sessionInvoked, +} from 'services/thunks/session'; import { receivedModels } from 'services/thunks/model'; import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice'; import { LogLevelName } from 'roarr'; @@ -26,6 +19,17 @@ import { t } from 'i18next'; import { userInvoked } from 'app/store/actions'; import { LANGUAGES } from '../components/LanguagePicker'; import { imageUploaded } from 'services/thunks/image'; +import { + appSocketConnected, + appSocketDisconnected, + appSocketGeneratorProgress, + appSocketGraphExecutionStateComplete, + appSocketInvocationComplete, + appSocketInvocationError, + appSocketInvocationStarted, + appSocketSubscribed, + appSocketUnsubscribed, +} from 'services/events/actions'; export type CancelStrategy = 'immediate' | 'scheduled'; @@ -215,12 +219,15 @@ export const systemSlice = createSlice({ languageChanged: (state, action: PayloadAction) => { state.language = action.payload; }, + progressImageSet(state, action: PayloadAction) { + state.progressImage = action.payload; + }, }, extraReducers(builder) { /** * Socket Subscribed */ - builder.addCase(socketSubscribed, (state, action) => { + builder.addCase(appSocketSubscribed, (state, action) => { state.sessionId = action.payload.sessionId; state.canceledSession = ''; }); @@ -228,14 +235,14 @@ export const systemSlice = createSlice({ /** * Socket Unsubscribed */ - builder.addCase(socketUnsubscribed, (state) => { + builder.addCase(appSocketUnsubscribed, (state) => { state.sessionId = null; }); /** * Socket Connected */ - builder.addCase(socketConnected, (state) => { + builder.addCase(appSocketConnected, (state) => { state.isConnected = true; state.isCancelable = true; state.isProcessing = false; @@ -250,7 +257,7 @@ export const systemSlice = createSlice({ /** * Socket Disconnected */ - builder.addCase(socketDisconnected, (state) => { + builder.addCase(appSocketDisconnected, (state) => { state.isConnected = false; state.isProcessing = false; state.isCancelable = true; @@ -265,7 +272,7 @@ export const systemSlice = createSlice({ /** * Invocation Started */ - builder.addCase(invocationStarted, (state) => { + builder.addCase(appSocketInvocationStarted, (state) => { state.isCancelable = true; state.isProcessing = true; state.currentStatusHasSteps = false; @@ -279,7 +286,7 @@ export const systemSlice = createSlice({ /** * Generator Progress */ - builder.addCase(generatorProgress, (state, action) => { + builder.addCase(appSocketGeneratorProgress, (state, action) => { const { step, total_steps, progress_image } = action.payload.data; state.isProcessing = true; @@ -296,7 +303,7 @@ export const systemSlice = createSlice({ /** * Invocation Complete */ - builder.addCase(invocationComplete, (state, action) => { + builder.addCase(appSocketInvocationComplete, (state, action) => { const { data } = action.payload; // state.currentIteration = 0; @@ -305,7 +312,6 @@ export const systemSlice = createSlice({ state.currentStep = 0; state.totalSteps = 0; state.statusTranslationKey = 'common.statusProcessingComplete'; - state.progressImage = null; if (state.canceledSession === data.graph_execution_state_id) { state.isProcessing = false; @@ -316,7 +322,7 @@ export const systemSlice = createSlice({ /** * Invocation Error */ - builder.addCase(invocationError, (state) => { + builder.addCase(appSocketInvocationError, (state) => { state.isProcessing = false; state.isCancelable = true; // state.currentIteration = 0; @@ -333,7 +339,20 @@ export const systemSlice = createSlice({ }); /** - * Session Invoked - PENDING + * Graph Execution State Complete + */ + builder.addCase(appSocketGraphExecutionStateComplete, (state) => { + state.isProcessing = false; + state.isCancelable = false; + state.isCancelScheduled = false; + state.currentStep = 0; + state.totalSteps = 0; + state.statusTranslationKey = 'common.statusConnected'; + state.progressImage = null; + }); + + /** + * User Invoked */ builder.addCase(userInvoked, (state) => { @@ -343,15 +362,8 @@ export const systemSlice = createSlice({ state.statusTranslationKey = 'common.statusPreparing'; }); - builder.addCase(sessionInvoked.rejected, (state, action) => { - const error = action.payload as string | undefined; - state.toastQueue.push( - makeToast({ title: error || t('toast.serverError'), status: 'error' }) - ); - }); - /** - * Session Canceled + * Session Canceled - FULFILLED */ builder.addCase(sessionCanceled.fulfilled, (state, action) => { state.canceledSession = action.meta.arg.sessionId; @@ -368,18 +380,6 @@ export const systemSlice = createSlice({ ); }); - /** - * Session Canceled - */ - builder.addCase(graphExecutionStateComplete, (state) => { - state.isProcessing = false; - state.isCancelable = false; - state.isCancelScheduled = false; - state.currentStep = 0; - state.totalSteps = 0; - state.statusTranslationKey = 'common.statusConnected'; - }); - /** * Received available models from the backend */ @@ -414,6 +414,26 @@ export const systemSlice = createSlice({ builder.addCase(imageUploaded.fulfilled, (state) => { state.isUploading = false; }); + + // *** Matchers - must be after all cases *** + + /** + * Session Invoked - REJECTED + * Session Created - REJECTED + */ + builder.addMatcher(isAnySessionRejected, (state, action) => { + state.isProcessing = false; + state.isCancelable = false; + state.isCancelScheduled = false; + state.currentStep = 0; + state.totalSteps = 0; + state.statusTranslationKey = 'common.statusConnected'; + state.progressImage = null; + + state.toastQueue.push( + makeToast({ title: t('toast.serverError'), status: 'error' }) + ); + }); }, }); @@ -438,6 +458,12 @@ export const { isPersistedChanged, shouldAntialiasProgressImageChanged, languageChanged, + progressImageSet, } = systemSlice.actions; export default systemSlice.reducer; + +const isAnySessionRejected = isAnyOf( + sessionCreated.rejected, + sessionInvoked.rejected +); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx index f2529e5529..1b6b61f018 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx @@ -7,11 +7,11 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations'; import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale'; -import ParamWidth from 'features/parameters/components/Parameters/Core/ParamWidth'; -import ParamHeight from 'features/parameters/components/Parameters/Core/ParamHeight'; import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength'; import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit'; import ParamSchedulerAndModel from 'features/parameters/components/Parameters/Core/ParamSchedulerAndModel'; +import ParamBoundingBoxWidth from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxWidth'; +import ParamBoundingBoxHeight from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxHeight'; const selector = createSelector( uiSelector, @@ -41,8 +41,8 @@ const UnifiedCanvasCoreParameters = () => { - - + + @@ -55,8 +55,8 @@ const UnifiedCanvasCoreParameters = () => { - - + + )} diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx index 4aa68ad56a..c4501ffc44 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx @@ -2,7 +2,6 @@ import ProcessButtons from 'features/parameters/components/ProcessButtons/Proces import ParamSeedCollapse from 'features/parameters/components/Parameters/Seed/ParamSeedCollapse'; import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; -import ParamBoundingBoxCollapse from 'features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse'; import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse'; import ParamSeamCorrectionCollapse from 'features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse'; import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters'; @@ -20,7 +19,6 @@ const UnifiedCanvasParameters = () => { - diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts index 4893bb3bf6..65a48bc92c 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts @@ -19,7 +19,7 @@ export const initialUIState: UIState = { shouldPinGallery: true, shouldShowGallery: true, shouldHidePreview: false, - shouldShowProgressInViewer: false, + shouldShowProgressInViewer: true, schedulers: SCHEDULERS, }; diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index ecf8621ed6..ff083079f9 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -7,8 +7,8 @@ export { OpenAPI } from './core/OpenAPI'; export type { OpenAPIConfig } from './core/OpenAPI'; export type { AddInvocation } from './models/AddInvocation'; -export type { BlurInvocation } from './models/BlurInvocation'; export type { Body_upload_image } from './models/Body_upload_image'; +export type { CannyImageProcessorInvocation } from './models/CannyImageProcessorInvocation'; export type { CkptModelInfo } from './models/CkptModelInfo'; export type { CollectInvocation } from './models/CollectInvocation'; export type { CollectInvocationOutput } from './models/CollectInvocationOutput'; @@ -16,26 +16,43 @@ export type { ColorField } from './models/ColorField'; export type { CompelInvocation } from './models/CompelInvocation'; export type { CompelOutput } from './models/CompelOutput'; export type { ConditioningField } from './models/ConditioningField'; +export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation'; +export type { ControlField } from './models/ControlField'; +export type { ControlNetInvocation } from './models/ControlNetInvocation'; +export type { ControlOutput } from './models/ControlOutput'; export type { CreateModelRequest } from './models/CreateModelRequest'; -export type { CropImageInvocation } from './models/CropImageInvocation'; export type { CvInpaintInvocation } from './models/CvInpaintInvocation'; export type { DiffusersModelInfo } from './models/DiffusersModelInfo'; export type { DivideInvocation } from './models/DivideInvocation'; export type { Edge } from './models/Edge'; export type { EdgeConnection } from './models/EdgeConnection'; +export type { FloatCollectionOutput } from './models/FloatCollectionOutput'; +export type { FloatOutput } from './models/FloatOutput'; export type { Graph } from './models/Graph'; export type { GraphExecutionState } from './models/GraphExecutionState'; export type { GraphInvocation } from './models/GraphInvocation'; export type { GraphInvocationOutput } from './models/GraphInvocationOutput'; +export type { HedImageprocessorInvocation } from './models/HedImageprocessorInvocation'; export type { HTTPValidationError } from './models/HTTPValidationError'; +export type { ImageBlurInvocation } from './models/ImageBlurInvocation'; export type { ImageCategory } from './models/ImageCategory'; +export type { ImageChannelInvocation } from './models/ImageChannelInvocation'; +export type { ImageConvertInvocation } from './models/ImageConvertInvocation'; +export type { ImageCropInvocation } from './models/ImageCropInvocation'; export type { ImageDTO } from './models/ImageDTO'; export type { ImageField } from './models/ImageField'; +export type { ImageInverseLerpInvocation } from './models/ImageInverseLerpInvocation'; +export type { ImageLerpInvocation } from './models/ImageLerpInvocation'; export type { ImageMetadata } from './models/ImageMetadata'; +export type { ImageMultiplyInvocation } from './models/ImageMultiplyInvocation'; export type { ImageOutput } from './models/ImageOutput'; +export type { ImagePasteInvocation } from './models/ImagePasteInvocation'; +export type { ImageProcessorInvocation } from './models/ImageProcessorInvocation'; +export type { ImageRecordChanges } from './models/ImageRecordChanges'; +export type { ImageResizeInvocation } from './models/ImageResizeInvocation'; +export type { ImageScaleInvocation } from './models/ImageScaleInvocation'; export type { ImageToImageInvocation } from './models/ImageToImageInvocation'; export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation'; -export type { ImageType } from './models/ImageType'; export type { ImageUrlsDTO } from './models/ImageUrlsDTO'; export type { InfillColorInvocation } from './models/InfillColorInvocation'; export type { InfillPatchMatchInvocation } from './models/InfillPatchMatchInvocation'; @@ -43,31 +60,38 @@ export type { InfillTileInvocation } from './models/InfillTileInvocation'; export type { InpaintInvocation } from './models/InpaintInvocation'; export type { IntCollectionOutput } from './models/IntCollectionOutput'; export type { IntOutput } from './models/IntOutput'; -export type { InverseLerpInvocation } from './models/InverseLerpInvocation'; export type { IterateInvocation } from './models/IterateInvocation'; export type { IterateInvocationOutput } from './models/IterateInvocationOutput'; export type { LatentsField } from './models/LatentsField'; export type { LatentsOutput } from './models/LatentsOutput'; export type { LatentsToImageInvocation } from './models/LatentsToImageInvocation'; export type { LatentsToLatentsInvocation } from './models/LatentsToLatentsInvocation'; -export type { LerpInvocation } from './models/LerpInvocation'; +export type { LineartAnimeImageProcessorInvocation } from './models/LineartAnimeImageProcessorInvocation'; +export type { LineartImageProcessorInvocation } from './models/LineartImageProcessorInvocation'; export type { LoadImageInvocation } from './models/LoadImageInvocation'; export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation'; export type { MaskOutput } from './models/MaskOutput'; +export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation'; +export type { MidasDepthImageProcessorInvocation } from './models/MidasDepthImageProcessorInvocation'; +export type { MlsdImageProcessorInvocation } from './models/MlsdImageProcessorInvocation'; export type { ModelsList } from './models/ModelsList'; export type { MultiplyInvocation } from './models/MultiplyInvocation'; export type { NoiseInvocation } from './models/NoiseInvocation'; export type { NoiseOutput } from './models/NoiseOutput'; +export type { NormalbaeImageProcessorInvocation } from './models/NormalbaeImageProcessorInvocation'; +export type { OffsetPaginatedResults_ImageDTO_ } from './models/OffsetPaginatedResults_ImageDTO_'; +export type { OpenposeImageProcessorInvocation } from './models/OpenposeImageProcessorInvocation'; export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_'; -export type { PaginatedResults_ImageDTO_ } from './models/PaginatedResults_ImageDTO_'; +export type { ParamFloatInvocation } from './models/ParamFloatInvocation'; export type { ParamIntInvocation } from './models/ParamIntInvocation'; -export type { PasteImageInvocation } from './models/PasteImageInvocation'; +export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation'; export type { PromptOutput } from './models/PromptOutput'; export type { RandomIntInvocation } from './models/RandomIntInvocation'; export type { RandomRangeInvocation } from './models/RandomRangeInvocation'; export type { RangeInvocation } from './models/RangeInvocation'; export type { RangeOfSizeInvocation } from './models/RangeOfSizeInvocation'; export type { ResizeLatentsInvocation } from './models/ResizeLatentsInvocation'; +export type { ResourceOrigin } from './models/ResourceOrigin'; export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation'; export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation'; export type { ShowImageInvocation } from './models/ShowImageInvocation'; @@ -77,6 +101,7 @@ export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation'; export type { UpscaleInvocation } from './models/UpscaleInvocation'; export type { VaeRepo } from './models/VaeRepo'; export type { ValidationError } from './models/ValidationError'; +export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation'; export { ImagesService } from './services/ImagesService'; export { ModelsService } from './services/ModelsService'; diff --git a/invokeai/frontend/web/src/services/api/models/AddInvocation.ts b/invokeai/frontend/web/src/services/api/models/AddInvocation.ts index 1ff7b010c2..e9671a918f 100644 --- a/invokeai/frontend/web/src/services/api/models/AddInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/AddInvocation.ts @@ -10,6 +10,10 @@ export type AddInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'add'; /** * The first number diff --git a/invokeai/frontend/web/src/services/api/models/CannyImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/CannyImageProcessorInvocation.ts new file mode 100644 index 0000000000..3a8b0b21e7 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/CannyImageProcessorInvocation.ts @@ -0,0 +1,33 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Canny edge detection for ControlNet + */ +export type CannyImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'canny_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * low threshold of Canny pixel gradient + */ + low_threshold?: number; + /** + * high threshold of Canny pixel gradient + */ + high_threshold?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/CollectInvocation.ts b/invokeai/frontend/web/src/services/api/models/CollectInvocation.ts index d250ae4450..f190ab7073 100644 --- a/invokeai/frontend/web/src/services/api/models/CollectInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/CollectInvocation.ts @@ -10,6 +10,10 @@ export type CollectInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'collect'; /** * The item to collect (all inputs must be of the same type) diff --git a/invokeai/frontend/web/src/services/api/models/CompelInvocation.ts b/invokeai/frontend/web/src/services/api/models/CompelInvocation.ts index f03d53a841..1dc390c1be 100644 --- a/invokeai/frontend/web/src/services/api/models/CompelInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/CompelInvocation.ts @@ -10,6 +10,10 @@ export type CompelInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'compel'; /** * Prompt diff --git a/invokeai/frontend/web/src/services/api/models/ContentShuffleImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/ContentShuffleImageProcessorInvocation.ts new file mode 100644 index 0000000000..d8bc3fe58e --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ContentShuffleImageProcessorInvocation.ts @@ -0,0 +1,45 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies content shuffle processing to image + */ +export type ContentShuffleImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'content_shuffle_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * pixel resolution for edge detection + */ + detect_resolution?: number; + /** + * pixel resolution for output image + */ + image_resolution?: number; + /** + * content shuffle h parameter + */ + 'h'?: number; + /** + * content shuffle w parameter + */ + 'w'?: number; + /** + * cont + */ + 'f'?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ControlField.ts b/invokeai/frontend/web/src/services/api/models/ControlField.ts new file mode 100644 index 0000000000..4f493d4410 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ControlField.ts @@ -0,0 +1,29 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +export type ControlField = { + /** + * processed image + */ + image: ImageField; + /** + * control model used + */ + control_model: string; + /** + * weight given to controlnet + */ + control_weight: number; + /** + * % of total steps at which controlnet is first applied + */ + begin_step_percent: number; + /** + * % of total steps at which controlnet is last applied + */ + end_step_percent: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ControlNetInvocation.ts b/invokeai/frontend/web/src/services/api/models/ControlNetInvocation.ts new file mode 100644 index 0000000000..fad3af911b --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ControlNetInvocation.ts @@ -0,0 +1,41 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Collects ControlNet info to pass to other nodes + */ +export type ControlNetInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'controlnet'; + /** + * image to process + */ + image?: ImageField; + /** + * control model used + */ + control_model?: 'lllyasviel/sd-controlnet-canny' | 'lllyasviel/sd-controlnet-depth' | 'lllyasviel/sd-controlnet-hed' | 'lllyasviel/sd-controlnet-seg' | 'lllyasviel/sd-controlnet-openpose' | 'lllyasviel/sd-controlnet-scribble' | 'lllyasviel/sd-controlnet-normal' | 'lllyasviel/sd-controlnet-mlsd' | 'lllyasviel/control_v11p_sd15_canny' | 'lllyasviel/control_v11p_sd15_openpose' | 'lllyasviel/control_v11p_sd15_seg' | 'lllyasviel/control_v11f1p_sd15_depth' | 'lllyasviel/control_v11p_sd15_normalbae' | 'lllyasviel/control_v11p_sd15_scribble' | 'lllyasviel/control_v11p_sd15_mlsd' | 'lllyasviel/control_v11p_sd15_softedge' | 'lllyasviel/control_v11p_sd15s2_lineart_anime' | 'lllyasviel/control_v11p_sd15_lineart' | 'lllyasviel/control_v11p_sd15_inpaint' | 'lllyasviel/control_v11e_sd15_shuffle' | 'lllyasviel/control_v11e_sd15_ip2p' | 'lllyasviel/control_v11f1e_sd15_tile' | 'thibaud/controlnet-sd21-openpose-diffusers' | 'thibaud/controlnet-sd21-canny-diffusers' | 'thibaud/controlnet-sd21-depth-diffusers' | 'thibaud/controlnet-sd21-scribble-diffusers' | 'thibaud/controlnet-sd21-hed-diffusers' | 'thibaud/controlnet-sd21-zoedepth-diffusers' | 'thibaud/controlnet-sd21-color-diffusers' | 'thibaud/controlnet-sd21-openposev2-diffusers' | 'thibaud/controlnet-sd21-lineart-diffusers' | 'thibaud/controlnet-sd21-normalbae-diffusers' | 'thibaud/controlnet-sd21-ade20k-diffusers' | 'CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15' | 'CrucibleAI/ControlNetMediaPipeFace'; + /** + * weight given to controlnet + */ + control_weight?: number; + /** + * % of total steps at which controlnet is first applied + */ + begin_step_percent?: number; + /** + * % of total steps at which controlnet is last applied + */ + end_step_percent?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ControlOutput.ts b/invokeai/frontend/web/src/services/api/models/ControlOutput.ts new file mode 100644 index 0000000000..43f1b3341c --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ControlOutput.ts @@ -0,0 +1,17 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ControlField } from './ControlField'; + +/** + * node output for ControlNet info + */ +export type ControlOutput = { + type?: 'control_output'; + /** + * The control info dict + */ + control?: ControlField; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/CvInpaintInvocation.ts b/invokeai/frontend/web/src/services/api/models/CvInpaintInvocation.ts index 19342acf8f..874df93c30 100644 --- a/invokeai/frontend/web/src/services/api/models/CvInpaintInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/CvInpaintInvocation.ts @@ -12,6 +12,10 @@ export type CvInpaintInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'cv_inpaint'; /** * The image to inpaint diff --git a/invokeai/frontend/web/src/services/api/models/DivideInvocation.ts b/invokeai/frontend/web/src/services/api/models/DivideInvocation.ts index 3cb262e9af..fd5b3475ae 100644 --- a/invokeai/frontend/web/src/services/api/models/DivideInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/DivideInvocation.ts @@ -10,6 +10,10 @@ export type DivideInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'div'; /** * The first number diff --git a/invokeai/frontend/web/src/services/api/models/FloatCollectionOutput.ts b/invokeai/frontend/web/src/services/api/models/FloatCollectionOutput.ts new file mode 100644 index 0000000000..a3f08247a4 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/FloatCollectionOutput.ts @@ -0,0 +1,15 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * A collection of floats + */ +export type FloatCollectionOutput = { + type?: 'float_collection'; + /** + * The float collection + */ + collection?: Array; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/FloatOutput.ts b/invokeai/frontend/web/src/services/api/models/FloatOutput.ts new file mode 100644 index 0000000000..2331936b30 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/FloatOutput.ts @@ -0,0 +1,15 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * A float output + */ +export type FloatOutput = { + type?: 'float_output'; + /** + * The output float + */ + param?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/Graph.ts b/invokeai/frontend/web/src/services/api/models/Graph.ts index 039923e585..e89e815ab2 100644 --- a/invokeai/frontend/web/src/services/api/models/Graph.ts +++ b/invokeai/frontend/web/src/services/api/models/Graph.ts @@ -3,31 +3,50 @@ /* eslint-disable */ import type { AddInvocation } from './AddInvocation'; -import type { BlurInvocation } from './BlurInvocation'; +import type { CannyImageProcessorInvocation } from './CannyImageProcessorInvocation'; import type { CollectInvocation } from './CollectInvocation'; import type { CompelInvocation } from './CompelInvocation'; -import type { CropImageInvocation } from './CropImageInvocation'; +import type { ContentShuffleImageProcessorInvocation } from './ContentShuffleImageProcessorInvocation'; +import type { ControlNetInvocation } from './ControlNetInvocation'; import type { CvInpaintInvocation } from './CvInpaintInvocation'; import type { DivideInvocation } from './DivideInvocation'; import type { Edge } from './Edge'; import type { GraphInvocation } from './GraphInvocation'; +import type { HedImageprocessorInvocation } from './HedImageprocessorInvocation'; +import type { ImageBlurInvocation } from './ImageBlurInvocation'; +import type { ImageChannelInvocation } from './ImageChannelInvocation'; +import type { ImageConvertInvocation } from './ImageConvertInvocation'; +import type { ImageCropInvocation } from './ImageCropInvocation'; +import type { ImageInverseLerpInvocation } from './ImageInverseLerpInvocation'; +import type { ImageLerpInvocation } from './ImageLerpInvocation'; +import type { ImageMultiplyInvocation } from './ImageMultiplyInvocation'; +import type { ImagePasteInvocation } from './ImagePasteInvocation'; +import type { ImageProcessorInvocation } from './ImageProcessorInvocation'; +import type { ImageResizeInvocation } from './ImageResizeInvocation'; +import type { ImageScaleInvocation } from './ImageScaleInvocation'; import type { ImageToImageInvocation } from './ImageToImageInvocation'; import type { ImageToLatentsInvocation } from './ImageToLatentsInvocation'; import type { InfillColorInvocation } from './InfillColorInvocation'; import type { InfillPatchMatchInvocation } from './InfillPatchMatchInvocation'; import type { InfillTileInvocation } from './InfillTileInvocation'; import type { InpaintInvocation } from './InpaintInvocation'; -import type { InverseLerpInvocation } from './InverseLerpInvocation'; import type { IterateInvocation } from './IterateInvocation'; import type { LatentsToImageInvocation } from './LatentsToImageInvocation'; import type { LatentsToLatentsInvocation } from './LatentsToLatentsInvocation'; -import type { LerpInvocation } from './LerpInvocation'; +import type { LineartAnimeImageProcessorInvocation } from './LineartAnimeImageProcessorInvocation'; +import type { LineartImageProcessorInvocation } from './LineartImageProcessorInvocation'; import type { LoadImageInvocation } from './LoadImageInvocation'; import type { MaskFromAlphaInvocation } from './MaskFromAlphaInvocation'; +import type { MediapipeFaceProcessorInvocation } from './MediapipeFaceProcessorInvocation'; +import type { MidasDepthImageProcessorInvocation } from './MidasDepthImageProcessorInvocation'; +import type { MlsdImageProcessorInvocation } from './MlsdImageProcessorInvocation'; import type { MultiplyInvocation } from './MultiplyInvocation'; import type { NoiseInvocation } from './NoiseInvocation'; +import type { NormalbaeImageProcessorInvocation } from './NormalbaeImageProcessorInvocation'; +import type { OpenposeImageProcessorInvocation } from './OpenposeImageProcessorInvocation'; +import type { ParamFloatInvocation } from './ParamFloatInvocation'; import type { ParamIntInvocation } from './ParamIntInvocation'; -import type { PasteImageInvocation } from './PasteImageInvocation'; +import type { PidiImageProcessorInvocation } from './PidiImageProcessorInvocation'; import type { RandomIntInvocation } from './RandomIntInvocation'; import type { RandomRangeInvocation } from './RandomRangeInvocation'; import type { RangeInvocation } from './RangeInvocation'; @@ -40,6 +59,7 @@ import type { SubtractInvocation } from './SubtractInvocation'; import type { TextToImageInvocation } from './TextToImageInvocation'; import type { TextToLatentsInvocation } from './TextToLatentsInvocation'; import type { UpscaleInvocation } from './UpscaleInvocation'; +import type { ZoeDepthImageProcessorInvocation } from './ZoeDepthImageProcessorInvocation'; export type Graph = { /** @@ -49,7 +69,7 @@ export type Graph = { /** * The nodes in this graph */ - nodes?: Record; + nodes?: Record; /** * The connections between nodes and their fields in this graph */ diff --git a/invokeai/frontend/web/src/services/api/models/GraphExecutionState.ts b/invokeai/frontend/web/src/services/api/models/GraphExecutionState.ts index 8c2eb05657..ea41ce055b 100644 --- a/invokeai/frontend/web/src/services/api/models/GraphExecutionState.ts +++ b/invokeai/frontend/web/src/services/api/models/GraphExecutionState.ts @@ -4,6 +4,9 @@ import type { CollectInvocationOutput } from './CollectInvocationOutput'; import type { CompelOutput } from './CompelOutput'; +import type { ControlOutput } from './ControlOutput'; +import type { FloatCollectionOutput } from './FloatCollectionOutput'; +import type { FloatOutput } from './FloatOutput'; import type { Graph } from './Graph'; import type { GraphInvocationOutput } from './GraphInvocationOutput'; import type { ImageOutput } from './ImageOutput'; @@ -42,7 +45,7 @@ export type GraphExecutionState = { /** * The results of node executions */ - results: Record; + results: Record; /** * Errors raised when executing nodes */ diff --git a/invokeai/frontend/web/src/services/api/models/GraphInvocation.ts b/invokeai/frontend/web/src/services/api/models/GraphInvocation.ts index 5109a49a68..8512faae74 100644 --- a/invokeai/frontend/web/src/services/api/models/GraphInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/GraphInvocation.ts @@ -5,14 +5,17 @@ import type { Graph } from './Graph'; /** - * A node to process inputs and produce outputs. - * May use dependency injection in __init__ to receive providers. + * Execute a graph */ export type GraphInvocation = { /** * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'graph'; /** * The graph to run diff --git a/invokeai/frontend/web/src/services/api/models/HedImageprocessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/HedImageprocessorInvocation.ts new file mode 100644 index 0000000000..f975f18968 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/HedImageprocessorInvocation.ts @@ -0,0 +1,37 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies HED edge detection to image + */ +export type HedImageprocessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'hed_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * pixel resolution for edge detection + */ + detect_resolution?: number; + /** + * pixel resolution for output image + */ + image_resolution?: number; + /** + * whether to use scribble mode + */ + scribble?: boolean; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/BlurInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageBlurInvocation.ts similarity index 72% rename from invokeai/frontend/web/src/services/api/models/BlurInvocation.ts rename to invokeai/frontend/web/src/services/api/models/ImageBlurInvocation.ts index 0643e4b309..3ba86d8fab 100644 --- a/invokeai/frontend/web/src/services/api/models/BlurInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageBlurInvocation.ts @@ -7,12 +7,16 @@ import type { ImageField } from './ImageField'; /** * Blurs an image */ -export type BlurInvocation = { +export type ImageBlurInvocation = { /** * The id of this node. Must be unique among all nodes. */ id: string; - type?: 'blur'; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_blur'; /** * The image to blur */ diff --git a/invokeai/frontend/web/src/services/api/models/ImageCategory.ts b/invokeai/frontend/web/src/services/api/models/ImageCategory.ts index c4edf90fd3..84551d3cd6 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageCategory.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageCategory.ts @@ -3,6 +3,12 @@ /* eslint-disable */ /** - * The category of an image. Use ImageCategory.OTHER for non-default categories. + * The category of an image. + * + * - GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose. + * - MASK: The image is a mask image. + * - CONTROL: The image is a ControlNet control image. + * - USER: The image is a user-provide image. + * - OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes. */ -export type ImageCategory = 'general' | 'control' | 'other'; +export type ImageCategory = 'general' | 'mask' | 'control' | 'user' | 'other'; diff --git a/invokeai/frontend/web/src/services/api/models/ImageChannelInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageChannelInvocation.ts new file mode 100644 index 0000000000..47bfd4110f --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ImageChannelInvocation.ts @@ -0,0 +1,29 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Gets a channel from an image. + */ +export type ImageChannelInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_chan'; + /** + * The image to get the channel from + */ + image?: ImageField; + /** + * The channel to get + */ + channel?: 'A' | 'R' | 'G' | 'B'; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ImageConvertInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageConvertInvocation.ts new file mode 100644 index 0000000000..4bd59d03b0 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ImageConvertInvocation.ts @@ -0,0 +1,29 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Converts an image to a different mode. + */ +export type ImageConvertInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_conv'; + /** + * The image to convert + */ + image?: ImageField; + /** + * The mode to convert to + */ + mode?: 'L' | 'RGB' | 'RGBA' | 'CMYK' | 'YCbCr' | 'LAB' | 'HSV' | 'I' | 'F'; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/CropImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageCropInvocation.ts similarity index 80% rename from invokeai/frontend/web/src/services/api/models/CropImageInvocation.ts rename to invokeai/frontend/web/src/services/api/models/ImageCropInvocation.ts index 2676f5cb87..5207ebbf6d 100644 --- a/invokeai/frontend/web/src/services/api/models/CropImageInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageCropInvocation.ts @@ -7,12 +7,16 @@ import type { ImageField } from './ImageField'; /** * Crops an image to a specified box. The box can be outside of the image. */ -export type CropImageInvocation = { +export type ImageCropInvocation = { /** * The id of this node. Must be unique among all nodes. */ id: string; - type?: 'crop'; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_crop'; /** * The image to crop */ diff --git a/invokeai/frontend/web/src/services/api/models/ImageDTO.ts b/invokeai/frontend/web/src/services/api/models/ImageDTO.ts index c5377b4c76..f5f2603b03 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageDTO.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageDTO.ts @@ -4,7 +4,7 @@ import type { ImageCategory } from './ImageCategory'; import type { ImageMetadata } from './ImageMetadata'; -import type { ImageType } from './ImageType'; +import type { ResourceOrigin } from './ResourceOrigin'; /** * Deserialized image record, enriched for the frontend with URLs. @@ -17,7 +17,7 @@ export type ImageDTO = { /** * The type of the image. */ - image_type: ImageType; + image_origin: ResourceOrigin; /** * The URL of the image. */ @@ -50,6 +50,10 @@ export type ImageDTO = { * The deleted timestamp of the image. */ deleted_at?: string; + /** + * Whether this is an intermediate image. + */ + is_intermediate: boolean; /** * The session ID that generated this image, if it is a generated image. */ diff --git a/invokeai/frontend/web/src/services/api/models/ImageField.ts b/invokeai/frontend/web/src/services/api/models/ImageField.ts index fa22ae8007..63a12f4730 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageField.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageField.ts @@ -2,7 +2,7 @@ /* tslint:disable */ /* eslint-disable */ -import type { ImageType } from './ImageType'; +import type { ResourceOrigin } from './ResourceOrigin'; /** * An image field used for passing image objects between invocations @@ -11,7 +11,7 @@ export type ImageField = { /** * The type of the image */ - image_type: ImageType; + image_origin: ResourceOrigin; /** * The name of the image */ diff --git a/invokeai/frontend/web/src/services/api/models/InverseLerpInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageInverseLerpInvocation.ts similarity index 73% rename from invokeai/frontend/web/src/services/api/models/InverseLerpInvocation.ts rename to invokeai/frontend/web/src/services/api/models/ImageInverseLerpInvocation.ts index 33c59b7bac..0347d4dc38 100644 --- a/invokeai/frontend/web/src/services/api/models/InverseLerpInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageInverseLerpInvocation.ts @@ -7,12 +7,16 @@ import type { ImageField } from './ImageField'; /** * Inverse linear interpolation of all pixels of an image */ -export type InverseLerpInvocation = { +export type ImageInverseLerpInvocation = { /** * The id of this node. Must be unique among all nodes. */ id: string; - type?: 'ilerp'; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_ilerp'; /** * The image to lerp */ diff --git a/invokeai/frontend/web/src/services/api/models/LerpInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageLerpInvocation.ts similarity index 74% rename from invokeai/frontend/web/src/services/api/models/LerpInvocation.ts rename to invokeai/frontend/web/src/services/api/models/ImageLerpInvocation.ts index f2406c2246..388c86061c 100644 --- a/invokeai/frontend/web/src/services/api/models/LerpInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageLerpInvocation.ts @@ -7,12 +7,16 @@ import type { ImageField } from './ImageField'; /** * Linear interpolation of all pixels of an image */ -export type LerpInvocation = { +export type ImageLerpInvocation = { /** * The id of this node. Must be unique among all nodes. */ id: string; - type?: 'lerp'; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_lerp'; /** * The image to lerp */ diff --git a/invokeai/frontend/web/src/services/api/models/ImageMultiplyInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageMultiplyInvocation.ts new file mode 100644 index 0000000000..751ee49158 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ImageMultiplyInvocation.ts @@ -0,0 +1,29 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Multiplies two images together using `PIL.ImageChops.multiply()`. + */ +export type ImageMultiplyInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_mul'; + /** + * The first image to multiply + */ + image1?: ImageField; + /** + * The second image to multiply + */ + image2?: ImageField; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/PasteImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImagePasteInvocation.ts similarity index 79% rename from invokeai/frontend/web/src/services/api/models/PasteImageInvocation.ts rename to invokeai/frontend/web/src/services/api/models/ImagePasteInvocation.ts index 8a181ccf07..c883b9a5d8 100644 --- a/invokeai/frontend/web/src/services/api/models/PasteImageInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ImagePasteInvocation.ts @@ -7,12 +7,16 @@ import type { ImageField } from './ImageField'; /** * Pastes an image into another image. */ -export type PasteImageInvocation = { +export type ImagePasteInvocation = { /** * The id of this node. Must be unique among all nodes. */ id: string; - type?: 'paste'; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_paste'; /** * The base image */ diff --git a/invokeai/frontend/web/src/services/api/models/ImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageProcessorInvocation.ts new file mode 100644 index 0000000000..f972582e2f --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ImageProcessorInvocation.ts @@ -0,0 +1,25 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Base class for invocations that preprocess images for ControlNet + */ +export type ImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'image_processor'; + /** + * image to process + */ + image?: ImageField; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ImageRecordChanges.ts b/invokeai/frontend/web/src/services/api/models/ImageRecordChanges.ts new file mode 100644 index 0000000000..e597cd907d --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ImageRecordChanges.ts @@ -0,0 +1,29 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageCategory } from './ImageCategory'; + +/** + * A set of changes to apply to an image record. + * + * Only limited changes are valid: + * - `image_category`: change the category of an image + * - `session_id`: change the session associated with an image + * - `is_intermediate`: change the image's `is_intermediate` flag + */ +export type ImageRecordChanges = { + /** + * The image's new category. + */ + image_category?: ImageCategory; + /** + * The image's new session ID. + */ + session_id?: string; + /** + * The image's new `is_intermediate` flag. + */ + is_intermediate?: boolean; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ImageResizeInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageResizeInvocation.ts new file mode 100644 index 0000000000..3b096c83b7 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ImageResizeInvocation.ts @@ -0,0 +1,37 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Resizes an image to specific dimensions + */ +export type ImageResizeInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_resize'; + /** + * The image to resize + */ + image?: ImageField; + /** + * The width to resize to (px) + */ + width: number; + /** + * The height to resize to (px) + */ + height: number; + /** + * The resampling mode + */ + resample_mode?: 'nearest' | 'box' | 'bilinear' | 'hamming' | 'bicubic' | 'lanczos'; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ImageScaleInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageScaleInvocation.ts new file mode 100644 index 0000000000..bf4da28a4a --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ImageScaleInvocation.ts @@ -0,0 +1,33 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Scales an image by a factor + */ +export type ImageScaleInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'img_scale'; + /** + * The image to scale + */ + image?: ImageField; + /** + * The factor by which to scale the image + */ + scale_factor: number; + /** + * The resampling mode + */ + resample_mode?: 'nearest' | 'box' | 'bilinear' | 'hamming' | 'bicubic' | 'lanczos'; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ImageToImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageToImageInvocation.ts index fb43c76921..e63ec93ada 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageToImageInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageToImageInvocation.ts @@ -12,6 +12,10 @@ export type ImageToImageInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'img2img'; /** * The prompt to generate an image from @@ -45,6 +49,18 @@ export type ImageToImageInvocation = { * The model to use (currently ignored) */ model?: string; + /** + * Whether or not to produce progress images during generation + */ + progress_images?: boolean; + /** + * The control model to use + */ + control_model?: string; + /** + * The processed control image + */ + control_image?: ImageField; /** * The input image */ diff --git a/invokeai/frontend/web/src/services/api/models/ImageToLatentsInvocation.ts b/invokeai/frontend/web/src/services/api/models/ImageToLatentsInvocation.ts index f72d446615..5569c2fa86 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageToLatentsInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageToLatentsInvocation.ts @@ -12,6 +12,10 @@ export type ImageToLatentsInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'i2l'; /** * The image to encode diff --git a/invokeai/frontend/web/src/services/api/models/ImageType.ts b/invokeai/frontend/web/src/services/api/models/ImageType.ts deleted file mode 100644 index bba9134e63..0000000000 --- a/invokeai/frontend/web/src/services/api/models/ImageType.ts +++ /dev/null @@ -1,8 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -/** - * The type of an image. - */ -export type ImageType = 'results' | 'uploads' | 'intermediates'; diff --git a/invokeai/frontend/web/src/services/api/models/ImageUrlsDTO.ts b/invokeai/frontend/web/src/services/api/models/ImageUrlsDTO.ts index af80519ef2..81639be9b3 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageUrlsDTO.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageUrlsDTO.ts @@ -2,7 +2,7 @@ /* tslint:disable */ /* eslint-disable */ -import type { ImageType } from './ImageType'; +import type { ResourceOrigin } from './ResourceOrigin'; /** * The URLs for an image and its thumbnail. @@ -15,7 +15,7 @@ export type ImageUrlsDTO = { /** * The type of the image. */ - image_type: ImageType; + image_origin: ResourceOrigin; /** * The URL of the image. */ diff --git a/invokeai/frontend/web/src/services/api/models/InfillColorInvocation.ts b/invokeai/frontend/web/src/services/api/models/InfillColorInvocation.ts index 157c976e11..3e637b299c 100644 --- a/invokeai/frontend/web/src/services/api/models/InfillColorInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/InfillColorInvocation.ts @@ -13,6 +13,10 @@ export type InfillColorInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'infill_rgba'; /** * The image to infill diff --git a/invokeai/frontend/web/src/services/api/models/InfillPatchMatchInvocation.ts b/invokeai/frontend/web/src/services/api/models/InfillPatchMatchInvocation.ts index a4c18ade5d..325bfe2080 100644 --- a/invokeai/frontend/web/src/services/api/models/InfillPatchMatchInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/InfillPatchMatchInvocation.ts @@ -12,6 +12,10 @@ export type InfillPatchMatchInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'infill_patchmatch'; /** * The image to infill diff --git a/invokeai/frontend/web/src/services/api/models/InfillTileInvocation.ts b/invokeai/frontend/web/src/services/api/models/InfillTileInvocation.ts index 12113f57f5..dfb1cbc61d 100644 --- a/invokeai/frontend/web/src/services/api/models/InfillTileInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/InfillTileInvocation.ts @@ -12,6 +12,10 @@ export type InfillTileInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'infill_tile'; /** * The image to infill diff --git a/invokeai/frontend/web/src/services/api/models/InpaintInvocation.ts b/invokeai/frontend/web/src/services/api/models/InpaintInvocation.ts index 88ead9907c..b8ed268ef9 100644 --- a/invokeai/frontend/web/src/services/api/models/InpaintInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/InpaintInvocation.ts @@ -13,6 +13,10 @@ export type InpaintInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'inpaint'; /** * The prompt to generate an image from @@ -46,6 +50,18 @@ export type InpaintInvocation = { * The model to use (currently ignored) */ model?: string; + /** + * Whether or not to produce progress images during generation + */ + progress_images?: boolean; + /** + * The control model to use + */ + control_model?: string; + /** + * The processed control image + */ + control_image?: ImageField; /** * The input image */ diff --git a/invokeai/frontend/web/src/services/api/models/IterateInvocation.ts b/invokeai/frontend/web/src/services/api/models/IterateInvocation.ts index 0ff7a1258d..15bf92dfea 100644 --- a/invokeai/frontend/web/src/services/api/models/IterateInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/IterateInvocation.ts @@ -3,14 +3,17 @@ /* eslint-disable */ /** - * A node to process inputs and produce outputs. - * May use dependency injection in __init__ to receive providers. + * Iterates over a list of items */ export type IterateInvocation = { /** * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'iterate'; /** * The list of items to iterate over diff --git a/invokeai/frontend/web/src/services/api/models/LatentsToImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/LatentsToImageInvocation.ts index 8acd872e28..fcaa37d7e8 100644 --- a/invokeai/frontend/web/src/services/api/models/LatentsToImageInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/LatentsToImageInvocation.ts @@ -12,6 +12,10 @@ export type LatentsToImageInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'l2i'; /** * The latents to generate an image from diff --git a/invokeai/frontend/web/src/services/api/models/LatentsToLatentsInvocation.ts b/invokeai/frontend/web/src/services/api/models/LatentsToLatentsInvocation.ts index 29995c6ad9..f5b4912141 100644 --- a/invokeai/frontend/web/src/services/api/models/LatentsToLatentsInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/LatentsToLatentsInvocation.ts @@ -3,6 +3,7 @@ /* eslint-disable */ import type { ConditioningField } from './ConditioningField'; +import type { ControlField } from './ControlField'; import type { LatentsField } from './LatentsField'; /** @@ -13,6 +14,10 @@ export type LatentsToLatentsInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'l2l'; /** * Positive conditioning for generation @@ -43,13 +48,9 @@ export type LatentsToLatentsInvocation = { */ model?: string; /** - * Whether or not to generate an image that can tile without seams + * The control to use */ - seamless?: boolean; - /** - * The axes to tile the image on, 'x' and/or 'y' - */ - seamless_axes?: string; + control?: (ControlField | Array); /** * The latents to use as a base image */ diff --git a/invokeai/frontend/web/src/services/api/models/LineartAnimeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/LineartAnimeImageProcessorInvocation.ts new file mode 100644 index 0000000000..4796d2a049 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/LineartAnimeImageProcessorInvocation.ts @@ -0,0 +1,33 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies line art anime processing to image + */ +export type LineartAnimeImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'lineart_anime_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * pixel resolution for edge detection + */ + detect_resolution?: number; + /** + * pixel resolution for output image + */ + image_resolution?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/LineartImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/LineartImageProcessorInvocation.ts new file mode 100644 index 0000000000..8328849b50 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/LineartImageProcessorInvocation.ts @@ -0,0 +1,37 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies line art processing to image + */ +export type LineartImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'lineart_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * pixel resolution for edge detection + */ + detect_resolution?: number; + /** + * pixel resolution for output image + */ + image_resolution?: number; + /** + * whether to use coarse mode + */ + coarse?: boolean; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/LoadImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/LoadImageInvocation.ts index 745a9b44e4..f20d983f9b 100644 --- a/invokeai/frontend/web/src/services/api/models/LoadImageInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/LoadImageInvocation.ts @@ -2,7 +2,7 @@ /* tslint:disable */ /* eslint-disable */ -import type { ImageType } from './ImageType'; +import type { ImageField } from './ImageField'; /** * Load an image and provide it as output. @@ -12,14 +12,14 @@ export type LoadImageInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'load_image'; /** - * The type of the image + * The image to load */ - image_type: ImageType; - /** - * The name of the image - */ - image_name: string; + image?: ImageField; }; diff --git a/invokeai/frontend/web/src/services/api/models/MaskFromAlphaInvocation.ts b/invokeai/frontend/web/src/services/api/models/MaskFromAlphaInvocation.ts index e71b1f464b..e3693f6d98 100644 --- a/invokeai/frontend/web/src/services/api/models/MaskFromAlphaInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/MaskFromAlphaInvocation.ts @@ -12,6 +12,10 @@ export type MaskFromAlphaInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'tomask'; /** * The image to create the mask from diff --git a/invokeai/frontend/web/src/services/api/models/MediapipeFaceProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/MediapipeFaceProcessorInvocation.ts new file mode 100644 index 0000000000..bd223eed7d --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/MediapipeFaceProcessorInvocation.ts @@ -0,0 +1,33 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies mediapipe face processing to image + */ +export type MediapipeFaceProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'mediapipe_face_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * maximum number of faces to detect + */ + max_faces?: number; + /** + * minimum confidence for face detection + */ + min_confidence?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/MidasDepthImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/MidasDepthImageProcessorInvocation.ts new file mode 100644 index 0000000000..11023086a2 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/MidasDepthImageProcessorInvocation.ts @@ -0,0 +1,33 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies Midas depth processing to image + */ +export type MidasDepthImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'midas_depth_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * Midas parameter a = amult * PI + */ + a_mult?: number; + /** + * Midas parameter bg_th + */ + bg_th?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/MlsdImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/MlsdImageProcessorInvocation.ts new file mode 100644 index 0000000000..c2d4a61b9a --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/MlsdImageProcessorInvocation.ts @@ -0,0 +1,41 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies MLSD processing to image + */ +export type MlsdImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'mlsd_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * pixel resolution for edge detection + */ + detect_resolution?: number; + /** + * pixel resolution for output image + */ + image_resolution?: number; + /** + * MLSD parameter thr_v + */ + thr_v?: number; + /** + * MLSD parameter thr_d + */ + thr_d?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/MultiplyInvocation.ts b/invokeai/frontend/web/src/services/api/models/MultiplyInvocation.ts index eede8f18d7..9fd716f33d 100644 --- a/invokeai/frontend/web/src/services/api/models/MultiplyInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/MultiplyInvocation.ts @@ -10,6 +10,10 @@ export type MultiplyInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'mul'; /** * The first number diff --git a/invokeai/frontend/web/src/services/api/models/NoiseInvocation.ts b/invokeai/frontend/web/src/services/api/models/NoiseInvocation.ts index 59e50b76f3..239a24bfe5 100644 --- a/invokeai/frontend/web/src/services/api/models/NoiseInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/NoiseInvocation.ts @@ -10,6 +10,10 @@ export type NoiseInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'noise'; /** * The seed to use diff --git a/invokeai/frontend/web/src/services/api/models/NormalbaeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/NormalbaeImageProcessorInvocation.ts new file mode 100644 index 0000000000..ecfb50a09f --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/NormalbaeImageProcessorInvocation.ts @@ -0,0 +1,33 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies NormalBae processing to image + */ +export type NormalbaeImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'normalbae_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * pixel resolution for edge detection + */ + detect_resolution?: number; + /** + * pixel resolution for output image + */ + image_resolution?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/PaginatedResults_ImageDTO_.ts b/invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_ImageDTO_.ts similarity index 56% rename from invokeai/frontend/web/src/services/api/models/PaginatedResults_ImageDTO_.ts rename to invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_ImageDTO_.ts index 5d2bdae5ab..3408bea6db 100644 --- a/invokeai/frontend/web/src/services/api/models/PaginatedResults_ImageDTO_.ts +++ b/invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_ImageDTO_.ts @@ -5,25 +5,21 @@ import type { ImageDTO } from './ImageDTO'; /** - * Paginated results + * Offset-paginated results */ -export type PaginatedResults_ImageDTO_ = { +export type OffsetPaginatedResults_ImageDTO_ = { /** * Items */ items: Array; /** - * Current Page + * Offset from which to retrieve items */ - page: number; + offset: number; /** - * Total number of pages + * Limit of items to get */ - pages: number; - /** - * Number of items per page - */ - per_page: number; + limit: number; /** * Total number of items in result */ diff --git a/invokeai/frontend/web/src/services/api/models/OpenposeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/OpenposeImageProcessorInvocation.ts new file mode 100644 index 0000000000..5af21d542e --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/OpenposeImageProcessorInvocation.ts @@ -0,0 +1,37 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies Openpose processing to image + */ +export type OpenposeImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'openpose_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * whether to use hands and face mode + */ + hand_and_face?: boolean; + /** + * pixel resolution for edge detection + */ + detect_resolution?: number; + /** + * pixel resolution for output image + */ + image_resolution?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ParamFloatInvocation.ts b/invokeai/frontend/web/src/services/api/models/ParamFloatInvocation.ts new file mode 100644 index 0000000000..87c01f847f --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ParamFloatInvocation.ts @@ -0,0 +1,23 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * A float parameter + */ +export type ParamFloatInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'param_float'; + /** + * The float value + */ + param?: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ParamIntInvocation.ts b/invokeai/frontend/web/src/services/api/models/ParamIntInvocation.ts index 7047310a87..7a45d0a0ac 100644 --- a/invokeai/frontend/web/src/services/api/models/ParamIntInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ParamIntInvocation.ts @@ -10,6 +10,10 @@ export type ParamIntInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'param_int'; /** * The integer value diff --git a/invokeai/frontend/web/src/services/api/models/PidiImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/PidiImageProcessorInvocation.ts new file mode 100644 index 0000000000..a08bf6a920 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/PidiImageProcessorInvocation.ts @@ -0,0 +1,41 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies PIDI processing to image + */ +export type PidiImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'pidi_image_processor'; + /** + * image to process + */ + image?: ImageField; + /** + * pixel resolution for edge detection + */ + detect_resolution?: number; + /** + * pixel resolution for output image + */ + image_resolution?: number; + /** + * whether to use safe mode + */ + safe?: boolean; + /** + * whether to use scribble mode + */ + scribble?: boolean; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/RandomIntInvocation.ts b/invokeai/frontend/web/src/services/api/models/RandomIntInvocation.ts index 0a5220c31d..a2f7c2f02a 100644 --- a/invokeai/frontend/web/src/services/api/models/RandomIntInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/RandomIntInvocation.ts @@ -10,6 +10,10 @@ export type RandomIntInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'rand_int'; /** * The inclusive low value diff --git a/invokeai/frontend/web/src/services/api/models/RandomRangeInvocation.ts b/invokeai/frontend/web/src/services/api/models/RandomRangeInvocation.ts index c1f80042a6..925511578d 100644 --- a/invokeai/frontend/web/src/services/api/models/RandomRangeInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/RandomRangeInvocation.ts @@ -10,6 +10,10 @@ export type RandomRangeInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'random_range'; /** * The inclusive low value diff --git a/invokeai/frontend/web/src/services/api/models/RangeInvocation.ts b/invokeai/frontend/web/src/services/api/models/RangeInvocation.ts index 1c37ca7fe3..3681602a95 100644 --- a/invokeai/frontend/web/src/services/api/models/RangeInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/RangeInvocation.ts @@ -10,6 +10,10 @@ export type RangeInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'range'; /** * The start of the range diff --git a/invokeai/frontend/web/src/services/api/models/RangeOfSizeInvocation.ts b/invokeai/frontend/web/src/services/api/models/RangeOfSizeInvocation.ts index b918f17130..7dfac68d39 100644 --- a/invokeai/frontend/web/src/services/api/models/RangeOfSizeInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/RangeOfSizeInvocation.ts @@ -10,6 +10,10 @@ export type RangeOfSizeInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'range_of_size'; /** * The start of the range diff --git a/invokeai/frontend/web/src/services/api/models/ResizeLatentsInvocation.ts b/invokeai/frontend/web/src/services/api/models/ResizeLatentsInvocation.ts index c0fabb4984..9a7b6c61e4 100644 --- a/invokeai/frontend/web/src/services/api/models/ResizeLatentsInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ResizeLatentsInvocation.ts @@ -12,6 +12,10 @@ export type ResizeLatentsInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'lresize'; /** * The latents to resize diff --git a/invokeai/frontend/web/src/services/api/models/ResourceOrigin.ts b/invokeai/frontend/web/src/services/api/models/ResourceOrigin.ts new file mode 100644 index 0000000000..a82edda0c1 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ResourceOrigin.ts @@ -0,0 +1,12 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * The origin of a resource (eg image). + * + * - INTERNAL: The resource was created by the application. + * - EXTERNAL: The resource was not created by the application. + * This may be a user-initiated upload, or an internal application upload (eg Canvas init image). + */ +export type ResourceOrigin = 'internal' | 'external'; diff --git a/invokeai/frontend/web/src/services/api/models/RestoreFaceInvocation.ts b/invokeai/frontend/web/src/services/api/models/RestoreFaceInvocation.ts index e03ed01c81..0bacb5d805 100644 --- a/invokeai/frontend/web/src/services/api/models/RestoreFaceInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/RestoreFaceInvocation.ts @@ -12,6 +12,10 @@ export type RestoreFaceInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'restore_face'; /** * The input image diff --git a/invokeai/frontend/web/src/services/api/models/ScaleLatentsInvocation.ts b/invokeai/frontend/web/src/services/api/models/ScaleLatentsInvocation.ts index f398eaf408..506b21e540 100644 --- a/invokeai/frontend/web/src/services/api/models/ScaleLatentsInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ScaleLatentsInvocation.ts @@ -12,6 +12,10 @@ export type ScaleLatentsInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'lscale'; /** * The latents to scale diff --git a/invokeai/frontend/web/src/services/api/models/ShowImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/ShowImageInvocation.ts index 145895ad75..1b73055584 100644 --- a/invokeai/frontend/web/src/services/api/models/ShowImageInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/ShowImageInvocation.ts @@ -12,6 +12,10 @@ export type ShowImageInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'show_image'; /** * The image to show diff --git a/invokeai/frontend/web/src/services/api/models/SubtractInvocation.ts b/invokeai/frontend/web/src/services/api/models/SubtractInvocation.ts index 6f2da116a2..23334bd891 100644 --- a/invokeai/frontend/web/src/services/api/models/SubtractInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/SubtractInvocation.ts @@ -10,6 +10,10 @@ export type SubtractInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'sub'; /** * The first number diff --git a/invokeai/frontend/web/src/services/api/models/TextToImageInvocation.ts b/invokeai/frontend/web/src/services/api/models/TextToImageInvocation.ts index 184e35693b..7128ea8440 100644 --- a/invokeai/frontend/web/src/services/api/models/TextToImageInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/TextToImageInvocation.ts @@ -2,6 +2,8 @@ /* tslint:disable */ /* eslint-disable */ +import type { ImageField } from './ImageField'; + /** * Generates an image using text2img. */ @@ -10,6 +12,10 @@ export type TextToImageInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'txt2img'; /** * The prompt to generate an image from @@ -43,5 +49,17 @@ export type TextToImageInvocation = { * The model to use (currently ignored) */ model?: string; + /** + * Whether or not to produce progress images during generation + */ + progress_images?: boolean; + /** + * The control model to use + */ + control_model?: string; + /** + * The processed control image + */ + control_image?: ImageField; }; diff --git a/invokeai/frontend/web/src/services/api/models/TextToLatentsInvocation.ts b/invokeai/frontend/web/src/services/api/models/TextToLatentsInvocation.ts index d1ec5ed08c..f1831b2b59 100644 --- a/invokeai/frontend/web/src/services/api/models/TextToLatentsInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/TextToLatentsInvocation.ts @@ -3,6 +3,7 @@ /* eslint-disable */ import type { ConditioningField } from './ConditioningField'; +import type { ControlField } from './ControlField'; import type { LatentsField } from './LatentsField'; /** @@ -13,6 +14,10 @@ export type TextToLatentsInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 't2l'; /** * Positive conditioning for generation @@ -43,12 +48,8 @@ export type TextToLatentsInvocation = { */ model?: string; /** - * Whether or not to generate an image that can tile without seams + * The control to use */ - seamless?: boolean; - /** - * The axes to tile the image on, 'x' and/or 'y' - */ - seamless_axes?: string; + control?: (ControlField | Array); }; diff --git a/invokeai/frontend/web/src/services/api/models/UpscaleInvocation.ts b/invokeai/frontend/web/src/services/api/models/UpscaleInvocation.ts index 8416c2454d..d0aca63964 100644 --- a/invokeai/frontend/web/src/services/api/models/UpscaleInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/UpscaleInvocation.ts @@ -12,6 +12,10 @@ export type UpscaleInvocation = { * The id of this node. Must be unique among all nodes. */ id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; type?: 'upscale'; /** * The input image diff --git a/invokeai/frontend/web/src/services/api/models/ZoeDepthImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/models/ZoeDepthImageProcessorInvocation.ts new file mode 100644 index 0000000000..55d05f3167 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ZoeDepthImageProcessorInvocation.ts @@ -0,0 +1,25 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ImageField } from './ImageField'; + +/** + * Applies Zoe depth processing to image + */ +export type ZoeDepthImageProcessorInvocation = { + /** + * The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Whether or not this node is an intermediate node. + */ + is_intermediate?: boolean; + type?: 'zoe_depth_image_processor'; + /** + * image to process + */ + image?: ImageField; +}; + diff --git a/invokeai/frontend/web/src/services/api/schemas/$CannyImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$CannyImageProcessorInvocation.ts new file mode 100644 index 0000000000..e2f1bc2111 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$CannyImageProcessorInvocation.ts @@ -0,0 +1,31 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $CannyImageProcessorInvocation = { + description: `Canny edge detection for ControlNet`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + low_threshold: { + type: 'number', + description: `low threshold of Canny pixel gradient`, + }, + high_threshold: { + type: 'number', + description: `high threshold of Canny pixel gradient`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$ContentShuffleImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$ContentShuffleImageProcessorInvocation.ts new file mode 100644 index 0000000000..9c51fdecc0 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$ContentShuffleImageProcessorInvocation.ts @@ -0,0 +1,43 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $ContentShuffleImageProcessorInvocation = { + description: `Applies content shuffle processing to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + detect_resolution: { + type: 'number', + description: `pixel resolution for edge detection`, + }, + image_resolution: { + type: 'number', + description: `pixel resolution for output image`, + }, + 'h': { + type: 'number', + description: `content shuffle h parameter`, + }, + 'w': { + type: 'number', + description: `content shuffle w parameter`, + }, + 'f': { + type: 'number', + description: `cont`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$ControlField.ts b/invokeai/frontend/web/src/services/api/schemas/$ControlField.ts new file mode 100644 index 0000000000..81292b8638 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$ControlField.ts @@ -0,0 +1,37 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $ControlField = { + properties: { + image: { + type: 'all-of', + description: `processed image`, + contains: [{ + type: 'ImageField', + }], + isRequired: true, + }, + control_model: { + type: 'string', + description: `control model used`, + isRequired: true, + }, + control_weight: { + type: 'number', + description: `weight given to controlnet`, + isRequired: true, + }, + begin_step_percent: { + type: 'number', + description: `% of total steps at which controlnet is first applied`, + isRequired: true, + maximum: 1, + }, + end_step_percent: { + type: 'number', + description: `% of total steps at which controlnet is last applied`, + isRequired: true, + maximum: 1, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$ControlNetInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$ControlNetInvocation.ts new file mode 100644 index 0000000000..29ff507e66 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$ControlNetInvocation.ts @@ -0,0 +1,41 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $ControlNetInvocation = { + description: `Collects ControlNet info to pass to other nodes`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + control_model: { + type: 'Enum', + }, + control_weight: { + type: 'number', + description: `weight given to controlnet`, + maximum: 1, + }, + begin_step_percent: { + type: 'number', + description: `% of total steps at which controlnet is first applied`, + maximum: 1, + }, + end_step_percent: { + type: 'number', + description: `% of total steps at which controlnet is last applied`, + maximum: 1, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$ControlOutput.ts b/invokeai/frontend/web/src/services/api/schemas/$ControlOutput.ts new file mode 100644 index 0000000000..d94d633fca --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$ControlOutput.ts @@ -0,0 +1,28 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $ControlOutput = { + description: `node output for ControlNet info`, + properties: { + type: { + type: 'Enum', + }, + control: { + type: 'all-of', + description: `The control info dict`, + contains: [{ + type: 'ControlField', + }], + }, + width: { + type: 'number', + description: `The width of the noise in pixels`, + isRequired: true, + }, + height: { + type: 'number', + description: `The height of the noise in pixels`, + isRequired: true, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$HedImageprocessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$HedImageprocessorInvocation.ts new file mode 100644 index 0000000000..3cffa008f5 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$HedImageprocessorInvocation.ts @@ -0,0 +1,35 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $HedImageprocessorInvocation = { + description: `Applies HED edge detection to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + detect_resolution: { + type: 'number', + description: `pixel resolution for edge detection`, + }, + image_resolution: { + type: 'number', + description: `pixel resolution for output image`, + }, + scribble: { + type: 'boolean', + description: `whether to use scribble mode`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$ImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$ImageProcessorInvocation.ts new file mode 100644 index 0000000000..36748982c5 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$ImageProcessorInvocation.ts @@ -0,0 +1,23 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $ImageProcessorInvocation = { + description: `Base class for invocations that preprocess images for ControlNet`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$LineartAnimeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$LineartAnimeImageProcessorInvocation.ts new file mode 100644 index 0000000000..63a9c8158c --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$LineartAnimeImageProcessorInvocation.ts @@ -0,0 +1,31 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $LineartAnimeImageProcessorInvocation = { + description: `Applies line art anime processing to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + detect_resolution: { + type: 'number', + description: `pixel resolution for edge detection`, + }, + image_resolution: { + type: 'number', + description: `pixel resolution for output image`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$LineartImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$LineartImageProcessorInvocation.ts new file mode 100644 index 0000000000..6ba4064823 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$LineartImageProcessorInvocation.ts @@ -0,0 +1,35 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $LineartImageProcessorInvocation = { + description: `Applies line art processing to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + detect_resolution: { + type: 'number', + description: `pixel resolution for edge detection`, + }, + image_resolution: { + type: 'number', + description: `pixel resolution for output image`, + }, + coarse: { + type: 'boolean', + description: `whether to use coarse mode`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$MidasDepthImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$MidasDepthImageProcessorInvocation.ts new file mode 100644 index 0000000000..ea0b2b0099 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$MidasDepthImageProcessorInvocation.ts @@ -0,0 +1,31 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $MidasDepthImageProcessorInvocation = { + description: `Applies Midas depth processing to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + a_mult: { + type: 'number', + description: `Midas parameter a = amult * PI`, + }, + bg_th: { + type: 'number', + description: `Midas parameter bg_th`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$MlsdImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$MlsdImageProcessorInvocation.ts new file mode 100644 index 0000000000..1bff7579cc --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$MlsdImageProcessorInvocation.ts @@ -0,0 +1,39 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $MlsdImageProcessorInvocation = { + description: `Applies MLSD processing to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + detect_resolution: { + type: 'number', + description: `pixel resolution for edge detection`, + }, + image_resolution: { + type: 'number', + description: `pixel resolution for output image`, + }, + thr_v: { + type: 'number', + description: `MLSD parameter thr_v`, + }, + thr_d: { + type: 'number', + description: `MLSD parameter thr_d`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$NormalbaeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$NormalbaeImageProcessorInvocation.ts new file mode 100644 index 0000000000..7cdfe6f3ae --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$NormalbaeImageProcessorInvocation.ts @@ -0,0 +1,31 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $NormalbaeImageProcessorInvocation = { + description: `Applies NormalBae processing to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + detect_resolution: { + type: 'number', + description: `pixel resolution for edge detection`, + }, + image_resolution: { + type: 'number', + description: `pixel resolution for output image`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$OpenposeImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$OpenposeImageProcessorInvocation.ts new file mode 100644 index 0000000000..2a187e9cf2 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$OpenposeImageProcessorInvocation.ts @@ -0,0 +1,35 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $OpenposeImageProcessorInvocation = { + description: `Applies Openpose processing to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + hand_and_face: { + type: 'boolean', + description: `whether to use hands and face mode`, + }, + detect_resolution: { + type: 'number', + description: `pixel resolution for edge detection`, + }, + image_resolution: { + type: 'number', + description: `pixel resolution for output image`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$PidiImageProcessorInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$PidiImageProcessorInvocation.ts new file mode 100644 index 0000000000..0fd53967c2 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$PidiImageProcessorInvocation.ts @@ -0,0 +1,39 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $PidiImageProcessorInvocation = { + description: `Applies PIDI processing to image`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + image: { + type: 'all-of', + description: `image to process`, + contains: [{ + type: 'ImageField', + }], + }, + detect_resolution: { + type: 'number', + description: `pixel resolution for edge detection`, + }, + image_resolution: { + type: 'number', + description: `pixel resolution for output image`, + }, + safe: { + type: 'boolean', + description: `whether to use safe mode`, + }, + scribble: { + type: 'boolean', + description: `whether to use scribble mode`, + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/schemas/$RandomIntInvocation.ts b/invokeai/frontend/web/src/services/api/schemas/$RandomIntInvocation.ts new file mode 100644 index 0000000000..e5b0387d5a --- /dev/null +++ b/invokeai/frontend/web/src/services/api/schemas/$RandomIntInvocation.ts @@ -0,0 +1,16 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $RandomIntInvocation = { + description: `Outputs a single random integer.`, + properties: { + id: { + type: 'string', + description: `The id of this node. Must be unique among all nodes.`, + isRequired: true, + }, + type: { + type: 'Enum', + }, + }, +} as const; diff --git a/invokeai/frontend/web/src/services/api/services/ImagesService.ts b/invokeai/frontend/web/src/services/api/services/ImagesService.ts index 13b2ef836a..51fe6c820f 100644 --- a/invokeai/frontend/web/src/services/api/services/ImagesService.ts +++ b/invokeai/frontend/web/src/services/api/services/ImagesService.ts @@ -4,9 +4,10 @@ import type { Body_upload_image } from '../models/Body_upload_image'; import type { ImageCategory } from '../models/ImageCategory'; import type { ImageDTO } from '../models/ImageDTO'; -import type { ImageType } from '../models/ImageType'; +import type { ImageRecordChanges } from '../models/ImageRecordChanges'; import type { ImageUrlsDTO } from '../models/ImageUrlsDTO'; -import type { PaginatedResults_ImageDTO_ } from '../models/PaginatedResults_ImageDTO_'; +import type { OffsetPaginatedResults_ImageDTO_ } from '../models/OffsetPaginatedResults_ImageDTO_'; +import type { ResourceOrigin } from '../models/ResourceOrigin'; import type { CancelablePromise } from '../core/CancelablePromise'; import { OpenAPI } from '../core/OpenAPI'; @@ -16,41 +17,47 @@ export class ImagesService { /** * List Images With Metadata - * Gets a list of images with metadata - * @returns PaginatedResults_ImageDTO_ Successful Response + * Gets a list of images + * @returns OffsetPaginatedResults_ImageDTO_ Successful Response * @throws ApiError */ public static listImagesWithMetadata({ - imageType, - imageCategory, - page, - perPage = 10, + imageOrigin, + categories, + isIntermediate, + offset, + limit = 10, }: { /** - * The type of images to list + * The origin of images to list */ - imageType: ImageType, + imageOrigin?: ResourceOrigin, /** - * The kind of images to list + * The categories of image to include */ - imageCategory: ImageCategory, + categories?: Array, /** - * The page of image metadata to get + * Whether to list intermediate images */ - page?: number, + isIntermediate?: boolean, /** - * The number of image metadata per page + * The page offset */ - perPage?: number, - }): CancelablePromise { + offset?: number, + /** + * The number of images per page + */ + limit?: number, + }): CancelablePromise { return __request(OpenAPI, { method: 'GET', url: '/api/v1/images/', query: { - 'image_type': imageType, - 'image_category': imageCategory, - 'page': page, - 'per_page': perPage, + 'image_origin': imageOrigin, + 'categories': categories, + 'is_intermediate': isIntermediate, + 'offset': offset, + 'limit': limit, }, errors: { 422: `Validation Error`, @@ -65,20 +72,32 @@ export class ImagesService { * @throws ApiError */ public static uploadImage({ - imageType, - formData, imageCategory, + isIntermediate, + formData, + sessionId, }: { - imageType: ImageType, + /** + * The category of the image + */ + imageCategory: ImageCategory, + /** + * Whether this is an intermediate image + */ + isIntermediate: boolean, formData: Body_upload_image, - imageCategory?: ImageCategory, + /** + * The session ID associated with this upload, if any + */ + sessionId?: string, }): CancelablePromise { return __request(OpenAPI, { method: 'POST', url: '/api/v1/images/', query: { - 'image_type': imageType, 'image_category': imageCategory, + 'is_intermediate': isIntermediate, + 'session_id': sessionId, }, formData: formData, mediaType: 'multipart/form-data', @@ -96,13 +115,13 @@ export class ImagesService { * @throws ApiError */ public static getImageFull({ - imageType, + imageOrigin, imageName, }: { /** * The type of full-resolution image file to get */ - imageType: ImageType, + imageOrigin: ResourceOrigin, /** * The name of full-resolution image file to get */ @@ -110,9 +129,9 @@ export class ImagesService { }): CancelablePromise { return __request(OpenAPI, { method: 'GET', - url: '/api/v1/images/{image_type}/{image_name}', + url: '/api/v1/images/{image_origin}/{image_name}', path: { - 'image_type': imageType, + 'image_origin': imageOrigin, 'image_name': imageName, }, errors: { @@ -129,10 +148,13 @@ export class ImagesService { * @throws ApiError */ public static deleteImage({ - imageType, + imageOrigin, imageName, }: { - imageType: ImageType, + /** + * The origin of image to delete + */ + imageOrigin: ResourceOrigin, /** * The name of the image to delete */ @@ -140,9 +162,9 @@ export class ImagesService { }): CancelablePromise { return __request(OpenAPI, { method: 'DELETE', - url: '/api/v1/images/{image_type}/{image_name}', + url: '/api/v1/images/{image_origin}/{image_name}', path: { - 'image_type': imageType, + 'image_origin': imageOrigin, 'image_name': imageName, }, errors: { @@ -151,6 +173,42 @@ export class ImagesService { }); } + /** + * Update Image + * Updates an image + * @returns ImageDTO Successful Response + * @throws ApiError + */ + public static updateImage({ + imageOrigin, + imageName, + requestBody, + }: { + /** + * The origin of image to update + */ + imageOrigin: ResourceOrigin, + /** + * The name of the image to update + */ + imageName: string, + requestBody: ImageRecordChanges, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'PATCH', + url: '/api/v1/images/{image_origin}/{image_name}', + path: { + 'image_origin': imageOrigin, + 'image_name': imageName, + }, + body: requestBody, + mediaType: 'application/json', + errors: { + 422: `Validation Error`, + }, + }); + } + /** * Get Image Metadata * Gets an image's metadata @@ -158,13 +216,13 @@ export class ImagesService { * @throws ApiError */ public static getImageMetadata({ - imageType, + imageOrigin, imageName, }: { /** - * The type of image to get + * The origin of image to get */ - imageType: ImageType, + imageOrigin: ResourceOrigin, /** * The name of image to get */ @@ -172,9 +230,9 @@ export class ImagesService { }): CancelablePromise { return __request(OpenAPI, { method: 'GET', - url: '/api/v1/images/{image_type}/{image_name}/metadata', + url: '/api/v1/images/{image_origin}/{image_name}/metadata', path: { - 'image_type': imageType, + 'image_origin': imageOrigin, 'image_name': imageName, }, errors: { @@ -190,13 +248,13 @@ export class ImagesService { * @throws ApiError */ public static getImageThumbnail({ - imageType, + imageOrigin, imageName, }: { /** - * The type of thumbnail image file to get + * The origin of thumbnail image file to get */ - imageType: ImageType, + imageOrigin: ResourceOrigin, /** * The name of thumbnail image file to get */ @@ -204,9 +262,9 @@ export class ImagesService { }): CancelablePromise { return __request(OpenAPI, { method: 'GET', - url: '/api/v1/images/{image_type}/{image_name}/thumbnail', + url: '/api/v1/images/{image_origin}/{image_name}/thumbnail', path: { - 'image_type': imageType, + 'image_origin': imageOrigin, 'image_name': imageName, }, errors: { @@ -223,13 +281,13 @@ export class ImagesService { * @throws ApiError */ public static getImageUrls({ - imageType, + imageOrigin, imageName, }: { /** - * The type of the image whose URL to get + * The origin of the image whose URL to get */ - imageType: ImageType, + imageOrigin: ResourceOrigin, /** * The name of the image whose URL to get */ @@ -237,9 +295,9 @@ export class ImagesService { }): CancelablePromise { return __request(OpenAPI, { method: 'GET', - url: '/api/v1/images/{image_type}/{image_name}/urls', + url: '/api/v1/images/{image_origin}/{image_name}/urls', path: { - 'image_type': imageType, + 'image_origin': imageOrigin, 'image_name': imageName, }, errors: { diff --git a/invokeai/frontend/web/src/services/api/services/SessionsService.ts b/invokeai/frontend/web/src/services/api/services/SessionsService.ts index 23597c9e9e..6ae6783313 100644 --- a/invokeai/frontend/web/src/services/api/services/SessionsService.ts +++ b/invokeai/frontend/web/src/services/api/services/SessionsService.ts @@ -2,34 +2,53 @@ /* tslint:disable */ /* eslint-disable */ import type { AddInvocation } from '../models/AddInvocation'; -import type { BlurInvocation } from '../models/BlurInvocation'; +import type { CannyImageProcessorInvocation } from '../models/CannyImageProcessorInvocation'; import type { CollectInvocation } from '../models/CollectInvocation'; import type { CompelInvocation } from '../models/CompelInvocation'; -import type { CropImageInvocation } from '../models/CropImageInvocation'; +import type { ContentShuffleImageProcessorInvocation } from '../models/ContentShuffleImageProcessorInvocation'; +import type { ControlNetInvocation } from '../models/ControlNetInvocation'; import type { CvInpaintInvocation } from '../models/CvInpaintInvocation'; import type { DivideInvocation } from '../models/DivideInvocation'; import type { Edge } from '../models/Edge'; import type { Graph } from '../models/Graph'; import type { GraphExecutionState } from '../models/GraphExecutionState'; import type { GraphInvocation } from '../models/GraphInvocation'; +import type { HedImageprocessorInvocation } from '../models/HedImageprocessorInvocation'; +import type { ImageBlurInvocation } from '../models/ImageBlurInvocation'; +import type { ImageChannelInvocation } from '../models/ImageChannelInvocation'; +import type { ImageConvertInvocation } from '../models/ImageConvertInvocation'; +import type { ImageCropInvocation } from '../models/ImageCropInvocation'; +import type { ImageInverseLerpInvocation } from '../models/ImageInverseLerpInvocation'; +import type { ImageLerpInvocation } from '../models/ImageLerpInvocation'; +import type { ImageMultiplyInvocation } from '../models/ImageMultiplyInvocation'; +import type { ImagePasteInvocation } from '../models/ImagePasteInvocation'; +import type { ImageProcessorInvocation } from '../models/ImageProcessorInvocation'; +import type { ImageResizeInvocation } from '../models/ImageResizeInvocation'; +import type { ImageScaleInvocation } from '../models/ImageScaleInvocation'; import type { ImageToImageInvocation } from '../models/ImageToImageInvocation'; import type { ImageToLatentsInvocation } from '../models/ImageToLatentsInvocation'; import type { InfillColorInvocation } from '../models/InfillColorInvocation'; import type { InfillPatchMatchInvocation } from '../models/InfillPatchMatchInvocation'; import type { InfillTileInvocation } from '../models/InfillTileInvocation'; import type { InpaintInvocation } from '../models/InpaintInvocation'; -import type { InverseLerpInvocation } from '../models/InverseLerpInvocation'; import type { IterateInvocation } from '../models/IterateInvocation'; import type { LatentsToImageInvocation } from '../models/LatentsToImageInvocation'; import type { LatentsToLatentsInvocation } from '../models/LatentsToLatentsInvocation'; -import type { LerpInvocation } from '../models/LerpInvocation'; +import type { LineartAnimeImageProcessorInvocation } from '../models/LineartAnimeImageProcessorInvocation'; +import type { LineartImageProcessorInvocation } from '../models/LineartImageProcessorInvocation'; import type { LoadImageInvocation } from '../models/LoadImageInvocation'; import type { MaskFromAlphaInvocation } from '../models/MaskFromAlphaInvocation'; +import type { MediapipeFaceProcessorInvocation } from '../models/MediapipeFaceProcessorInvocation'; +import type { MidasDepthImageProcessorInvocation } from '../models/MidasDepthImageProcessorInvocation'; +import type { MlsdImageProcessorInvocation } from '../models/MlsdImageProcessorInvocation'; import type { MultiplyInvocation } from '../models/MultiplyInvocation'; import type { NoiseInvocation } from '../models/NoiseInvocation'; +import type { NormalbaeImageProcessorInvocation } from '../models/NormalbaeImageProcessorInvocation'; +import type { OpenposeImageProcessorInvocation } from '../models/OpenposeImageProcessorInvocation'; import type { PaginatedResults_GraphExecutionState_ } from '../models/PaginatedResults_GraphExecutionState_'; +import type { ParamFloatInvocation } from '../models/ParamFloatInvocation'; import type { ParamIntInvocation } from '../models/ParamIntInvocation'; -import type { PasteImageInvocation } from '../models/PasteImageInvocation'; +import type { PidiImageProcessorInvocation } from '../models/PidiImageProcessorInvocation'; import type { RandomIntInvocation } from '../models/RandomIntInvocation'; import type { RandomRangeInvocation } from '../models/RandomRangeInvocation'; import type { RangeInvocation } from '../models/RangeInvocation'; @@ -42,6 +61,7 @@ import type { SubtractInvocation } from '../models/SubtractInvocation'; import type { TextToImageInvocation } from '../models/TextToImageInvocation'; import type { TextToLatentsInvocation } from '../models/TextToLatentsInvocation'; import type { UpscaleInvocation } from '../models/UpscaleInvocation'; +import type { ZoeDepthImageProcessorInvocation } from '../models/ZoeDepthImageProcessorInvocation'; import type { CancelablePromise } from '../core/CancelablePromise'; import { OpenAPI } from '../core/OpenAPI'; @@ -151,7 +171,7 @@ export class SessionsService { * The id of the session */ sessionId: string, - requestBody: (LoadImageInvocation | ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation), + requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageprocessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation), }): CancelablePromise { return __request(OpenAPI, { method: 'POST', @@ -188,7 +208,7 @@ export class SessionsService { * The path to the node in the graph */ nodePath: string, - requestBody: (LoadImageInvocation | ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation), + requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageprocessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation), }): CancelablePromise { return __request(OpenAPI, { method: 'PUT', diff --git a/invokeai/frontend/web/src/services/events/actions.ts b/invokeai/frontend/web/src/services/events/actions.ts index 76bffeaa49..5832cb24b1 100644 --- a/invokeai/frontend/web/src/services/events/actions.ts +++ b/invokeai/frontend/web/src/services/events/actions.ts @@ -12,46 +12,153 @@ type BaseSocketPayload = { timestamp: string; }; -// Create actions for each socket event +// Create actions for each socket // Middleware and redux can then respond to them as needed +/** + * Socket.IO Connected + * + * Do not use. Only for use in middleware. + */ export const socketConnected = createAction( 'socket/socketConnected' ); +/** + * App-level Socket.IO Connected + */ +export const appSocketConnected = createAction( + 'socket/appSocketConnected' +); + +/** + * Socket.IO Disconnect + * + * Do not use. Only for use in middleware. + */ export const socketDisconnected = createAction( 'socket/socketDisconnected' ); +/** + * App-level Socket.IO Disconnected + */ +export const appSocketDisconnected = createAction( + 'socket/appSocketDisconnected' +); + +/** + * Socket.IO Subscribed + * + * Do not use. Only for use in middleware. + */ export const socketSubscribed = createAction< BaseSocketPayload & { sessionId: string } >('socket/socketSubscribed'); +/** + * App-level Socket.IO Subscribed + */ +export const appSocketSubscribed = createAction< + BaseSocketPayload & { sessionId: string } +>('socket/appSocketSubscribed'); + +/** + * Socket.IO Unsubscribed + * + * Do not use. Only for use in middleware. + */ export const socketUnsubscribed = createAction< BaseSocketPayload & { sessionId: string } >('socket/socketUnsubscribed'); -export const invocationStarted = createAction< - BaseSocketPayload & { data: InvocationStartedEvent } ->('socket/invocationStarted'); +/** + * App-level Socket.IO Unsubscribed + */ +export const appSocketUnsubscribed = createAction< + BaseSocketPayload & { sessionId: string } +>('socket/appSocketUnsubscribed'); -export const invocationComplete = createAction< +/** + * Socket.IO Invocation Started + * + * Do not use. Only for use in middleware. + */ +export const socketInvocationStarted = createAction< + BaseSocketPayload & { data: InvocationStartedEvent } +>('socket/socketInvocationStarted'); + +/** + * App-level Socket.IO Invocation Started + */ +export const appSocketInvocationStarted = createAction< + BaseSocketPayload & { data: InvocationStartedEvent } +>('socket/appSocketInvocationStarted'); + +/** + * Socket.IO Invocation Complete + * + * Do not use. Only for use in middleware. + */ +export const socketInvocationComplete = createAction< BaseSocketPayload & { data: InvocationCompleteEvent; } ->('socket/invocationComplete'); +>('socket/socketInvocationComplete'); -export const invocationError = createAction< +/** + * App-level Socket.IO Invocation Complete + */ +export const appSocketInvocationComplete = createAction< + BaseSocketPayload & { + data: InvocationCompleteEvent; + } +>('socket/appSocketInvocationComplete'); + +/** + * Socket.IO Invocation Error + * + * Do not use. Only for use in middleware. + */ +export const socketInvocationError = createAction< BaseSocketPayload & { data: InvocationErrorEvent } ->('socket/invocationError'); +>('socket/socketInvocationError'); -export const graphExecutionStateComplete = createAction< +/** + * App-level Socket.IO Invocation Error + */ +export const appSocketInvocationError = createAction< + BaseSocketPayload & { data: InvocationErrorEvent } +>('socket/appSocketInvocationError'); + +/** + * Socket.IO Graph Execution State Complete + * + * Do not use. Only for use in middleware. + */ +export const socketGraphExecutionStateComplete = createAction< BaseSocketPayload & { data: GraphExecutionStateCompleteEvent } ->('socket/graphExecutionStateComplete'); +>('socket/socketGraphExecutionStateComplete'); -export const generatorProgress = createAction< +/** + * App-level Socket.IO Graph Execution State Complete + */ +export const appSocketGraphExecutionStateComplete = createAction< + BaseSocketPayload & { data: GraphExecutionStateCompleteEvent } +>('socket/appSocketGraphExecutionStateComplete'); + +/** + * Socket.IO Generator Progress + * + * Do not use. Only for use in middleware. + */ +export const socketGeneratorProgress = createAction< BaseSocketPayload & { data: GeneratorProgressEvent } ->('socket/generatorProgress'); +>('socket/socketGeneratorProgress'); -// dispatch this when we need to fully reset the socket connection -export const socketReset = createAction('socket/socketReset'); +/** + * App-level Socket.IO Generator Progress + */ +export const appSocketGeneratorProgress = createAction< + BaseSocketPayload & { data: GeneratorProgressEvent } +>('socket/appSocketGeneratorProgress'); diff --git a/invokeai/frontend/web/src/services/events/middleware.ts b/invokeai/frontend/web/src/services/events/middleware.ts index bd1d60099a..f1eb844f2c 100644 --- a/invokeai/frontend/web/src/services/events/middleware.ts +++ b/invokeai/frontend/web/src/services/events/middleware.ts @@ -8,7 +8,7 @@ import { import { socketSubscribed, socketUnsubscribed } from './actions'; import { AppThunkDispatch, RootState } from 'app/store/store'; import { getTimestamp } from 'common/util/getTimestamp'; -import { sessionInvoked, sessionCreated } from 'services/thunks/session'; +import { sessionCreated } from 'services/thunks/session'; import { OpenAPI } from 'services/api'; import { setEventListeners } from 'services/events/util/setEventListeners'; import { log } from 'app/logging/useLogger'; @@ -64,15 +64,9 @@ export const socketMiddleware = () => { if (sessionCreated.fulfilled.match(action)) { const sessionId = action.payload.id; - const sessionLog = socketioLog.child({ sessionId }); const oldSessionId = getState().system.sessionId; if (oldSessionId) { - sessionLog.debug( - { oldSessionId }, - `Unsubscribed from old session (${oldSessionId})` - ); - socket.emit('unsubscribe', { session: oldSessionId, }); @@ -85,8 +79,6 @@ export const socketMiddleware = () => { ); } - sessionLog.debug(`Subscribe to new session (${sessionId})`); - socket.emit('subscribe', { session: sessionId }); dispatch( @@ -95,9 +87,6 @@ export const socketMiddleware = () => { timestamp: getTimestamp(), }) ); - - // Finally we actually invoke the session, starting processing - dispatch(sessionInvoked({ sessionId })); } next(action); diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts index 4431a9fd8b..2c4cba510a 100644 --- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts +++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts @@ -1,14 +1,13 @@ import { MiddlewareAPI } from '@reduxjs/toolkit'; import { AppDispatch, RootState } from 'app/store/store'; import { getTimestamp } from 'common/util/getTimestamp'; -import { sessionCanceled } from 'services/thunks/session'; import { Socket } from 'socket.io-client'; import { - generatorProgress, - graphExecutionStateComplete, - invocationComplete, - invocationError, - invocationStarted, + socketGeneratorProgress, + socketGraphExecutionStateComplete, + socketInvocationComplete, + socketInvocationError, + socketInvocationStarted, socketConnected, socketDisconnected, socketSubscribed, @@ -16,12 +15,6 @@ import { import { ClientToServerEvents, ServerToClientEvents } from '../types'; import { Logger } from 'roarr'; import { JsonObject } from 'roarr/dist/types'; -import { - receivedResultImagesPage, - receivedUploadImagesPage, -} from 'services/thunks/gallery'; -import { receivedModels } from 'services/thunks/model'; -import { receivedOpenAPISchema } from 'services/thunks/schema'; import { makeToast } from '../../../app/components/Toaster'; import { addToast } from '../../../features/system/store/systemSlice'; @@ -43,37 +36,13 @@ export const setEventListeners = (arg: SetEventListenersArg) => { dispatch(socketConnected({ timestamp: getTimestamp() })); - const { results, uploads, models, nodes, config, system } = getState(); + const { sessionId } = getState().system; - const { disabledTabs } = config; - - // These thunks need to be dispatch in middleware; cannot handle in a reducer - if (!results.ids.length) { - dispatch(receivedResultImagesPage()); - } - - if (!uploads.ids.length) { - dispatch(receivedUploadImagesPage()); - } - - if (!models.ids.length) { - dispatch(receivedModels()); - } - - if (!nodes.schema && !disabledTabs.includes('nodes')) { - dispatch(receivedOpenAPISchema()); - } - - if (system.sessionId) { - log.debug( - { sessionId: system.sessionId }, - `Subscribed to existing session (${system.sessionId})` - ); - - socket.emit('subscribe', { session: system.sessionId }); + if (sessionId) { + socket.emit('subscribe', { session: sessionId }); dispatch( socketSubscribed({ - sessionId: system.sessionId, + sessionId, timestamp: getTimestamp(), }) ); @@ -101,7 +70,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => { * Disconnect */ socket.on('disconnect', () => { - log.debug('Disconnected'); dispatch(socketDisconnected({ timestamp: getTimestamp() })); }); @@ -109,70 +77,29 @@ export const setEventListeners = (arg: SetEventListenersArg) => { * Invocation started */ socket.on('invocation_started', (data) => { - if (getState().system.canceledSession === data.graph_execution_state_id) { - log.trace( - { data, sessionId: data.graph_execution_state_id }, - `Ignored invocation started (${data.node.type}) for canceled session (${data.graph_execution_state_id})` - ); - return; - } - - log.info( - { data, sessionId: data.graph_execution_state_id }, - `Invocation started (${data.node.type})` - ); - dispatch(invocationStarted({ data, timestamp: getTimestamp() })); + dispatch(socketInvocationStarted({ data, timestamp: getTimestamp() })); }); /** * Generator progress */ socket.on('generator_progress', (data) => { - if (getState().system.canceledSession === data.graph_execution_state_id) { - log.trace( - { data, sessionId: data.graph_execution_state_id }, - `Ignored generator progress (${data.node.type}) for canceled session (${data.graph_execution_state_id})` - ); - return; - } - - log.trace( - { data, sessionId: data.graph_execution_state_id }, - `Generator progress (${data.node.type})` - ); - dispatch(generatorProgress({ data, timestamp: getTimestamp() })); + dispatch(socketGeneratorProgress({ data, timestamp: getTimestamp() })); }); /** * Invocation error */ socket.on('invocation_error', (data) => { - log.error( - { data, sessionId: data.graph_execution_state_id }, - `Invocation error (${data.node.type})` - ); - dispatch(invocationError({ data, timestamp: getTimestamp() })); + dispatch(socketInvocationError({ data, timestamp: getTimestamp() })); }); /** * Invocation complete */ socket.on('invocation_complete', (data) => { - log.info( - { data, sessionId: data.graph_execution_state_id }, - `Invocation complete (${data.node.type})` - ); - const sessionId = data.graph_execution_state_id; - - const { cancelType, isCancelScheduled } = getState().system; - - // Handle scheduled cancelation - if (cancelType === 'scheduled' && isCancelScheduled) { - dispatch(sessionCanceled({ sessionId })); - } - dispatch( - invocationComplete({ + socketInvocationComplete({ data, timestamp: getTimestamp(), }) @@ -183,10 +110,11 @@ export const setEventListeners = (arg: SetEventListenersArg) => { * Graph complete */ socket.on('graph_execution_state_complete', (data) => { - log.info( - { data, sessionId: data.graph_execution_state_id }, - `Graph execution state complete (${data.graph_execution_state_id})` + dispatch( + socketGraphExecutionStateComplete({ + data, + timestamp: getTimestamp(), + }) ); - dispatch(graphExecutionStateComplete({ data, timestamp: getTimestamp() })); }); }; diff --git a/invokeai/frontend/web/src/services/thunks/gallery.ts b/invokeai/frontend/web/src/services/thunks/gallery.ts deleted file mode 100644 index 01e8a986b2..0000000000 --- a/invokeai/frontend/web/src/services/thunks/gallery.ts +++ /dev/null @@ -1,45 +0,0 @@ -import { log } from 'app/logging/useLogger'; -import { createAppAsyncThunk } from 'app/store/storeUtils'; -import { ImagesService } from 'services/api'; - -export const IMAGES_PER_PAGE = 20; - -const galleryLog = log.child({ namespace: 'gallery' }); - -export const receivedResultImagesPage = createAppAsyncThunk( - 'results/receivedResultImagesPage', - async (_arg, { getState, rejectWithValue }) => { - const { page, pages, nextPage } = getState().results; - - if (nextPage === page) { - rejectWithValue([]); - } - - const response = await ImagesService.listImagesWithMetadata({ - imageType: 'results', - imageCategory: 'general', - page: getState().results.nextPage, - perPage: IMAGES_PER_PAGE, - }); - - galleryLog.info({ response }, `Received ${response.items.length} results`); - - return response; - } -); - -export const receivedUploadImagesPage = createAppAsyncThunk( - 'uploads/receivedUploadImagesPage', - async (_arg, { getState }) => { - const response = await ImagesService.listImagesWithMetadata({ - imageType: 'uploads', - imageCategory: 'general', - page: getState().uploads.nextPage, - perPage: IMAGES_PER_PAGE, - }); - - galleryLog.info({ response }, `Received ${response.items.length} uploads`); - - return response; - } -); diff --git a/invokeai/frontend/web/src/services/thunks/image.ts b/invokeai/frontend/web/src/services/thunks/image.ts index 6831eb647d..87832c6b1e 100644 --- a/invokeai/frontend/web/src/services/thunks/image.ts +++ b/invokeai/frontend/web/src/services/thunks/image.ts @@ -1,10 +1,6 @@ -import { log } from 'app/logging/useLogger'; import { createAppAsyncThunk } from 'app/store/storeUtils'; -import { InvokeTabName } from 'features/ui/store/tabMap'; +import { selectImagesAll } from 'features/gallery/store/imagesSlice'; import { ImagesService } from 'services/api'; -import { getHeaders } from 'services/util/getHeaders'; - -const imagesLog = log.child({ namespace: 'image' }); type imageUrlsReceivedArg = Parameters< (typeof ImagesService)['getImageUrls'] @@ -17,7 +13,6 @@ export const imageUrlsReceived = createAppAsyncThunk( 'api/imageUrlsReceived', async (arg: imageUrlsReceivedArg) => { const response = await ImagesService.getImageUrls(arg); - imagesLog.info({ arg, response }, 'Received image urls'); return response; } ); @@ -33,16 +28,11 @@ export const imageMetadataReceived = createAppAsyncThunk( 'api/imageMetadataReceived', async (arg: imageMetadataReceivedArg) => { const response = await ImagesService.getImageMetadata(arg); - imagesLog.info({ arg, response }, 'Received image record'); return response; } ); -type ImageUploadedArg = Parameters<(typeof ImagesService)['uploadImage']>[0] & { - // extra arg to determine post-upload actions - we check for this when the image is uploaded - // to determine if we should set the init image - activeTabName?: InvokeTabName; -}; +type ImageUploadedArg = Parameters<(typeof ImagesService)['uploadImage']>[0]; /** * `ImagesService.uploadImage()` thunk @@ -51,13 +41,8 @@ export const imageUploaded = createAppAsyncThunk( 'api/imageUploaded', async (arg: ImageUploadedArg) => { // strip out `activeTabName` from arg - the route does not need it - const { activeTabName, ...rest } = arg; - const response = await ImagesService.uploadImage(rest); - const { location } = getHeaders(response); - - imagesLog.debug({ arg: '', response, location }, 'Image uploaded'); - - return { response, location }; + const response = await ImagesService.uploadImage(arg); + return response; } ); @@ -70,9 +55,48 @@ export const imageDeleted = createAppAsyncThunk( 'api/imageDeleted', async (arg: ImageDeletedArg) => { const response = await ImagesService.deleteImage(arg); - - imagesLog.debug({ arg, response }, 'Image deleted'); - + return response; + } +); + +type ImageUpdatedArg = Parameters<(typeof ImagesService)['updateImage']>[0]; + +/** + * `ImagesService.updateImage()` thunk + */ +export const imageUpdated = createAppAsyncThunk( + 'api/imageUpdated', + async (arg: ImageUpdatedArg) => { + const response = await ImagesService.updateImage(arg); + return response; + } +); + +type ImagesListedArg = Parameters< + (typeof ImagesService)['listImagesWithMetadata'] +>[0]; + +export const IMAGES_PER_PAGE = 20; + +/** + * `ImagesService.listImagesWithMetadata()` thunk + */ +export const receivedPageOfImages = createAppAsyncThunk( + 'api/receivedPageOfImages', + async (_, { getState }) => { + const state = getState(); + const { categories } = state.images; + + const totalImagesInFilter = selectImagesAll(state).filter((i) => + categories.includes(i.image_category) + ).length; + + const response = await ImagesService.listImagesWithMetadata({ + categories, + isIntermediate: false, + offset: totalImagesInFilter, + limit: IMAGES_PER_PAGE, + }); return response; } ); diff --git a/invokeai/frontend/web/src/services/thunks/session.ts b/invokeai/frontend/web/src/services/thunks/session.ts index dca4134886..cf87fb30f5 100644 --- a/invokeai/frontend/web/src/services/thunks/session.ts +++ b/invokeai/frontend/web/src/services/thunks/session.ts @@ -1,7 +1,7 @@ import { createAppAsyncThunk } from 'app/store/storeUtils'; -import { SessionsService } from 'services/api'; +import { GraphExecutionState, SessionsService } from 'services/api'; import { log } from 'app/logging/useLogger'; -import { serializeError } from 'serialize-error'; +import { isObject } from 'lodash-es'; const sessionLog = log.child({ namespace: 'session' }); @@ -11,99 +11,89 @@ type SessionCreatedArg = { >[0]['requestBody']; }; +type SessionCreatedThunkConfig = { + rejectValue: { arg: SessionCreatedArg; error: unknown }; +}; + /** * `SessionsService.createSession()` thunk */ -export const sessionCreated = createAppAsyncThunk( - 'api/sessionCreated', - async (arg: SessionCreatedArg, { rejectWithValue }) => { - try { - const response = await SessionsService.createSession({ - requestBody: arg.graph, - }); - sessionLog.info({ arg, response }, `Session created (${response.id})`); - return response; - } catch (err: any) { - sessionLog.error( - { - error: serializeError(err), - }, - 'Problem creating session' - ); - return rejectWithValue(err.message); - } - } -); - -type NodeAddedArg = Parameters<(typeof SessionsService)['addNode']>[0]; - -/** - * `SessionsService.addNode()` thunk - */ -export const nodeAdded = createAppAsyncThunk( - 'api/nodeAdded', - async ( - arg: { node: NodeAddedArg['requestBody']; sessionId: string }, - _thunkApi - ) => { - const response = await SessionsService.addNode({ - requestBody: arg.node, - sessionId: arg.sessionId, +export const sessionCreated = createAppAsyncThunk< + GraphExecutionState, + SessionCreatedArg, + SessionCreatedThunkConfig +>('api/sessionCreated', async (arg, { rejectWithValue }) => { + try { + const response = await SessionsService.createSession({ + requestBody: arg.graph, }); - - sessionLog.info({ arg, response }, `Node added (${response})`); - return response; + } catch (error) { + return rejectWithValue({ arg, error }); } -); +}); + +type SessionInvokedArg = { sessionId: string }; + +type SessionInvokedThunkConfig = { + rejectValue: { + arg: SessionInvokedArg; + error: unknown; + }; +}; + +const isErrorWithStatus = (error: unknown): error is { status: number } => + isObject(error) && 'status' in error; /** * `SessionsService.invokeSession()` thunk */ -export const sessionInvoked = createAppAsyncThunk( - 'api/sessionInvoked', - async (arg: { sessionId: string }, { rejectWithValue }) => { - const { sessionId } = arg; +export const sessionInvoked = createAppAsyncThunk< + void, + SessionInvokedArg, + SessionInvokedThunkConfig +>('api/sessionInvoked', async (arg, { rejectWithValue }) => { + const { sessionId } = arg; - try { - const response = await SessionsService.invokeSession({ - sessionId, - all: true, - }); - sessionLog.info({ arg, response }, `Session invoked (${sessionId})`); - - return response; - } catch (error) { - const err = error as any; - if (err.status === 403) { - return rejectWithValue(err.body.detail); - } - throw error; + try { + const response = await SessionsService.invokeSession({ + sessionId, + all: true, + }); + return response; + } catch (error) { + if (isErrorWithStatus(error) && error.status === 403) { + return rejectWithValue({ arg, error: (error as any).body.detail }); } + return rejectWithValue({ arg, error }); } -); +}); type SessionCanceledArg = Parameters< (typeof SessionsService)['cancelSessionInvoke'] >[0]; - +type SessionCanceledThunkConfig = { + rejectValue: { + arg: SessionCanceledArg; + error: unknown; + }; +}; /** * `SessionsService.cancelSession()` thunk */ -export const sessionCanceled = createAppAsyncThunk( - 'api/sessionCanceled', - async (arg: SessionCanceledArg, _thunkApi) => { - const { sessionId } = arg; +export const sessionCanceled = createAppAsyncThunk< + void, + SessionCanceledArg, + SessionCanceledThunkConfig +>('api/sessionCanceled', async (arg: SessionCanceledArg, _thunkApi) => { + const { sessionId } = arg; - const response = await SessionsService.cancelSessionInvoke({ - sessionId, - }); + const response = await SessionsService.cancelSessionInvoke({ + sessionId, + }); - sessionLog.info({ arg, response }, `Session canceled (${sessionId})`); - - return response; - } -); + return response; +}); type SessionsListedArg = Parameters< (typeof SessionsService)['listSessions'] diff --git a/invokeai/frontend/web/src/services/types/guards.ts b/invokeai/frontend/web/src/services/types/guards.ts index 266e991f4d..4d33cfa246 100644 --- a/invokeai/frontend/web/src/services/types/guards.ts +++ b/invokeai/frontend/web/src/services/types/guards.ts @@ -1,5 +1,3 @@ -import { ResultsImageDTO } from 'features/gallery/store/resultsSlice'; -import { UploadsImageDTO } from 'features/gallery/store/uploadsSlice'; import { get, isObject, isString } from 'lodash-es'; import { GraphExecutionState, @@ -9,18 +7,11 @@ import { PromptOutput, IterateInvocationOutput, CollectInvocationOutput, - ImageType, ImageField, LatentsOutput, - ImageDTO, + ResourceOrigin, } from 'services/api'; -export const isUploadsImageDTO = (image: ImageDTO): image is UploadsImageDTO => - image.image_type === 'uploads'; - -export const isResultsImageDTO = (image: ImageDTO): image is ResultsImageDTO => - image.image_type === 'results'; - export const isImageOutput = ( output: GraphExecutionState['results'][string] ): output is ImageOutput => output.type === 'image_output'; @@ -49,10 +40,10 @@ export const isCollectOutput = ( output: GraphExecutionState['results'][string] ): output is CollectInvocationOutput => output.type === 'collect_output'; -export const isImageType = (t: unknown): t is ImageType => - isString(t) && ['results', 'uploads', 'intermediates'].includes(t); +export const isResourceOrigin = (t: unknown): t is ResourceOrigin => + isString(t) && ['internal', 'external'].includes(t); export const isImageField = (imageField: unknown): imageField is ImageField => isObject(imageField) && isString(get(imageField, 'image_name')) && - isImageType(get(imageField, 'image_type')); + isResourceOrigin(get(imageField, 'image_origin')); diff --git a/invokeai/frontend/web/stats.html b/invokeai/frontend/web/stats.html index dba25766e4..e9c4381206 100644 --- a/invokeai/frontend/web/stats.html +++ b/invokeai/frontend/web/stats.html @@ -6157,7 +6157,7 @@ var drawChart = (function (exports) {